In [None]:
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision import transforms
from random import randrange
import torch
import os
import csv


class ISICDataset(Dataset):
    def __init__(self, img_dir, labels_dir):
        file = open(labels_dir, "r")
        csv_reader = csv.reader(file)

        self.img_labels = []
        for row in csv_reader:
            self.img_labels.append(row)
        
        self.img_dir = img_dir
        self.transform = transforms.Compose([
                        transforms.Resize((224,224)), 
                        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                             ])

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx][0])
        label = randrange(4)
        image = torch.rot90(self.transform(read_image(img_path)/255).permute(1,2,0), label, [0,1]).permute(2,0,1)
        return image, label

In [None]:
traindir = '../input/isic-2018/ISIC2018_Task1-2_Training_Input/ISIC2018_Task1-2_Training_Input'
train_labels_dir = '../input/isic-2018/train_labels.csv'
validdir = '../input/isic-2018/ISIC2018_Task1-2_Validation_Input/ISIC2018_Task1-2_Validation_Input'
valid_labels_dir = '../input/isic-2018/valid_labels.csv'

In [None]:
from torch.utils.data import random_split, DataLoader

data = ISICDataset(traindir, train_labels_dir)
train_dl = DataLoader(data, batch_size = 32, shuffle = True, pin_memory = True)

valid_data = ISICDataset(validdir, valid_labels_dir)
valid_dl = DataLoader(valid_data, batch_size = 10)

In [None]:
import torchvision
import torch
model = torchvision.models.resnet34(pretrained = True)
model.fc = torch.nn.Linear(in_features = 512, out_features = 4, bias = True)

In [None]:
import wandb
wandb.init(project='ISIC2018', entity='tro2vs')
config = wandb.config
config.learning_rate = 1e-5
wandb.watch(model)

In [None]:
from sklearn.metrics import f1_score
import time

def norm_pred(pred):
    return pred.argmax(dim = 1)

def acc(labels, pred):
    return (pred.argmax(dim = 1) == labels.cuda()).sum()/len(labels)

def train(model, train_dl, valid_dl, loss_fn, optimizer, epochs=1):
    start = time.time()
    model.cuda()
    best_acc = 0
    
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  
                dataloader = train_dl
            else:
                model.train(False)
                dataloader = valid_dl

            running_loss = 0.0
            running_acc = 0.0

            step = 0
            
            batches = 0
            
            f1 = [0 for i in range(4)]
            f1_mic = 0
            f1_mac = 0
            f1_weighted = 0
            
            for x, y in dataloader:
                x = x.cuda()
                y = y.cuda()
                step += 1

                if phase == 'train':
                    optimizer.zero_grad()
                    outputs = model(x)
                    loss = loss_fn(outputs, y)

                    loss.backward()
                    optimizer.step()

                else:
                    with torch.no_grad():
                        outputs = model(x)
                        loss = loss_fn(outputs, y)
                
                
                f1_0 = f1_score(y.cpu().detach().numpy(),
                                            norm_pred(outputs).cpu().detach().numpy(),
                                  average = None, labels=[0,1,2,3])
                
                f1 = [f1_0[i]+f1[i] for i in range(4)]
                
                f1_mic += f1_score(y.cpu().detach().numpy(),
                                            norm_pred(outputs).cpu().detach().numpy(),
                                  average = 'micro', labels=[0,1,2,3])
                f1_mac += f1_score(y.cpu().detach().numpy(),
                                            norm_pred(outputs).cpu().detach().numpy(),
                                  average = 'macro', labels=[0,1,2,3])
                f1_weighted += f1_score(y.cpu().detach().numpy(),
                                            norm_pred(outputs).cpu().detach().numpy(),
                                  average = 'weighted', labels=[0,1,2,3])
                
                running_acc  += acc(y, outputs)*dataloader.batch_size
                running_loss += loss*dataloader.batch_size
                batches += dataloader.batch_size

                if step % 10 == 0:
                    print('Current step: {}  Loss: {}  Acc: {}  AllocMem (Mb): {}'.format(step, loss, acc(y, outputs), torch.cuda.memory_allocated()/1024/1024))
                
          
                
            epoch_loss = running_loss / batches
            epoch_acc = running_acc / batches
            
            print('{} Loss: {:.4f} Acc: {}'.format(phase, epoch_loss, epoch_acc))
            
            if phase == 'valid':
                wandb.log({"Loss/test": epoch_loss})
                wandb.log({"Accuracy/test": epoch_acc})
                wandb.log({"f1-score/test/micro": f1_mic})
                wandb.log({"f1-score/test/macro": f1_mac})
                wandb.log({"f1-score/test/weighted": f1_weighted})

                for i in range(4):
                    wandb.log({"f1-score/test/class_"+str(i): f1[i]})      
                
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), "best_model.pt")
                    wandb.save('./best_model.pt')
            else:
                torch.save(model.state_dict(), "full_train_model.pt")
                wandb.save('./full_train_model.pt')
                
                wandb.log({"Loss/train": epoch_loss})
                wandb.log({"Accuracy/train": epoch_acc})
                wandb.log({"f1-score/train/micro": f1_mic})
                wandb.log({"f1-score/train/macro": f1_mac})
                wandb.log({"f1-score/train/weighted": f1_weighted})
                
                for i in range(4):
                    wandb.log({"f1-score/train/class_"+str(i): f1[i]})
                


    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    #model = torch.load("best_model.pt")

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr = 1e-5)
train(model, train_dl, valid_dl, loss_fn, 
                             opt, epochs = 30)

In [None]:
wandb.finish()