In [1]:
import torch
import torch.nn as nn
import torch.utils.data as tdata
import torchvision.transforms as transforms
import torchvision.models as models
import PIL.Image as Image
from pathlib import Path

In [2]:
import pandas as pd
import numpy as np
import random
import os
import pickle
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import cohen_kappa_score
import pytorch_lightning as pl
from contribs.ranger import Ranger

In [3]:
class TileDataset(tdata.Dataset):
    
    def __init__(self, img_path, dataframe, transform=None, normalize_stats=None):
        
        self.img_path = Path(img_path)
        self.df = df_train
        self.img_list = self.df['image_id'].values
        self.transform = transform
        if normalize_stats is not None:
            self.normalize_stats = {}
            for k, v in normalize_stats.items():
                self.normalize_stats[k] = transforms.Normalize(v[0], v[1])
        else:
            self.normalize_stats = None
        
    def __getitem__(self, idx):
        img_id = self.img_list[idx]
        
        tiles = self.img_path.glob('**/' + img_id + '*.png')
        metadata = self.df.iloc[idx]
        image_tiles = []
        for tile in tiles:
            image = Image.open(tile)
            
            if self.transform is not None:
                image = self.transform(image)
                
            if self.normalize_stats is not None:
                provider = metadata['data_provider']
                image = self.normalize_stats[provider](image)
                image_tiles.append(image)
                
        image_tiles = torch.stack(image_tiles, dim=0)
        
        return {'image':image_tiles, 'provider':metadata['data_provider'], 
                'isup':metadata['isup_grade'], 'gleason':metadata['gleason_score']}
        
    def __len__(self):
        return len(self.img_list)

In [13]:
class AdaptiveConcatPool2d(nn.Module):
    def __init__(self):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.max = nn.AdaptiveMaxPool2d(output_size=(1, 1))
    def forward(self, x):
        avg_x = self.avg(x)
        max_x = self.max(x)
        return torch.cat([avg_x, max_x], dim=1)

In [38]:
class Model(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.Sequential(*list(models.resnet50(pretrained=True).children())[:-2])
        self.pool = AdaptiveConcatPool2d()
        self.fc = nn.Linear(2048 * 2, 6)
        
    def forward(self, x):
        h = x.view(-1, 3, 128, 128)
        h = self.feature_extractor(h)
        bn, c, height, width = h.shape
        h = h.view(-1, 16, c, height, width).permute(0,2,1,3,4).contiguous().view(-1, c, height * 16, width)
        h = self.pool(h)
        h = h.squeeze(2).squeeze(2)
        h = self.fc(h)
        return h

In [69]:
class LightModel(pl.LightningModule):
    
    def __init__(self, train_idx, val_idx, provider_stats):
        super().__init__()
        self.train_idx = train_idx
        self.val_idx = val_idx
        self.model = Model()
        self.provider_stats = provider_stats
        
    def forward(self, batch):
        return self.model(batch['image'])
        
    def prepare_data(self):
    
        transform_train = transforms.Compose([transforms.ToTensor()])
        transform_test = transforms.Compose([transforms.ToTensor()])
        self.trainset = TileDataset(TRAIN_PATH, df_train.iloc[self.train_idx], transform=transform_train, normalize_stats=self.provider_stats)
        self.valset = TileDataset(TRAIN_PATH, df_train.iloc[self.val_idx], transform=transform_test, normalize_stats=self.provider_stats)
    
    def train_dataloader(self):
        train_dl = tdata.DataLoader(self.trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
        return train_dl
    
    def val_dataloader(self):
        val_dl = tdata.DataLoader(self.valset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
        return [val_dl]
    
    def cross_entropy_loss(self, logits, gt):
        loss_fn = nn.CrossEntropyLoss()
        return loss_fn(logits, gt)
    
    def configure_optimizers(self):
        optimizer = Ranger(self.model.parameters())
        return optimizer
    
    def training_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.cross_entropy_loss(logits, batch['isup']).unsqueeze(0)
        return {'loss': loss, 'log': {'train_loss': loss}}
        
    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        loss = self.cross_entropy_loss(logits, batch['isup']).unsqueeze(0)
        preds = logits.argmax(1)
        return {'val_loss': loss, 'preds': preds, 'gt': batch['isup']}
    
    def validation_end(self, outputs):
        avg_loss = torch.cat([out['val_loss'] for out in outputs], dim=0).mean()
        preds = torch.cat([out['preds'] for out in outputs], dim=0)
        gt = torch.cat([out['gt'] for out in outputs], dim=0)
        preds = preds.detach().cpu().numpy()
        gt = gt.detach().cpu().numpy()
        kappa = cohen_kappa_score(preds, gt, weights='quadratic')
        tensorboard_logs = {'val_loss': avg_loss, 'kappa': kappa}
        
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

In [70]:
TRAIN_PATH = 'G:/Datasets/panda/train_tiles/imgs/'
CSV_PATH = 'G:/Datasets/panda/train.csv'
SEED = 34
BATCH_SIZE = 8

In [71]:
random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [72]:
df_train = pd.read_csv(CSV_PATH)
df_train = df_train[~(df_train['image_id'].isin(['8d90013d52788c1e2f5f47ad80e65d48']))]

In [73]:
kfold = StratifiedKFold(n_splits=5, random_state=SEED, shuffle=True)
splits = kfold.split(df_train, df_train['isup_grade'])

In [74]:
train_idx, val_idx = next(splits)

In [75]:
with open('./stats.pkl', 'rb') as file:
    provider_stats = pickle.load(file)

In [76]:
model = LightModel(train_idx, val_idx, provider_stats)
trainer = pl.Trainer(gpus=[0])
trainer.fit(model)

Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…

BrokenPipeError: [Errno 32] Broken pipe