In [1]:
import pydicom
import glob
import matplotlib.pyplot as plt
import pandas as pd
import os
import numpy as np

from torch.utils.data import Dataset

def window_image(img, window_center,window_width, intercept, slope):
    img = (img*slope +intercept)
    img_min = window_center - window_width//2
    img_max = window_center + window_width//2
    img[img<img_min] = img_min
    img[img>img_max] = img_max
    return img 

def dcm2np(dcm_file):
    ds = pydicom.dcmread(dcm_file)
    image = ds.pixel_array
    window_center, window_width, intercept, slope = get_windowing(ds)
    #the below line makes all the difference, allowing to see parenchimal details
    image_windowed = window_image(image, window_center, window_width, intercept, slope)
    image_windowed = image_windowed / image_windowed.max()
    return image_windowed

def get_first_of_dicom_field_as_int(x):
    #get x[0] as in int is x is a 'pydicom.multival.MultiValue', otherwise get int(x)
    if type(x) == pydicom.multival.MultiValue:
        return int(x[0])
    else:
        return int(x)

def get_windowing(data):
    dicom_fields = [data[('0028','1050')].value, #window center
                    data[('0028','1051')].value, #window width
                    data[('0028','1052')].value, #intercept
                    data[('0028','1053')].value] #slope
    return [get_first_of_dicom_field_as_int(x) for x in dicom_fields]


def next_batch(typed_pf_loaders, batch_size=100):
    all_types = ['epidural', 'intraparenchymal', 'intraventricular',
                 'subarachnoid', 'subdural']
    batch = []
    for hemtype in typed_pf_loaders:
        pf_loader_p, pf_loader_n = typed_pf_loaders[hemtype]
        for i in range(batch_size//6):
            final_label = np.zeros(6)
            if i % 2 == 0:
                x = pf_loader_p.__getitem__(pos_neg=1)
                final_label[all_types.index(hemtype)] = 1.
                final_label[-1] = 1. #any also true
            else:
                x = pf_loader_n.__getitem__(pos_neg=0)
                
            if x.shape == (224, 224, 3):
                x = x.transpose(2,0,1)
                batch.append((x,final_label))

    batch_x = np.array([i[0] for i in batch])
    #batch_x = batch_x[:,np.newaxis,:]
    batch_y = np.array([i[1] for i in batch])

    return batch_x, batch_y


class PF_Loader(Dataset):
    def __init__(self, df):
        """Constructor for Loader"""
        self.df = df

    def __len__(self):
        return len(self.df) 

    def __getitem__(self, pos_neg=0):
        """Itemgetter for Loader"""
        data = self.df.sample(1)
        img_name = data.iloc[0].PatientID
        file_name = '../input/rsna-intracranial-hemorrhage-detection/stage_1_train_images/ID_'+img_name+'.dcm'
        np_image = dcm2np(file_name)
        from scipy import ndimage
        zoomed_out = ndimage.zoom(np_image, 0.45)
        zoomed_out = zoomed_out[:224,:224]
        cm = plt.get_cmap('gist_rainbow')
        zoomed_out = cm(zoomed_out)
        zoomed_out = zoomed_out[:, :, :3]
        return zoomed_out

In [2]:
import torch
torch.cuda.is_available()
import os
import pickle
import random
import glob
from copy import deepcopy
import numpy as np
import torch
from sklearn.model_selection import KFold
from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import psutil
import pickle
import os
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.utils import shuffle
from torchvision import datasets, models, transforms

def df2pf_loader(df, subtype='any'):
    df['Sub_type'] = df['ID'].str.split("_", n = 3, expand = True)[2]
    df['PatientID'] = df['ID'].str.split("_", n = 3, expand = True)[1]
    bleed_subtype_df = df.loc[df['Sub_type'] == subtype]

    df_subtype_pos = bleed_subtype_df.loc[bleed_subtype_df['Label'] == 1]
    df_subtype_neg = bleed_subtype_df.loc[bleed_subtype_df['Label'] == 0]

    pf_loader_pos = PF_Loader(df_subtype_pos)
    pf_loader_neg = PF_Loader(df_subtype_neg)
    return pf_loader_pos, pf_loader_neg


df = pd.read_csv('../input/rsna-intracranial-hemorrhage-detection/stage_1_train.csv')
df = df.sample(250000)

msk = np.random.rand(len(df)) < 0.8 #80% for training
train = df[msk]
val_test = df[~msk]

msk_val_test = np.random.rand(len(val_test)) < 0.5
val = val_test[msk_val_test] #10% val
test = val_test[~msk_val_test] #10% test

print('Train size:', len(train))
print('Val size:', len(val))
print('Test size:', len(test))

df = train

import torch.nn.functional as F    
def my_loss(y_pred, y_true,device='cuda'):
    weights = [[1.0, 1.0, 1.0, 1.0, 1.0, 2.0]] * y_pred.shape[0]
    weights = np.array(weights)
    weights = torch.from_numpy(weights).float()
    weights = weights.to(device)
    return F.binary_cross_entropy_with_logits(y_pred, y_true, weights)
    
#Load data
typed_pf_loaders = {}
bleeding_types =  ['epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural']
for hem_type in bleeding_types:
    train_pf_loader_pos, train_pf_loader_neg = df2pf_loader(df.sample(100000), subtype=hem_type) 
    typed_pf_loaders[hem_type] = [train_pf_loader_pos, train_pf_loader_neg]

Train size: 200067
Val size: 24907
Test size: 25026


In [3]:
#Learning and net parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 0.01
n_batches = 900
batch_size = 32
bleed_net = models.resnet18(pretrained=True)
num_ftrs = bleed_net.fc.in_features
bleed_net.fc = nn.Linear(num_ftrs, 6)
bleed_net = bleed_net.to(device)
optimizer = optim.SGD(bleed_net.parameters(), lr=lr)
loss_fn = my_loss

from torch.optim.lr_scheduler import StepLR
stepsize = 100 
lr_gamma = 0.99
scheduler = StepLR(optimizer, step_size=stepsize, gamma=lr_gamma)

#Initialize logs
train_loss_log = []
val_loss_log = []
test_loss_log = []

#TRAIN THE MODEL
for i in range(n_batches):
    bleed_net.train()
    x, y = next_batch(typed_pf_loaders,batch_size=batch_size)
    x_train_tensor = torch.from_numpy(x).float().to(device)
    y_train_tensor = torch.from_numpy(y).float().to(device)
    yhat = bleed_net(x_train_tensor)
    if i % 100 == 0:
        yhat_any_pred = yhat[:,-1].detach().cpu().numpy().flatten().tolist()
        gt_any = y_train_tensor[:,-1].cpu().numpy().flatten().tolist()
        for ypred,yreal in zip(yhat_any_pred,gt_any):
            print(ypred,yreal)
    
    loss = loss_fn(yhat, y_train_tensor)
    loss.backward()    
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    print('Loss: {} | Batch {}/{}'.format(loss.item(),i,n_batches))
    train_loss_log.append(loss.item())
    
print('Training finished.')


#TODO: apply softmax before submission!

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /tmp/.cache/torch/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 112MB/s]


