In [None]:
import os
import torch
import wandb
import numpy as np
from torchvision.models import resnet34 
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, average_precision_score
from torch.utils.data import SubsetRandomSampler, DataLoader
from torch.optim import Adam
from torchvision import transforms
from torch import nn 

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
wandb.init(project='net_diseases')

In [None]:
if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu" 
device = torch.device(dev) 

In [None]:
class CustomCrop:
    
    def __call__(self, sample):
        shape = sample.shape
        min_dimension = min(shape[1], shape[2])
        center_crop = transforms.CenterCrop(min_dimension)
        sample = center_crop(sample)
        return sample

In [None]:
composed = transforms.Compose(
    [transforms.ToTensor(), CustomCrop(), transforms.Resize((224, 224)),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [None]:
dataset = ImageFolder(root='root', transform=composed)

In [None]:
targets = dataset.targets

In [None]:
train_indexes, test_indexes = train_test_split(np.arange(len(targets)), test_size=0.2, shuffle=True, stratify=targets)

In [None]:
train_sampler = SubsetRandomSampler(train_indexes)
test_sampler = SubsetRandomSampler(test_indexes)

In [None]:
wandb.config.batch_size = 32

In [None]:
train_loader = DataLoader(dataset, batch_size=wandb.config.batch_size, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=wandb.config.batch_size, sampler=test_sampler)

In [None]:
net = resnet34(pretrained=True)
net.fc = nn.Linear(net.fc.in_features, 20)
net = net.to(device)

In [None]:
wandb.config.learning_rate = 0.00001
criterion = nn.CrossEntropyLoss()
optimizer = Adam(net.parameters(), lr=wandb.config.learning_rate)

In [None]:
mapping = ['атопический дерматит',
           'акне',
           'псориаз',
           'розацеа',
           'бородавки',
           'герпес',
           'витилиго',
           'клп',
           'аллергический контактный дерматит',
           'экзема',
           'дерматомикозы',
           'булезный пемфигоид', 
           'пузырчатка',
           'контагиозный моллюск',
           'крапивница',
           'кератоз',
           'чесотка',
           'себореный дерматит',
           'актинический',
           'базалиома']

In [None]:
def log_epoch(epoch, y_true_train, y_pred_train, y_true_test, y_pred_test, train_loss, test_loss):
    step = {'epoch': epoch, 'train loss': train_loss, 'test loss': test_loss}
    
    map_train = average_precision_score(y_true_train.reshape(-1), y_pred_train.reshape(-1))
    map_test = average_precision_score(y_true_test.reshape(-1), y_pred_test.reshape(-1))

    current_metrics = [map_train, map_test]
    
    step['mAP/train'] = map_train
    step['mAP/test'] = map_test
    
    

    step[f'f1/train'] = f1_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro')
    step[f'f1/test'] = f1_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro')
    step[f'precision/train'] = precision_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro')
    step[f'precision/test'] = precision_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro')
    step[f'recall/train'] = recall_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro')
    step[f'recall/test'] = recall_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro')
    step[f'accuracy/train'] = accuracy_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1))
    step[f'accuracy/test'] = accuracy_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1))

    current_metrics.append(f1_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro'))
    current_metrics.append(f1_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro'))
    current_metrics.append(precision_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro'))
    current_metrics.append(precision_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro'))
    current_metrics.append(recall_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1), average='macro'))
    current_metrics.append(recall_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1), average='macro'))
    current_metrics.append(accuracy_score(np.argmax(y_true_train, 1), np.argmax(y_pred_train, 1)))
    current_metrics.append(accuracy_score(np.argmax(y_true_test, 1), np.argmax(y_pred_test, 1)))
    
    for i in range(20):
        step[f'mAP class train/{mapping[i]}'] = average_precision_score(y_true_train[:, i], y_pred_train[:, i])
        step[f'mAP class test/{mapping[i]}'] = average_precision_score(y_true_test[:, i], y_pred_test[:, i])
        
        current_metrics.append(average_precision_score(y_true_train[:, i], y_pred_train[:, i]))
        current_metrics.append(average_precision_score(y_true_test[:, i], y_pred_test[:, i]))
    
    
    wandb.log(step)
    return current_metrics

In [None]:
wandb.config.epochs = 100

