In [None]:
# Constants
IMAGE_SIZE = (512, 512)
WINDOW = {'level' : 45, 'width' : 90}
DATASET_DIR = "../input/stroke-dataset/"
INPUT_DIR = "../input/stroke-dataset/DICOM/"
MASK_DIR = "../input/stroke-dataset/MASK/"
TRAIN_DIR = "../input/stroke-dataset/train_oversampled4x_v2.txt"
VALID_DIR = "../input/stroke-dataset/valid_v2.txt"
BATCH_SIZE = 8
EPOCH_COUNT = 100
PROJECT_NAME = "Stroke Detection"
LEARNING_RATE = 2.36e-3
LEARNING_RATE_DECAY = 0.995
TRAIN_BATCH_COUNT = 25
VALID_BATCH_COUNT = 10
VALID_IMAGE_COUNT = 50
MEAN = 0.449
STD = 0.226
SAVE_PER_BATCH = 2
DRIVE_ID = ""

VALID_IMAGE_NAMES =  ["15525_Iskemi", "11858_Iskemi", "11689_Iskemi", "15726_Iskemi", "16511_Iskemi", "11314_Iskemi", "11655_Iskemi", "10937_Iskemi", "15677_Iskemi", "14055_Iskemi", "11159_Iskemi", "15125_Iskemi", "13343_Iskemi", "15029_Iskemi", "12878_Iskemi", "10907_Iskemi", "15853_Iskemi", "12677_Iskemi", "14575_Iskemi", "14792_Iskemi", "14297_Iskemi", "10442_Kanama", "10050_Kanama", "17011_Kanama", "10975_Kanama", "10876_Kanama", "15455_Kanama", "13385_Kanama", "14903_Kanama", "13411_Kanama", "10395_Kanama", "10549_Kanama", "16940_Kanama", "11728_Kanama", "15898_Kanama", "12312_Kanama", "15597_Kanama", "15338_Kanama", "13543_Kanama", "11049_Kanama", "15559_Kanama", "15788_Kanama", "14747_Kanama", "14623_Kanama", "10046_Kanama", "16658_InmeYok", "14596_InmeYok", "13093_InmeYok", "12410_InmeYok", "11388_InmeYok"]

import torch

if torch.cuda.is_available():  
    DEVICE = 'cuda:0'
else:  
    DEVICE = 'cpu'

In [None]:
import numpy as np
import pydicom
from pydicom.pixel_data_handlers.util import apply_modality_lut
import gdcm
import cv2

def resize_image(image, new_size):
    return cv2.resize(image, new_size)

def read_dicom(file_loc):
    dicom = pydicom.dcmread(file_loc)
    dicom.BitsStored = 16
    return apply_modality_lut(dicom.pixel_array, dicom)

def window_dicom(image, window):
    return (np.clip((image - (window['level'] - (window['width'] / 2))) / window['width'], 0., 1.0) * 255).astype('uint8')

def read_image(file_loc):
    return cv2.imread(file_loc, cv2.IMREAD_GRAYSCALE)

import matplotlib.pyplot as plt

plt.style.use('ggplot')

def show_image(image, color, show_now = True):
    if(show_now):
        plt.figure(figsize = (10, 10))
        
    if(color == 'gray'):
        plt.imshow(image, cmap = plt.cm.gray, vmin = 0, vmax = 255)
    elif(color == 'rgb'):
        plt.imshow(image, vmin = 0, vmax = 255)
        
    plt.xticks(())
    plt.yticks(())
    
    if(show_now):     
        plt.show()
        plt.close()

def show_all(images, titles):
    col_cnt = len(images)
    row_cnt = len(images[0])
    
    plt.figure(figsize = (col_cnt * 10, row_cnt * 10))
    
    for row in range(row_cnt):
        for col in range(col_cnt):
            if(len(images[col][row].shape) == 2):
                plt.subplot(row_cnt, col_cnt, row * col_cnt + (col + 1))
                show_image(images[col][row], 'gray', False)
                plt.title(titles[col] + " " + str(row + 1))
            elif(len(images[col][row].shape) == 3 and images[col][row].shape[-1] == 3):
                plt.subplot(row_cnt, col_cnt, row * col_cnt + (col + 1))
                show_image(images[col][row], 'rgb', False)
                plt.title(titles[col] + " " + str(row + 1))
            
    plt.show()
    plt.close()
    
