In [None]:
import os
import warnings
import pickle

import wandb
from PIL import Image
import numpy as np
from skimage import io
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, average_precision_score
import torch
import torch.nn as nn
from torchvision.models import resnet34
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.optim import Adam

In [None]:
warnings.filterwarnings('ignore')

In [None]:
wandb.init(project='net_2_branches')

In [None]:
if torch.cuda.is_available():
    dev = 'cuda:0'
else:
    dev = 'cpu'
device = torch.device(dev)

In [None]:
images_dir = '/mnt/tank/scratch/esergeenko/net_2_branches/'
labels_diseases_file = 'net_2_branches_diseases.pickle'
labels_morphology_file = 'net_2_branches_morphology.pickle'

In [None]:
with open(labels_diseases_file, 'rb') as f:
    labels_diseases = pickle.load(f)
with open(labels_morphology_file, 'rb') as f:
    labels_morphology = pickle.load(f)

In [None]:
class CustomCrop:
    
    def __call__(self, sample):
        shape = sample.shape
        min_dimension = min(shape[1], shape[2])
        center_crop = transforms.CenterCrop(min_dimension)
        sample = center_crop(sample)
        return sample

In [None]:
composed = transforms.Compose(
    [transforms.PILToTensor(), CustomCrop(), transforms.Resize((224, 224))])

In [None]:
images = torch.empty((41068, 3, 224, 224), dtype=torch.uint8)
for i in range(len(labels_diseases)):
    image = Image.open(images_dir + f'{i}.jpg')
    images[i] = composed(image).detach().clone()

In [None]:
resnet = resnet34(pretrained=True)

In [None]:
class Net(nn.Module):
    def __init__(self, resnet, out_features_diseases, out_features_morph):
        super(Net, self).__init__()
        self.base_model = nn.Sequential(*list(resnet.children())[:-1])
        self.branch_1 = nn.Linear(512, out_features_diseases)
        self.branch_2 = nn.Linear(512, out_features_morph)
        
    def forward(self, x):
        x = self.base_model(x)
        x = torch.flatten(x, 1)
        x1 = self.branch_1(x)
        x2 = self.branch_2(x)
        return x1, x2

In [None]:
class MorphDisDataset(Dataset):
    
    def __init__(self, images, labels_diseases, labels_morphology, transform):
        self.images = images
        self.transform = transform
        self.labels_diseases = labels_diseases
        self.labels_morphology = labels_morphology
        
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        
        image = self.images[idx].float()
        labels_morphology = self.labels_morphology[idx]
        labels_diseases = self.labels_diseases[idx]
        
        if self.transform:
            sample = {'image': self.transform(image), 'labels_diseases': labels_diseases, 'labels_morphology': labels_morphology}
        else:
            sample = {'image': image, 'labels_diseases': labels_diseases, 'labels_morphology': labels_morphology}
        return sample

