In [None]:
from torch.utils.data import Dataset
import pandas as pd
import torch
import numpy as np
import cv2
import torch
import matplotlib.pyplot as plt
import os
from skimage.io import imread, imsave
from skimage.color import rgb2gray, rgba2rgb, gray2rgb
import pytorch_lightning as pl


# TODO: implement this function
def load_images(**kwargs):
    """
    This function should load new training images from some source and save it to some folder
    
    NOTE: our database has patients that have labels approved by doctors, only such images should be loaded
    to finetune the model (so you can use them if you want, but here we implement self-supervised learning).
    """
    raise NotImplemented('Load images and labels is not implemented')
    
# TODO: implement this function
def load_model(path_to_ckpt):
    """
    this function is used to load pre-trained model
    """
    raise NotImplemented('Load model is not implemented')
    
class MyDataset(Dataset):
    def __init__(self, path_to_csv, path_to_root, labels = None, aug = None, tr = None):
        """
        This is dataset that will supply training procedure with files
        
        Args:
            path_to_csv: string with path to csv file to be read by pd.read_csv()
            path_to_root: string woth path to folder where loaded images are located
            aug and tr: albumentations
            labels: numpy array of predicted labels, if none, labels will be read from csv
        """
        super(Dataset, self).__init__()
        
        self.path_to_root = path_to_root
        self.csv = pd.read_csv(path_to_csv)
        self.pathologies = ['Atelectasis',
                             'Cardiomegaly',
                             'Consolidation',
                             'Edema',
                             'Effusion',
                             'Emphysema',
                             'Fibrosis',
                             'Hernia',
                             'Infiltration',
                             'Mass',
                             'Nodule',
                             'Pleural_Thickening',
                             'Pneumonia',
                             'Pneumothorax']
        
        self.aug = aug
        self.tr = tr
        
        if labels:
            self.labels = labels
        else:
            []
        for pathology in self.pathologies:
            self.labels.append(self.csv['Finding Labels'].str.contains(pathology).values)
            
            self.labels = np.asarray(self.labels).T
            self.labels = self.labels.astype(np.float32)
        
    def __getitem__(self, idx):
        item = self.csv.iloc[idx]
        
        img = imread(os.path.join(self.path_to_root,item['Image Index']))
        if img.ndim==3:
            if img.shape[2]==4:
                img = rgb2gray(rgba2rgb(img))
            else:
                img = rgb2gray(img)
        img = np.expand_dims(img,-1)
        label = self.labels[idx]
    
        if self.aug:
            img = self.aug(image=img)['image']
        if self.tr:
            img = self.tr(image=img)['image']
            
        sample = {"image": img, 'label': label}
            
        return sample
    
    def __len__(self):
        return len(self.csv)
 
class BigBroNet(pl.LightningModule):
    """
    wrapper for metrics logging and training
    """
    def __init__(self, model):
        super(BigBroNet, self).__init__()
        self.model = model
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=0.00001)

        return [optimizer], []
        
    def forward(self,x):
        pred = self.model(x['image'])
        return {
            "y_pred": pred,
            "y_true": x['label']
        }
        
    def training_step(self, batch, batch_idx):

        postfix = "/train"
        
        out = self._step(batch, batch_idx, postfix)

        return {
            "loss": out["loss"],
            "y_true": out['y_true'],
            'y_pred': out['y_pred']
        }
    
    def training_epoch_end(self, outputs):
        out = self._epoch_end(outputs, postfix="/train_full")
        return {"log": out}        
    
    def validation_step(self, batch, batch_idx):
        out = self._step(batch, batch_idx, postfix="/val")
        return out

    def validation_epoch_end(self, outputs):
        out = self._epoch_end(outputs, postfix="/val")
        return {"log": out}

    def test_step(self, batch, batch_idx):
        out = self._step(batch, batch_idx, postfix="/test")
        return out

    def test_end(self, outputs):
        out = self._epoch_end(outputs, postfix="/test")
        return {'log': out}
    
    def _step(self, batch, batch_idx, postfix = ""):
        res = self(batch)
        loss = self.loss(res['y_true'], res['y_pred'])
        
        return {
            "loss": loss,
            **res
        }
    
    def _epoch_end(self, outputs, postfix=""):
        res = {}
        
        res.update( {f"avg_loss{postfix}":torch.stack([x["loss"] for x in outputs]).mean()} )

        y_true = []
        y_pred = []
        
        for batch in outputs:
            y_true+=list(batch['y_true'].clone().detach().cpu().numpy())
            y_pred+=list(batch['y_pred'].clone().detach().cpu().numpy())
          
        y_true = torch.tensor(y_true)
        y_pred = torch.tensor(y_pred)
        
        res.update(classification_metrics_vector(y_true, y_pred, postfix))
        res.update({f"loss{postfix}_full":self.loss(y_true, y_pred)})
        
        return res
        
    def loss(self, y_true, y_pred):
        if not hasattr(self,'criterion'):
            self.criterion = torch.nn.BCELoss()
        
        return self.criterion(y_pred.float(),y_true.float()).float()
    

    
from albumentations.core.composition import *
from albumentations.pytorch import ToTensor, ToTensorV2
from albumentations.augmentations.transforms import *
    
tr = Compose([Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), Resize(224, 224), ToTensorV2()])
aug = Compose([OneOf([GaussNoise(),Blur(blur_limit=7,p=0.1)], p=0.7),HorizontalFlip(p=0.5)
         ,ShiftScaleRotate(shift_limit=0.2,scale_limit=0.15,rotate_limit=20,interpolation=1,p=1,border_mode=4)
         ,OneOf([RandomBrightnessContrast(brightness_limit=0.3,contrast_limit=0.3), RandomGamma()],p=0.8)])


# TODO: change path_to_csv and path_to_root
dataset = MyDataset(path_to_csv, path_to_root, tr = tr)

# TODO: change path
model = load_model(path_to_ckpt)

# TODO: write predictor that will save all predictions to numpy array so that later we can pass them as labels
# predicted_labels = predict_labels(model, images)

model = BigBro(model)

from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint

trainer = Trainer(checkpoint_callback = ModelCheckpoint(filepath="./", mode="min", monitor="loss/val_full", save_weights_only=True, save_top_k=1),
                  gpus=0,
                  num_nodes=1,
                  max_epochs=10)


# dataloader params
params = dict(batch_size=32, num_workers = 6, pin_memory = True, shuffle=True)
# training
# TODO: change path_to_csv and path_to_root
# TODO: Finish all TODOs above so that predicted_labels will exists
trainer.fit(lig_model, train_dataloader = DataLoader(MyDataset(path_to_csv, path_to_root, labels=predicted_labels, tr = tr, aug = aug),**params), val_dataloaders = DataLoader(dataset, **params) )


"""
CONGRATZ! YOU TUNED THE MODEL! NOW REPLACE IT WITH THE ONE ON PRODUCTION AND THAT IS IT!
"""