def fix_image_colors(image):
    return image * 127.5

def get_difference(image, mask):
    diffdict = {'diff' : np.zeros(image.shape + (3,))}
    
    diffdict['diff'][(image != mask) & (image != 0)] = [255, 0, 0]
    diffdict['diff'][(image == mask) & (image != 0)] = [0, 255, 0]
    diffdict['diff'][(image != mask) & (image == 0)] = [0, 0, 255]
    
    diffdict['red'] = int(((image != mask) & (image != 0)).sum())
    diffdict['green'] = int(((image == mask) & (image != 0)).sum())
    diffdict['blue'] = int(((image != mask) & (image == 0)).sum())
    
    return diffdict

from segmentation_models_pytorch.encoders import preprocess_input

def preprocess(image):
    return preprocess_input(image, mean = MEAN, std = STD, input_range = [0, 1])

In [None]:
paths = {'train' : {}, 'valid' : {}, 'valid_images' : {}}

train_names_file = open(TRAIN_DIR, "r")
train_names = train_names_file.read().split('\n')

paths['train']['in'] = [INPUT_DIR + name + ".dcm" for name in train_names]
paths['train']['mask'] = [MASK_DIR + name + ".png" for name in train_names]

valid_names_file = open(VALID_DIR, "r")
valid_names = valid_names_file.read().split('\n')

paths['valid']['in'] = [INPUT_DIR + name + ".dcm" for name in valid_names]
paths['valid']['mask'] = [MASK_DIR + name + ".png" for name in valid_names]

paths['valid_images']['in'] = [INPUT_DIR + name + ".dcm" for name in VALID_IMAGE_NAMES]
paths['valid_images']['mask'] = [MASK_DIR + name + ".png" for name in VALID_IMAGE_NAMES]

In [None]:
import albumentations as A

resize = A.Compose([
    A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1], p = 1),
])

