# Import modules

In [110]:
import logging
import os
from os.path import splitext
from os import listdir
import sys
import scipy.io
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, ConcatDataset
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from glob import glob
from PIL import Image, ImageFilter, ImageEnhance
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from IPython.display import FileLink

# Unet Model

In [111]:
""" Parts of the U-Net model """

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
            
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
        nn.init.xavier_uniform_(self.conv1.weight)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1)
        nn.init.xavier_uniform_(self.conv2.weight)
        
        self.double_conv = nn.Sequential(
            self.conv1,
            #nn.Dropout(p=0.10),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            self.conv2,
            #nn.Dropout(p=0.10),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
    
    
""" Full assembly of the parts to form the complete network """

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

# Data augmentation (transforms)

In [112]:
from skimage.transform import rescale, rotate
from torchvision.transforms import Compose

def transforms(scale=None, angle=None, flip_prob=None):
    transform_list = []

    if scale is not None:
        transform_list.append(Scale(scale))
    if angle is not None:
        transform_list.append(Rotate(angle))
    if flip_prob is not None:
        transform_list.append(HorizontalFlip(flip_prob))

    return Compose(transform_list)


class Scale(object):

    def __init__(self, scale):
        self.scale = scale

    def __call__(self, sample):
        image, mask = sample

        img_size = image.shape[0]

        scale = np.random.uniform(low=1.0 - self.scale, high=1.0 + self.scale)
        #scale = np.random.uniform(low=0.5, high=0.8)
        #print(scale,scale)
        
        image = rescale(
            image,
            (scale, scale),
            multichannel=True,
            preserve_range=True,
            mode="constant",
            anti_aliasing=False,
        )
        mask = rescale(
            mask,
            (scale, scale),
            order=0,
            multichannel=True,
            preserve_range=True,
            mode="constant",
            anti_aliasing=False,
        )

        if scale < 1.0:
            diff = (img_size - image.shape[0]) / 2.0
            padding = ((int(np.floor(diff)), int(np.ceil(diff))),) * 2 + ((0, 0),)
            image = np.pad(image, padding, mode="constant", constant_values=0)
            mask = np.pad(mask, padding, mode="constant", constant_values=0)
        else:
            x_min = (image.shape[0] - img_size) // 2
            x_max = x_min + img_size
            image = image[x_min:x_max, x_min:x_max, ...]
            mask = mask[x_min:x_max, x_min:x_max, ...]

        return image, mask


class Rotate(object):

    def __init__(self, angle):
        self.angle = angle

    def __call__(self, sample):
        image, mask = sample

        angle = np.random.uniform(low=-self.angle, high=self.angle)
        image = rotate(image, angle, resize=False, preserve_range=True, mode="constant")
        mask = rotate(
            mask, angle, resize=False, order=0, preserve_range=True, mode="constant"
        )
        return image, mask


class HorizontalFlip(object):

    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, sample):
        image, mask = sample

        if np.random.rand() > self.flip_prob:
            return image, mask

        image = np.fliplr(image).copy()
        mask = np.fliplr(mask).copy()

        return image, mask
    

# Dataloader

In [113]:
class BasicDataset(Dataset):
    def __init__(self, imgs_dir, masks_dir, scale=1, enh_factor=None, isval=False, val_ids=[*range(1,11)], isEDES=None, mask_suffix='', transform=None):
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.scale = scale
        self.enh_factor = enh_factor
        self.mask_suffix = mask_suffix
        self.transform = transform
        assert 0 < scale <= 1, 'Scale must be between 0 and 1'
        
        val_fids = ['{0:03}'.format(ele) for ele in val_ids]
        lenval = len(val_fids[0])
        
        if isval:
            self.ids = [splitext(file)[0] for file in listdir(masks_dir)
                        if not file.startswith('.') and file[:lenval] in val_fids]
        else:
            self.ids = [splitext(file)[0] for file in listdir(masks_dir)
                        if not file.startswith('.') and file[:lenval] not in val_fids]
        if isEDES=='ED':
            ED = [x for x in self.ids if x[3:5]=='01']
            self.ids = ED
        if isEDES=='ES':
            ES = [x for x in self.ids if x[3:5]!='01']
            self.ids = ES
            

    
        logging.info(f'Creating dataset with {len(self.ids)} examples')

    def __len__(self):
        return len(self.ids)

    @classmethod
    def preprocess(cls, pil_img, pil_mask, scale, enh_factor=None):
        #resize
        #w, h = pil_img.size
        #newW, newH = int(scale * w), int(scale * h)
        newW, newH = 256, 256
        assert newW > 0 and newH > 0, 'Scale is too small'
        pil_img = pil_img.resize((newW, newH))
        pil_mask = pil_mask.resize((newW, newH))
        
        #image enhancement
        if enh_factor is not None:
            enh = np.random.uniform(low=1.0/(1.0 + enh_factor), high=1.0 + enh_factor)
            #pil_img = pil_img.filter(ImageFilter.UnsharpMask(radius=2, percent=150, threshold=3))
            pil_img = ImageEnhance.Contrast(pil_img).enhance(enh)
            pil_img = ImageEnhance.Sharpness(pil_img).enhance(enh)
            #pil_img = pil_img.filter(ImageFilter.GaussianBlur(radius=2))
        
        img_nd = np.array(pil_img)
        mask_nd = np.array(pil_mask)

        if len(img_nd.shape) == 2:
            img_nd = np.expand_dims(img_nd, axis=2)
            mask_nd = np.expand_dims(mask_nd, axis=2)

        # HWC to CHW
        img_trans = img_nd.transpose((2, 0, 1))
        mask_trans = mask_nd.transpose((2, 0, 1))
        
        if img_trans.max() > 1:
            img_trans = img_trans / 255
            
        mask_trans = 1.* (mask_trans > 0)

        return img_trans, mask_trans


    def __getitem__(self, i):
        idx = self.ids[i]
        mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
        img_file = glob(self.imgs_dir + idx + '.*')
        

        assert len(mask_file) == 1, \
            f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
        assert len(img_file) == 1, \
            f'Either no image or multiple images found for the ID {idx}: {img_file}'
        mask = Image.open(mask_file[0])
        img = Image.open(img_file[0])

        assert img.size == mask.size, \
            f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
        
        img, mask = self.preprocess(img, mask, self.scale, self.enh_factor)
        
        if self.transform is not None:          
            img, mask = self.transform((img.transpose((1, 2, 0)), mask.transpose((1, 2, 0))))
            img = img.transpose((2, 0, 1))
            mask = mask.transpose((2, 0, 1))

        return {
            'image': torch.from_numpy(img).type(torch.FloatTensor),
            'mask': torch.from_numpy(mask).type(torch.FloatTensor)
        }

# Dice loss

In [114]:
from torch.autograd import Function

class DiceCoeff(Function):
    """Dice coeff for individual examples"""

    def forward(self, input, target):
        self.save_for_backward(input, target)
        eps = 0.0001
        self.inter = torch.dot(input.view(-1), target.view(-1))
        self.union = torch.sum(input) + torch.sum(target) + eps

        t = (2 * self.inter.float() + eps) / self.union.float()
        return t

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):

        input, target = self.saved_variables
        grad_input = grad_target = None

        if self.needs_input_grad[0]:
            grad_input = grad_output * 2 * (target * self.union - self.inter) \
                         / (self.union * self.union)
        if self.needs_input_grad[1]:
            grad_target = None

        return grad_input, grad_target


def dice_coeff(input, target):
    """Dice coeff for batches"""
    if input.is_cuda:
        s = torch.FloatTensor(1).cuda().zero_()
    else:
        s = torch.FloatTensor(1).zero_()

    for i, c in enumerate(zip(input, target)):
        s = s + DiceCoeff().forward(c[0], c[1])

    return s / (i + 1)

# Evaluation

In [115]:
# Enable dropout of the model
# used in the function eval_net() and predict_img()
def enable_dropout(model):
  for m in model.modules():
    if m.__class__.__name__.startswith('Dropout'):
      m.train()

def eval_net(net, loader, device, dropout=None):
    """Evaluation without the densecrf with the dice coefficient"""
    net.eval()
    
    if dropout is not None:
        enable_dropout(net)
        
    mask_type = torch.float32 if net.n_classes == 1 else torch.long
    n_val = len(loader)  # the number of batch
    tot = 0

    #with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
    for batch in loader:
        imgs, true_masks = batch['image'], batch['mask']
        imgs = imgs.to(device=device, dtype=torch.float32)
        true_masks = true_masks.to(device=device, dtype=mask_type)

        with torch.no_grad():
            mask_pred = net(imgs)

        if net.n_classes > 1:
            tot += F.cross_entropy(mask_pred, true_masks).item()
        else:
            pred = torch.sigmoid(mask_pred)
            pred = (pred > 0.5).float()
            tot += dice_coeff(pred, true_masks).item()
            #pbar.update()

    net.train()
    return tot / n_val

# Hausdorff distance

In [116]:
import cv2
from scipy.spatial.distance import directed_hausdorff

# def Hdistance(img1, img2):  
    
