In [None]:
import os
from fastai.vision.all import *
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

In [3]:
def get_labels(file_path):
    if file_path[0] == 'N':
        return 'healthy' 
    else:
        return 'unhealthy'

In [None]:
def split_files(beta, base_path='./dataset_train/'):
    healthy_files = []
    unhealthy_files = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.startswith('N'):
                healthy_files.append(file)
            else:
                unhealthy_files.append(file)
    len_healthy = len(healthy_files)
    len_unhealthy = len(unhealthy_files)
    

    #take random sample of length beta*len_healthy
    healthy_files_sample = random.sample(healthy_files, int(beta*len_healthy))
    unhealthy_files_sample = random.sample(unhealthy_files, int(beta*len_unhealthy))

    return ((healthy_files_sample, unhealthy_files), (healthy_files, unhealthy_files_sample))

In [None]:
base_path = './dataset_train/'
class CNN:
    def __init__(self, beta, id):
        self.beta = beta
        self.id = id
        self.files = split_files(beta)[id]
        self.dls = ImageDataLoaders.from_name_func(
            base_path,
            self.files,
            valid_pct=0.2,
            seed=42,
            label_func=get_labels,
            item_tfms=Resize(224)
        )
        self.learn = cnn_learner(self.dls, resnet34, metrics=error_rate)
    
    def train(self, epochs):
        self.learn.fine_tune(epochs)

    def show_batch(self):
        self.dls.show_batch()
        
    def save(self):
        self.learn.export(f'./models/model_{self.id}.pkl')

    def load(self):
        self.learn = load_learner(f'./models/model_{self.id}.pkl')

    def predict(self, file_path):
        img = Image.open(file_path)
        return self.learn.predict(img)
    
    def plot(self):
        self.learn.recorder.plot_loss()

    def show_results(self):
        self.learn.show_results()