In [None]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.io import read_image
from torchvision import transforms
import torch
import numpy as np
import csv
import os

class ISICDataset(Dataset):
    
    def __init__(self, img_dir, labels_dir, patch_dim, gap):
        self.patch_dim, self.gap = patch_dim, gap
        file = open(labels_dir, "r")
        csv_reader = csv.reader(file)

        self.img_labels = []
        for row in csv_reader:
            self.img_labels.append(row)
        
        self.img_dir = img_dir
        
        self.transform = transforms.Compose([
                        transforms.Resize((700,700)), 
                        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                                             ])
        
    def get_patch_from_grid(self, image, patch_dim, gap):
        image = np.array(image)
        
        offset_x, offset_y = image.shape[0] - (patch_dim*3 + gap*2), image.shape[1] - (patch_dim*3 + gap*2)
        start_grid_x, start_grid_y = np.random.randint(0, offset_x), np.random.randint(0, offset_y)
        patch_loc_arr = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2), (3, 3)]
        loc = np.random.randint(len(patch_loc_arr))
        tempx, tempy = patch_loc_arr[loc]
    
        patch_x_pt = start_grid_x + patch_dim * (tempx-1) + gap * (tempx-1)
        patch_y_pt = start_grid_y + patch_dim * (tempy-1) + gap * (tempy-1)
        random_patch = image[patch_x_pt:patch_x_pt+patch_dim, patch_y_pt:patch_y_pt+patch_dim]

        patch_x_pt = start_grid_x + patch_dim * (2-1) + gap * (2-1)
        patch_y_pt = start_grid_y + patch_dim * (2-1) + gap * (2-1)
        uniform_patch = image[patch_x_pt:patch_x_pt+patch_dim, patch_y_pt:patch_y_pt+patch_dim]
    
        random_patch_label = loc
    
        return uniform_patch, random_patch, random_patch_label
    
    def __len__(self):
        return len(self.img_labels)
    
    def __getitem__(self, index):
        path = os.path.join(self.img_dir, self.img_labels[index][0])
        image = self.transform(read_image(path)/255).permute(1,2,0)

        uniform_patch, random_patch, random_patch_label = self.get_patch_from_grid(image, self.patch_dim, self.gap)
        return torch.Tensor(uniform_patch).permute(2,0,1), torch.Tensor(random_patch).permute(2,0,1), random_patch_label

In [None]:
traindir = '../input/isic-2018/ISIC2018_Task1-2_Training_Input/ISIC2018_Task1-2_Training_Input'
train_labels_dir = '../input/isic-2018/train_labels.csv'
validdir = '../input/isic-2018/ISIC2018_Task1-2_Validation_Input/ISIC2018_Task1-2_Validation_Input'
valid_labels_dir = '../input/isic-2018/valid_labels.csv'

In [None]:
data = ISICDataset(traindir, train_labels_dir, patch_dim=224, gap=10)
train_dl = DataLoader(data, batch_size=32, shuffle=True, pin_memory = True)

valid_data = ISICDataset(validdir, valid_labels_dir, patch_dim=224, gap=10)
valid_dl = DataLoader(valid_data, batch_size = 10)

In [None]:
from torchvision.models import resnet34
from torch import nn

class Resnet34Network(nn.Module):
    def __init__(self):
        super(Resnet34Network, self).__init__()
        res = resnet34()
        res.fc = nn.Linear(in_features = 512, out_features = 4096, bias = True)
        self.cnn = res
        
        self.fc = nn.Sequential(nn.Linear(2 * 4096, 4096),
                                nn.ReLU(inplace=True), nn.Linear(4096, 4096),
                                nn.ReLU(inplace=True), nn.Linear(4096, 8))
    
    def forward(self, input1, input2):
        output1 = self.cnn(input1)
        output2 = self.cnn(input2)
        output = torch.cat((output1, output2), 1)
        output = self.fc(output)
        return output

model = Resnet34Network()

In [None]:
import wandb
from kaggle_secrets import UserSecretsClient
my_key = UserSecretsClient().get_secret("wandb-key")
wandb.login(key = my_key)
wandb.init(project='ISIC2018', entity='tro2vs')
config = wandb.config
config.learning_rate = 1e-5
wandb.watch(model)

