In [1]:
import torch
import os, os.path
from skimage import io, transform
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import matplotlib.pyplot as plt
import pickle 
import numpy as np
from torchvision.models import resnet34 
from torch.optim import Adam
from torch import nn
from sklearn.metrics import f1_score, recall_score, precision_score, average_precision_score
import wandb

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
resume=False

In [4]:
wandb.init(project='net_morphology')

wandb: Currently logged in as: truffaut (use `wandb login --relogin` to force relogin)
wandb: wandb version 0.10.22 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


In [5]:
# with open('y_stratified.pickle', 'wb') as f:
#     pickle.dump(y, f)

In [6]:
with open('y_stratified.pickle', 'rb') as f:
    y = pickle.load(f)

In [7]:
if torch.cuda.is_available():
    dev = 'cuda:0'
else:
    dev = 'cpu'
device = torch.device(dev)

In [8]:
pos_weight = torch.zeros(23)
for file in os.listdir('net_morph'):
    ms = file.strip('.jpg').split('_')[1:]
    for m in ms:
        pos_weight[int(m) - 1] += 1 
pos_weight = pos_weight / pos_weight.sum()
pos_weight = pos_weight.to(device)

In [9]:
class CustomCrop:
    
    def __call__(self, sample):
        shape = sample.shape
        min_dimension = min(shape[1], shape[2])
        center_crop = transforms.CenterCrop(min_dimension)
        sample = center_crop(sample)
        return sample

