In [None]:
import os
import warnings
import pickle

import wandb
import numpy as np
from skimage import io
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torchvision.models import resnet34
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.optim import Adam

In [None]:
warnings.filterwarnings('ignore')

In [None]:
with open('y_stratified.pickle', 'rb') as f:
    y = pickle.load(f)

In [None]:
if torch.cuda.is_available():
    dev = 'cuda:0'
else:
    dev = 'cpu'
device = torch.device(dev)

In [None]:
root_dir = 'net_dis_morph'

In [None]:
resnet = resnet34(pretrained=True)

In [None]:
class CustomCrop:
    
    def __call__(self, sample):
        shape = sample.shape
        min_dimension = min(shape[1], shape[2])
        center_crop = transforms.CenterCrop(min_dimension)
        sample = center_crop(sample)
        return sample

In [None]:
class Net(nn.Module):
    def __init__(self, resnet, out_features_diseases, out_features_morph):
        super(Net, self).__init__()
        self.base_model = nn.Sequential(*list(resnet.children())[:-1])
        self.branch_1 = nn.Linear(512, out_features_diseases)
        self.branch_2 = nn.Linear(512, out_features_morph)
        
    def forward(self, x):
        x = self.base_model(x)
        x = torch.flatten(x, 1)
        x1 = self.branch_1(x)
        x2 = self.branch_2(x)
        return x1, x2

In [None]:
class MorphDisDataset(Dataset):
    
    def __init__(self, root_dir, transform, labels_file):
        self.root_dir = root_dir
        self.transform = transform
        with open(labels_file, 'rb') as f:
            self.labels = pickle.load(f)
        
        
    def __len__(self):
        return len([name for name in os.listdir(self.root_dir) if os.path.isfile(os.path.join(self.root_dir, name))])
    
    def __getitem__(self, idx):
        
        filename = f'{idx}.jpg'
        labels = self.labels[idx]
        
        image = io.imread(f'{self.root_dir}/{filename}').copy()
        
        labels_morphology = np.zeros(8)
        for l in labels['morphology']:
            labels_morphology[l - 1] = 1
        
        if self.transform:
            sample = {'image': self.transform(image), 'labels_diseases': labels['disease'] - 1, 'labels_morphology': labels_morphology}
        else:
            sample = {'image': image, 'labels_diseases': labels['disease'] - 1, 'labels_morphology': labels_morphology}
        return sample

In [None]:
composed = transforms.Compose(
    [transforms.ToTensor(), CustomCrop(), transforms.Resize((224, 224)),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [None]:
dataset = MorphDisDataset(root_dir, composed, 'labels.pickle')

In [None]:
with open('labels.pickle', 'rb') as f:
    labels = pickle.load(f)

In [None]:
# y = []
# for i in range(len(dataset)):
#     y.append(np.where(dataset[i]['labels_diseases'] == 1)[0][0])

In [None]:
# with open('y_stratified.pickle', 'wb') as f:
#     pickle.dump(y, f)

In [None]:
train_indexes, test_indexes = train_test_split(np.arange(len(y)), test_size=0.2, shuffle=True, stratify=y)
train_sampler = SubsetRandomSampler(train_indexes)
test_sampler = SubsetRandomSampler(test_indexes)

In [None]:
batch_size = 16
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)

In [None]:
pos_weight = torch.zeros(8)
for i in range(len(dataset.labels)):
    if len(dataset.labels[i]) > 0:
        for l in dataset.labels[i]['morphology']:
            pos_weight[int(l) - 1] += 1
pos_weight = pos_weight / pos_weight.sum()
pos_weight = pos_weight.to(device)

In [None]:
net = Net(resnet, 20, 8)
criterion_morphology = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion_diseases = nn.CrossEntropyLoss()

In [None]:
lr = 0.00001
net = net.to(device)
optimizer = Adam(net.parameters(), lr)

In [None]:
epochs = 100
for epoch in range(epochs):
    net.train()
    for i, data in enumerate(train_loader, 0):
        inputs, labels_diseases, labels_morphology = data['image'], data['labels_diseases'], data['labels_morphology']
        optimizer.zero_grad()
        o1, o2 = net(inputs.to(device))
        
        loss1 = criterion_diseases(o1, labels_diseases.to(device).long())
        
        mask = np.ones(len(labels_morphology), dtype=bool)
        for idx, l in enumerate(labels_morphology, 0):
            if sum(l) != 0:
                mask[idx] = False
        mask = torch.tensor(mask).int()
        
        labels_morphology = labels_morphology[mask != 1].to(device)
        o2 = o2[mask != 1]
        
        loss2 = criterion_morphology(o2, labels_morphology)
        
        loss = loss1 + loss2
        loss.backward()
        optimizer.step()