-0.14790363609790802 1.0
0.2123090624809265 0.0
-0.6982162594795227 1.0
-0.49528801441192627 0.0
-0.8293451070785522 1.0
-0.8112273812294006 1.0
-0.11899876594543457 0.0
-0.279494047164917 1.0
-0.07821743190288544 0.0
-0.1453905552625656 1.0
-0.4209209084510803 1.0
-0.8332934975624084 0.0
-0.02719990909099579 1.0
-0.37929636240005493 0.0
-0.33453643321990967 1.0
-0.34086424112319946 1.0
-0.7805141806602478 0.0
0.2364392727613449 1.0
0.23321451246738434 0.0
-0.01272091269493103 1.0
-0.04558075964450836 1.0
0.0936700850725174 0.0
-0.02679884433746338 1.0
-0.6820119619369507 0.0
-0.14240768551826477 1.0
Loss: 0.7104156017303467 | Batch 0/900
Loss: 0.6574836373329163 | Batch 1/900
Loss: 0.6389825344085693 | Batch 2/900
Loss: 0.6283395290374756 | Batch 3/900
Loss: 0.5728054046630859 | Batch 4/900
Loss: 0.5797837972640991 | Batch 5/900
Loss: 0.5388105511665344 | Batch 6/900
Loss: 0.5666735172271729 | Batch 7/900
Loss: 0.5699052810668945 | Batch 8/900
Loss: 0.5297257900238037 | Batch 9/900
Lo

  xa[xa < 0] = -1


