In [6]:
# Basic libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import time

# Create tiles
from create_tiles import add_column_tiles

# Deep learning model
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler, RandomSampler, SequentialSampler
from warmup_scheduler import GradualWarmupScheduler
from efficientnet_pytorch import model as enet
import albumentations
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score

In [7]:
df_train = pd.read_csv("train.csv")
image_folder = "./train/train"
add_column_tiles(df_train, image_folder)

  2%|▉                                          | 7/340 [00:11<09:28,  1.71s/it]


KeyboardInterrupt: 

In [None]:
enet_type = "efficientnet-b0"
fold = 0
tile_size = 128
image_size = 128
n_tiles = 36
batch_size = 6
out_dim = 5
init_lr = 3e-4
warmup_factor = 10
n_epochs = 30

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
skf = StratifiedKFold(5, shuffle=True)
df_train["fold"] = -1
for i, (train_idx, valid_idx) in enumerate(skf.split(df_train, df_train["isup_grade"])):
    df_train.loc[valid_idx, "fold"] = i

In [None]:
class enetv2(nn.Module):
    def __init__(self, backbone, out_dim):
        super(enetv2, self).__init__()
        self.enet = enet.EfficientNet.from_pretrained(backbone)

        self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
        self.enet._fc = nn.Identity()

    def extract(self, x):
        return self.enet(x)

    def forward(self, x):
        x = self.extract(x)
        x = self.myfc(x)
        return x

In [None]:
class PANDADataset(Dataset):
    def __init__(self, df, image_size, n_tiles=n_tiles, rand=False, transform=None,):
        self.df = df.reset_index(drop=True)
        self.image_size = image_size
        self.n_tiles = n_tiles
        self.rand = rand
        self.transform = transform

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        tiles = row.tiles
        if self.rand:
            idxes = np.random.choice(list(range(self.n_tiles)), self.n_tiles, replace=False)
        else:
            idxes = list(range(self.n_tiles))
        n_row_tiles = int(np.sqrt(self.n_tiles))
        images = np.zeros((image_size * n_row_tiles, image_size * n_row_tiles, 3))
        for h in range(n_row_tiles):
            for w in range(n_row_tiles):
                i = h * n_row_tiles + w
                if len(tiles) > idxes[i]:
                    this_img = tiles[idxes[i]]['img']
                else:
                    this_img = np.ones((self.image_size, self.image_size, 3)).astype(np.uint8) * 255
                this_img = 255 - this_img
                if self.transform is not None:
                    this_img = self.transform(image=this_img)['image']
                h1 = h * image_size
                w1 = w * image_size
                images[h1:h1+image_size, w1:w1+image_size] = this_img
        if self.transform is not None:
            images = self.transform(image=images)['image']
        images = images.astype(np.float32)
        images /= 255
        images = images.transpose(2, 0, 1)
        label = np.zeros(5).astype(np.float32)
        label[:row.isup_grade] = 1.
        return torch.tensor(images), torch.tensor(label)

In [None]:
transforms_train = albumentations.Compose([albumentations.HorizontalFlip(p=0.5),
                                           albumentations.VerticalFlip(p=0.5),
                                           albumentations.Transpose()])

transforms_val = albumentations.Compose([])

In [None]:
dataset_show = PANDADataset(df_train, image_size, n_tiles, 0, transform=transforms_train)

In [None]:
from pylab import rcParams
rcParams['figure.figsize'] = 20,10
for i in range(2):
    f, axarr = plt.subplots(1,5)
    for p in range(5):
        idx = np.random.randint(0, len(dataset_show))
        img, label = dataset_show[idx]
        axarr[p].imshow(1. - img.transpose(0, 1).transpose(1,2).squeeze())
        axarr[p].set_title(str(sum(label)))

In [None]:
criterion = nn.BCEWithLogitsLoss()