In [10]:
composed = transforms.Compose(
    [transforms.ToTensor(), CustomCrop(), transforms.Resize((224, 224)),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [11]:
class MorphDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len([name for name in os.listdir(self.root_dir) if os.path.isfile(os.path.join(self.root_dir, name))])
    
    
    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
            
        for file in os.listdir(self.root_dir):
            if file.startswith(f'{idx}_'):
                filename = file
                break
        

        labels = filename.strip('.jpg').split('_')[1:]
        labels = [int(l) for l in labels]
        image = io.imread(f'{self.root_dir}\{filename}').copy()
        
        label = torch.zeros(23)
        
        for l in labels:
            label[int(l) - 1] = 1
            
        if self.transform:
            sample = {'image': self.transform(image), 'label': label}
        else:
            sample = {'image': image, 'label': label}
        return sample

In [12]:
dataset = MorphDataset('net_morph', transform=composed)

In [13]:
y = []
for i in range(len(dataset)):
    y.append(max(dataset[i]['label'].tolist()))

In [14]:
train_indexes, test_indexes = train_test_split(np.arange(len(y)), test_size=0.2, shuffle=True, stratify=y)
train_sampler = SubsetRandomSampler(train_indexes)
test_sampler = SubsetRandomSampler(test_indexes)

In [15]:
wandb.config.batch_size = 16

In [16]:
train_loader = DataLoader(dataset, batch_size=wandb.config.batch_size, sampler=train_sampler)
test_loader = DataLoader(dataset, batch_size=wandb.config.batch_size, sampler=test_sampler)

In [17]:
net = resnet34(pretrained=True)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

In [18]:
wandb.config.lr = 0.00001
net.fc = nn.Linear(net.fc.in_features, 23)
net = net.to(device)
optimizer = Adam(net.parameters(), wandb.config.lr)

In [19]:
if resume:
    last = int(input())
    net_checkpoint = torch.load(f'net_{last}.pt')
    opt_checkpoint = torch.load(f'opt_{last}.pt')
    net.load_state_dict(net_checkpoint)
    opt.load_state_dict(opt_checkpoint)

In [20]:
def get_labels(predictions, treshold):
    return (predictions > treshold).astype(int)

In [21]:
mapping = [
    'пятно',
    'бугорок',
    'узел',
    'папула',
    'волдырь',
    'пузырек',
    'пузырь',
    'гнойничок',
    'гиперпигментация',
    'гипопигментация',
    'эрозия',
    'язва',
    'чешуйка',
    'корка',
    'рубец',
    'трещина',
    'экскориация',
    'кератоз',
    'лихенификация',
    'вегетация',
    'дерматосклероз',
    'анетодермия',
    'атрофодермия',
]

In [22]:
def log_epoch(epoch, y_true_train, y_pred_train, y_true_test, y_pred_test, train_loss, test_loss):
    step = {'epoch': epoch, 'train loss': train_loss, 'test loss': test_loss}
    
    map_train = average_precision_score(y_true_train.reshape(-1), y_pred_train.reshape(-1))
    map_test = average_precision_score(y_true_test.reshape(-1), y_pred_test.reshape(-1))

    current_metrics = [map_train, map_test]
    
    step['mAP/train'] = map_train
    step['mAP/test'] = map_test
    
    
    for treshold in np.arange(0.1, 1, 0.1):
        step[f'f1 train/{round(treshold, 1)}'] = f1_score(y_true_train, get_labels(y_pred_train, treshold), average='macro')
        step[f'f1 test/{round(treshold, 1)}'] = f1_score(y_true_test, get_labels(y_pred_test, treshold), average='macro')
        step[f'precision train/{round(treshold, 1)}'] = precision_score(y_true_train, get_labels(y_pred_train, treshold), average='macro')
        step[f'precision test/{round(treshold, 1)}'] = precision_score(y_true_test, get_labels(y_pred_test, treshold), average='macro')
        step[f'recall train/{round(treshold, 1)}'] = recall_score(y_true_train, get_labels(y_pred_train, treshold), average='macro')
        step[f'recall test/{round(treshold, 1)}'] = recall_score(y_true_test, get_labels(y_pred_test, treshold), average='macro')
        
        current_metrics.append(f1_score(y_true_train, get_labels(y_pred_train, treshold), average='macro'))
        current_metrics.append(f1_score(y_true_test, get_labels(y_pred_test, treshold), average='macro'))
        current_metrics.append(precision_score(y_true_train, get_labels(y_pred_train, treshold), average='macro'))
        current_metrics.append(precision_score(y_true_test, get_labels(y_pred_test, treshold), average='macro'))
        current_metrics.append(recall_score(y_true_train, get_labels(y_pred_train, treshold), average='macro'))
        current_metrics.append(recall_score(y_true_test, get_labels(y_pred_test, treshold), average='macro'))
    
    for i in range(23):
        step[f'mAP class train/{mapping[i]}'] = average_precision_score(y_true_train[:, i], y_pred_train[:, i])
        step[f'mAP class test/{mapping[i]}'] = average_precision_score(y_true_test[:, i], y_pred_test[:, i])
        
        current_metrics.append(average_precision_score(y_true_train[:, i], y_pred_train[:, i]))
        current_metrics.append(average_precision_score(y_true_test[:, i], y_pred_test[:, i]))
    
    
    wandb.log(step)
    return current_metrics

In [23]:
wandb.config.epochs = 100

In [None]:
best_metrics = []
current_metrics = []
for epoch in range(wandb.config.epochs):
    net.train()
    running_loss = 0.0
    j = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data['image'], data['label']
        optimizer.zero_grad()
        outputs = net(inputs.to(device))
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        j += 1
        if (i + 1) % 100 == 0:
            print(f'Epoch: {epoch + 1}, {i + 1}/{len(train_loader)}, loss: {running_loss / j}')
            running_loss = 0.0
            j = 0
            
    print('val')      
    net.eval()
    
    with torch.no_grad():
        y_true_train = np.empty((1, 23))
        y_pred_train = np.empty((1, 23))
        train_loss = 0.0
        
        for i, data in enumerate(train_loader, 0):
            images, labels = data['image'], data['label']
            outputs = net(images.to(device))
            loss = criterion(outputs, labels.to(device))
            train_loss += loss.item()
            
            predicted = nn.functional.softmax(outputs).cpu().detach().numpy()
            y_true_train = np.concatenate((y_true_train, labels.numpy()))
            y_pred_train = np.concatenate((y_pred_train, predicted))
        
            if (i + 1) % 100 == 0:
                print(f'Epoch: {epoch + 1}, {i + 1}/{len(train_loader)}')

        train_loss = train_loss / len(train_loader)

        y_true_test = np.empty((1, 23))
        y_pred_test = np.empty((1, 23))
        test_loss = 0.0
        
        for i, data in enumerate(test_loader, 0):
            images, labels = data['image'], data['label']
            outputs = net(images.to(device))
            loss = criterion(outputs, labels.to(device))
            test_loss += loss.item()
            
            predicted = nn.functional.softmax(outputs).cpu().detach().numpy()
            y_true_test = np.concatenate((y_true_test, labels.numpy()))
            y_pred_test = np.concatenate((y_pred_test, predicted))
            
            if (i + 1) % 100 == 0:
                print(f'Epoch: {epoch + 1}, {i + 1}/{len(test_loader)}')
            
        test_loss = test_loss / len(test_loader)
    
    y_true_train = y_true_train[1:]
    y_pred_train = y_pred_train[1:]
    y_true_test = y_true_test[1:]
    y_pred_test = y_pred_test[1:]
        
    current_metrics = log_epoch(epoch + 1,
                                y_true_train,
                                y_pred_train,
                                y_true_test,
                                y_pred_test,
                                train_loss,
                                test_loss
    )
    
    if len(best_metrics) == 0:
        best_metrics = current_metrics.copy()
    
    i = 0
    for b, c in zip(best_metrics, current_metrics):
        best_metrics[i] = max(b, c)
        i += 1
        
    torch.save(net.state_dict(), f'net_{epoch}.pt')
    torch.save(optimizer.state_dict(), f'opt_{epoch}.pt')
    
    if os.path.exists(f'net_{epoch - 1}.pt'):
        os.remove(f'net_{epoch - 1}.pt')
        os.remove(f'opt_{epoch - 1}.pt')
        
    
print('Finished')

Epoch: 1, 100/421, loss: 0.4457943019270897
Epoch: 1, 200/421, loss: 0.2926978226006031
Epoch: 1, 300/421, loss: 0.18201131790876388
Epoch: 1, 400/421, loss: 0.11806857846677303
val
Epoch: 1, 100/421
Epoch: 1, 200/421
Epoch: 1, 300/421
Epoch: 1, 400/421
Epoch: 1, 100/106
Epoch: 2, 100/421, loss: 0.08540832564234734
Epoch: 2, 200/421, loss: 0.07103590473532677
Epoch: 2, 300/421, loss: 0.06136499915271997
Epoch: 2, 400/421, loss: 0.055053204223513605
val
Epoch: 2, 100/421
Epoch: 2, 200/421
Epoch: 2, 300/421
Epoch: 2, 400/421
Epoch: 2, 100/106
Epoch: 3, 100/421, loss: 0.048906770311295984
Epoch: 3, 200/421, loss: 0.04559596471488476
Epoch: 3, 300/421, loss: 0.04273132801055908
Epoch: 3, 400/421, loss: 0.04013422004878521
val
Epoch: 3, 100/421
Epoch: 3, 200/421
Epoch: 3, 300/421
Epoch: 3, 400/421
Epoch: 3, 100/106
Epoch: 4, 100/421, loss: 0.03791448283940554
Epoch: 4, 200/421, loss: 0.03650893883779645
Epoch: 4, 300/421, loss: 0.03543317498639226
Epoch: 4, 400/421, loss: 0.0343943078815937

Epoch: 30, 100/421, loss: 0.0047243833192624155
Epoch: 30, 200/421, loss: 0.004825017008697614
Epoch: 30, 300/421, loss: 0.005069098402746022
Epoch: 30, 400/421, loss: 0.0050748637225478885
val
Epoch: 30, 100/421
Epoch: 30, 200/421
Epoch: 30, 300/421
Epoch: 30, 400/421
Epoch: 30, 100/106
Epoch: 31, 100/421, loss: 0.004882832455914468
Epoch: 31, 200/421, loss: 0.005120495883747935
Epoch: 31, 300/421, loss: 0.004815945859299973
Epoch: 31, 400/421, loss: 0.004857943393290043
val
Epoch: 31, 100/421
Epoch: 31, 200/421
Epoch: 31, 300/421
Epoch: 31, 400/421
Epoch: 31, 100/106
Epoch: 32, 100/421, loss: 0.004613061678828672
Epoch: 32, 200/421, loss: 0.004640693372348324
Epoch: 32, 300/421, loss: 0.004854259650455788
Epoch: 32, 400/421, loss: 0.00457751342561096
val
Epoch: 32, 100/421
Epoch: 32, 200/421
Epoch: 32, 300/421
Epoch: 32, 400/421
Epoch: 32, 100/106
Epoch: 33, 100/421, loss: 0.004796802076743916
Epoch: 33, 200/421, loss: 0.004657987850951031
Epoch: 33, 300/421, loss: 0.0044349727453663

Epoch: 58, 400/421, loss: 0.0018266732507618143
val
Epoch: 58, 100/421
Epoch: 58, 200/421
Epoch: 58, 300/421
Epoch: 58, 400/421
Epoch: 58, 100/106
Epoch: 59, 100/421, loss: 0.0016542535368353128
Epoch: 59, 200/421, loss: 0.0016961736246594228
Epoch: 59, 300/421, loss: 0.0018113770242780447
Epoch: 59, 400/421, loss: 0.0018709315644809977
val
Epoch: 59, 100/421
Epoch: 59, 200/421
Epoch: 59, 300/421
Epoch: 59, 400/421
Epoch: 59, 100/106
Epoch: 60, 100/421, loss: 0.0017109216286917217
Epoch: 60, 200/421, loss: 0.0017769121628953145
Epoch: 60, 300/421, loss: 0.0015562990022590384
Epoch: 60, 400/421, loss: 0.0017971111298538745
val
Epoch: 60, 100/421
Epoch: 60, 200/421
Epoch: 60, 300/421
Epoch: 60, 400/421
Epoch: 60, 100/106
Epoch: 61, 100/421, loss: 0.001423242571472656
Epoch: 61, 200/421, loss: 0.0017923717678058892
Epoch: 61, 300/421, loss: 0.001750935512536671
Epoch: 61, 400/421, loss: 0.0019322068692417815
val
Epoch: 61, 100/421
Epoch: 61, 200/421
Epoch: 61, 300/421
Epoch: 61, 400/421
E

In [None]:
wandb.run.summary['mAP/train'] = best_metrics[0]
wandb.run.summary['mAP/test'] = best_metrics[1]
j = 2
for treshold in np.arange(0.1, 1, 0.1):
    wandb.run.summary[f'f1 train/{round(treshold, 1)}'] = best_metrics[j]; j += 1 
    wandb.run.summary[f'f1 test/{round(treshold, 1)}'] = best_metrics[j]; j += 1
    wandb.run.summary[f'precision train/{round(treshold, 1)}'] = best_metrics[j]; j += 1 
    wandb.run.summary[f'precision test/{round(treshold, 1)}'] = best_metrics[j]; j += 1
    wandb.run.summary[f'recall train/{round(treshold, 1)}'] = best_metrics[j]; j += 1
    wandb.run.summary[f'recall test/{round(treshold, 1)}'] = best_metrics[j]; j += 1

for i in range(23):
    wandb.run.summary[f'mAP class train/{mapping[i]}'] = best_metrics[j]; j += 1
    wandb.run.summary[f'mAP class test/{mapping[i]}'] = best_metrics[j]; j += 1