#     ctrs1, hierarchy = cv2.findContours(np.uint8(img1), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
#     ctrs2, hierarchy = cv2.findContours(np.uint8(img2), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
    
#     u = None
#     v = None
    
#     if ctrs1 :
#         for i in range(len(ctrs1)) :
#             if u is None:
#                 u = np.squeeze(ctrs1[0])
#             else:
#                 u = np.concatenate([u, np.squeeze(ctrs1[0])])
#     if ctrs2 :
#         for i in range(len(ctrs2)) :
#             if v is None:
#                 v = np.squeeze(ctrs2[0])
#             else:
#                 v = np.concatenate([v, np.squeeze(ctrs2[0])])
                
#     distance = directed_hausdorff(u,v)[0]
    
#     return distance

def Hdistance(img1, img2):  
    
    u = np.transpose(np.nonzero(img1))
    v = np.transpose(np.nonzero(img2))
                
    distance = max(directed_hausdorff(u,v)[0],directed_hausdorff(v,u)[0])
    
    return distance

# Train-val split by patient number

In [117]:
import random
ls=[*range(1,101)]
random.seed(a=30)
random.shuffle(ls)
val_ids = [ls[20*i:20*(i+1)] for i in [*range(5)]]
print(val_ids[0])
print(val_ids[1])
print(val_ids[2])
print(val_ids[3])
print(val_ids[4])

[85, 39, 24, 3, 78, 43, 66, 88, 95, 71, 25, 91, 48, 72, 87, 14, 81, 58, 46, 73]
[56, 35, 23, 26, 19, 28, 30, 75, 41, 99, 62, 92, 90, 5, 37, 44, 65, 29, 47, 22]
[31, 53, 42, 34, 12, 59, 86, 94, 2, 50, 8, 96, 76, 16, 63, 6, 93, 17, 61, 57]
[20, 54, 74, 40, 64, 15, 100, 13, 55, 36, 89, 82, 98, 9, 69, 45, 52, 68, 77, 21]
[10, 97, 32, 67, 1, 60, 11, 18, 83, 49, 51, 7, 33, 27, 84, 80, 4, 79, 38, 70]


# Train net