In [13]:
def train_epoch(loader, optimizer):

    model.train()
    train_loss = []
    bar = tqdm(loader)
    for (data, target) in bar:
        
        data, target = data.to(device), target.to(device)
        loss_func = criterion
        optimizer.zero_grad()
        logits = model(data)
        loss = loss_func(logits, target)
        loss.backward()
        optimizer.step()
        scheduler.step()

        loss_np = loss.detach().cpu().numpy()
        train_loss.append(loss_np)
        smooth_loss = sum(train_loss[-100:]) / min(len(train_loss), 100)
        bar.set_description('loss: %.5f, smth: %.5f' % (loss_np, smooth_loss))
    return train_loss

In [14]:
def val_epoch(loader, get_output=False):

    model.eval()
    val_loss = []
    LOGITS = []
    PREDS = []
    TARGETS = []

    with torch.no_grad():
        for (data, target) in tqdm(loader):
            data, target = data.to(device), target.to(device)
            logits = model(data)

            loss = criterion(logits, target)

            pred = logits.sigmoid().sum(1).detach().round()
            LOGITS.append(logits)
            PREDS.append(pred)
            TARGETS.append(target.sum(1))

            val_loss.append(loss.detach().cpu().numpy())
        val_loss = np.mean(val_loss)

    LOGITS = torch.cat(LOGITS).cpu().numpy()
    PREDS = torch.cat(PREDS).cpu().numpy()
    TARGETS = torch.cat(TARGETS).cpu().numpy()
    acc = (PREDS == TARGETS).mean() * 100.
    auc = roc_auc_score(pd.get_dummies(TARGETS), pd.get_dummies(PREDS))
    auc_k = roc_auc_score(pd.get_dummies(df_valid[df_valid['data_provider'] == 'karolinska'].isup_grade.values), 
                          pd.get_dummies(PREDS[df_valid['data_provider'] == 'karolinska']))
    auc_r = roc_auc_score(pd.get_dummies(df_valid[df_valid['data_provider'] == 'radboud'].isup_grade.values), 
                          pd.get_dummies(PREDS[df_valid['data_provider'] == 'radboud']))
    
    print('auc', auc, 'auc_k', auc_k, 'auc_r', auc_r)

    if get_output:
        return LOGITS
    else:
        return val_loss, acc

In [15]:
train_idx = np.where((df_train['fold'] != fold))[0]
valid_idx = np.where((df_train['fold'] == fold))[0]

df_this  = df_train.loc[train_idx]
df_valid = df_train.loc[valid_idx]

dataset_train = PANDADataset(df_this , image_size, n_tiles, transform=transforms_train)
dataset_valid = PANDADataset(df_valid, image_size, n_tiles, transform=transforms_val)

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, sampler=RandomSampler(dataset_train))
valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, sampler=SequentialSampler(dataset_valid))

model = enetv2(enet_type, out_dim=out_dim)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=init_lr/warmup_factor)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs-warmup_epo)
scheduler = GradualWarmupScheduler(optimizer, multiplier=warmup_factor, after_scheduler=scheduler_cosine)

print(len(dataset_train), len(dataset_valid))


Loaded pretrained weights for efficientnet-b0
272 68


In [None]:
for epoch in range(1, n_epochs+1):
    print(time.ctime(), 'Epoch:', epoch)
    scheduler.step(epoch-1)

    train_loss = train_epoch(train_loader, optimizer)
    val_loss, acc = val_epoch(valid_loader)

    content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, train loss: {np.mean(train_loss):.5f}, val loss: {np.mean(val_loss):.5f}, acc: {(acc):.5f}'
    print(content)



Fri Mar 18 17:05:42 2022 Epoch: 1


loss: 0.61657, smth: 0.67015: 100%|█████████████| 46/46 [09:22<00:00, 12.24s/it]
100%|███████████████████████████████████████████| 12/12 [00:37<00:00,  3.16s/it]


auc 0.5 auc_k 0.5 auc_r 0.5
Fri Mar 18 17:15:43 2022 Epoch 1, lr: 0.0000300, train loss: 0.67015, val loss: 0.65784, acc: 10.29412
Fri Mar 18 17:15:43 2022 Epoch: 2


loss: 0.55857, smth: 0.60951:  20%|██▋           | 9/46 [01:50<07:31, 12.19s/it]