In [None]:
composed = transforms.Compose([transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [None]:
dataset = MorphDisDataset(images, labels_diseases, labels_morphology, composed)

In [None]:
kf = StratifiedKFold(n_splits=5)

In [None]:
for i, (train_indexes, test_indexes) in enumerate(kf.split(np.arange(len(labels_diseases)), labels_diseases)):
    if i == 0:
        train_sampler = SubsetRandomSampler(train_indexes)
        test_sampler = SubsetRandomSampler(test_indexes)

In [None]:
wandb.config.batch_size = 32
train_loader = DataLoader(dataset, batch_size=wandb.config.batch_size, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=wandb.config.batch_size, sampler=test_sampler)

In [None]:
pos_weight = torch.zeros(8)
for i in range(len(labels_morphology)):
    pos_weight += labels_morphology[i]
pos_weight = pos_weight / pos_weight.sum()
pos_weight = pos_weight.to(device)

In [None]:
net = Net(resnet, 20, 8)
criterion_morphology = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion_diseases = nn.CrossEntropyLoss()

In [None]:
wandb.config.lr = 0.0001
net = net.to(device)
optimizer = Adam(net.parameters(), wandb.config.lr)

In [None]:
mapping_diseases = [
    'атопический дерматит',
   'акне',
   'псориаз',
   'розацеа',
   'бородавки',
   'герпес',
   'витилиго',
   'клп',
   'аллергический контактный дерматит',
   'экзема',
   'дерматомикозы',
   'булезный пемфигоид', 
   'пузырчатка',
   'контагиозный моллюск',
   'крапивница',
   'кератоз',
   'чесотка',
   'себореный дерматит',
   'актинический',
   'базалиома'
]
mapping_morphology = [
    'пятно',
    'бугорок',
    'узел',
    'папула',
    'волдырь',
    'пузырек',
    'пузырь',
    'гнойничок'
]

In [None]:
def get_labels(predictions, treshold):
    return (predictions > treshold).astype(int)

In [None]:
def log_epoch_diseases(epoch, y_true_train, y_pred_train, y_true_test, y_pred_test, train_loss, test_loss, mapping):
    step = {'epoch': epoch, 'dis loss/train': train_loss, 'dis loss/test': test_loss}
    
    map_train = average_precision_score(y_true_train.reshape(-1), y_pred_train.reshape(-1))
    map_test = average_precision_score(y_true_test.reshape(-1), y_pred_test.reshape(-1))

    current_metrics = [map_train, map_test]
    
    step['dis mAP/train'] = map_train
    step['dis mAP/test'] = map_test
    

    step[f'dis f1/train'] = f1_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro')
    step[f'dis f1/test'] = f1_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro')
    step[f'dis precision/train'] = precision_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro')
    step[f'dis precision/test'] = precision_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro')
    step[f'dis recall/train'] = recall_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro')
    step[f'dis recall/test'] = recall_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro')
    step[f'dis accuracy/train'] = accuracy_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1))
    step[f'dis accuracy/test'] = accuracy_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1))

    current_metrics.append(f1_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro'))
    current_metrics.append(f1_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro'))
    current_metrics.append(precision_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro'))
    current_metrics.append(precision_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro'))
    current_metrics.append(recall_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro'))
    current_metrics.append(recall_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro'))
    current_metrics.append(accuracy_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1)))
    current_metrics.append(accuracy_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1)))
    
    for i in range(20):
        step[f'dis mAP class/train {mapping[i]}'] = average_precision_score(y_true_train[:, i], y_pred_train[:, i])
        step[f'dis mAP class/test {mapping[i]}'] = average_precision_score(y_true_test[:, i], y_pred_test[:, i])
        
        current_metrics.append(average_precision_score(y_true_train[:, i], y_pred_train[:, i]))
        current_metrics.append(average_precision_score(y_true_test[:, i], y_pred_test[:, i]))
        
    return current_metrics, step
    
def log_epoch_morphology(step, epoch, y_true_train, y_pred_train, y_true_test, y_pred_test, train_loss, test_loss, mapping):
    
    step['mor loss/train'] = train_loss
    step['mor loss/test'] = test_loss
    
    map_train = average_precision_score(y_true_train.reshape(-1), y_pred_train.reshape(-1))
    map_test = average_precision_score(y_true_test.reshape(-1), y_pred_test.reshape(-1))

    current_metrics = [map_train, map_test]
    
    step['mor mAP/train'] = map_train
    step['mor mAP/test'] = map_test
    
    
    for treshold in np.arange(0.1, 1, 0.1):
        step[f'mor f1/train {round(treshold, 1)}'] = f1_score(y_true_train, get_labels(y_pred_train, treshold), average='macro')
        step[f'mor f1/test {round(treshold, 1)}'] = f1_score(y_true_test, get_labels(y_pred_test, treshold), average='macro')
        step[f'mor precision/train {round(treshold, 1)}'] = precision_score(y_true_train, get_labels(y_pred_train, treshold), average='macro')
        step[f'mor precision/test {round(treshold, 1)}'] = precision_score(y_true_test, get_labels(y_pred_test, treshold), average='macro')
        step[f'mor recall/train {round(treshold, 1)}'] = recall_score(y_true_train, get_labels(y_pred_train, treshold), average='macro')
        step[f'mor recall/test {round(treshold, 1)}'] = recall_score(y_true_test, get_labels(y_pred_test, treshold), average='macro')
        
        current_metrics.append(f1_score(y_true_train, get_labels(y_pred_train, treshold), average='macro'))
        current_metrics.append(f1_score(y_true_test, get_labels(y_pred_test, treshold), average='macro'))
        current_metrics.append(precision_score(y_true_train, get_labels(y_pred_train, treshold), average='macro'))
        current_metrics.append(precision_score(y_true_test, get_labels(y_pred_test, treshold), average='macro'))
        current_metrics.append(recall_score(y_true_train, get_labels(y_pred_train, treshold), average='macro'))
        current_metrics.append(recall_score(y_true_test, get_labels(y_pred_test, treshold), average='macro'))
    
    for i in range(8):
        step[f'mor mAP class/train {mapping[i]}'] = average_precision_score(y_true_train[:, i], y_pred_train[:, i])
        step[f'mor mAP class/test {mapping[i]}'] = average_precision_score(y_true_test[:, i], y_pred_test[:, i])
        
        current_metrics.append(average_precision_score(y_true_train[:, i], y_pred_train[:, i]))
        current_metrics.append(average_precision_score(y_true_test[:, i], y_pred_test[:, i]))
    
    wandb.log(step)
    
    return current_metrics

In [None]:
wandb.config.epochs = 100
best_metrics_b1 = []
current_metrics_b1 = []
best_metrics_b2 = []
current_metrics_b2 = []
for epoch in range(wandb.config.epochs):
    net.train()
    print('Training:')

    for i, data in enumerate(train_loader, 0):
        inputs, labels_diseases, labels_morphology = data['image'], data['labels_diseases'], data['labels_morphology']
        optimizer.zero_grad()
        o1, o2 = net(inputs.to(device))
        
        loss1 = criterion_diseases(o1, labels_diseases.to(device).long())
        
        mask = np.ones(len(labels_morphology), dtype=bool)
        for idx, l in enumerate(labels_morphology, 0):
            if sum(l) != 0:
                mask[idx] = False
        mask = torch.tensor(mask).int()
        
        labels_morphology = labels_morphology[mask != 1].to(device)
        o2 = o2[mask != 1]
        
        loss2 = criterion_morphology(o2, labels_morphology)
        
        loss = loss1 + loss2
        
        if (i + 1) % 250 == 0:
            print(f'Epoch: {epoch + 1}, {i + 1}/{len(train_loader)}')
        loss.backward()
        optimizer.step()
    
    net.eval()
    with torch.no_grad():

        y_true_train_b1 = np.empty((1, 20))
        y_pred_train_b1 = np.empty((1, 20))
        y_true_train_b2 = np.empty((1, 8))
        y_pred_train_b2 = np.empty((1, 8))
        train_loss_b1 = 0.0
        train_loss_b2 = 0.0
        print('Evaluating train:')
        for i, data in enumerate(train_loader, 0):
            inputs, labels_diseases, labels_morphology = data['image'], data['labels_diseases'], data['labels_morphology']
            o1, o2 = net(inputs.to(device))

            mask = np.ones(len(labels_morphology), dtype=bool)
            for idx, l in enumerate(labels_morphology, 0):
                if sum(l) != 0:
                    mask[idx] = False
            mask = torch.tensor(mask).int()

            labels_morphology = labels_morphology[mask != 1].to(device)
            o2 = o2[mask != 1]

            loss1 = criterion_diseases(o1, labels_diseases.to(device).long())
            loss2 = criterion_morphology(o2, labels_morphology)


            train_loss_b1 += loss1.item()
            train_loss_b2 += loss2.item()

            predicted_b1 = F.softmax(o1).cpu().detach().numpy()
            predicted_b2 = F.softmax(o2).cpu().detach().numpy()

            t = np.zeros((len(labels_diseases), 20))
            for j in range(len(labels_diseases)):
                t[j, labels_diseases[j]] = 1

            y_true_train_b1 = np.concatenate((y_true_train_b1, t))
            y_pred_train_b1 = np.concatenate((y_pred_train_b1, predicted_b1))


            y_true_train_b2 = np.concatenate((y_true_train_b2, labels_morphology.cpu().numpy()))
            y_pred_train_b2 = np.concatenate((y_pred_train_b2, predicted_b2))

            if (i + 1) % 250 == 0:
                print(f'Epoch: {epoch + 1}, {i + 1}/{len(train_loader)}')

        train_loss_b1 = train_loss_b1 / len(train_loader)
        train_loss_b2 = train_loss_b2 / len(train_loader)

        y_true_test_b1 = np.empty((1, 20))
        y_pred_test_b1 = np.empty((1, 20))
        y_true_test_b2 = np.empty((1, 8))
        y_pred_test_b2 = np.empty((1, 8))
        test_loss_b1 = 0.0
        test_loss_b2 = 0.0

        print('Evaluating test:')
        for i, data in enumerate(test_loader, 0):
            inputs, labels_diseases, labels_morphology = data['image'], data['labels_diseases'], data['labels_morphology']
            o1, o2 = net(inputs.to(device))

            mask = np.ones(len(labels_morphology), dtype=bool)
            for idx, l in enumerate(labels_morphology, 0):
                if sum(l) != 0:
                    mask[idx] = False
            mask = torch.tensor(mask).int()

            labels_morphology = labels_morphology[mask != 1].to(device)
            o2 = o2[mask != 1]

            loss1 = criterion_diseases(o1, labels_diseases.to(device).long())
            loss2 = criterion_morphology(o2, labels_morphology)

            test_loss_b1 += loss1.item()
            test_loss_b2 += loss2.item()

            predicted_b1 = F.softmax(o1).cpu().detach().numpy()
            predicted_b2 = F.softmax(o2).cpu().detach().numpy()

            t = np.zeros((len(labels_diseases), 20))
            for j in range(len(labels_diseases)):
                t[j, labels_diseases[j]] = 1


            y_true_test_b1 = np.concatenate((y_true_test_b1, t))
            y_pred_test_b1 = np.concatenate((y_pred_test_b1, predicted_b1))

            y_true_test_b2 = np.concatenate((y_true_test_b2, labels_morphology.cpu().numpy()))
            y_pred_test_b2 = np.concatenate((y_pred_test_b2, predicted_b2))

            if (i + 1) % 100 == 0:
                print(f'Epoch: {epoch + 1}, {i + 1}/{len(test_loader)}')

        test_loss_b1 = test_loss_b1 / len(test_loader)
        test_loss_b2 = test_loss_b2 / len(test_loader)

        y_true_train_b1 = y_true_train_b1[1:]
        y_pred_train_b1 = y_pred_train_b1[1:]
        y_true_test_b1 = y_true_test_b1[1:]
        y_pred_test_b1 = y_pred_test_b1[1:]

        y_true_train_b2 = y_true_train_b2[1:]
        y_pred_train_b2 = y_pred_train_b2[1:]
        y_true_test_b2 = y_true_test_b2[1:]
        y_pred_test_b2 = y_pred_test_b2[1:]


        current_metrics_b1, step = log_epoch_diseases(
                                epoch + 1,
                                y_true_train_b1,
                                y_pred_train_b1,
                                y_true_test_b1,
                                y_pred_test_b1,
                                train_loss_b1,
                                test_loss_b1,
                                mapping_diseases)
        current_metrics_b2 = log_epoch_morphology(
                                step,
                                epoch + 1,
                                y_true_train_b2,
                                y_pred_train_b2,
                                y_true_test_b2,
                                y_pred_test_b2,
                                train_loss_b2,
                                test_loss_b2,
                                mapping_morphology)

        if len(best_metrics_b1) == 0:
            best_metrics_b1 = current_metrics_b2.copy()
            best_metrics_b2 = current_metrics_b2.copy()

        i = 0
        for b, c in zip(best_metrics_b1, current_metrics_b1):
            best_metrics_b1[i] = max(b, c)
            i += 1

        i = 0
        for b, c in zip(best_metrics_b2, current_metrics_b2):
            best_metrics_b2[i] = max(b, c)
            i += 1

print('Finished.')

In [None]:
wandb.run.summary['dis mAP/train'] = best_metrics_b1[0]
wandb.run.summary['dis mAP/test'] = best_metrics_b1[1]
j = 2
wandb.run.summary[f'dis f1/train'] = best_metrics_b1[j]; j += 1 
wandb.run.summary[f'dis f1/test'] = best_metrics_b1[j]; j += 1
wandb.run.summary[f'dis precision/train'] = best_metrics_b1[j]; j += 1 
wandb.run.summary[f'dis precision/test'] = best_metrics_b1[j]; j += 1
wandb.run.summary[f'dis recall/train'] = best_metrics_b1[j]; j += 1
wandb.run.summary[f'dis recall/test'] = best_metrics_b1[j]; j += 1
wandb.run.summary[f'dis accuracy/train'] = best_metrics_b1[j]; j += 1
wandb.run.summary[f'dis accuracy/test'] = best_metrics_b1[j]; j += 1

for i in range(20):
    wandb.run.summary[f'dis mAP class/train {mapping_diseases[i]}'] = best_metrics_b1[j]; j += 1
    wandb.run.summary[f'dis mAP class/test {mapping_diseases[i]}'] = best_metrics_b1[j]; j += 1

In [None]:
wandb.run.summary['mor mAP/train'] = best_metrics_b2[0]
wandb.run.summary['mor mAP/test'] = best_metrics_b2[1]
j = 2
for treshold in np.arange(0.1, 1, 0.1):
    wandb.run.summary[f'mor f1 train/ {round(treshold, 1)}'] = best_metrics_b2[j]; j += 1 
    wandb.run.summary[f'mor f1 test/ {round(treshold, 1)}'] = best_metrics_b2[j]; j += 1
    wandb.run.summary[f'mor precision/train {round(treshold, 1)}'] = best_metrics_b2[j]; j += 1 
    wandb.run.summary[f'mor precision/test {round(treshold, 1)}'] = best_metrics_b2[j]; j += 1
    wandb.run.summary[f'mor recall/train {round(treshold, 1)}'] = best_metrics_b2[j]; j += 1
    wandb.run.summary[f'mor recall/test {round(treshold, 1)}'] = best_metrics_b2[j]; j += 1

for i in range(8):
    wandb.run.summary[f'morphology mAP class/train {mapping_morphology[i]}'] = best_metrics_b2[j]; j += 1
    wandb.run.summary[f'morphology mAP class/test {mapping_morphology[i]}'] = best_metrics_b2[j]; j += 1