In [None]:
 import os

import torch
import wandb
import numpy as np
from skimage import io, transform
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.models import resnet34 
from torch.optim import Adam
from torch import nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, recall_score, precision_score, average_precision_score

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
wandb.init(project='net_primary_morphology')

In [None]:
if torch.cuda.is_available():
    dev = 'cuda:0'
else:
    dev = 'cpu'
device = torch.device(dev)

In [None]:
root_dir = '/mnt/tank/scratch/esergeenko/net_primary_morph'

In [None]:
pos_weight = torch.zeros(8)
for file in os.listdir(root_dir):
    ms = file.strip('.jpg').split('_')[1:]
    for m in ms:
        pos_weight[int(m) - 1] += 1 
pos_weight = pos_weight / pos_weight.sum()
pos_weight = pos_weight.to(device)

In [None]:
class CustomCrop:
    
    def __call__(self, sample):
        shape = sample.shape
        min_dimension = min(shape[1], shape[2])
        center_crop = transforms.CenterCrop(min_dimension)
        sample = center_crop(sample)
        return sample

In [None]:
composed = transforms.Compose(
    [transforms.ToTensor(), CustomCrop(), transforms.Resize((224, 224)),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [None]:
class MorphDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len([name for name in os.listdir(self.root_dir) if os.path.isfile(os.path.join(self.root_dir, name))])
    
    
    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
            
        for file in os.listdir(self.root_dir):
            if file.startswith(f'{idx}_'):
                filename = file
                break
        

        labels = filename.strip('.jpg').split('_')[1:]
        labels = [int(l) for l in labels]
        image = io.imread(f'{self.root_dir}\{filename}').copy()
        
        label = torch.zeros(8)
        
        for l in labels:
            label[int(l) - 1] = 1
            
        if self.transform:
            sample = {'image': self.transform(image), 'label': label}
        else:
            sample = {'image': image, 'label': label}
        return sample

In [None]:
dataset = MorphDataset(root_dir, transform=composed)

In [None]:
y = []
for i in range(len(dataset)):
    y.append(max(dataset[i]['label'].tolist()))

In [None]:
train_indexes, test_indexes = train_test_split(np.arange(len(y)), test_size=0.2, shuffle=True, stratify=y)
train_sampler = SubsetRandomSampler(train_indexes)
test_sampler = SubsetRandomSampler(test_indexes)

In [None]:
wandb.config.batch_size = 16

In [None]:
train_loader = DataLoader(dataset, batch_size=wandb.config.batch_size, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=wandb.config.batch_size, sampler=test_sampler)

In [None]:
net = resnet34(pretrained=True)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

In [None]:
wandb.config.lr = 0.00001
net.fc = nn.Linear(net.fc.in_features, 8)
net = net.to(device)
optimizer = Adam(net.parameters(), wandb.config.lr)

In [None]:
def get_labels(predictions, treshold):
    return (predictions > treshold).astype(int)

In [None]:
mapping = [
    'пятно',
    'бугорок',
    'узел',
    'папула',
    'волдырь',
    'пузырек',
    'пузырь',
    'гнойничок',
    'гиперпигментация',
    'гипопигментация',
    'эрозия',
    'язва',
    'чешуйка',
    'корка',
    'рубец',
    'трещина',
    'экскориация',
    'кератоз',
    'лихенификация',
    'вегетация',
    'дерматосклероз',
    'анетодермия',
    'атрофодермия',
]

In [None]:
def log_epoch(epoch, y_true_train, y_pred_train, y_true_test, y_pred_test, train_loss, test_loss):
    step = {'epoch': epoch, 'train loss': train_loss, 'test loss': test_loss}
    
    map_train = average_precision_score(y_true_train.reshape(-1), y_pred_train.reshape(-1))
    map_test = average_precision_score(y_true_test.reshape(-1), y_pred_test.reshape(-1))

    current_metrics = [map_train, map_test]
    
    step['mAP/train'] = map_train
    step['mAP/test'] = map_test
    
    
    for treshold in np.arange(0.1, 1, 0.1):
        step[f'f1 train/{round(treshold, 1)}'] = f1_score(y_true_train, get_labels(y_pred_train, treshold), average='macro')
        step[f'f1 test/{round(treshold, 1)}'] = f1_score(y_true_test, get_labels(y_pred_test, treshold), average='macro')
        step[f'precision train/{round(treshold, 1)}'] = precision_score(y_true_train, get_labels(y_pred_train, treshold), average='macro')
        step[f'precision test/{round(treshold, 1)}'] = precision_score(y_true_test, get_labels(y_pred_test, treshold), average='macro')
        step[f'recall train/{round(treshold, 1)}'] = recall_score(y_true_train, get_labels(y_pred_train, treshold), average='macro')
        step[f'recall test/{round(treshold, 1)}'] = recall_score(y_true_test, get_labels(y_pred_test, treshold), average='macro')
        
        current_metrics.append(f1_score(y_true_train, get_labels(y_pred_train, treshold), average='macro'))
        current_metrics.append(f1_score(y_true_test, get_labels(y_pred_test, treshold), average='macro'))
        current_metrics.append(precision_score(y_true_train, get_labels(y_pred_train, treshold), average='macro'))
        current_metrics.append(precision_score(y_true_test, get_labels(y_pred_test, treshold), average='macro'))
        current_metrics.append(recall_score(y_true_train, get_labels(y_pred_train, treshold), average='macro'))
        current_metrics.append(recall_score(y_true_test, get_labels(y_pred_test, treshold), average='macro'))
    
    for i in range(8):
        step[f'mAP class train/{mapping[i]}'] = average_precision_score(y_true_train[:, i], y_pred_train[:, i])
        step[f'mAP class test/{mapping[i]}'] = average_precision_score(y_true_test[:, i], y_pred_test[:, i])
        
        current_metrics.append(average_precision_score(y_true_train[:, i], y_pred_train[:, i]))
        current_metrics.append(average_precision_score(y_true_test[:, i], y_pred_test[:, i]))
    
    
    wandb.log(step)
    return current_metrics

In [None]:
wandb.config.epochs = 100

In [None]:
best_metrics = []
current_metrics = []
for epoch in range(wandb.config.epochs):
    net.train()
    running_loss = 0.0
    j = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data['image'], data['label']
        optimizer.zero_grad()
        outputs = net(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        j += 1
        if (i + 1) % 100 == 0:
            print(f'Epoch: {epoch + 1}, {i + 1}/{len(train_loader)}, loss: {running_loss / j}')
            running_loss = 0.0
            j = 0
            
    print('val')      
    net.eval()
    
    with torch.no_grad():
        y_true_train = np.empty((1, 8))
        y_pred_train = np.empty((1, 8))
        train_loss = 0.0
        
        for i, data in enumerate(train_loader, 0):
            images, labels = data['image'], data['label']
            outputs = net(images.to(device))
            loss = criterion(outputs, labels.to(device))
            train_loss += loss.item()
            
            predicted = nn.functional.softmax(outputs).cpu().detach().numpy()
            y_true_train = np.concatenate((y_true_train, labels.numpy()))
            y_pred_train = np.concatenate((y_pred_train, predicted))
        
            if (i + 1) % 100 == 0:
                print(f'Epoch: {epoch + 1}, {i + 1}/{len(train_loader)}')

        train_loss = train_loss / len(train_loader)

        y_true_test = np.empty((1, 8))
        y_pred_test = np.empty((1, 8))
        test_loss = 0.0
        
        for i, data in enumerate(test_loader, 0):
            images, labels = data['image'], data['label']
            outputs = net(images.to(device))
            loss = criterion(outputs, labels.to(device))
            test_loss += loss.item()
            
            predicted = nn.functional.softmax(outputs).cpu().detach().numpy()
            y_true_test = np.concatenate((y_true_test, labels.numpy()))
            y_pred_test = np.concatenate((y_pred_test, predicted))
            
            if (i + 1) % 100 == 0:
                print(f'Epoch: {epoch + 1}, {i + 1}/{len(test_loader)}')
            
        test_loss = test_loss / len(test_loader)
    
    y_true_train = y_true_train[1:]
    y_pred_train = y_pred_train[1:]
    y_true_test = y_true_test[1:]
    y_pred_test = y_pred_test[1:]
        
    current_metrics = log_epoch(epoch + 1,
                                y_true_train,
                                y_pred_train,
                                y_true_test,
                                y_pred_test,
                                train_loss,
                                test_loss
    )
    
    if len(best_metrics) == 0:
        best_metrics = current_metrics.copy()
    
    i = 0
    for b, c in zip(best_metrics, current_metrics):
        best_metrics[i] = max(b, c)
        i += 1
        
    torch.save(net.state_dict(), f'net_{epoch}.pt')
    torch.save(optimizer.state_dict(), f'opt_{epoch}.pt')
    
    if os.path.exists(f'net_{epoch - 1}.pt'):
        os.remove(f'net_{epoch - 1}.pt')
        os.remove(f'opt_{epoch - 1}.pt')
        
    
print('Finished')

In [None]:
wandb.run.summary['mAP/train'] = best_metrics[0]
wandb.run.summary['mAP/test'] = best_metrics[1]
j = 2
for treshold in np.arange(0.1, 1, 0.1):
    wandb.run.summary[f'f1 train/{round(treshold, 1)}'] = best_metrics[j]; j += 1 
    wandb.run.summary[f'f1 test/{round(treshold, 1)}'] = best_metrics[j]; j += 1
    wandb.run.summary[f'precision train/{round(treshold, 1)}'] = best_metrics[j]; j += 1 
    wandb.run.summary[f'precision test/{round(treshold, 1)}'] = best_metrics[j]; j += 1
    wandb.run.summary[f'recall train/{round(treshold, 1)}'] = best_metrics[j]; j += 1
    wandb.run.summary[f'recall test/{round(treshold, 1)}'] = best_metrics[j]; j += 1

for i in range(8):
    wandb.run.summary[f'mAP class train/{mapping[i]}'] = best_metrics[j]; j += 1
    wandb.run.summary[f'mAP class test/{mapping[i]}'] = best_metrics[j]; j += 1