In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import pandas as pd

datapath = "/kaggle/input/vinbigdata-chest-xray-abnormalities-detection/"
metapath = "/kaggle/input/xraydata/xraydata"
data = pd.read_csv(f"{datapath}/train.csv")
data = data.set_index("image_id")
data["width"] = (data["x_max"]-data["x_min"])
data["height"] = (data["y_max"]-data["y_min"])
data["area"] = (data["height"]*data["width"])
data.head()

## Test and Train data creation
Train and Test data is created by

1. Combining all 67K annotations into 15K, one for each image
2. Splitting 15K into approx 12/3K train/test ensuring distribution of classes across gender is maintained in both.
3. expanded train/test by creating 320x320 images of full, 2x2 and 4x4 sections of xray images while sampling normal xray/sections with a probability of 0.2 only. This gave us close to 80K train and 20K test images. 
4. for each dataset we also stoe metadate that help us to create corresponding masks for training.

In [None]:
trainmeta = pd.read_csv(f"{metapath}/traindata.csv").set_index("id")
trainmeta.head()

In [None]:
testmeta = pd.read_csv(f"{metapath}/testdata.csv").set_index("id")
testmeta.head()

In [None]:
NUM_CLASSES = 14

## Dataset Creation

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset
import random
import re
import albumentations as A
from torchvision import transforms
import albumentations.pytorch as AP
import skimage
import cv2
from PIL import Image
from skimage.transform import resize


class AlbumentationTransforms:
    """
    Helper class to create test and train transforms using Albumentations
    """
    def __init__(self, transforms_list=[]):
        transforms_list.append(AP.ToTensor())
        self.transforms = A.Compose(transforms_list)
        
    def __call__(self, img):
        img = np.array(img)
        #print(img)
        return self.transforms(image=img)['image']

In [None]:
from skimage import io

class DataLoader:
    """
    Helper class to load test and train data
    """
    def __init__(self, shuffle=True, batch_size=128, seed=1):
        cuda = torch.cuda.is_available()

        if cuda:
            torch.cuda.manual_seed(seed)

        # dataloader arguments - something you'll fetch these from cmdprmt
        self.dataloader_args = dict(shuffle=shuffle, batch_size=batch_size, num_workers=4, 
                                    pin_memory=True) if cuda else dict(shuffle=shuffle, 
                                                                       batch_size=batch_size)

    def load(self, data):
        return torch.utils.data.DataLoader(data, **self.dataloader_args)


      
class XrayDataset(Dataset):
    """hest XRay dataset reader."""

    def __init__(self, filepath, data, levels, size, channel_mean, channel_stdev, meta, transforms = None):
        """
        Args:
            data (string):
        """
        self.transforms = transforms
        self.filepath = filepath
        self.images = data
        self.size = size # must macth the image size
        self.means = channel_mean
        self.stdevs = channel_stdev
        self.meta = meta

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        image = io.imread(os.path.join(self.filepath, f"{self.images[idx]:08}.png"), as_gray=True, pilmode="L")
        # get the meta
        info = self.meta.loc[self.images[idx]]
        wr = (info.resized_width/info.original_width)/info.dim_ratio
        hr = (info.resized_height/info.original_height)/info.dim_ratio
        #if idx==3:
        #    print(idx, info)
        # create masks from info
        #me = metadata.loc[image]
        annots = data.loc[info['image_id']]
        labels = [info['c0'], info['c1'], info['c2'], info['c3'], info['c4'], info['c5'], info['c6'], info['c7'], 
                  info['c8'], info['c9'], info['c10'], info['c11'], info['c12'], info['c13']]
        masks = np.zeros(shape=[NUM_CLASSES, self.size, self.size], dtype="float")
        if not info['c14']:
            for cid in range(13):
                if labels[cid]:
                    class_instances = annots[annots["class_id"]==cid]
                    for index, instance in class_instances.iterrows():
                        
                        # get the left, bottom, top and right extents of the instance
                        # in the new size. But we have not take into account our view
                        l = max(int(instance.x_min), int(info['left'])) - int(info['left'])
                        r = min(int(instance.x_max), int(info['right'])) - int(info['left'])
                        b = max(int(instance.y_min), int(info['top'])) - int(info['top'])
                        t = min(int(instance.y_max), int(info['bottom'])) - int(info['top'])
                        
                        if l<r and b<t:
                            #if idx==3:
                            #    print(instance, "\n", "----------", "\n", cid, l, r, b, t, "\n", "----------")
                            l = round(l*wr)
                            r = round(r*wr)
                            b = round(b*hr)
                            t = round(t*hr)
                            
                            masks[cid, b:t, l:r] = 1  
        # convert set labels to a list
        labels = np.array(labels)       

        if self.transforms:
            image = self.transforms(image)

        # Bhargav Sir to calculate mean and stdev of the dataset. Uncomment below after that
        image = (image - self.means)/self.stdevs
        image = torch.transpose(image, 1, 0)
        image = torch.reshape(image, (1, image.shape[1], image.shape[0]))
        #image = torch.from_numpy(image)
        masks = torch.from_numpy(masks)
        labels = torch.from_numpy(labels)
        
        # Decide if we are sending meta - sex, pixel spacing, positional encoding
        # calculate that here from meta, window, dim

        return image.float(), masks.float() #, labels.float(), idx