Loss: 0.45279067754745483 | Batch 44/900
Loss: 0.4899994432926178 | Batch 45/900
Loss: 0.4782494306564331 | Batch 46/900
Loss: 0.48284414410591125 | Batch 47/900
Loss: 0.479348361492157 | Batch 48/900
Loss: 0.47949305176734924 | Batch 49/900
Loss: 0.4396929442882538 | Batch 50/900
Loss: 0.44036439061164856 | Batch 51/900
Loss: 0.4833606481552124 | Batch 52/900
Loss: 0.4999653100967407 | Batch 53/900
Loss: 0.4443829357624054 | Batch 54/900
Loss: 0.451754629611969 | Batch 55/900
Loss: 0.47793814539909363 | Batch 56/900
Loss: 0.449314683675766 | Batch 57/900
Loss: 0.41079047322273254 | Batch 58/900
Loss: 0.4408605098724365 | Batch 59/900
Loss: 0.4192027747631073 | Batch 60/900
Loss: 0.3983316123485565 | Batch 61/900
Loss: 0.4531762897968292 | Batch 62/900
Loss: 0.42835235595703125 | Batch 63/900
Loss: 0.4392414391040802 | Batch 64/900
Loss: 0.4472474157810211 | Batch 65/900
Loss: 0.502885103225708 | Batch 66/900
Loss: 0.4427412152290344 | Batch 67/900
Loss: 0.4061276614665985 | Batch 68/9



Loss: 0.4003208875656128 | Batch 95/900
Loss: 0.4750533103942871 | Batch 96/900
Loss: 0.4789159893989563 | Batch 97/900
Loss: 0.41413983702659607 | Batch 98/900
Loss: 0.409820556640625 | Batch 99/900
3.0965867042541504 1.0
-1.512224793434143 0.0
1.4533872604370117 1.0
-2.0504543781280518 0.0
1.402923583984375 1.0
2.066190481185913 1.0
0.5783424973487854 0.0
1.5393232107162476 1.0
-2.161595582962036 0.0
0.9634659290313721 1.0
0.9998089671134949 1.0
-0.989628255367279 0.0
1.3111385107040405 1.0
-1.3725165128707886 0.0
1.9299262762069702 1.0
1.1097040176391602 1.0
-1.2334685325622559 0.0
0.8374959230422974 1.0
0.5794848203659058 0.0
2.3115270137786865 1.0
2.1304619312286377 1.0
-1.8145626783370972 0.0
0.8620997667312622 1.0
-1.6324210166931152 0.0
1.1791465282440186 1.0
Loss: 0.36210110783576965 | Batch 100/900
Loss: 0.41695210337638855 | Batch 101/900
Loss: 0.4100249409675598 | Batch 102/900
Loss: 0.4132751524448395 | Batch 103/900
Loss: 0.3498927652835846 | Batch 104/900
Loss: 0.4194121

Loss: 0.4042816758155823 | Batch 266/900
Loss: 0.3498111665248871 | Batch 267/900
Loss: 0.46213003993034363 | Batch 268/900
Loss: 0.3214319944381714 | Batch 269/900
Loss: 0.3640369772911072 | Batch 270/900
Loss: 0.4489280879497528 | Batch 271/900
Loss: 0.42939916253089905 | Batch 272/900
Loss: 0.41201016306877136 | Batch 273/900
Loss: 0.34075531363487244 | Batch 274/900
Loss: 0.3753644824028015 | Batch 275/900
Loss: 0.38169804215431213 | Batch 276/900
Loss: 0.46825918555259705 | Batch 277/900
Loss: 0.3953167498111725 | Batch 278/900
Loss: 0.4427550435066223 | Batch 279/900
Loss: 0.3422120213508606 | Batch 280/900
Loss: 0.3445453345775604 | Batch 281/900
Loss: 0.35650497674942017 | Batch 282/900
Loss: 0.45864370465278625 | Batch 283/900
Loss: 0.4064382016658783 | Batch 284/900
Loss: 0.3875836431980133 | Batch 285/900
Loss: 0.37403616309165955 | Batch 286/900
Loss: 0.4638702869415283 | Batch 287/900
Loss: 0.34432414174079895 | Batch 288/900
Loss: 0.37936899065971375 | Batch 289/900
Loss:

