In [104]:
import os

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, transforms, Compose
from torchvision.models import mobilenet_v3_small, MobileNet_V3_Small_Weights, vgg16, VGG16_Weights
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

torch.manual_seed(1234)

<torch._C.Generator at 0x206d98be5b0>

In [100]:
df = pd.read_csv('Dr.Eye/datasets/ret_annot.csv')
filenames = df['id_code']
from pandas import DataFrame

def balance_data(df:DataFrame, n_samples:int=200, save:bool=True, dir_path:str=None) -> DataFrame:
    """Balance the dataset with n_samples of each classes samples """
    data_per_labels = {}
    for label in range(n_classes):
        data = df[df['diagnosis'] == label]
        data_per_labels[label] = data[:n_samples]

    data_stats = {i: len(data) for i, data in data_per_labels.items() }
    balanced_data = pd.concat(data_per_labels.values(), ignore_index=True)
    
    if save:
        file_name = f'annot_{n_samples}.csv'
        if dir_path is None:
            dir_path = ''
        file_path = os.path.join(dir_path, file_name)
        
        balanced_data.to_csv(file_path, index=False)
    
    return balanced_data

data = balance_data(df)
print(len(data))

993


In [82]:
from collections import Counter
import numpy as np

counter = Counter(labels.to_list())
data_per_labels = np.array(list(counter.values()))
# plt.bar(range(len(data_per_labels)), data_per_labels)
print(counter)

Counter({0: 1805, 2: 999, 1: 370, 4: 295, 3: 193})


In [121]:
class RetDataset(Dataset):
    """load retinopathy dataset"""
    def __init__(self,d_path:str, annot_path:str=None, transforms=None) -> None:
        self.annots = pd.read_csv(annot_path)
        self.d_path = d_path
        self.transforms = transforms
    
    def __len__(self, ) -> int:
        return len(self.annots)
    
    def __getitem__(self, idx)->torch.Tensor:
        filename, label = self.annots.iloc[idx]
        file_path = os.path.join(self.d_path, filename + '.png')
        img = Image.open(file_path)
        
        if self.transforms:
            img_tensor = self.transforms(img)
        else:
            img_tensor = ToTensor()(img) 
        
        return img_tensor, torch.tensor(label)

In [70]:
from torch.nn import functional as F
import torch.nn as nn

class Net(nn.Module):
    """Custom network for retinopathy detection"""
    def __init__(self, in_c:int, out_c:int, nb_cl:int=5):
        super().__init__()
        self.features = nn.Sequential(nn.Conv2d(in_c, out_c, 3), 
        nn.ReLU(),
        nn.MaxPool2d(2, 2), 
        nn.Conv2d(out_c, 64, 3),
        nn.ReLU(),
        nn.MaxPool2d(2, 2), )
        
        self.classifier = nn.Sequential(nn.Linear(246016, 1000), nn.Linear(1000, nb_cl))

    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 246016)
        x = self.classifier(x)
        return x

In [71]:
model = Net(3, 64)

In [77]:
tmp = model.features[-1]

In [101]:
# configs
dataset_path = 'Dr.Eye/datasets/ret_dataset/'
annot_path = 'annot_200.csv'
n_classes = 5

In [102]:
# splitting the dataset
transforms = Compose([ToTensor(), Resize((256, 256))])
dataset = RetDataset(dataset_path, annot_path, transforms)

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_dataloader = DataLoader(train_dataset, batch_size=16 ,shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16)
dataloader = DataLoader(dataset, batch_size=16 ,shuffle=True) 

In [128]:
# hyperparameters
lr = 0.01
epochs = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def train(model, dataloader, criterion, optimizer, save=True, checkpt_path=None) -> None:
    """train model and save it if save is true"""
    model.train()
    model.to(device)
    checkpt = {}
    for epoch in range(epochs):
        acc = 0
        for x, y in dataloader:
            x.to(device)
            y.to(device)
            y_pred = model(x)
            loss = criterion(y_pred, y.flatten())
            loss.backward()
            acc += (y_pred.argmax(dim=1) == y.flatten()).sum()
            optimizer.zero_grad()
            optimizer.step()
        
        if epoch%100 == 0 and save:
            checkpt['model_st_dict'] = model.state_dict()
            checkpt['acc'] = acc
            checkpt['loss'] = loss
            if checkpt_path is None:
                checkpt_path = ''
            checkpt_file_name =  f'checkpoint_{epoch}.pth'
            checkpt_file_path = os.path.join(checkpt_path, checkpt_file_name)
            torch.save(checkpt, checkpt_file_path)
            
        print(f'{epoch} epochs: {loss.item()} acc = {acc/(len(dataloader)*16)}')

In [221]:
class GenRetSample():
    """ Generate randomly dataset sample: img, target"""
    def __init__(self, dataset):
        self.dataset = dataset
        self.index = 0
    
    def __iter__(self, ):
        return self
    
    def __next__(self):
        if len(self.dataset) > self.index:
            self.index += 1
            return dataset[self.index]
        else :
            raise StopIteration()
    


In [229]:
dataset = RetDataset(dataset_path, annot_path, transforms)
gen = GenRetSample(dataset)
img, lab = next(gen)
x_1 = img.unsqueeze(dim=0)
with torch.no_grad():
    y_1 = model(x_1)
    print(y_1.argmax(dim=1), lab)

In [107]:
# load mobilenet v3
def get_mobilenet_v3(n_classes:int=5, weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1) -> torch.nn.Module:
    """Return mobilenet_v3 model with frozen features layers"""
    model_small = mobilenet_v3_small(weights=weights)
    model_small.classifier[-1] = nn.Linear(in_features=1024, out_features=5)

    # deactivate gradient computation the features layers
    for param in model_small.features.parameters():
        param.requires_grad = False
    
    return model_small

In [35]:
def get_vgg16(n_classes:int=5, weights=VGG16_Weights.IMAGENET1K_V1) -> torch.nn.Module:
    """Return VGG16 model with frozen features layers"""
    vgg_model = vgg16(weights=weights)
    vgg_model.classifier[-1] = torch.nn.Linear(in_features=4096, out_features=n_classes)

    for param in vgg_model.features.parameters():
        param.requires_grad = False
    
    return vgg_model

In [112]:
# Data preparation for mobilenet
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset_pretrained = RetDataset(dataset_path, annot_path)
dataloader_pretrained = DataLoader(dataset_mobilenet, batch_size=16, shuffle=True)

In [117]:
model = get_mobilenet_v3()
x, _ = dataset_pretrained[0]
x = x.unsqueeze(dim=0)
model(x), _

(tensor([[0.0047, 0.1146, 0.0971, 0.0120, 0.1415]], grad_fn=<AddmmBackward0>),
 tensor(0))

In [115]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
train(model, dataloader_pretrained, criterion, optimizer)

In [110]:
@torch.no_grad
def test(model, dataloader_test):
    model.eval()
    acc = 0
    data_size = len(test_dataloader.dataset)
    for x, y in dataloader_test:
        y_pred = model(x) 
        acc += (y_pred.argmax(dim=1) == y).sum()
    print(f"test accuracy: {acc/data_size}")

In [111]:
test(model, test_dataloader)

test accuracy: 0.21717171370983124


In [63]:
len(test_dataloader.dataset)

732

In [None]:
# Save a checkpoint during training
checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}

torch.save(checkpoint, 'checkpoint.pth')

# Load a checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In [116]:
def load_model(checkpoint_path=None, arch='vgg'):
    model_checkpt = torch.load(checkpoint_path)
    return model_checkpt