In [None]:
print(len(trainmeta), len(testmeta))

## Setting the data

Below we can pick a subset of data instead. Entire data will take 7 hours to run


In [None]:
#train_data, test_data = trainmeta.index, testmeta.index # for full data
train_data, test_data = trainmeta.index[:10000], testmeta.index[:2500] # for part data say

In [None]:
# Data Properties
IMAGE_WIDTH = 320
HIERARCHY_LEVELS = 1
BATCH_SIZE = 16
channel_mean = 0.51615639
channel_stdev = 0.24926406

transforms = AlbumentationTransforms(
                [
                    #A.Resize(80, 80),
                    # A.Normalize(mean=channel_mean, std=channel_stdev),
                    #A.CLAHE(always_apply=True),
                    A.RandomBrightnessContrast(),
                    #A.InvertImg(),
                    A.GaussNoise()
                ]
             )
test_tranforms = AlbumentationTransforms()
train = XrayDataset(f"{metapath}/traindata", train_data, HIERARCHY_LEVELS, IMAGE_WIDTH, 
                    channel_mean, channel_stdev, trainmeta, transforms = transforms)
test = XrayDataset(f"{metapath}/testdata", test_data, HIERARCHY_LEVELS, IMAGE_WIDTH, 
                   channel_mean, channel_stdev, testmeta, transforms = test_tranforms)

dataloader = DataLoader(batch_size=BATCH_SIZE, shuffle=True)

# train dataloader
train_loader = dataloader.load(train)

# test dataloader
test_loader = dataloader.load(test)

images, masks = iter(test_loader).next()
images.shape
c = 0
im = Image.fromarray(((images[c][0].numpy()*channel_stdev + channel_mean)*255).astype('uint8'))
im
mi = Image.fromarray((masks[c][0].numpy()*255).astype('uint8'))
mi

In [None]:
!pip install torchsummary

## Model train/test setup

In [None]:
# Copied from EVA4. Need significat changes. This is just for an idea. 
# Follow Conclusions from https://github.com/abhinavdayal/DepthMask to make changes
# other than just input and output shape changes

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from tqdm.notebook import tqdm

class Train:
    def __init__(self, model, dataloader, optimizer, runmanager, lossfn, scheduler=None, L1lambda = 0):
        self.model = model
        self.dataloader = dataloader
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.L1lambda = L1lambda
        self.lossfn = lossfn
        self.runmanager = runmanager

    def run(self):
        self.model.train()
        pbar = tqdm(self.dataloader)
        # TODO: meta may come as fourth thing here
        for data, target in pbar:
            self.runmanager.begin_batch()
            # get samples
            data, target = data.to(self.model.device), target.to(self.model.device)

            # Init
            self.optimizer.zero_grad()
            # In PyTorch, we need to set the gradients to zero before starting to do backpropragation because PyTorch accumulates the gradients on subsequent backward passes. 
            # Because of this, when you start your training loop, ideally you should zero out the gradients so that you do the parameter update correctly.

            # Predict
            y_pred = self.model(data)

            # Calculate loss
            loss = self.lossfn(y_pred, target)

            #Implementing L1 regularization
            if self.L1lambda > 0:
                reg_loss = 0.
                for param in self.model.parameters():
                    reg_loss += torch.sum(param.abs())
                loss += self.L1lambda * reg_loss


            # Backpropagation
            loss.backward()
            self.optimizer.step()

            # Update pbar-tqdm
            self.runmanager.track_train_loss(loss)
            
            lr = 0
            if self.scheduler:
                lr = self.scheduler.get_last_lr()[0]
            else:
                # not recalling why i used sekf.optimizer.lr_scheduler.get_last_lr[0]
                lr = self.optimizer.param_groups[0]['lr']

            batchtime = self.runmanager.end_batch(lr)
            pbar.set_description(f'time: {batchtime:0.2f}, loss: {loss.item():0.4f}')
            
            if self.scheduler:
                self.scheduler.step()

