In [11]:
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision.transforms import Resize
from torchvision.io import read_image
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class BreaKHisFlatDataset(Dataset):
    def __init__(self, image_folder, magnitude = "400X", transform = None):
        self.image_paths = [x for x in image_folder.rglob("*.png") if magnitude in x.parent.name]
        self.labels = [x.relative_to(image_folder).parts[0] for x in self.image_paths]
        self.classes = list(set(self.labels))
        self.class_to_label = {x:i for i, x in enumerate(self.classes)}
        self.label_to_class = {v:k for k, v in self.class_to_label.items()}
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label = self.class_to_label[self.labels[index]]
        image = read_image(str(image_path))
        if self.transform:
            image = self.transform(image)
            
        return image.reshape(-1).float(), label

In [3]:
image_folder = Path("D:/personal/study/introduction to PA/data/BreaKHis_v1/histology_slides/breast")
transforms = Resize((46,70)) # 460 x 700 original 
full_dataset = BreaKHisFlatDataset(image_folder, transform=transforms)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [4]:
def binary_acc(y_pred, y_test):
    y_pred_tag = torch.round(y_pred)

    correct_results_sum = (y_pred_tag == y_test).sum().float()
    acc = correct_results_sum/y_test.shape[0]
    acc = torch.round(acc * 100)
    
    return acc

def val_get_cm(net):
    net.eval()
    cm = np.zeros((2,2))
    for i, data in enumerate(test_dataloader, 0):
        inputs, labels = data
        outputs = net(inputs)
        cm += confusion_matrix(torch.round(outputs).detach().numpy(), labels.unsqueeze(1).detach().numpy())
    return cm

def metrics_from_cm(cm):
    tn, fp, fn, tp = cm.ravel()
    precision = tp/(tp+fp)
    recall = tp/(tp+fn)
    f1 = 2*precision*recall/(precision + recall)
    spec = tn / (tn + fp)
    sens = tp / (tp + fn)
    print(f"f1: {f1}, specificity :{spec}, sensitivity {sens}")
    return f1

In [5]:
def train_model(net):
    best_model = None
    best_val_f1 = 0
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(net.parameters())#, lr=0.001, momentum=0.9)
    val_f1s = []
    train_f1s = []
    num_epoch = 10
    for epoch in range(num_epoch):
        net.train()
        running_loss = []
        running_acc = []
        cm = np.zeros((2,2))
        for i, data in enumerate(train_dataloader, 0):
            inputs, labels = data
            
            optimizer.zero_grad()
            outputs = net(inputs)
            
            loss = criterion(outputs, labels.unsqueeze(1).float())
            running_acc.append(binary_acc(outputs, labels.unsqueeze(1)))
            cm += confusion_matrix(torch.round(outputs).detach().numpy(), labels.unsqueeze(1).detach().numpy())
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss.append(loss.item())
        print(f"Epoch {epoch}:")
        print("Train metrics:")
        train_f1 = metrics_from_cm(cm)
        train_f1s.append(train_f1)
        print(f"loss: {sum(running_loss)/len(running_loss)}, acc: {sum(running_acc)/len(running_acc)}")
        print("Val metrics:")
        f1 = metrics_from_cm(val_get_cm(net))
        val_f1s.append(f1)
        if f1 > best_val_f1:
            print(f"Best val f1: {f1}")
            best_model = net
            best_val_f1 = f1
        print("\n")
    return best_model, val_f1s, train_f1s
    

In [12]:
# net = nn.Sequential(
#         nn.Linear(46*70*3, 1024),
#         nn.Dropout(0.1),
#         nn.BatchNorm1d(1024),
#         nn.LeakyReLU(),
#         nn.Linear(1024, 124),
#         nn.Dropout(0.1),
#         nn.BatchNorm1d(124),
#         nn.LeakyReLU(),
#         nn.Linear(124, 1),
#         nn.Sigmoid()
#     )
    
net = nn.Sequential(
        nn.Linear(46*70*3, 1024),
        nn.Dropout(0.1),
        nn.BatchNorm1d(1024),
        nn.LeakyReLU(),
        nn.Linear(1024, 512),
        nn.Dropout(0.1),
        nn.BatchNorm1d(512),
        nn.LeakyReLU(),
        nn.Linear(512, 124),
        nn.Dropout(0.1),
        nn.BatchNorm1d(124),
        nn.LeakyReLU(),
        nn.Linear(124, 1),
        nn.Sigmoid()
    )


net.train()
best_model, val_f1s, train_f1s = train_model(net)

Epoch 0:
Train metrics:
f1: 0.7067357512953367, specificity :0.8586171310629515, sensitivity 0.7002053388090349
loss: 0.7022254596585813, acc: 80.56521606445312
Val metrics:
f1: 0.7319587628865979, specificity :0.8607142857142858, sensitivity 0.8452380952380952
Best val f1: 0.7319587628865979


Epoch 1:
Train metrics:
f1: 0.7608200455580865, specificity :0.8636363636363636, sensitivity 0.835
loss: 0.6559532025586003, acc: 85.73912811279297
Val metrics:
f1: 0.5513196480938416, specificity :0.8796992481203008, sensitivity 0.4069264069264069


Epoch 2:
Train metrics:
f1: 0.7755581668625147, specificity :0.863342566943675, sensitivity 0.8847184986595175
loss: 0.6408193862956503, acc: 87.0
Val metrics:
f1: 0.6056338028169014, specificity :0.8736842105263158, sensitivity 0.4942528735632184


Epoch 3:
Train metrics:
f1: 0.7976878612716762, specificity :0.8755846585594013, sensitivity 0.8914728682170543
loss: 0.6297169353650964, acc: 88.17391204833984
Val metrics:
f1: 0.5517241379310345, speci

Best val_f1 = 0.73