Loss: 0.4238719046115875 | Batch 437/900
Loss: 0.3680320680141449 | Batch 438/900
Loss: 0.41400885581970215 | Batch 439/900
Loss: 0.2970896065235138 | Batch 440/900
Loss: 0.42402005195617676 | Batch 441/900
Loss: 0.45844635367393494 | Batch 442/900
Loss: 0.4066469967365265 | Batch 443/900
Loss: 0.35072195529937744 | Batch 444/900
Loss: 0.3190588653087616 | Batch 445/900
Loss: 0.38244473934173584 | Batch 446/900
Loss: 0.40063726902008057 | Batch 447/900
Loss: 0.3085257411003113 | Batch 448/900
Loss: 0.44675537943840027 | Batch 449/900
Loss: 0.3060677945613861 | Batch 450/900
Loss: 0.3125413954257965 | Batch 451/900
Loss: 0.3417947590351105 | Batch 452/900
Loss: 0.322979211807251 | Batch 453/900
Loss: 0.3370652198791504 | Batch 454/900
Loss: 0.38957828283309937 | Batch 455/900
Loss: 0.4031812250614166 | Batch 456/900
Loss: 0.3107227683067322 | Batch 457/900
Loss: 0.38197943568229675 | Batch 458/900
Loss: 0.348332017660141 | Batch 459/900
Loss: 0.4607774317264557 | Batch 460/900
Loss: 0.3

Loss: 0.32523393630981445 | Batch 608/900
Loss: 0.343631386756897 | Batch 609/900
Loss: 0.33921071887016296 | Batch 610/900
Loss: 0.2803829312324524 | Batch 611/900
Loss: 0.30754604935646057 | Batch 612/900
Loss: 0.2754673361778259 | Batch 613/900
Loss: 0.36563900113105774 | Batch 614/900
Loss: 0.34525781869888306 | Batch 615/900
Loss: 0.32863420248031616 | Batch 616/900
Loss: 0.35500776767730713 | Batch 617/900
Loss: 0.29837119579315186 | Batch 618/900
Loss: 0.3040110766887665 | Batch 619/900
Loss: 0.3422364890575409 | Batch 620/900
Loss: 0.34116047620773315 | Batch 621/900
Loss: 0.3743980824947357 | Batch 622/900
Loss: 0.3085270822048187 | Batch 623/900
Loss: 0.29831400513648987 | Batch 624/900
Loss: 0.4529253840446472 | Batch 625/900
Loss: 0.3451620936393738 | Batch 626/900
Loss: 0.302609384059906 | Batch 627/900
Loss: 0.3011123538017273 | Batch 628/900
Loss: 0.2871696650981903 | Batch 629/900
Loss: 0.32763269543647766 | Batch 630/900
Loss: 0.3739931583404541 | Batch 631/900
Loss: 0

Loss: 0.3338601291179657 | Batch 792/900
Loss: 0.3196426331996918 | Batch 793/900
Loss: 0.37829792499542236 | Batch 794/900
Loss: 0.34789758920669556 | Batch 795/900
Loss: 0.29169684648513794 | Batch 796/900
Loss: 0.28794628381729126 | Batch 797/900
Loss: 0.3932211399078369 | Batch 798/900
Loss: 0.429395467042923 | Batch 799/900
2.204669237136841 1.0
-0.866763174533844 0.0
0.8708613514900208 1.0
-1.3806105852127075 0.0
2.778560161590576 1.0
1.5615696907043457 1.0
-6.175429344177246 0.0
2.6665213108062744 1.0
-0.09605918079614639 0.0
3.2289063930511475 1.0
1.9632753133773804 1.0
-2.5731024742126465 0.0
3.2511730194091797 1.0
-1.2666049003601074 0.0
2.5645945072174072 1.0
2.5352673530578613 1.0
-4.493488788604736 0.0
0.8738111853599548 1.0
-2.963772773742676 0.0
0.5661290287971497 1.0
3.076889991760254 1.0
-2.851388931274414 0.0
-1.285794973373413 1.0
3.8090896606445312 0.0
4.122494220733643 1.0
Loss: 0.3242267370223999 | Batch 800/900
Loss: 0.3479602634906769 | Batch 801/900
Loss: 0.394