class Test:
    def __init__(self, model, dataloader, runmanager, lossfn, scheduler=None):
        self.model = model
        self.dataloader = dataloader
        self.runmanager = runmanager
        self.scheduler = scheduler
        self.lossfn = lossfn
        print("initialized tester with ", self.model.device)

    def run(self):
        self.model.eval()
        with torch.no_grad():
            pbar = tqdm(self.dataloader)
            for data, target in pbar:
                data, target = data.to(self.model.device), target.to(self.model.device)
                output = self.model(data)
                loss = self.lossfn(output, target)
                self.runmanager.track_test_loss(loss)

            if self.scheduler and isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                pbar.write("In scheduler step with loss of ", self.runmanager.get_test_loss())
                self.scheduler.step(self.runmanager.get_test_loss())

class ModelTrainer:
    def __init__(self, model, optimizer, train_loader, test_loader, runmanager, 
                 lossfn, scheduler=None, batch_scheduler=False, L1lambda = 0):
        self.model = model
        print(self.model.device)
        self.model.to(self.model.device)
        self.scheduler = scheduler
        self.batch_scheduler = batch_scheduler
        self.optimizer = optimizer
        self.runmanager = runmanager
        self.lossfn = lossfn
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.train = Train(model, train_loader, optimizer, self.runmanager, self.lossfn, 
                           self.scheduler if self.batch_scheduler else None, L1lambda)
        self.test = Test(model, test_loader, self.runmanager, self.lossfn, self.scheduler)


    def run(self, runparams, epochs=10):
        pbar = tqdm(range(1, epochs+1), desc="Epochs")
        self.runmanager.begin_run(runparams, self.train_loader, self.test_loader)
        for epoch in pbar:
            self.runmanager.begin_epoch()
            self.train.run()
            self.test.run()
            lr = self.optimizer.param_groups[0]['lr']
            pbar.write(self.runmanager.end_epoch(lr))
            self.runmanager.savebest(self.model.name)
            # need to ake it more readable and allow for other schedulers
            if self.scheduler and not self.batch_scheduler and not isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                self.scheduler.step()
            pbar.write(f"Learning Rate = {lr:0.6f}")

        # save stats for later lookup
        self.runmanager.save(self.model.name)

# A Unet styled Model

In [None]:

class Net(nn.Module):
    def __init__(self, name="Model"):
        super(Net, self).__init__()
        self.name = name
        
    def summary(self, input_size): #input_size=(1, 28, 28)
        summary(self, input_size=input_size)

    def gotrain(self, optimizer, train_loader, test_loader, epochs, runmanager, runparams, lossfn, 
                scheduler=None, batch_scheduler=False, L1lambda=0):
        self.trainer = ModelTrainer(self, optimizer, train_loader, test_loader, runmanager, lossfn, 
                                    scheduler, batch_scheduler, L1lambda)
        self.trainer.run(runparams, epochs)


class InitialBlock(nn.Module):
    def __init__(self, planes):
        super(InitialBlock, self).__init__()
        self.conv1 = nn.Conv2d(1, planes, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes*2, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes*2)
        self.conv3 = nn.Conv2d(planes*2, planes*4, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes*4)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.relu(self.bn3(self.conv3(out)))
        return  out