In [None]:
from sklearn.metrics import f1_score
import time

def norm_pred(pred):
    return pred.argmax(dim = 1)

def acc(labels, pred):
    return (pred.argmax(dim = 1) == labels.cuda()).sum()/len(labels)

def train(model, train_dl, valid_dl, loss_fn, optimizer, epochs=1):
    start = time.time()
    model.cuda()
    best_acc = 0
    
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  
                dataloader = train_dl
            else:
                model.train(False)
                dataloader = valid_dl

            running_loss = 0.0
            running_acc = 0.0

            step = 0
            
            batches = 0
            
            f1 = [0 for i in range(8)]
            f1_mic = 0
            f1_mac = 0
            f1_weighted = 0
            
            for x1, x2, y in dataloader:
                x1 = x1.cuda()
                x2 = x2.cuda()
                y = y.cuda()
                step += 1

                if phase == 'train':
                    optimizer.zero_grad()
                    outputs = model(x1,x2)
                    loss = loss_fn(outputs, y)

                    loss.backward()
                    optimizer.step()

                else:
                    with torch.no_grad():
                        outputs = model(x1,x2)
                        loss = loss_fn(outputs, y)
                
                
                f1_0 = f1_score(y.cpu().detach().numpy(),
                                            norm_pred(outputs).cpu().detach().numpy(),
                                  average = None, labels=[0,1,2,3,4,5,6,7])
                
                f1 = [f1_0[i]+f1[i] for i in range(8)]
                
                f1_mic += f1_score(y.cpu().detach().numpy(),
                                            norm_pred(outputs).cpu().detach().numpy(),
                                  average = 'micro', labels=[0,1,2,3,4,5,6,7])
                f1_mac += f1_score(y.cpu().detach().numpy(),
                                            norm_pred(outputs).cpu().detach().numpy(),
                                  average = 'macro', labels=[0,1,2,3,4,5,6,7])
                f1_weighted += f1_score(y.cpu().detach().numpy(),
                                            norm_pred(outputs).cpu().detach().numpy(),
                                  average = 'weighted', labels=[0,1,2,3,4,5,6,7])
                
                running_acc  += acc(y, outputs)*dataloader.batch_size
                running_loss += loss*dataloader.batch_size
                batches += dataloader.batch_size

                if step % 10 == 0:
                    print('Current step: {}  Loss: {}  Acc: {}  AllocMem (Mb): {}'.format(step, loss, acc(y, outputs), torch.cuda.memory_allocated()/1024/1024))
                
          
                
            epoch_loss = running_loss / batches
            epoch_acc = running_acc / batches
            
            print('{} Loss: {:.4f} Acc: {}'.format(phase, epoch_loss, epoch_acc))
            
            if phase == 'valid':
                wandb.log({"Loss/test": epoch_loss})
                wandb.log({"Accuracy/test": epoch_acc})
                wandb.log({"f1-score/test/micro": f1_mic/step})
                wandb.log({"f1-score/test/macro": f1_mac/step})
                wandb.log({"f1-score/test/weighted": f1_weighted/step})

                for i in range(8):
                    wandb.log({"f1-score/test/class_"+str(i): f1[i]/step})      
                
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), "best_model.pt")
                    wandb.save('./best_model.pt')
            else:
                torch.save(model.state_dict(), "full_train_model.pt")
                wandb.save('./full_train_model.pt')
                
                wandb.log({"Loss/train": epoch_loss})
                wandb.log({"Accuracy/train": epoch_acc})
                wandb.log({"f1-score/train/micro": f1_mic/step})
                wandb.log({"f1-score/train/macro": f1_mac/step})
                wandb.log({"f1-score/train/weighted": f1_weighted/step})
                
                for i in range(8):
                    wandb.log({"f1-score/train/class_"+str(i): f1[i]/step})
                


    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    #model = torch.load("best_model.pt")

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr = 1e-5)
train(model, train_dl, valid_dl, loss_fn, 
                             opt, epochs = 30)

In [None]:
wandb.finish()