In [1]:
import os
from fastai.vision.all import *
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
!git clone https://github.com/KevivJaknap/BreastCancerPrediction.git

Cloning into 'BreastCancerPrediction'...
remote: Enumerating objects: 1586, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 1586 (delta 0), reused 2 (delta 0), pack-reused 1583[K
Receiving objects: 100% (1586/1586), 44.20 MiB | 20.03 MiB/s, done.
Resolving deltas: 100% (53/53), done.


In [3]:
def get_labels(file_path):
    if file_path[0] == 'N':
        return 'healthy'
    else:
        return 'unhealthy'

In [4]:
def split_files(beta, base_path='/content/BreastCancerPrediction/dataset_train/'):
    healthy_files = []
    unhealthy_files = []
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.startswith('N'):
                healthy_files.append((base_path+file, 'healthy'))
            else:
                unhealthy_files.append((base_path+file, 'unhealthy'))
    len_healthy = len(healthy_files)
    len_unhealthy = len(unhealthy_files)


    #split into two sets
    first = healthy_files[:] + random.sample(unhealthy_files, int(len_unhealthy*beta))
    second = random.sample(healthy_files, int(len_healthy*beta)) + unhealthy_files[:]
    random.shuffle(first)
    random.shuffle(second)

    print(f"First contains {len_healthy} healthy and {len(first) - len_healthy} unhealthy")
    print(f"Second contains {len(second)-len_unhealthy} healthy and {len_unhealthy} unhealthy")
    return first, second

In [7]:
base_path = '/content/BreastCancerPrediction/dataset_train/'
class CNN:
    def __init__(self, id, files):
        self.id = id
        self.files = files
        self.dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   get_x=ColReader(0),
                   get_y=ColReader(1),
                   splitter=RandomSplitter(),
                   item_tfms=Resize(224),
                   batch_tfms=aug_transforms(size=224, min_scale=0.75))
        self.dls = self.dblock.dataloaders(self.files, bs=64)
        self.learn = vision_learner(self.dls, resnet34, metrics=accuracy)

    def train(self, epochs):
        self.learn.fine_tune(epochs)

    def show_batch(self):
        self.dls.show_batch()

    def save(self):
        self.learn.export(f'./models/model_{self.id}.pkl')

    def load(self):
        self.learn = load_learner(f'./models/model_{self.id}.pkl')

    def predict(self, file_path):
        img = Image.open(file_path)
        return self.learn.predict(img)

    def plot(self):
        self.learn.recorder.plot_loss()

    def show_results(self):
        self.learn.show_results()

In [8]:
files1, files2 = split_files(0.75)
model1 = CNN(id = 1, files = files1)
model2 = CNN(id = 2, files = files2)
model1.train(4)
model2.train(4)

First contains 642 healthy and 480 unhealthy
Second contains 481 healthy and 640 unhealthy


epoch,train_loss,valid_loss,accuracy,time
0,0.849182,0.332803,0.861607,00:08


epoch,train_loss,valid_loss,accuracy,time
0,0.410765,0.312025,0.879464,00:08
1,0.286168,0.233743,0.910714,00:09
2,0.235682,0.040272,0.977679,00:09
3,0.198402,0.021899,0.991071,00:08


epoch,train_loss,valid_loss,accuracy,time
0,0.857101,0.649526,0.75,00:08


epoch,train_loss,valid_loss,accuracy,time
0,0.45528,0.17344,0.928571,00:09
1,0.398415,0.10734,0.964286,00:09
2,0.294491,0.026237,0.991071,00:09
3,0.229959,0.024031,0.995536,00:09


In [None]:
def get_test_accuracy(model):
  test_path = '/content/BreastCancerPrediction/dataset_test'
  test_xy = split_files(1, base_path=test_path)
  x, y = [], []
  for a, b in test_xy:
    x.append(a)
    y.append(b)
  test_dl = model.dls.test_dl(x)
  preds = model.learn.get_preds(dl=test_dl)
  preds = preds[0].argmax(dim=1)
  preds = [model.dls.vocab[i] for i in preds]


In [None]:
from fastai.callback.hook import *

# Define a hook function
def hook_fn(m, i, o): return o.detach().clone()

# Register the hook on the last layer
hook_output = Hook(learn.model[-1], hook_fn, is_forward=True)

# Get some input data
x, = first(dls.valid)

# Get the outputs of the last layer (logits)
with torch.no_grad():
    outputs = learn.model.eval()(x)

# Get the logits from the hooked layer
logits = hook_output.storedbb