In [None]:
best_metrics = []
current_metrics = []
for epoch in range(wandb.config.epochs):
    print('Training:')
    net.train()
    running_loss = 0.0
    j = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0], data[1]
        optimizer.zero_grad()
        outputs = net(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        j += 1
        if (i + 1) % 100 == 0:
            print(f'Epoch: {epoch + 1}, {i + 1}/{len(train_loader)}, loss: {running_loss / j}')
            running_loss = 0.0
            j = 0
            
    
    net.eval()
    
    with torch.no_grad():
        print('Evaluating train:')
        y_true_train = np.empty((1, 20))
        y_pred_train = np.empty((1, 20))
        train_loss = 0.0
        
        for i, data in enumerate(train_loader, 0):
            images, labels = data[0], data[1]
            outputs = net(images.to(device))
            loss = criterion(outputs, labels.to(device))
            train_loss += loss.item()
            
            t = np.zeros((len(labels), 20))
            for j in range(len(labels)):
                t[j, labels[j]] = 1
            
            predicted = nn.functional.softmax(outputs).cpu().detach().numpy()
            y_true_train = np.concatenate((y_true_train, t))
            y_pred_train = np.concatenate((y_pred_train, predicted))
        
            if (i + 1) % 100 == 0:
                print(f'Epoch: {epoch + 1}, {i + 1}/{len(train_loader)}')

        train_loss = train_loss / len(train_loader)
        
        
        print('Evaluating test:')
        y_true_test = np.empty((1, 20))
        y_pred_test = np.empty((1, 20))
        test_loss = 0.0
        
        for i, data in enumerate(test_loader, 0):
            images, labels = data[0], data[1]
            outputs = net(images.to(device))
            loss = criterion(outputs, labels.to(device))
            test_loss += loss.item()
            
            
            t = np.zeros((len(labels), 20))
            for j in range(len(labels)):
                t[j, labels[j]] = 1
            
            predicted = nn.functional.softmax(outputs).cpu().detach().numpy()
            y_true_test = np.concatenate((y_true_test, t))
            y_pred_test = np.concatenate((y_pred_test, predicted))
            
            if (i + 1) % 100 == 0:
                print(f'Epoch: {epoch + 1}, {i + 1}/{len(test_loader)}')
            
        test_loss = test_loss / len(test_loader)
    
    y_true_train = y_true_train[1:]
    y_pred_train = y_pred_train[1:]
    y_true_test = y_true_test[1:]
    y_pred_test = y_pred_test[1:]
        
    current_metrics = log_epoch(epoch + 1,
                                y_true_train,
                                y_pred_train,
                                y_true_test,
                                y_pred_test,
                                train_loss,
                                test_loss
    )
    
    if len(best_metrics) == 0:
        best_metrics = current_metrics.copy()
    
    i = 0
    for b, c in zip(best_metrics, current_metrics):
        best_metrics[i] = max(b, c)
        i += 1
        
    torch.save(net.state_dict(), f'net_{epoch}.pt')
    torch.save(optimizer.state_dict(), f'opt_{epoch}.pt')
    
    if os.path.exists(f'net_{epoch - 1}.pt'):
        os.remove(f'net_{epoch - 1}.pt')
        os.remove(f'opt_{epoch - 1}.pt')
        
    
print('Finished')

In [None]:
wandb.run.summary['mAP/train'] = best_metrics[0]
wandb.run.summary['mAP/test'] = best_metrics[1]
j = 2
wandb.run.summary[f'f1/train'] = best_metrics[j]; j += 1 
wandb.run.summary[f'f1/test'] = best_metrics[j]; j += 1
wandb.run.summary[f'precision/train'] = best_metrics[j]; j += 1 
wandb.run.summary[f'precision/test'] = best_metrics[j]; j += 1
wandb.run.summary[f'recall/train'] = best_metrics[j]; j += 1
wandb.run.summary[f'recall/test'] = best_metrics[j]; j += 1
wandb.run.summary[f'accuracy/train'] = best_metrics[j]; j += 1
wandb.run.summary[f'accuracy/test'] = best_metrics[j]; j += 1

for i in range(20):
    wandb.run.summary[f'mAP class train/{mapping[i]}'] = best_metrics[j]; j += 1
    wandb.run.summary[f'mAP class test/{mapping[i]}'] = best_metrics[j]; j += 1