class EncoderPath(nn.Module):
    def __init__(self, inplanes, outplanes, dilation):
        super(EncoderPath, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=3, padding=dilation, 
                               stride=2, dilation=dilation, bias=False)
        self.bn1 = nn.BatchNorm2d(outplanes)
        self.conv2 = nn.Conv2d(outplanes, outplanes, kernel_size=3, padding=dilation, 
                               stride=1, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(outplanes)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return out


class EncoderBlock(nn.Module):
    def __init__(self, inplanes, outplanes):
        super(EncoderBlock, self).__init__()
        self.direct = nn.Conv2d(inplanes, outplanes//4, kernel_size=1, padding=0, 
                                stride=2, bias=False)
        self.directbn = nn.BatchNorm2d(outplanes//4)
        # if we need to reduce we can do groupwise here with shuffle, worth a try
        self.path1 = EncoderPath(inplanes, outplanes//4, 1)
        self.path2 = EncoderPath(inplanes, outplanes//4, 2)
        self.path3 = EncoderPath(inplanes, outplanes//4, 4)

    def forward(self, x):
        p1 = self.path1(x)
        p2 = self.path2(x)
        p3 = self.path3(x)
        x = self.directbn(self.direct(x))
        out = torch.cat((x, p1, p2, p3), 1)
        out = F.relu(out)
        return out

class DecoderBlock(nn.Module):
    def __init__(self, planes):
        super(DecoderBlock, self).__init__()
        #self.upsample = nn.ConvTranspose2d(planes*4, planes*4, kernel_size=3, stride=2, padding=1)
        # At this point we will use Pixel Shuffle to make resolution 224x224 
        planes = planes//4 #due to pixel shuffle
        # it may be useful to shuffle before adding groups?
        self.conv1 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, 
                               bias=False, groups = 1)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, 
                               bias=False, groups = 1)
        self.bn2 = nn.BatchNorm2d(planes)

    def forward(self, x):
        out = F.pixel_shuffle(x, 2) # 32 channels
        out = F.relu(self.bn1(self.conv1(out))) # 64 channels
        out = self.bn2(self.conv2(out)) # 128 channels
        return out

class Encoder(nn.Module):
    def __init__(self, planes):
        super(Encoder,self).__init__()
        self.encoder1 = EncoderBlock(planes, planes*2)   # 128 channels# RF = 24
        self.encoder2 = EncoderBlock(planes*2, planes*4)  # 256 channels
        self.encoder3 = EncoderBlock(planes*4, planes*8)  # 512 channels

    def forward(self,x):
        e1 = self.encoder1(x) # 32 channels 80x80
        e2 = self.encoder2(e1) # 64 channels 40x40
        e3 = self.encoder3(e2) # 128 channels 20x20

        return e1, e2, e3

class MinMaxScaler(nn.Module):
    def __init(self):
        super(MinMaxScaler, self).__init__()

    def forward(self, x):
        s = x.shape
        y = x.view(s[0], s[1], -1) 
        y = y - y.min(2, keepdim=True)[0]
        y = y/(y.max(2, keepdim=True)[0] )
        y = y.view(s)
        return y


class Decoder(nn.Module):
    def __init__(self, planes):
        super(Decoder,self).__init__()
        self.decoder1 = DecoderBlock(planes)   # 512 INPUT  AND 128 OUTPUT
        self.decoder2 = DecoderBlock(planes//2)  # 256 Input 64 output
        self.decoder3 = DecoderBlock(planes//4)  # 128 Input 32 output
        # e2 has 256 outputs
        self.e2conv = nn.Conv2d(planes//2, planes//4, kernel_size=1, padding=0, stride=1, bias=False)
        self.e2bn = nn.BatchNorm2d(planes//4) # 64 channels
        #e1 has 128 outputs
        self.e1conv = nn.Conv2d(planes//4, planes//8, kernel_size=1, padding=0, stride=1, bias=False)
        self.e1bn = nn.BatchNorm2d(planes//8) # 64 channels
        # e0 has 64 outputs
        self.e0conv = nn.Conv2d(planes//8, planes//16, kernel_size=1, padding=0, stride=1, bias=False)
        self.e0bn = nn.BatchNorm2d(planes//16) # 32 channels

        self.conv1 = nn.Conv2d(planes//8, planes//8, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes//8)

        self.conv2 = nn.Conv2d(planes//8, planes//8, kernel_size=3, padding=1, stride=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes//8)

        self.conv3 = nn.Conv2d(planes//8, 14, kernel_size=1, stride=1, bias=False)
        self.minmaxscaler = MinMaxScaler()

   
    def forward(self, *inputs):

        x, e2, e1, e0 = inputs
        d = self.decoder1(x) # 32 channels 80x80
        e2 = self.e2bn(self.e2conv(e2))
        d = F.relu(torch.cat((d, e2), 1))

        d = self.decoder2(d) # 32 channels 80x80
        e1 = self.e1bn(self.e1conv(e1))
        d = F.relu(torch.cat((d, e1), 1))

        d = self.decoder3(d) # 64 channels 40x40
        e0 = self.e0bn(self.e0conv(e0))
        d = F.relu(torch.cat((d, e0), 1))

        d = F.relu(self.bn1(self.conv1(d)))
        d = F.relu(self.bn2(self.conv2(d)))

        d = self.conv3(d)
        return self.minmaxscaler(d)


#implementation of the new resnet model
class XrayEncoderDecoder(Net):
    def __init__(self,name="XrayEncoderDecoder", planes=16):
        super(XrayEncoderDecoder,self).__init__(name)
        self.prepLayer = InitialBlock(planes)  # 64 channels
        self.encoder = Encoder(planes*4)  # 512 channels
        self.decoder = Decoder(planes*32)

    def forward(self,x):
        data_shape = x.size()
        e0 = self.prepLayer(x) # 32 channels 160x160
        e1, e2, e3 = self.encoder(e0) # 32 channels 80x80
        return self.decoder(e3, e2, e1, e0)

# Init weights

In [None]:
def init_weights(m):
    if str(type(m)).startswith('torch.nn.modules.conv') :
        torch.nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')

model = XrayEncoderDecoder(name="XrayEncoderDecoder")
model.apply(init_weights)

use_cuda = torch.cuda.is_available()
model.device = torch.device("cuda" if use_cuda else "cpu")
model.to(model.device)
model.summary((1,384, 384))

# Managing stats and backup

In [None]:
# COPIED from EVA4. Need to see if this would work in Kaggle context 
# Need to be cleaned up
# import standard PyTorch modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter # TensorBoard support

# import torchvision module to handle image manipulation
import torchvision
import torchvision.transforms as transforms

# calculate train time, writing train data to files etc.
import time
import pandas as pd
import json
import os

# import modules to build RunBuilder and RunManager helper classes
from collections  import OrderedDict
from collections import namedtuple
from itertools import product

# Helper class, help track loss, accuracy, epoch time, run time, 
# hyper-parameters etc. Also record to TensorBoard and write into csv, json
class RunManager():
    def __init__(self, savepath, channel_means, channel_stdevs, network):

        # tracking every epoch count, loss, accuracy, time
        self.epoch_count = 0
        self.batch_count = 0
        self.min_val_loss = 10e10
        self.epoch_train_loss = 0
        self.epoch_start_time = None
        self.batch_start_time = None
        self.savepath = savepath
        self.network = network

        # tracking every run count, run data, hyper-params used, time
        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None

        # record model, trainloader and TensorBoard 
        self.network = network
        self.trainloader = None
        self.channel_means = channel_means
        self.channel_stdevs = channel_stdevs
          
    
    # record the count, hyper-param, model, trainloader of each run
    # record sample images and network graph to TensorBoard  
    def begin_run(self, run, trainloader, testloader):
        self.run_start_time = time.time()

        self.run_params = run
        self.run_count += 1

        self.trainloader = trainloader
        self.testloader = testloader
        self.batchlrs = []

        import socket
        from datetime import datetime
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        log_dir = os.path.join(self.savepath,f'runs/{current_time}_{socket.gethostname()}-{run}')

    # when run ends, close TensorBoard, zero epoch count
    def end_run(self):
        self.epoch_count = 0
        self.batch_count = 0

    def begin_batch(self):
        self.batch_start_time = time.time()

    def end_batch(self, lr):
        self.batch_count += 1
        self.batchlrs.append((lr, self.batch_count))
        batch_duration = time.time() - self.batch_start_time
        return batch_duration

    # zero epoch count, loss, accuracy, 
    def begin_epoch(self):
        self.epoch_start_time = time.time()
        self.epoch_count += 1
        self.epoch_train_loss = 0
        self.epoch_test_loss = 0
        self.batchlrs = []
        self.batchloss = []

    def end_epoch(self, lr):
        # calculate epoch duration and run duration(accumulate)
        epoch_duration = time.time() - self.epoch_start_time
        run_duration = time.time() - self.run_start_time

        # record epoch loss and accuracy
        trainloss = self.get_train_loss()
        testloss = self.get_test_loss()
        

        results = OrderedDict()
        results["run"] = self.run_count
        results["epoch"] = self.epoch_count
        results["train loss"] = trainloss
        results["test loss"] = testloss
        # Write into 'results' (OrderedDict) for all run related data
        results["epoch duration"] = epoch_duration
        results["run duration"] = run_duration   

        # Record hyper-params into 'results'
        for k,v in self.run_params.items(): 
            results[k] = v
        self.run_data.append(results)

        return f'{results}'

    # accumulate loss of batch into entire epoch loss
    def track_train_loss(self, loss):
        # multiply batch size so variety of batch sizes can be compared
        self.batchloss.append((loss.item(), self.batch_count+1))
        self.epoch_train_loss += loss.item() * self.trainloader.batch_size

    def track_test_loss(self, loss):
        self.epoch_test_loss += loss.item() * self.testloader.batch_size
        
    def get_test_loss(self):
        return self.epoch_test_loss / len(self.testloader.dataset)

    def get_train_loss(self):
        return self.epoch_train_loss / len(self.trainloader.dataset)

    def savebest(self, fileName):
        f = os.path.join(self.savepath, f'{fileName}.pt')
        if self.epoch_test_loss < self.min_val_loss:
            torch.save(self.network.state_dict(), f)

    # save end results of all runs into csv, json for further analysis
    def save(self, fileName):
        f = os.path.join(self.savepath, f'{fileName}.csv')
        pd.DataFrame.from_dict(
            self.run_data, 
            orient = 'columns',
        ).to_csv(f)

In [None]:
import os
outdir = '/kaggle/working/output'
if not os.path.exists(outdir):
    os.mkdir(outdir)

#check if the model.name file exists
savedmodel = os.path.join(outdir, model.name+'.pt')

if os.path.exists(savedmodel):
    model.load_state_dict(torch.load(savedmodel))

"""
classification=False, 
"""
"""
classes={'Aortic enlargement':1, 'Atelectasis':2, 'Calcification': 3,
          'Cardiomegaly': 4, 'Consolidation': 5, 'ILD': 6, 'Infiltration':7,
          'Lung Opacity': 8, 'Nodule/Mass': 9, 'Other lesion': 10, 
          'Pleural effusion': 11, 'Pleural thickening': 12, 'Pneumothorax': 13,
          'Pulmonary fibrosis': 14
         }
"""

m = RunManager(outdir, channel_mean, channel_stdev, model)



## LOSS

We need to try MSE Loss only separately. 

In [None]:
import torch
import torch.nn.functional as F
from math import exp
import numpy as np
import torch.nn as nn

class WeightedLoss(nn.Module):
    def __init__(self, lossfn, weights):
        super(WeightedLoss, self).__init__()
        self.lossfn = lossfn
        self.weights = weights

    def forward(self, source, target):
        """
        We get one batch
        In this batch some images will be normal and some may have pathology
        We need to apply weight for each class
        But each image in batch will have multiple classes or be normal
        
        """
        loss = torch.tensor(0).float().to(target.device)
        normal = torch.sum(target) == 0
        #print(source.shape)
        for j in range(source.shape[0]): # for each image j in he batch
            for i in range(len(self.weights)-1): # for each class i in the image
                #print(i, j, source[j:j+1,i:i+1,:,:].shape, target[j:j+1,i:i+1,:,:].shape)
                l = self.lossfn(source[j:j+1,i:i+1,:,:], target[j:j+1,i:i+1,:,:])
                w = self.weights[-1] if normal else self.weights[i]
                loss += w*l
                #print(w, l, loss)
        
        return loss

class MixedLoss(nn.Module):
    def __init__(self, loss1, loss2, alpha):
        super(MixedLoss, self).__init__()
        self.loss1 = loss1
        self.loss2 = loss2
        self.alpha = alpha

    def forward(self, source, target):
        loss1 = self.loss1(source, target)
        loss2 = self.loss2(source, target)
        #print("loss1 = ", loss1, "loss2 = ", loss2)
        loss = self.alpha*loss1 + (1 - self.alpha)*loss2
        return loss

# below code is adapted from
# https://github.com/jorge-pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py
# only change is to make MSSSIM as a loss or difference measure instead of similarity measure

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()


def create_window(window_size, channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window


def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
    if val_range is None:
        if torch.max(img1) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(img1) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val
    else:
        L = val_range

    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel=channel).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2

    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2

    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2
    cs = torch.mean(v1 / v2)  # contrast sensitivity

    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

    if size_average:
        ret = ssim_map.mean()
    else:
        ret = ssim_map.mean(1).mean(1).mean(1)

    if full:
        return ret, cs
    return ret


def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=True):
    device = img1.device
    weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
    levels = weights.size()[0]
    mssim = []
    mcs = []
    for _ in range(levels):
        sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
        mssim.append(sim)
        mcs.append(cs)

        img1 = F.avg_pool2d(img1, (2, 2))
        img2 = F.avg_pool2d(img2, (2, 2))

    mssim = torch.stack(mssim)
    mcs = torch.stack(mcs)

    # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
    if normalize:
        mssim = (mssim + 1) / 2
        mcs = (mcs + 1) / 2

    pow1 = mcs ** weights
    pow2 = mssim ** weights
    # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
    output = torch.prod(pow1[:-1] * pow2[-1])
    return output


class MSSSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, channel=3):
        super(MSSSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = channel

    def forward(self, img1, img2):
        # TODO: store window between calls if possible
        s = msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
        # we need loss
        loss = torch.clamp((1 - s)*0.5, min=0, max=1)

        if self.size_average == 'mean':
            loss = torch.mean(loss)

        return loss

In [None]:
L1lambda = 0
L2lambda = 1e-4
# TODO: We have 14 classes  that are imbalanced
# We need to find the imbalance in dataset and for each of the 14
# do a weighting. 
lossfn = MixedLoss(nn.L1Loss(), MSSSIM(21), 0.16)
imbalance_weights = (1-torch.tensor([3067, 186, 452, 2300, 353, 38, 613, 1322, 826, 1134, 1032, 1981, 96, 1617, 10606])/15000).float()
imbalance_weights = imbalance_weights.to(device = ("cuda" if use_cuda else "cpu"))

# we can replace the weighted loss with MSE loss plainly and see.
criterion = WeightedLoss(lossfn, imbalance_weights) # nn.MSELoss()

optimizer = optim.SGD(model.parameters(), lr=1e-2) #optim.Adam(model.parameters(), lr=1e-5, weight_decay=L2lambda) # 
#lr_finder = LRRangeFinder(model, optimizer, criterion, 1, 1e-2, 100, train_loader)
#lr_finder.range_test()

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
#https://research.nvidia.com/sites/default/files/pubs/2017-03_Loss-Functions-for/NN_ImgProc.pdf
L1lambda = 0
L2lambda = 0
EPOCHS = 4

max_lr = 0.3
optimizer = optim.SGD(model.parameters(), lr=max_lr/100, momentum=0.9, nesterov=True, weight_decay=L2lambda)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=len(train_loader), div_factor=5, pct_start=0.2, epochs=EPOCHS)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=int(len(train)/batch_size)+1, epochs=EPOCHS,  pct_start=5/24, div_factor=600, final_div_factor=1 )
#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=int(len(train_loader))+1, epochs=EPOCHS,  pct_start=0.2, div_factor=10, final_div_factor=10 )
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr = 0.003, max_lr=max_lr, mode='triangular2')

m = RunManager(outdir, channel_mean, channel_stdev, model)

model.gotrain(optimizer, train_loader, test_loader, EPOCHS, m, {'lr':'cyclic', 'loss':'MSSSIM+L1'}, 
              criterion, scheduler, True, L1lambda)

model.eval()

with torch.no_grad():
    images, target = iter(test_loader).next()
    images, target = images.to(model.device), target.to(model.device)
    output = model(images)

for i in range(output.shape[0]):
    print(torch.sum(target[i]), torch.sum(output[i]))

import matplotlib.pyplot as plt
c = 3
im = Image.fromarray(((images[c][0].cpu().numpy()*channel_stdev + channel_mean)*255).astype('uint8'))
for x in range(14):
    plt.imshow(output[c][x].cpu().numpy())
    plt.show()
im

torch.cuda.empty_cache()