In [118]:
def train_net(dir_img, dir_mask, dir_checkpoint,
              net, device, epochs=5, batch_size=1, lr=0.001,
              val_percent=0.1, save_cp=True,
              img_scale=1, val_set = 0):

    train = BasicDataset(dir_img, dir_mask, img_scale, isval=False, val_ids=val_ids[val_set])
    for kk in range(5):
        train_tf = BasicDataset(dir_img, dir_mask, img_scale, isval=False, val_ids=val_ids[val_set], enh_factor=1., transform=transforms(scale=0.1,angle=15, flip_prob=0.5))
        #train_tf = BasicDataset(dir_img, dir_mask, img_scale, isval=False, val_ids=val_ids[val_set])
        train = ConcatDataset([train,train_tf])
        
    val = BasicDataset(dir_img, dir_mask, img_scale, isval=True, val_ids=val_ids[val_set])
    n_train = int(len(train))
    n_val = int(len(val))
    
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)

    writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    print(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()
        
        
    # validation score
    progresses = []
    valScores = []
    val_score_max = -1
    
    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (10 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
                        writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
                    val_score = eval_net(net, val_loader, device)
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
                    progress = 1 + global_step * batch_size / n_train
                    if net.n_classes > 1:
                        print('Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                        progresses.append(progress)
                        valScores.append(val_score)
                    else:
                        print('Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)
                        progresses.append(progress)
                        valScores.append(val_score)

                    writer.add_images('images', imgs, global_step)
                    if net.n_classes == 1:
                        writer.add_images('masks/true', true_masks, global_step)
                        writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
                    
                    if val_score > val_score_max:
                        val_score_max = val_score
                        print(val_score_max)
                        try:
                            os.mkdir(dir_checkpoint)
                            logging.info('Created checkpoint directory')
                        except OSError:
                            pass
                        torch.save(net.state_dict(),
                                   dir_checkpoint + f'maxValScore.pth')
                        logging.info(f'maxValScore{val_score_max} saved !')
            
        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    print(progresses,valScores)
    df = pd.DataFrame({'progress':progresses,'val_score': valScores}) 
    df.to_csv(f'valScores_batchSize{batch_size}.csv')
    
    writer.close() 

# Localized cropping

In [119]:
from torchvision import transforms as trans
from PIL import ImageOps
import shutil

def locCrop(img, mask, alpha=1):
    
    sum0 = np.sum(mask,axis=0)
    left = next((i for i, x in enumerate(sum0) if x), None)
    right = len(sum0)-1-next((i for i, x in enumerate(np.flipud(sum0)) if x), None)

    sum1 = np.sum(mask,axis=1)
    up = next((i for i, x in enumerate(sum1) if x), None)
    down = len(sum1)-1-next((i for i, x in enumerate(np.flipud(sum1)) if x), None)

    center = [(up+down)//2,(left+right)//2]

    figsize_max = min([center[0],len(sum1)-1-center[0],center[1],len(sum0)-1-center[1]])
    figsize_min = int(2*max(abs(down-up)//2,abs(right-left)//2))
    figsize = min(figsize_min,figsize_max) + alpha * abs(figsize_max-figsize_min)
    figsize = int(figsize)

    img_cropped = img[center[0]-figsize:center[0]+figsize+1,center[1]-figsize:center[1]+figsize+1]
    mask_cropped = mask[center[0]-figsize:center[0]+figsize+1,center[1]-figsize:center[1]+figsize+1]
    
    return img_cropped, mask_cropped, center, figsize

def preCrop(img):
    
    shape = img.shape
    delta = abs(shape[0]-shape[1])//2
    
    if delta == 0:
        return img, shape
    
    if shape[0]>shape[1]:
        img_cropped = img[delta:-delta,:]   
    else:
        img_cropped = img[:,delta:-delta]    
        
    return img_cropped, shape

# Predict

In [120]:
'''predict the mask for a single image'''
def predict_img(net,
                full_img,
                full_gt,
                device,
                scale_factor=1,#this should be in consistent with the one in train.py
                out_threshold=0.5):
    net.eval()
    enable_dropout(net)
    
    img, gt = BasicDataset.preprocess(full_img, full_gt, scale_factor)
    img = torch.from_numpy(img)
    gt = torch.from_numpy(gt)
    
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)
    gt = gt.unsqueeze(0)
    gt = gt.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)
            
        dc = DiceCoeff().forward((probs > out_threshold).float(), gt)
        
        probs = probs.squeeze(0)
        tf = trans.Compose(
            [
                trans.ToPILImage(),
                trans.Resize(full_img.size[1]),
                trans.ToTensor()
            ]
        )

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold, dc

# Main code stars from here


1. Precropping, resizing, bad data removal

In [121]:
target = {1:'RV',2:'MYO',3:'LV'}
whichclass = 2
val_set = 4
epochs = 4



imgs_dir = '../input/acdc-uncropped/imgs_train/'
imgs_dir_out = 'imgs_precropped/'
if os.path.isdir(imgs_dir_out):
    shutil.rmtree(imgs_dir_out)
os.mkdir(imgs_dir_out)
imgs = [file for file in listdir(imgs_dir) if not file.startswith('.')]
for img in imgs:
    pil_img = Image.open(imgs_dir+img)
    img_cropped, shape = preCrop(np.array(pil_img))
    addname = '_{0:03}'.format(shape[0]) + '_{0:03}'.format(shape[1]) +'_.png'
    img_gray = ImageOps.grayscale(Image.fromarray(img_cropped).resize((256,256)))
    img_gray.save(imgs_dir_out+img.replace('.png', addname))
print(f'Total number of images for training is: {len(os.listdir(imgs_dir_out))}')

      

for wc in list(target.keys()):
    masks_dir = imgs_dir.replace('imgs','masks')
    masks_dir_out = target[wc] + '_precropped/'
    if os.path.isdir(masks_dir_out):
        shutil.rmtree(masks_dir_out)
    os.mkdir(masks_dir_out)
    masks = [file for file in listdir(masks_dir) if not file.startswith('.')]
    for mask in masks:
        pil_mask = Image.open(masks_dir+mask)
        masktg = (np.array(pil_mask)==wc)
        if np.sum(masktg)>0:
            mask_cropped, shape = preCrop(masktg)
            addname = '_{0:03}'.format(shape[0]) + '_{0:03}'.format(shape[1]) +'_.png'
            mask_gray = ImageOps.grayscale(Image.fromarray(mask_cropped).resize((256,256)))
            mask_gray.save(masks_dir_out+mask.replace('.png', addname))
    print(f'Number of nonempty masks for {target[wc]} is: {len(os.listdir(masks_dir_out))}')

Total number of images for training is: 1902
Number of nonempty masks for RV is: 1558
Number of nonempty masks for MYO is: 1828
Number of nonempty masks for LV is: 1808


2. First training using data without localized cropping

In [None]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
#   - For 1 class and background, use n_classes=1
#   - For 2 classes, use n_classes=1
#   - For N > 2 classes, use n_classes=N
net = UNet(n_channels=1, n_classes=1, bilinear=True)
print(f'Network:\n'
             f'\t{net.n_channels} input channels\n'
             f'\t{net.n_classes} output channels (classes)\n'
             f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

#------------------------------------------------------------------------------
#------------------------------------------------------------------------------

net.load_state_dict(torch.load('../input/uncropped-models/uncropped_'+ target[whichclass] + '.pth'))
#net.load_state_dict(torch.load('checkpoints/maxValScore.pth'))
net.to(device=device)

# dir_img = '../input/dpcprc/imgs_png_' + target[whichclass] + '/'
# dir_mask = dir_img.replace('imgs','masks')
# dir_checkpoint = 'checkpoints/'
dir_img = imgs_dir_out
dir_mask = dir_img.replace('imgs',target[whichclass])
dir_checkpoint = 'checkpoints/'

train_net(dir_img, dir_mask, dir_checkpoint,
          net=net,device=device,
          lr=0.0005,epochs=epochs,
          batch_size=8,img_scale=1, val_set=val_set)

FileLink(r'checkpoints/maxValScore.pth')
#FileLink(r'checkpoints/uncropped_' + target[whichclass] + '_fd' + str(val_set) + '.pth')
#Plot out dice coef v.s. progress
# val_scores = pd.read_csv(f'valScores_batchSize8.csv')
# fig, ax = plt.subplots(1, 1, figsize=(4,2))
# ax.plot(val_scores['progress'],val_scores['val_score'])
# ax.set_xlabel('progress')
# ax.set_ylabel('accuracy')
# plt.show()

Using device cuda


Epoch 1/4:   0%|          | 0/8916 [00:00<?, ?img/s]

Network:
	1 input channels
	1 output channels (classes)
	Bilinear upscaling
Starting training:
        Epochs:          4
        Batch size:      8
        Learning rate:   0.0005
        Training size:   8916
        Validation size: 342
        Checkpoints:     True
        Device:          cuda
        Images scaling:  1
    


Epoch 1/4:  10%|▉         | 888/8916 [00:35<04:40, 28.57img/s, loss (batch)=0.0069] 

Validation Dice Coeff: 0.8736981664385114
0.8736981664385114


Epoch 1/4:  20%|█▉        | 1776/8916 [01:14<03:59, 29.87img/s, loss (batch)=0.0105] 

Validation Dice Coeff: 0.8753210944788796
0.8753210944788796


Epoch 1/4:  30%|██▉       | 2664/8916 [01:53<03:35, 29.05img/s, loss (batch)=0.00963]

Validation Dice Coeff: 0.8764803877898625
0.8764803877898625


Epoch 1/4:  40%|███▉      | 3552/8916 [02:34<03:13, 27.67img/s, loss (batch)=0.00965]

Validation Dice Coeff: 0.8692870665164221


Epoch 1/4:  50%|████▉     | 4440/8916 [03:13<02:31, 29.54img/s, loss (batch)=0.00803]

Validation Dice Coeff: 0.8790493579137892
0.8790493579137892


Epoch 1/4:  51%|█████▏    | 4584/8916 [03:26<02:27, 29.27img/s, loss (batch)=0.00798]

3. Generate the predicted mask using model trained above and use it to do localized cropping

In [None]:
net = UNet(n_channels=1, n_classes=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.load_state_dict(torch.load('checkpoints/maxValScore.pth'))
#net.load_state_dict(torch.load('../input/uncropped-models/uncropped_MYO.pth'))
net.to(device=device)

# dir_img = '../input/dpcprc/imgs_png_MYO/'
# dir_mask = dir_img.replace('imgs','masks')

# path_imgs = dir_img
# path_masks = dir_mask

path_imgs_out='imgs_loccropped/'
path_masks_out=path_imgs_out.replace('imgs',target[whichclass])   

if os.path.isdir(path_imgs_out):
    shutil.rmtree(path_imgs_out)
if os.path.isdir(path_masks_out):
    shutil.rmtree(path_masks_out)
    
os.mkdir(path_imgs_out)
os.mkdir(path_masks_out)

imgs = os.listdir(dir_mask)
N = len(imgs)
count = 0
dcs = np.zeros([N,1])

for ii in range(N):
    if ii%(N//10)==0:
        print(f'progress: {ii}/{N}')
        print(f'{count}/{N} are predicted wrong.')
    try:
        img = Image.open(dir_img + imgs[ii])
        mask = Image.open(dir_mask + imgs[ii])
        pred, dc = predict_img(net=net, full_img=img, full_gt=mask, device=device, scale_factor=1)
        dcs[ii] = dc.item()
        img, _, _, _ = locCrop(np.array(img), pred, alpha=0)
        mask, _, center, figsize = locCrop(np.array(mask), pred, alpha=0)
        cropinfo = '_{0:03}'.format(center[0]) + '_{0:03}'.format(center[1]) + '_{0:03}_'.format(figsize) 
        oldname = imgs[ii].split('_')
        newname = oldname[0]+cropinfo+'.png'
        img_gray = ImageOps.grayscale(Image.fromarray(img).resize((256,256)))
        img_gray.save(path_imgs_out + newname)
        mask_gray = ImageOps.grayscale(Image.fromarray(mask).resize((256,256)))
        mask_gray.save(path_masks_out + newname)
    except:
        count += 1
        
files = os.listdir(path_masks_out)
print('Number of files in ' + path_masks_out + ' is ' + str(len(files)))
# fig, ax = plt.subplots(1, 1, figsize=(8,4))
# ax.hist(dcs)
# ax.set_xlabel('dice coefficient')
# ax.set_ylabel('counts')
# print(np.mean(dcs))

# files = os.listdir(path_imgs_out)
# img = Image.open(path_imgs_out+files[0])
# mask = Image.open(path_masks_out+files[0])
# fig, ax = plt.subplots(1,2,figsize=(10,8))
# ax[0].imshow(img)
# ax[1].imshow(mask)
# plt.show()
# print(files[0])

4. Second training using the dataset after localized cropping

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(n_channels=1, n_classes=1, bilinear=True)
print(f'Network:\n'
      f'\t{net.n_channels} input channels\n'
      f'\t{net.n_classes} output channels (classes)\n'
      f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

net.load_state_dict(torch.load('checkpoints/maxValScore.pth'))
#net.load_state_dict(torch.load('../input/cropped-models/cropped_MYO_fd1.pth'))
#net.load_state_dict(torch.load('checkpoints_loc/maxValScore.pth'))
net.to(device=device)
  
train_net(path_imgs_out, path_masks_out, 'checkpoints_loc/',
          net=net,device=device,
          lr=0.0005,epochs=epochs,
          batch_size=8,img_scale=1, val_set=val_set)

#FileLink(r'checkpoints_loc/cropped_' + target[whichclass] + '_fd' + str(val_set) + '.pth')
FileLink(r'checkpoints_loc/maxValScore.pth')

# Figure setup

In [None]:
plt.style.use('classic')
#mpl.rcParams['text.usetex'] = True
mpl.rcParams['font.size'] = 16
mpl.rcParams["figure.facecolor"] = 'white'
mpl.rcParams['axes.facecolor'] = 'white'
mpl.rcParams['axes.titlesize'] = 20
mpl.rcParams['axes.labelsize'] = 20
mpl.rcParams['lines.linewidth'] = 4
mpl.rcParams['lines.markersize'] = 10
mpl.rcParams['xtick.labelsize'] = 16
mpl.rcParams['ytick.labelsize'] = 16
mpl.rcParams['lines.linewidth'] = 2
mpl.rcParams["legend.loc"]='lower right'

In [None]:
# progress1 = [[1.099488609948861, 1.198977219897722, 1.298465829846583, 1.397954439795444, 1.497443049744305, 1.596931659693166, 1.696420269642027, 1.795908879590888, 1.895397489539749, 1.99488609948861, 2.0943747094374707, 2.193863319386332, 2.293351929335193, 2.392840539284054, 2.4923291492329147, 2.591817759181776, 2.691306369130637, 2.790794979079498, 2.8902835890283587, 2.98977219897722, 3.0892608089260807, 3.188749418874942, 3.2882380288238027, 3.387726638772664, 3.4872152487215247, 3.586703858670386, 3.6861924686192467, 3.785681078568108, 3.8851696885169686, 3.98465829846583, 4.084146908414691, 4.183635518363552, 4.283124128312412, 4.382612738261274, 4.482101348210135, 4.581589958158996, 4.681078568107857, 4.780567178056717, 4.880055788005579, 4.97954439795444, 5.079033007903301, 5.178521617852161, 5.278010227801023, 5.377498837749884, 5.476987447698745, 5.576476057647605, 5.675964667596467, 5.775453277545328, 5.874941887494189, 5.974430497443049],
#              [1.0994475138121547, 1.1988950276243093, 1.298342541436464, 1.3977900552486187, 1.4972375690607735, 1.5966850828729282, 1.6961325966850829, 1.7955801104972375, 1.8950276243093924, 1.994475138121547, 2.0939226519337018, 2.1933701657458564, 2.292817679558011, 2.3922651933701657, 2.4917127071823204, 2.591160220994475, 2.6906077348066297, 2.790055248618785, 2.889502762430939, 2.988950276243094, 3.088397790055249, 3.1878453038674035, 3.287292817679558, 3.386740331491713, 3.4861878453038675, 3.585635359116022, 3.685082872928177, 3.7845303867403315, 3.883977900552486, 3.983425414364641]]
# score1 = [[0.8636284932798269, 0.8776765465736389, 0.867186861378806, 0.8588063607410509, 0.8804032425491177, 0.8780016388211932, 0.8665912090515604, 0.8764826971657422, 0.8777249954184707, 0.8810944180099332, 0.8792837590587382, 0.8785797990098292, 0.8796050049820725, 0.880178516008416, 0.8804260589638535, 0.8799173612983859, 0.8800301795103112, 0.879841651235308, 0.879574387657399, 0.8794345259666443, 0.8798628790037972, 0.8794224967761916, 0.8804717745099749, 0.8793627996833957, 0.8799176933814068, 0.8795922751329384, 0.8800389085497174, 0.8797656504475341, 0.8800505424032405, 0.8803096681225057, 0.8802962461296393, 0.8800350622254975, 0.8801655684198652, 0.8802403321071547, 0.879528786454882, 0.880151783933445, 0.8792021359716143, 0.880732003523379, 0.8798004619929255, 0.8800493928850913, 0.8796448829222698, 0.8792653436563453, 0.8802097172153239, 0.8803299020747749, 0.8794541772531004, 0.8802619345334112, 0.8802654049834426, 0.8796908989244577, 0.879714268810895, 0.8802875742620352],
#           [0.8948776214680774, 0.8968009986775987, 0.8937028912787742, 0.8965728282928467, 0.8907774610722319, 0.8992152239413972, 0.9011993408203125, 0.9013569976421113, 0.9027111162530616, 0.9014488331815029, 0.9026887150520974, 0.9015849296082842, 0.9016351344737601, 0.901884138584137, 0.9022486501551689, 0.9025593564865437, 0.901854322311726, 0.9021950224612622, 0.9021940015731974, 0.9022992621076867, 0.9020021405625851, 0.9020975447715597, 0.9021984148532787, 0.901768648878057, 0.9022896759053494, 0.9021218723439156, 0.9024313206368304, 0.9019861728587049, 0.9026742843871421, 0.9022086047111674]]
# progress2 =[[1.099488609948861, 1.198977219897722, 1.298465829846583, 1.397954439795444, 1.497443049744305, 1.596931659693166, 1.696420269642027, 1.795908879590888, 1.895397489539749, 1.99488609948861, 2.0943747094374707, 2.193863319386332, 2.293351929335193, 2.392840539284054, 2.4923291492329147, 2.591817759181776, 2.691306369130637, 2.790794979079498, 2.8902835890283587, 2.98977219897722, 3.0892608089260807, 3.188749418874942, 3.2882380288238027, 3.387726638772664, 3.4872152487215247, 3.586703858670386, 3.6861924686192467, 3.785681078568108, 3.8851696885169686, 3.98465829846583, 4.084146908414691, 4.183635518363552, 4.283124128312412, 4.382612738261274, 4.482101348210135, 4.581589958158996, 4.681078568107857, 4.780567178056717, 4.880055788005579, 4.97954439795444, 5.079033007903301, 5.178521617852161, 5.278010227801023, 5.377498837749884, 5.476987447698745, 5.576476057647605, 5.675964667596467, 5.775453277545328, 5.874941887494189, 5.974430497443049],
#             [1.0994475138121547, 1.1988950276243093, 1.298342541436464, 1.3977900552486187, 1.4972375690607735, 1.5966850828729282, 1.6961325966850829, 1.7955801104972375, 1.8950276243093924, 1.994475138121547, 2.0939226519337018, 2.1933701657458564, 2.292817679558011, 2.3922651933701657, 2.4917127071823204, 2.591160220994475, 2.6906077348066297, 2.790055248618785, 2.889502762430939, 2.988950276243094, 3.088397790055249, 3.1878453038674035, 3.287292817679558, 3.386740331491713, 3.4861878453038675, 3.585635359116022, 3.685082872928177, 3.7845303867403315, 3.883977900552486, 3.983425414364641]]
# score2 = [[0.8649396592256974, 0.8728411818037227, 0.8732049161074112, 0.8804709072015724, 0.8794895130760816, 0.8808796880196552, 0.884669829388054, 0.8837418604870232, 0.8867632454755355, 0.8857988946291865, 0.8855989137474372, 0.8830434898940884, 0.890810988387283, 0.8895560223229078, 0.8910183286180302, 0.8905407895847243, 0.8900842362520646, 0.8908397166096435, 0.8910987717764718, 0.8912913993913301, 0.8910834752783483, 0.890699028968811, 0.8916924547175972, 0.891591058701885, 0.8916883079373107, 0.8913915181646541, 0.8913232647642797, 0.8911220528641526, 0.8914804823544561, 0.8913716734672079, 0.891605401525692, 0.8906243978714456, 0.8910963231203507, 0.8917624050257157, 0.8916468085074911, 0.8914127568809354, 0.8912546111612903, 0.8914024075683282, 0.8914984464645386, 0.8913899891230525, 0.8914608991875941, 0.8910256976983986, 0.8916520379027542, 0.8912246263757044, 0.8907616746668913, 0.8914485561604403, 0.8909434955947253, 0.8913534952669727, 0.8911190908782336, 0.8917179788861956],
#           [0.880016878564307, 0.8894319483574401, 0.8871558004237236, 0.8902061480156919, 0.8923427855714838, 0.892469141077488, 0.8939733429157988, 0.8955392659978664, 0.8940256643802562, 0.8936227915134836, 0.8951473451675253, 0.8998220803889823, 0.8997927003718437, 0.9004315290045231, 0.9002417503519261, 0.9001701448826079, 0.8992868839426243, 0.9001439226434585, 0.900293488451775, 0.9004353348245012, 0.8997154007566736, 0.8999213170497975, 0.9004183056506705, 0.900638240448972, 0.8999398591670584, 0.900075170587986, 0.9004230156857916, 0.9000258991058837, 0.9003503373328675, 0.9000945928248953]]

progress1 = [1.099488609948861, 1.198977219897722, 1.298465829846583, 1.397954439795444, 1.497443049744305, 1.596931659693166, 1.696420269642027, 1.795908879590888, 1.895397489539749, 1.99488609948861, 2.0943747094374707, 2.193863319386332, 2.293351929335193, 2.392840539284054, 2.4923291492329147, 2.591817759181776, 2.691306369130637, 2.790794979079498, 2.8902835890283587, 2.98977219897722, 3.0892608089260807, 3.188749418874942, 3.2882380288238027, 3.387726638772664, 3.4872152487215247, 3.586703858670386, 3.6861924686192467, 3.785681078568108, 3.8851696885169686, 3.98465829846583, 4.084146908414691, 4.183635518363552, 4.283124128312412, 4.382612738261274, 4.482101348210135, 4.581589958158996, 4.681078568107857, 4.780567178056717, 4.880055788005579, 4.97954439795444]
progress2 = [1.099488609948861, 1.198977219897722, 1.298465829846583, 1.397954439795444, 1.497443049744305, 1.596931659693166, 1.696420269642027, 1.795908879590888, 1.895397489539749, 1.99488609948861, 2.0943747094374707, 2.193863319386332, 2.293351929335193, 2.392840539284054, 2.4923291492329147, 2.591817759181776, 2.691306369130637, 2.790794979079498, 2.8902835890283587, 2.98977219897722, 3.0892608089260807, 3.188749418874942, 3.2882380288238027, 3.387726638772664, 3.4872152487215247, 3.586703858670386, 3.6861924686192467, 3.785681078568108, 3.8851696885169686, 3.98465829846583, 4.084146908414691, 4.183635518363552, 4.283124128312412, 4.382612738261274, 4.482101348210135, 4.581589958158996, 4.681078568107857, 4.780567178056717, 4.880055788005579, 4.97954439795444]

score1 = [[0.855654885574263, 0.8697443859917777, 0.8667325365300081, 0.8631193613519474, 0.8753710693242599, 0.8710850550203907, 0.8718101382255554, 0.8777509319538973, 0.8623332721846444, 0.8721519453184945, 0.8706986064813576, 0.8780898001729226, 0.8759579512537742, 0.8786154523187754, 0.8776276342722834, 0.8794273758421138, 0.8783293804343866, 0.8783224796762272, 0.8776306108552583, 0.8783892794531218, 0.8778246908771749, 0.8780236025245822, 0.8781100815656234, 0.8788006354351433, 0.8787081034816041, 0.8781707323327357, 0.8785026584352765, 0.8775367542188994, 0.8785223693263774, 0.8785665838085875, 0.8788249200704147, 0.8775661916148906, 0.8782593717380446, 0.8783793887313531, 0.8779879978724888, 0.8781258792293315, 0.8784824634084896, 0.8781085379269659, 0.8789805891562481, 0.8789009439701937],
         [0.8922322050054022, 0.8911372651445105, 0.8925718944123451, 0.8986208603737202, 0.8929864358394703, 0.8977598700117557, 0.8927685174536197, 0.8998177749045352, 0.8988627410949545, 0.9008229273430844, 0.8996364360160016, 0.9016737062880333, 0.901095007328277, 0.9013089157165365, 0.9007232214542146, 0.9016745458257959, 0.9009371108197152, 0.9015019066790317, 0.9013351011783519, 0.9006973961566357, 0.9009881831230001, 0.900884507818425, 0.9008524646150305, 0.9007304754663021, 0.9012766645309773, 0.9011016716348365, 0.9011515277497312, 0.9011177927889722, 0.9009236350972601, 0.9012676350613857, 0.9009558736009801, 0.9012234363150089, 0.9010141111434774, 0.9012481367334406, 0.9011525740014746, 0.9014115358920808, 0.9007502464537925, 0.9011424044345288, 0.9011947233626183, 0.9011359722056287],
         [0.8694353971792304, 0.8708973565827245, 0.8680963917918827, 0.8710534870624542, 0.8734917990539385, 0.8752500661041426, 0.867904870406441, 0.8679268735906353, 0.8714740056058635, 0.8729678781136222, 0.8737756532171498, 0.8726318776607513, 0.874153847279756, 0.8751833723938983, 0.8749370950719585, 0.8755343932172527, 0.8743913795637048, 0.8751924621022266, 0.8747939713623213, 0.8751513361930847, 0.8757497162922568, 0.8749043850795083, 0.8750632560771444, 0.8754050926021908, 0.8739369939202848, 0.8756282329559326, 0.8755390203517416, 0.8758951231189396, 0.8749477241350256, 0.8746388748936031, 0.8756955123465994, 0.875213851099429, 0.8756855842859849, 0.8752649076606916, 0.876217369152152, 0.8751561719438304, 0.8750806450843811, 0.8752854738546454, 0.8750817516575689, 0.8758864402770996],
         [0.8878558732214428, 0.883214682340622, 0.8774316069625673, 0.8892701708135151, 0.8916168780553908, 0.8889593552975428, 0.8936431606610616, 0.8904137753304981, 0.8912597837902251, 0.885606206598736, 0.8905723634220305, 0.8897406586578914, 0.8911880822408766, 0.8913135329882304, 0.8916207282316118, 0.8910677674270812, 0.8904387823172978, 0.8909118743169875, 0.8910052563462939, 0.8909169847056979, 0.8909315665562948, 0.8907514455772582, 0.8906930699234917, 0.8909803018683479, 0.8913972292627607, 0.8912939613773709, 0.8917379890169416, 0.8909724851449331, 0.8909255919002351, 0.8911143271696, 0.8907399205934434, 0.8907767207849593, 0.891022115945816, 0.8911514239651817, 0.8911692613647098, 0.8910624782244364, 0.8909538303102765, 0.8909697859060197, 0.8907164548124585, 0.8913733788899013]]

score2 = [[0.864508999853718, 0.8724318754916288, 0.8769387736612436, 0.8732143737831894, 0.8787931130856884, 0.8795219453013673, 0.8771879928452628, 0.882166964667184, 0.8774799838358042, 0.8808515886871182, 0.8827179828468634, 0.8800302938539155, 0.8835532908536949, 0.8822819249970573, 0.8746479336096316, 0.8848898204005494, 0.8801168896714036, 0.8863767434139641, 0.8868037462234497, 0.8822127276537369, 0.8863892980984279, 0.8801783396273243, 0.8873108686233053, 0.8876096873867269, 0.8884098955563137, 0.8888583669857103, 0.8886555749542859, 0.8888088513393791, 0.8881488211300909, 0.8887759520083057, 0.8884218700078069, 0.8887308519713732, 0.8887826325942059, 0.8884454600665034, 0.8886825770747905, 0.8890972271257517, 0.8889634244296015, 0.8890892449690371, 0.8889070566819639, 0.889069699511236],
         [0.8828102971645112, 0.8854381492797364, 0.8950134934263027, 0.8960165508249973, 0.8957556105674581, 0.8962008978458161, 0.8949247002601624, 0.8938817610131934, 0.8955622771953015, 0.9006253176547111, 0.9013269518284087, 0.9021366403457967, 0.9024915302053411, 0.9009958036402439, 0.901416985278434, 0.9027714336172064, 0.9020487146174654, 0.9015700601516886, 0.9031405829368754, 0.9016897868602833, 0.9016927823107294, 0.9028575889607693, 0.9025855888711646, 0.9026112924230859, 0.9023707889496012, 0.9027645473784589, 0.9024673231104587, 0.9022222135929351, 0.9028770264158857, 0.902535025109636, 0.9027477693050465, 0.9022137307106181, 0.9021970274600577, 0.9021573269620855, 0.9025333523750305, 0.9030954279798142, 0.9020905456644424, 0.9027980350433512, 0.9028048832365807, 0.9024732759658326],
         [0.8368520918099777, 0.8529903370401134, 0.8494134998839834, 0.8611240568368331, 0.862463134786357, 0.8619275572507278, 0.8650365160859149, 0.8648535816565804, 0.8694261973318846, 0.8581675044868303, 0.8579025618408037, 0.8681442996729976, 0.8672201633453369, 0.8716774673565574, 0.8703524729479915, 0.871675325476605, 0.8702450601950936, 0.8725874449895776, 0.8730468801830126, 0.8719258956287218, 0.8731323700884114, 0.8726673061433046, 0.8727063057215317, 0.872833279163941, 0.8727183471555295, 0.8733202193094336, 0.8723922436651976, 0.8732876324135325, 0.8717457690964574, 0.8718553807424463, 0.8721560291621996, 0.8729888058227041, 0.8734577285206836, 0.87332124295442, 0.872264766174814, 0.8726287367551223, 0.8708200985970704, 0.8728650510311127, 0.8725898032603057, 0.8718158745247385],
         [0.8721459039619991, 0.8947414855162302, 0.8938128451506296, 0.8974556454590389, 0.8902039456935156, 0.8957784871260325, 0.9012019818737393, 0.8976622536068871, 0.8986985726015908, 0.8995015436694735, 0.9051617696171715, 0.9042755436329615, 0.9045031482265109, 0.9041080077489217, 0.9045320735091255, 0.9051630638894581, 0.9052745785032, 0.9049901550724393, 0.9050465666112446, 0.905099976630438, 0.9046956045286996, 0.9054125987348103, 0.9049397848901295, 0.9051241065774646, 0.9055037512665703, 0.9052623042038509, 0.905392953327724, 0.9051953866368249, 0.9050338296663194, 0.9053828957534972, 0.9056157625856853, 0.9051467237018403, 0.9056228158019838, 0.90505831298374, 0.9050741280828204, 0.9050246704192388, 0.9055783635094052, 0.904909257377897, 0.9053372201465425, 0.9053030198528653]]

# fig = plt.subplots(1,1,figsize=(8,5))
# plt.plot(progress1+[x+progress1[-1]-1 for x in progress2], score1[val_set]+score2[val_set], label='Mloc',color='b',ls=':')
# plt.plot(progress1, score1[val_set], label='M', color='b', ls='-')
# #plt.plot([x+progress1[-1]-1 for x in progress2], score2, label='Mloc')
# plt.xlabel('progress')
# plt.ylabel('accuracy')
# plt.ylim([0.89,0.91])
# plt.legend()
# plt.grid()
# plt.tight_layout()
# plt.show()

In [None]:
progress1 = [1.099488609948861, 1.198977219897722, 1.298465829846583, 1.397954439795444, 1.497443049744305, 1.596931659693166, 1.696420269642027, 1.795908879590888, 1.895397489539749, 1.99488609948861, 2.0943747094374707, 2.193863319386332, 2.293351929335193, 2.392840539284054, 2.4923291492329147, 2.591817759181776, 2.691306369130637, 2.790794979079498, 2.8902835890283587, 2.98977219897722, 3.0892608089260807, 3.188749418874942, 3.2882380288238027, 3.387726638772664, 3.4872152487215247, 3.586703858670386, 3.6861924686192467, 3.785681078568108, 3.8851696885169686, 3.98465829846583, 4.084146908414691, 4.183635518363552, 4.283124128312412, 4.382612738261274, 4.482101348210135, 4.581589958158996, 4.681078568107857, 4.780567178056717, 4.880055788005579, 4.97954439795444, 5.079033007903301, 5.178521617852161, 5.278010227801023, 5.377498837749884, 5.476987447698745, 5.576476057647605, 5.675964667596467, 5.775453277545328, 5.874941887494189, 5.974430497443049, 6.0739191073919105, 6.173407717340772, 6.272896327289633, 6.372384937238493, 6.4718735471873545, 6.571362157136216, 6.670850767085077, 6.770339377033937, 6.8698279869827985, 6.96931659693166, 7.068805206880521, 7.168293816829381, 7.2677824267782425, 7.367271036727104, 7.466759646675965, 7.566248256624825, 7.6657368665736865, 7.765225476522548, 7.864714086471409, 7.964202696420269, 8.06369130636913, 8.163179916317992, 8.262668526266854, 8.362157136215714, 8.461645746164574, 8.561134356113435, 8.660622966062297, 8.760111576011157, 8.85960018596002, 8.95908879590888]
score1 = [1.1907330283682832e-07, 1.1907330283682832e-07, 0.02427325332612402, 0.38323991251539213, 0.28456539782334345, 0.6098700153584383, 0.7034092995585227, 0.7257583384611168, 0.735852448307738, 0.767386216290143, 0.7493928232971503, 0.7647900678673569, 0.7790533900260925, 0.7475393268526817, 0.8070465022203873, 0.6034433610585271, 0.8186729416555288, 0.8242861981294594, 0.8324665719149064, 0.8166449057812594, 0.8187563042251431, 0.8032253798173399, 0.8359598432268415, 0.849766440537511, 0.845004349338765, 0.8480028303302064, 0.8459246645168382, 0.8506210896433616, 0.849363140913905, 0.8533017501539114, 0.8498111902451029, 0.8520236282932515, 0.8516871223644334, 0.8530464926544501, 0.8501871301203358, 0.8564872133488558, 0.8535619329433052, 0.8540137775090276, 0.8525879176295533, 0.8511430365698678, 0.8527966603940847, 0.8556524077240302, 0.8511891231244925, 0.8519795670801279, 0.8522080365492373, 0.8506941247959526, 0.8544574063651416, 0.853691324895742, 0.8509884403676403, 0.8556906301148084, 0.852085213271939, 0.8521156700289979, 0.8543028089464927, 0.8550138461346529, 0.8535430455694393, 0.8499389205660138, 0.852263512660046, 0.8556517253116686, 0.8492952281115006, 0.854258120059967, 0.8547801217254327, 0.8533223782266889, 0.8513119755959024, 0.8538461145089598, 0.8497290854551354, 0.851979597490661, 0.85334484066282, 0.8542870319619471, 0.8538557230209818, 0.8543270047830076, 0.8544341435237807, 0.853520581916887, 0.8542065754228708, 0.8550800449994146, 0.8549714672322176, 0.8499473600971456, 0.853166684812429, 0.8520406849530279, 0.8509265573657289, 0.8492035440036229]
progress2 = [1.0996275605214152, 1.1992551210428306, 1.2988826815642458, 1.3985102420856612, 1.4981378026070764, 1.5977653631284916, 1.697392923649907, 1.7970204841713222, 1.8966480446927374, 1.9962756052141528, 2.0959031657355682, 2.195530726256983, 2.2951582867783986, 2.394785847299814, 2.494413407821229, 2.5940409683426444, 2.69366852886406, 2.793296089385475, 2.89292364990689, 2.9925512104283056, 3.0921787709497206, 3.191806331471136, 3.2914338919925514, 3.3910614525139664, 3.490689013035382, 3.5903165735567972, 3.689944134078212, 3.7895716945996276, 3.889199255121043, 3.988826815642458]
score2 = [0.8215928904864253, 0.8324903882279688, 0.848473956390303, 0.8462184971692611, 0.8366210898574518, 0.8500317079680306, 0.8468478273372261, 0.8472242404003533, 0.8616796780605706, 0.853792527500464, 0.8691203776670962, 0.8702286129095116, 0.8664737197817588, 0.864110208287531, 0.8690099691858097, 0.8760749843655801, 0.8779594265684789, 0.8779666496782886, 0.8774571029507384, 0.8782852559673543, 0.8775566195955082, 0.876864205817787, 0.8767415844664281, 0.8768765123523011, 0.8794354893723313, 0.8781167986441631, 0.8791781213818765, 0.8796007900821919, 0.878403867994036, 0.8789295919087469]
noaug=[1.099488609948861, 1.198977219897722, 1.298465829846583, 1.397954439795444, 1.497443049744305, 1.596931659693166, 1.696420269642027, 1.795908879590888, 1.895397489539749, 1.99488609948861, 2.0943747094374707, 2.193863319386332, 2.293351929335193, 2.392840539284054, 2.4923291492329147, 2.591817759181776, 2.691306369130637, 2.790794979079498, 2.8902835890283587, 2.98977219897722, 3.0892608089260807, 3.188749418874942, 3.2882380288238027, 3.387726638772664, 3.4872152487215247, 3.586703858670386, 3.6861924686192467, 3.785681078568108, 3.8851696885169686, 3.98465829846583, 4.084146908414691, 4.183635518363552, 4.283124128312412, 4.382612738261274, 4.482101348210135, 4.581589958158996, 4.681078568107857, 4.780567178056717, 4.880055788005579, 4.97954439795444, 5.079033007903301, 5.178521617852161, 5.278010227801023, 5.377498837749884, 5.476987447698745, 5.576476057647605, 5.675964667596467, 5.775453277545328, 5.874941887494189, 5.974430497443049, 6.0739191073919105, 6.173407717340772, 6.272896327289633, 6.372384937238493, 6.4718735471873545, 6.571362157136216, 6.670850767085077, 6.770339377033937, 6.8698279869827985, 6.96931659693166, 7.068805206880521, 7.168293816829381, 7.2677824267782425, 7.367271036727104, 7.466759646675965, 7.566248256624825, 7.6657368665736865, 7.765225476522548, 7.864714086471409, 7.964202696420269, 8.06369130636913, 8.163179916317992, 8.262668526266854, 8.362157136215714, 8.461645746164574, 8.561134356113435, 8.660622966062297, 8.760111576011157, 8.85960018596002, 8.95908879590888]
scorew =[1.1907330283682832e-07, 0.3683184080434089, 0.17003002747589227, 0.3522007988423717, 0.5373481098486452, 0.5709804168769291, 0.5893815409164039, 0.5416090324216959, 0.6560262496374092, 0.6550083148236178, 0.7859639415935594, 0.632151190115481, 0.8046263772614148, 0.7730331554704782, 0.8017043532157431, 0.772656111084685, 0.826451301574707, 0.827999899581987, 0.8280082211202505, 0.8295652793378246, 0.8336802076320259, 0.8317816756209548, 0.8370822996509318, 0.8390735898699079, 0.8325791930665776, 0.8400975149504992, 0.8344012036615488, 0.836971507996929, 0.8324971417991482, 0.8417948958825092, 0.838374991806186, 0.8401235986729058, 0.8389935396155532, 0.8420599735513026, 0.8399619411449043, 0.8395010232925415, 0.8382367625528452, 0.8408909963101757, 0.8382195903330433, 0.8400759830766794, 0.8357000265802655, 0.8398763403600576, 0.8402454329996693, 0.8387178365065127, 0.8393859547011706, 0.8403564509080381, 0.8412735182411817, 0.8387672049658639, 0.8397990440835759, 0.8364433840829499, 0.8359072500345658, 0.8365817228142096, 0.841271770243742, 0.8436129287797578, 0.8403777005721111, 0.8383661411246475, 0.839821884826738, 0.8386687064657405, 0.8369723753053315, 0.838394216128758, 0.8393560356023361, 0.8416854289113259, 0.8393667157815428, 0.8381567913658765, 0.8404526029314313, 0.8382811181399287, 0.8366442930941679, 0.8408230664778729, 0.8376644022610723, 0.8371759227343968, 0.8421115388675612, 0.8356834662203886, 0.8392643308152958, 0.8374798747957969, 0.8381880351475307, 0.8362888474853671, 0.8401335307529995, 0.8358709313431565, 0.840711519426229, 0.8426690162444601]

fig = plt.subplots(1,1,figsize=(8,5))
plt.plot(noaug, scorew, label='no data augmentation')
plt.plot(progress1, score1, label='data augmentation')
# plt.plot(progress1+[x+progress1[-1]-1 for x in progress2], score1+score2, label='Mloc')
# plt.plot(progress1, score1, label='M')
#plt.plot([x+progress1[-1]-1 for x in progress2], score2, label='Mloc')
plt.xlabel('progress')
plt.ylabel('accuracy')
plt.ylim([0.5,0.9])
plt.legend()
plt.grid()
plt.tight_layout()
plt.show()

5. Fit the mask predicted by the second model into to the original-size background and calculate averaged dice coefficient and Hausdorff distance of the validation set

In [None]:
def predict_original_size_img(path1_img, path1_mask, path2_img, path2_mask, filename1):
    
    img = Image.open(path1_img+filename1)
    mask= Image.open(path1_mask+filename1)
    pred, dc = predict_img(net=net, full_img=img, full_gt=mask, device=device, scale_factor=1)
    
    names1 = filename1.split('_')
    center, figsize = [int(names1[1]),int(names1[2])], int(names1[3])
    roi = Image.fromarray(pred).resize((2*figsize+1,2*figsize+1))
    pred1 = np.zeros((256,256))
    pred1[center[0]-figsize:center[0]+figsize+1,center[1]-figsize:center[1]+figsize+1] = np.array(roi)
    
    filename2 = [file for file in os.listdir(path2_mask) if names1[0] in file]
    names2 = filename2[0].split('_')
    shape = [int(names2[1]),int(names2[2])]
    delta = abs(shape[0]-shape[1])//2
    
    org_img = np.zeros(shape)
    org_mask = np.zeros(shape)
    org_pred = np.zeros(shape)
    
    if shape[0] < shape[1]:
        img2 = Image.open(path2_img+filename2[0]).resize((shape[0],shape[0]))
        mask2 = Image.open(path2_mask+filename2[0]).resize((shape[0],shape[0]))
        pred2 = Image.fromarray(pred1).resize((shape[0],shape[0]))
        org_img[:,delta:shape[1]-delta] = np.array(img2)
        org_mask[:,delta:shape[1]-delta] = np.array(mask2)
        org_pred[:,delta:shape[1]-delta] = np.array(pred2)
    else:
        img2 = Image.open(path2_img+filename2[0]).resize((shape[1],shape[1]))
        mask2 = Image.open(path2_mask+filename2[0]).resize((shape[1],shape[1]))
        pred2 = Image.fromarray(pred1).resize((shape[1],shape[1]))
        org_img[delta:shape[0]-delta,:] = np.array(img2)
        org_mask[delta:shape[0]-delta,:] = np.array(mask2)
        org_pred[delta:shape[0]-delta,:] = np.array(pred2)
        
    return org_img, org_mask, org_pred, dc

In [None]:
net = UNet(n_channels=1, n_classes=1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#net.load_state_dict(torch.load('../input/cropped-models/cropped_MYO_fd5.pth'))
net.load_state_dict(torch.load('checkpoints_loc/maxValScore.pth'))
net.to(device=device)

# path_uncrop='../input/dpcprc/masks_png_MYO/'
# path_imgs_out='cropped_imgs_MYO/'
# path_masks_out=path_imgs_out.replace('imgs','masks')

for isEDES in ['ED','ES']:
    
    val = BasicDataset(path_imgs_out, path_masks_out, scale=1, isval=True, val_ids=val_ids[val_set], isEDES=isEDES)
    print(len(val.ids))

    #calculate mean Hausdorff distance for ED and ES seperately
    hds = []
    dcs = []
    count = 0
    for idx in val.ids:

        org_img, org_mask, org_pred, dc = predict_original_size_img(path_imgs_out, path_imgs_out, dir_img, dir_mask, idx+'.png')
        dcs.append(dc)    

        if idx in val.ids[:1]:

    #         org_img, org_mask, org_pred, dc = predict_original_size_img(path_imgs_out, path_imgs_out, dir_img, dir_mask, idx+'.png')
    #         dcs.append(dc)
            mpl.rcParams.update(mpl.rcParamsDefault)
            fig, ax = plt.subplots(1,3,figsize=(10,6))
            ax[0].imshow(org_img)
            ax[1].imshow(org_mask>128)
            ax[2].imshow(org_pred>0.5)
            plt.show() 
            print(idx,org_pred.shape)

        hds.append(Hdistance(org_mask>128,org_pred>0.5)) 


    fig, ax = plt.subplots(1, 1, figsize=(10,5))
    ax.hist(hds)
    print(f'{count}/{len(val.ids)} are wrong')
    print(f'Average Hausdorff distance is {np.mean(hds)}')



    #calculate mean dice for ED and ES seperately
    val_loader = DataLoader(val, batch_size=1, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)  
    N = 1
    dcs = np.zeros([N,1])
    for ii in range(N):
        dcs[ii] = eval_net(net,val_loader,device) #,dropout=True)
        print(f'Dice coefficient is {dcs[ii]}')

    # fig, ax = plt.subplots(1, 1, figsize=(10,5))
    # ax.hist(dcs)
    # ax.set_xlabel('batch averaged dice coefficient')
    # ax.set_ylabel('counts')
    # print(np.mean(dcs),np.std(dcs))

#MYO
ED_dc = [0.88066437,0.89591653,0.86847068,0.90013122]
ED_hd = [3.216099005955892,2.3189204189075694,3.4599727812540793,2.4873833419471487]
ES_dc = [0.89801328,0.91469021,0.8788926,0.91105698]
ES_hd = [3.375262832389057,2.6626082124870147,4.373472434808897,2.9845150950857677]
print(np.mean(ED_dc),np.mean(ED_hd),np.mean(ES_dc),np.mean(ES_hd))

# Dropout(uncertainty quantification)

In [None]:
# net = UNet(n_channels=1, n_classes=1)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# net.load_state_dict(torch.load('../input/cropped-models/cropped_MYO_fd5.pth'))
# #net.load_state_dict(torch.load('../input/dropout015/dropout0.15_MYO_fd0_1.pth'))
# net.to(device=device)

# # path_uncrop='../input/dpcprc/masks_png_MYO/'
# # path_imgs_out='cropped_imgs_MYO/'
# # path_masks_out=path_imgs_out.replace('imgs','masks')

# #calculate mean Hausdorff distance for ED and ES seperately
# dcs = []
# for ii in range(500):
#     img = Image.open(path_imgs_out+idx+'.png')
#     mask= Image.open(path_masks_out+idx+'.png')
#     pred, dc = predict_img(net=net, full_img=img, full_gt=mask, device=device, scale_factor=1)
#     dcs.append(dc)

# Results
1. Results for cross validation

In [None]:
#LV
ED=[0.9462,0.9443,0.9451985045316371,0.9570678024622746,0.9546]
EDh=[3.1288981457606995,3.045574807838474,6.880659150035709,2.3735783089168847,2.6094133416681764]
ES=[0.8947,0.909,0.9108,0.93119632176196937,0.9327]
ESh=[4.869201996536504,3.327,4.296308674749143,2.6232,2.7097230082378547]
print(np.mean(ED))
print(np.mean(ES))
print(np.mean(EDh))
print(np.mean(ESh))
#RV
ED=[0.915,0.912,0.90086896,0.92847941,0.92611341]
EDh=[8.97292256797219,8.371719867951663,11.57335163197323,5.019068179659404,5.791896375713262]
ES=[0.8604534,0.84788184,0.84088899,0.87287045,0.88587517]
ESh=[9.385108289589429,9.550419508277407,11.744264557052109,7.752770689232926,8.261678977180875]
print(np.mean(ED))
print(np.mean(ES))
print(np.mean(EDh))
print(np.mean(ESh))
#MYO
ED=[0.87930599,0.89374,0.8768134,0.8996050850504396,0.89201334]
EDh=[4.196813975851807,3.61,4.21,2.584,2.64]
ES=[0.89878689,0.91036824,0.88742615,0.91195765,0.90956352]
ESh=[4.108809377223461,4.087738966126137,4.745608943101348,3.2631633836329916,2.878045816357292]
print(np.mean(ED))
print(np.mean(ES))
print(np.mean(EDh))
print(np.mean(ESh))

2. Results for dropout experiments

In [None]:
p=[0.05,0.1,0.15]
s_mean=[0.8731267429152291,0.869039029,0.8629484493881823]
sigma=[0.00022759099204778944,0.0006517089053875286,0.0005613714007445122]
scores_05 = [0.87331342,0.87333143,0.8732863,0.87340288,0.87288952,0.87289357,0.87296747,0.87269603,0.8732761,0.87343681,0.87319477,0.87306098,0.87283402,0.87314631,0.87325575,0.87356689,0.87279015,0.87288194,0.87303729,0.87305472,0.87318868,0.87316056,0.87302827,0.87295709,0.87286149,0.87341638,0.8731365,0.87292554,0.87290065,0.87341309,0.87310909,0.87339313,0.87374544,0.87270255,0.87345979,0.87300217,0.87329002,0.87331728,0.87267972,0.87304084,0.87331457,0.87336918,0.87322015,0.8729096,0.87340271,0.87279351,0.87285423,0.87334867,0.87275539,0.87296747,0.87342451,0.87333014,0.87284233,0.87318678,0.87293003,0.87306413,0.87267714,0.87320023,0.87314106,0.8733636,0.87309104,0.87320482,0.87311574,0.87292176,0.87314028,0.87282322,0.87325132,0.87301823,0.87303616,0.8734009,0.8728462,0.87314371,0.87300764,0.87294975,0.87327676,0.87287998,0.87328415,0.87301052,0.87295519,0.87368164,0.87356092,0.87339108,0.87326908,0.87285096,0.87321546,0.87306212,0.87308202,0.87311798,0.87325014,0.87308527,0.87338681,0.87291477,0.87313492,0.87340304,0.87327668,0.87321046,0.87320059,0.8731605,0.8728588,0.87306366]
scores_10 = [0.8697967784106732, 0.8693817990645766, 0.8710553497821093, 0.8687466697767376, 0.8684712124429643, 0.8690698056668044, 0.8692041685059667, 0.8695977672934532, 0.8677411967329681, 0.8691929432749749, 0.8675195140950381, 0.8692341388389468, 0.8698640009760856, 0.869625474140048, 0.8684838706254959, 0.8702223702520132, 0.8686109191924334, 0.869550997465849, 0.8678180234879256, 0.868933097999543, 0.8693770147487521, 0.8685399895906448, 0.8687310522049665, 0.8694990168139338, 0.8688508304208518, 0.8689055595546961, 0.8685190906375646, 0.8696642123907804, 0.8687148254364729, 0.8699958860874176, 0.8692010127753019, 0.8683642178028822, 0.868608687967062, 0.8689484779536724, 0.8697110473364592, 0.8693254534900189, 0.869063181951642, 0.8687352734804153, 0.8686322996765375, 0.867580402046442, 0.8683947152644396, 0.868176386244595, 0.8691585700958967, 0.8674692007340491, 0.868741292282939, 0.8694709821790457, 0.8690870878100395, 0.8687619062140584, 0.8688468013703823, 0.8691414667572827, 0.8689496354758739, 0.869881102591753, 0.8699932587891817, 0.8682038147002459, 0.8683469915017485, 0.8692991247028112, 0.8688899119943381, 0.8698239135742187, 0.868507260158658, 0.8680613169819117, 0.8692799863964319, 0.8685752669721842, 0.8691157221049071, 0.8679554709792137, 0.8683901857584715, 0.8690043222904206, 0.8694651550054551, 0.8691477528214455, 0.8690514581650496, 0.8682719898968935, 0.8698485902696848, 0.8688517032936215, 0.8693661823868751, 0.8694078106433153, 0.8705538620054721, 0.8693889805674553, 0.8688534115999937, 0.8697546669840812, 0.867976725101471, 0.8696103118360042, 0.8683804791048169, 0.8693442234396934, 0.8691976804286241, 0.8688010571151972, 0.8690754567086697, 0.8690780286490917, 0.8680252852575889, 0.8690958709269762, 0.8694042455404997, 0.8691308841854334, 0.8693637311086059, 0.8685282526165247, 0.8702637777104973, 0.8695156516134739, 0.869143930003047, 0.8694523039460182, 0.8689988732337952, 0.86930310966447, 0.8696112053096294, 0.8700269440561533]
#scores_10_1 = [0.8695041617180824, 0.8695819115202147, 0.869921554978337, 0.8694646126735994, 0.8694982252947946, 0.8698973296946962, 0.8699726377172359, 0.8692059874655593, 0.8695959871250855, 0.8699221885430474, 0.8693505477500627, 0.8695090449298298, 0.8693889092082192, 0.8693479714364137, 0.8690055620512211, 0.8693532701752827, 0.8695043024492071, 0.8697922996449041, 0.8694078948133932, 0.8699777959847288, 0.8698058038350903, 0.8694281473293418, 0.8695799793400139, 0.8690699458134611, 0.8697155366435997, 0.8693223776327708, 0.8693599015766128, 0.8697834880463278, 0.8702069995788908, 0.8691727157725577, 0.8691886113033542, 0.8695850595961747, 0.8692518714397791, 0.8693563786877927, 0.8702029786232097, 0.8696609478506038, 0.8694580989879861, 0.8702071247353038, 0.8695613113200632, 0.8695358515206157, 0.8694146290192372, 0.8699480964418324, 0.8697672185491165, 0.8696104233218592, 0.8703169019040716, 0.8696279087033788, 0.8699434115557232, 0.8696980368029452, 0.8690376171460813, 0.8690754087879334, 0.8695617002144994, 0.8691071851807136, 0.869464520257714, 0.8692518390382429, 0.8699106886063519, 0.8694578544596162, 0.8696648157021053, 0.8697054823294937, 0.8696435349214939, 0.8693264063183115, 0.8689589793161278, 0.8697070323414761, 0.869996946833226, 0.8697368759685054, 0.8701242447869456, 0.8693943607046498, 0.8687436954793335, 0.8694093369142407, 0.8691480127715653, 0.869281923485234, 0.8699288432745621, 0.8700401980286078, 0.869844733933604, 0.8692880685461091, 0.8691833309002057, 0.86898835597851, 0.8690079038669157, 0.8692386481760017, 0.8695652967402646, 0.8692005178226831, 0.869600878925575, 0.8696625985506191, 0.8696909831402365, 0.8696444400110228, 0.8689008589220302, 0.8694221834837328, 0.8699720621132072, 0.8699058495746287, 0.8694579730758957, 0.8699018192219491, 0.8700858978764126, 0.8697136565662532, 0.8698822184815047, 0.8698526408160027, 0.8701315800837149, 0.8697505363810066, 0.8698928673827825, 0.8701202650602328, 0.8695017286021781, 0.8693233012309902]
#scores_15 = [0.87419026,0.87449493, 0.87468625, 0.8744494, 0.87405172, 0.87425401, 0.87425781, 0.87477176, 0.87409505, 0.8734193, 0.87359523, 0.8742911, 0.87487478, 0.87447447, 0.87412666, 0.87439657, 0.87452888, 0.87414008, 0.87437926, 0.87414775, 0.87437481, 0.87430009, 0.87449099, 0.87384319, 0.87483437, 0.87448268, 0.87426891, 0.8742287, 0.87464983, 0.87431562, 0.8736647, 0.87460026, 0.87459545, 0.87410114, 0.87460884, 0.87397896, 0.87425215, 0.8744169, 0.87409821, 0.87417882, 0.87404053, 0.8739908, 0.87389715, 0.87431355, 0.87454532, 0.87440259, 0.87363784, 0.8739239, 0.87421054, 0.87464649, 0.87511009, 0.87406448, 0.87459085, 0.87450476, 0.87453247, 0.8740677, 0.87456093, 0.87469369, 0.87481578, 0.87464547, 0.8742097, 0.87461432, 0.87407086, 0.87394763, 0.87489132, 0.87462682, 0.87449617, 0.87430415, 0.874796, 0.87451398, 0.87427127, 0.8737603, 0.87428971, 0.87440497, 0.87376835, 0.87410355, 0.87387992, 0.87431058, 0.87514257, 0.87470644, 0.87409007, 0.87444854, 0.87489385, 0.8747469, 0.8744604, 0.87393671, 0.87429399, 0.87445799, 0.87442261, 0.87440114, 0.87479824, 0.87423457, 0.87390477, 0.87391397, 0.87409264, 0.87421278, 0.87450037, 0.87460683, 0.87494885, 0.87416295]
scores_15 = [0.8628673703595996, 0.8624048015847802, 0.8623990073800087, 0.8632287028804422, 0.8607268931930211, 0.8629376422986388, 0.8629292130842805, 0.8627884811162949, 0.8623025111109018, 0.8628406461328268, 0.8621940354630351, 0.86354527246207, 0.8627034287899733, 0.8628176900744439, 0.8634056971222163, 0.8635280176997184, 0.8622370509058237, 0.8630144560337066, 0.8635791182145476, 0.8634279810637235, 0.8629674334824086, 0.8638452585414051, 0.8637461336329579, 0.8628504181280732, 0.8634079901874065, 0.8630818465352058, 0.8630608800426125, 0.8627412924915552, 0.8635623157024384, 0.8625804045796395, 0.862338138744235, 0.8622362869232894, 0.8635695446655154, 0.8627199812605977, 0.8632014139462263, 0.8627511885017156, 0.8630326968058943, 0.8633510715886951, 0.8623301983997226, 0.8629188737645745, 0.8625350360758602, 0.8623312294110655, 0.8627950203418732, 0.8624043329060078, 0.8632362184301019, 0.8642266337573529, 0.8623936504684389, 0.8623649432882666, 0.8625842549279332, 0.8629636375419796, 0.8621334177162499, 0.863564831390977, 0.863170187920332, 0.8627568396553397, 0.8640693670511246, 0.8626588713005185, 0.8633089047670365, 0.8633482159301639, 0.862763608172536, 0.8630064859986305, 0.8624780905805528, 0.8624703116901219, 0.862679332792759, 0.8627989114075899, 0.8637751768529415, 0.863292491659522, 0.8626386575587094, 0.8627015389502048, 0.8626961840316654, 0.8621253264322877, 0.8630137604475021, 0.8635489182174205, 0.8638475822471082, 0.8633016800135374, 0.8621791492402554, 0.8638886228576302, 0.8632199138775468, 0.8638071192428469, 0.8627064261212944, 0.8629739084467292, 0.8633303272351622, 0.8632696686312556, 0.8630903960205615, 0.8630421718023717, 0.8639113697037101, 0.8631205774471163, 0.8623588394001126, 0.862848153039813, 0.862566682510078, 0.8635766552761197, 0.8641947719827294, 0.862480667643249, 0.8620175444707274, 0.8634789062291384, 0.8631041492530493, 0.8627016406133771, 0.8629758520796895, 0.8618405134603381, 0.8632346680760383, 0.8627712093293667]
#print(len(scores_05),len(scores_10),len(scores_15),np.mean(scores_15))
fig = plt.subplots(1,1,figsize=(8,5))
plt.hist(scores_05, histtype=u'step', label='p=0.05')
plt.hist(scores_10, histtype=u'step', label='p=0.10')
plt.hist(scores_15, histtype=u'step', label='p=0.15')
plt.xlabel('accuracy')
plt.legend()
plt.tight_layout()
plt.show()
#print(np.array(dcs).reshape((100,)).tolist())

# Download output files from Kaggle
1. Copy the target file to the directory /kaggle/working
2. Run the FileLink code


In [None]:
from IPython.display import FileLink
FileLink(r'checkpoints/maxValScore.pth')

In [None]:
./imgs_precropped