transform = A.Compose([
    A.HorizontalFlip(p = 0.5),
#     A.CenterCrop(height = IMAGE_SIZE[0] // 2 + IMAGE_SIZE[0] // 3, width = IMAGE_SIZE[1] // 2 + IMAGE_SIZE[1] // 3, p = 0.2),
    A.Resize(IMAGE_SIZE[0], IMAGE_SIZE[1], p = 1),
#     A.Downscale(scale_min = 0.25, scale_max = 0.5, p = 0.1),
    A.RandomSizedCrop((IMAGE_SIZE[0] // 2 + IMAGE_SIZE[0] // 4, IMAGE_SIZE[0]), height = IMAGE_SIZE[0], width = IMAGE_SIZE[1], p = 0.8),
#     A.ElasticTransform(p = 0.8),
#     A.GridDistortion(p = 0.9),
    A.GaussNoise(p = 0.2),
    A.Affine(shear = (-45, 45), p = 0.2),
])

from torch.utils.data import Dataset

class Stroke_DataSet(Dataset):
    def __init__(self, input_paths, mask_paths, window, transform = None):
        self.input_paths = input_paths
        self.mask_paths = mask_paths
        self.window = window
        self.transform = transform
    
    def __len__(self):
        return len(self.input_paths)
    
    def __getitem__(self, i):
        dicom = read_dicom(self.input_paths[i])
        dicom = window_dicom(dicom, self.window)
        
        mask = read_image(self.mask_paths[i])
        
        if(self.transform != None):
            transformed = self.transform(image = dicom, mask = mask)
            dicom = transformed['image']
            mask = transformed['mask']
            
        dicom = dicom.astype('float32')
        mask = mask.astype('int64')
        
        return (dicom, mask)
    
datasets = {}

datasets['train'] = Stroke_DataSet(paths['train']['in'], paths['train']['mask'], WINDOW, transform)
datasets['valid'] = Stroke_DataSet(paths['valid']['in'], paths['valid']['mask'], WINDOW, resize)
datasets['valid_images'] = Stroke_DataSet(paths['valid_images']['in'], paths['valid_images']['mask'], WINDOW, resize)

from torch.utils.data import DataLoader

loaders = {}

loaders['train'] = DataLoader(datasets['train'], batch_size = BATCH_SIZE, num_workers = 4, pin_memory = True)
loaders['valid'] = DataLoader(datasets['valid'], batch_size = BATCH_SIZE, num_workers = 4, pin_memory = True)

In [None]:
import torch.nn as nn
import segmentation_models_pytorch as smp
from google_drive_downloader import GoogleDriveDownloader as gdd

torch.cuda.empty_cache()

is_saved = False
drive_id = DRIVE_ID

if(is_saved):
    gdd.download_file_from_google_drive(drive_id, "./checkpoint.pth")
    checkpoint = torch.load("./checkpoint.pth", map_location = DEVICE)
    
    model = checkpoint['model']
    
    # Combined olan icin:
    DiceLoss = checkpoint['criterions'][0]
    FocalLoss = checkpoint['criterions'][1]
    NormFocalLoss = checkpoint['criterions'][2]
    
    optimizer = checkpoint['optimizer']
    lr_scheduler = checkpoint['lr_scheduler']
    
    start_epoch = checkpoint['epoch'] + 1
    ################################
    
#     criterion = checkpoint['criterion']
#     criterion = smp.losses.FocalLoss(mode = 'multiclass')
#     optimizer = checkpoint['optimizer']
#     lr_scheduler = checkpoint['lr_scheduler']
#     lr_scheduler = checkpoint['lr_sched']
#     start_epoch = checkpoint['epoch'] + 1

else:
#     model = smp.Unet(
#         encoder_name = "timm-efficientnet-b0",
#         encoder_weights = "imagenet",
#         in_channels = 1,
#         classes = 3,
#         activation = None,
# #         decoder_attention_type = "scse",
#     )

#     model = smp.DeepLabV3Plus(
#         encoder_name = "resnet34",
#         encoder_weights = "imagenet",
#         in_channels = 1,
#         classes = 3,
#         activation = None,
#     )

    model = smp.DeepLabV3Plus(
        encoder_name = "efficientnet-b0",
        encoder_weights = "imagenet",
        in_channels = 1,
        classes = 3,
        activation = None,
    )

#     model = smp.UnetPlusPlus(
#         encoder_name = "resnet34",
#         encoder_weights = "imagenet",
#         in_channels = 1,
#         classes = 3,
#         activation = None,
#         decoder_attention_type = "scse",
#     )

    model.to(DEVICE)
    
#     criterion = smp.losses.FocalLoss(mode = 'multiclass', normalized = True)
#     criterion = smp.losses.FocalLoss(mode = 'multiclass')
#     criterion = smp.losses.DiceLoss(mode = 'multiclass')
    DiceLoss = smp.losses.DiceLoss(mode = 'multiclass').to(DEVICE)
    FocalLoss = smp.losses.FocalLoss(mode = 'multiclass').to(DEVICE)
    NormFocalLoss = smp.losses.FocalLoss(mode = 'multiclass', normalized = True).to(DEVICE)
#     JaccardLoss = smp.losses.JaccardLoss(mode = 'multiclass')
#     optimizer = torch.optim.AdamW(model.parameters(), lr = LEARNING_RATE) # Maybe add weight decay
    optimizer = torch.optim.AdamW(model.parameters(), lr = LEARNING_RATE) # Maybe add weight decay
#     lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer = optimizer, gamma = LEARNING_RATE_DECAY)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer = optimizer)
#     lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer = optimizer, T_max = 3)

    criterions = {}
    
#     criterions[0] = smp.losses.FocalLoss(mode = 'multiclass')
#     criterions[1] = smp.losses.DiceLoss(mode = 'multiclass')
    
    start_epoch = 1

print(sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
def calculateIoU(prediction, mask):
    prediction = torch.argmax(prediction, dim = 1).cpu().detach().numpy()
    mask = mask.cpu().detach().numpy()
    mask1 = np.where(mask == 1, 1, 0).astype(np.uint8)
    mask2 = np.where(mask == 2, 1, 0).astype(np.uint8)

    kernel = np.ones((3,3))

    erosion1 = cv2.erode(mask1, kernel, iterations = 1) 
    dilation1 = cv2.dilate(mask1, kernel, iterations = 1)

    erosion2 = cv2.erode(mask2, kernel, iterations = 1) 
    dilation2= cv2.dilate(mask2, kernel, iterations = 1)
    
    erodedMask = np.zeros(mask.shape, dtype = np.uint8)
    erodedMask[erosion1 == 1] = 1
    erodedMask[erosion2 == 1] = 2
    
    dilatedMask = np.zeros(mask.shape, dtype = np.uint8)
    dilatedMask[dilation1 == 1] = 1
    dilatedMask[dilation2 == 1] = 2    
    
    intersection = np.where(np.logical_and(dilatedMask == prediction, dilatedMask != 0), 1, 0)        
    intersectionCount = np.count_nonzero(intersection)

    union = np.where(np.logical_or(erodedMask != 0, prediction != 0), 1, 0)
    unionCount = np.count_nonzero(union)
    
    if(unionCount == 0): score = 1
    else: score = intersectionCount / unionCount
    
    return score

In [None]:
print(model)

In [None]:
import torchvision
import wandb

def calc_loss(outputs, masks):
    return DiceLoss(outputs, masks) + (FocalLoss(outputs, masks) * 7 + NormFocalLoss(outputs, masks) * 1e5 * 7) / 2.0

def TrainEpoch(model, loader, run, epoch):
    print("Training:")
    
    model.train()

    train_loss_sum = 0.0
    dice_loss_sum = 0.0
    loss_count = 0
    cur_loss = 0.0
    cur_dice_loss = 0.0

    for i, data in enumerate(loader, 1):
        inputs, masks = data

        inputs = preprocess(inputs).to(DEVICE).unsqueeze(1)
        masks = masks.to(DEVICE)

        optimizer.zero_grad(set_to_none = True)

        outputs = model(inputs)
        
        loss = calc_loss(outputs, masks)
        dice_loss = DiceLoss(outputs, masks)

        loss.backward()
        optimizer.step()

        cur_loss += loss.item()
        train_loss_sum += loss.item()
        loss_count += 1
        
        cur_dice_loss += dice_loss.item()
        dice_loss_sum += dice_loss.item()

        if(i % TRAIN_BATCH_COUNT == 0):
            average_loss = cur_loss / TRAIN_BATCH_COUNT
            average_dice_loss = cur_dice_loss / TRAIN_BATCH_COUNT
            print("    Epoch: %d, Batch Count: [%d, %d], Average Loss: %.16f, Average Dice Loss %.16f" % (epoch, i, len(loader), average_loss, average_dice_loss))
            run.log({"Training Loss" : average_loss})
            run.log({"Training Dice Loss" : average_dice_loss})
            cur_loss = 0
            cur_dice_loss = 0

    run.log({"Training Loss Average" : train_loss_sum / loss_count})
    run.log({"Training Dice Loss Average" : dice_loss_sum / loss_count})
        
def ValidEpoch(model, loader, dataset, run, epoch):
    print("Validation:")
            
    model.eval()

    with torch.no_grad():
        valid_loss_sum = 0.0
        dice_loss_sum = 0.0

        loss_count = 0
        cur_loss = 0.0
        cur_dice_loss = 0.0

        for i, data in enumerate(loader, 1):
            inputs, masks = data

            inputs = preprocess(inputs).to(DEVICE).unsqueeze(1)
            masks = masks.to(DEVICE)

            outputs = model(inputs)
        
            loss = calc_loss(outputs, masks)
            dice_loss = DiceLoss(outputs, masks)

            cur_loss += loss.item()
            valid_loss_sum += loss.item()
            
            cur_dice_loss += dice_loss.item()
            dice_loss_sum += dice_loss.item()
            
            loss_count += 1

            if(i % VALID_BATCH_COUNT == 0):
                average_loss = cur_loss / VALID_BATCH_COUNT
                average_dice_loss = cur_dice_loss / VALID_BATCH_COUNT
                print("    Epoch: %d, Batch Count: [%d, %d], Average Loss: %.16f, Average Dice Loss %.16f" % (epoch, i, len(loader), average_loss, average_dice_loss))
                run.log({"Validation Loss" : average_loss})
                run.log({"Validation Dice Loss" : average_dice_loss})
                cur_loss = 0
                cur_dice_loss = 0

        run.log({"Validation Loss Average" : valid_loss_sum / loss_count})
        run.log({"Validation Dice Loss Average" : dice_loss_sum / loss_count})

        table = wandb.Table(columns = ["id", "image", "mask", "prediction", "difference", "correct (green)", "couldn\'t predict (blue)", "predicted wrongly (red)"])

        for i in range(0, VALID_IMAGE_COUNT):
            image, mask = dataset[i]

            tensr = torch.tensor(preprocess(image)).to(DEVICE).unsqueeze(0).unsqueeze(0)

            output = model(tensr)

            output = torch.argmax(output, dim = 1)

            prediction = output.squeeze(0).cpu().detach().numpy()

            image_fixed = image.astype('int64')
            mask_fixed = fix_image_colors(mask).astype('int64')
            prediction_fixed = fix_image_colors(prediction).astype('int64')
            difference = get_difference(prediction, mask)

            diff = difference['diff']
            green = difference['green']
            blue = difference['blue']
            red = difference['red']

            table.add_data(i + 1, wandb.Image(image_fixed), wandb.Image(mask_fixed), wandb.Image(prediction_fixed), wandb.Image(diff), green, blue, red)
            
        IoUSum = 0.0
            
        for image, mask in datasets['valid']:
            tensr = torch.tensor(preprocess(image)).to(DEVICE).unsqueeze(0).unsqueeze(0)
            
            output = model(tensr)
            
            IoUSum += calculateIoU(output, mask)
            
        run.log({"IoU Score" : IoUSum / len(datasets['valid'])})

        run.log({("Validation Images %d" % epoch) : table})
        
        return valid_loss_sum / loss_count

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
api_key = user_secrets.get_secret("wandb_api_key")

In [None]:
!wandb login $api_key

In [None]:
torch.backends.cudnn.benchmark = True
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.emit_nvtx(False)
torch.autograd.profiler.profile(False)

run = wandb.init(project = PROJECT_NAME)

for epoch in range(start_epoch, start_epoch + EPOCH_COUNT):
    for phase in ['train', 'valid']:
        if(phase == 'train'):
            TrainEpoch(model, loaders['train'], run, epoch)
        else:
            loss = ValidEpoch(model, loaders['valid'], datasets['valid_images'], run, epoch)
            
    lr_scheduler.step(loss)
    
    if(epoch % SAVE_PER_BATCH == 0):
        checkpoint = { 
            'epoch': epoch,
            'model': model,
            'criterions' : [DiceLoss, FocalLoss, NormFocalLoss],
            'optimizer': optimizer,
            'lr_scheduler': lr_scheduler
        }
        
        torch.save(checkpoint, ("./model_checkpoint" + str(epoch) + ".pth"))
        wandb.save("./model_checkpoint" + str(epoch) + ".pth")

run.finish()

In [None]:
model.eval()

DiceLoss = smp.losses.DiceLoss(mode = 'multiclass')

with torch.no_grad():
    dice_loss_sum = 0.0

    loss_count = 0
    cur_dice_loss = 0.0

    for i, data in enumerate(datasets['valid'], 1):
        inputs, masks = data

        inputs = torch.tensor(preprocess(inputs)).to(DEVICE).unsqueeze(0).unsqueeze(0)
        masks = torch.tensor(masks).to(DEVICE).unsqueeze(0)

        outputs = model(inputs)

        dice_loss = DiceLoss(outputs, masks)

        cur_dice_loss += dice_loss.item()
        dice_loss_sum += dice_loss.item()

        loss_count += 1
        
    print(dice_loss_sum / loss_count)

In [None]:
images = [[], [], [], []]

model.eval()

for i in range(0, VALID_IMAGE_COUNT):
    image, mask = datasets['valid_images'][i]

    tensr = torch.tensor(preprocess(image)).to(DEVICE).unsqueeze(0).unsqueeze(0)
    
    output = model(tensr)

    output = torch.argmax(output, dim = 1)

    prediction = output.squeeze(0).cpu().detach().numpy()

    image_fixed = image.astype('int64')
    mask_fixed = fix_image_colors(mask).astype('int64')
    prediction_fixed = fix_image_colors(prediction).astype('int64')
    difference = get_difference(prediction, mask)
    
    images[0].append(image_fixed)
    images[1].append(mask_fixed)
    images[2].append(prediction_fixed)
    images[3].append(difference['diff'])
    
show_all(images, ['Dicom', 'Mask', 'Prediction', 'Difference'])