In [1]:
import torch
import torch.nn as nn
import torch.utils.data as tdata
import torchvision.transforms as transforms
import torchvision.models as models
import PIL.Image as Image
from pathlib import Path

In [2]:
import pandas as pd
import numpy as np
import random
import os
import pickle
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import cohen_kappa_score

In [3]:
class TileDataset(tdata.Dataset):
    
    def __init__(self, img_path, dataframe, transform=None, normalize_stats=None):
        
        self.img_path = Path(img_path)
        self.df = df_train
        self.img_list = self.df['image_id'].values
        self.transform = transform
        if normalize_stats is not None:
            self.normalize_stats = {}
            for k, v in normalize_stats.items():
                self.normalize_stats[k] = transforms.Normalize(v[0], v[1])
        else:
            self.normalize_stats = None
        
    def __getitem__(self, idx):
        img_id = self.img_list[idx]
        image = Image.open(self.img_path/(img_id + '.png'))
        metadata = self.df.iloc[idx]
        
        if self.transform is not None:
            image = self.transform(image)
        
        if self.normalize_stats is not None:
            provider = metadata['data_provider']
            image = self.normalize_stats[provider](image)
            
        return {'image':image, 'provider':metadata['data_provider'], 
                'isup':metadata['isup_grade'], 'gleason':metadata['gleason_score']}
        
    def __len__(self):
        return len(self.img_list)

In [16]:
class Model(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.base = models.resnet50(pretrained=False)
        self.base.fc = nn.Linear(2048, 6)
        
    def forward(self, x):
        return self.base(x)

In [5]:
TRAIN_PATH = 'G:/Datasets/panda/train_tiles/imgs/'
CSV_PATH = 'G:/Datasets/panda/train.csv'
SEED = 34
BATCH_SIZE = 8

In [6]:
random.seed(SEED)
os.environ['PYTHONHASHSEED'] = str(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [7]:
df_train = pd.read_csv(CSV_PATH)

In [8]:
kfold = StratifiedKFold(n_splits=5, random_state=SEED, shuffle=True)
splits = kfold.split(df_train, df_train['isup_grade'])

In [9]:
transform_train = transforms.Compose([transforms.ToTensor()])
transform_test = transforms.Compose([transforms.ToTensor()])

In [10]:
train_idx, val_idx = next(splits)

In [11]:
with open('./stats.pkl', 'rb') as file:
    provider_stats = pickle.load(file)

In [12]:
trainset = TileDataset(TRAIN_PATH, df_train.iloc[train_idx], transform=transform_train, normalize_stats=provider_stats)
valset = TileDataset(TRAIN_PATH, df_train.iloc[val_idx], transform=transform_test, normalize_stats=provider_stats)
train_dl = tdata.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
val_dl = tdata.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False)

In [17]:
model = Model().to('cuda')
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [18]:
for data in train_dl:
    optimizer.zero_grad()
    imgs = data['image'].to('cuda')
    labels = data['isup'].to('cuda')
    
    preds = model(imgs)
    loss = loss_fn(preds, labels)
    loss.backward()
    optimizer.step()
    
    print(loss.item())

2.1591379642486572
6.841738700866699
7.4817705154418945
4.4923481941223145
1.5214855670928955
2.7481865882873535
2.144104480743408
2.2197721004486084
2.0950417518615723
3.2949533462524414
1.9124219417572021
1.9456874132156372
1.6738011837005615
1.5856982469558716
3.180528163909912
3.442686080932617
2.492560863494873
2.003411293029785
1.8806445598602295
2.0654239654541016
1.7461273670196533
1.7161785364151
1.7558269500732422
1.8333479166030884
1.5472936630249023
1.8800368309020996
1.6235963106155396
1.7412068843841553
2.64563250541687
1.756497859954834
1.5287655591964722
1.355462908744812
1.3664922714233398
3.0305604934692383
2.3059561252593994
1.5654423236846924
2.1744565963745117
1.649634838104248
1.7444010972976685
1.214150309562683
1.6545212268829346
2.258922576904297
1.970077633857727
2.1417622566223145
1.6767898797988892
1.7648800611495972
1.9798171520233154
2.1235544681549072
2.0423357486724854
1.757497787475586
1.8845679759979248
1.73183012008667
1.8806828260421753
1.59361350536

KeyboardInterrupt: 

In [15]:
preds

In [None]:
trainset[3]['image'].mean()

In [None]:
#dataset = TileDataset(TRAIN_PATH, df_train, transform = transforms.ToTensor())
dataset = TileDataset(TRAIN_PATH, df_train)

In [None]:
dataset[0]['image']