In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
from typing import Dict
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ExponentialLR
import torchio as tio

from ml.models.ResUnet import ResUNet
from ml.models.unet_deepsup import Unet_MSS

from ml.models.building_blocks import VG_discriminator
from ml.extra_libraries.CycleGAN_losses import (CycleLoss, ReconstructionLoss, SegmentationLoss,
                                       DiscriminatorLoss, GeneratorLoss)
from ml.losses import IOU_Metric

In [2]:
from ml.tio_dataset import TioDataset
train_settings  = {
    "patch_shape" : (64, 64, 32),
    "patches_per_volume" : 32,
    "patches_queue_length" : 512,
    "batch_size" : 2,
    "num_workers": 4,
    "sampler": "weighted" #"uniform",#
}

# val_settings  = {
#     "patch_shape" : (32, 32, 32),
#     "patches_per_volume" : 32,
#     "patches_queue_length" : 1440,
#     "batch_size" : 8,
#     "num_workers": 4,
#     "sampler": "uniform",#"weighted" #"uniform",#
# }

test_settings  = {
    "patch_shape" : (192, 192, 128),
    "overlap_shape" : (24, 24, 16),
    "batch_size" : 1,
    "num_workers": 4,
}

data_dir = "/home/msst/Documents/medtech/MainData"
dataset = TioDataset(data_dir,
                 train_settings=train_settings,
                 val_settings=None,
                 test_settings=test_settings)

In [3]:
class VanGan(nn.Module):
    def __init__(self, modules):
        super(VanGan, self).__init__()
        self.gen_IS = modules['gen_IS']
        self.gen_SI = modules['gen_SI']
        self.disc_I = modules['disc_I']
        self.disc_S = modules['disc_S']

In [4]:
modules = {
    'gen_IS': ResUNet(channels_coef=8),
    'gen_SI': ResUNet(channels_coef=8),
    'disc_I': VG_discriminator(channels_coef=64),
    'disc_S': VG_discriminator(channels_coef=64),
}
model = VanGan(modules)

In [5]:
def check_None(tensor):
    if torch.isnan(tensor).sum() > 0:
        raise RuntimeError("None here")

In [6]:
class VG_Controller:
    def __init__(self, config: Dict):
        self.config = config
        self.device = config['device']
        self.model = config["model"]
        
        otimizers_settings = config["otimizers_settings"]
        self.gen_IS_opt = otimizers_settings['gen_IS_opt'](self.model.gen_IS)
        self.gen_SI_opt = otimizers_settings['gen_SI_opt'](self.model.gen_SI)
        self.disc_I_opt = otimizers_settings['disc_I_opt'](self.model.disc_I)
        self.disc_S_opt = otimizers_settings['disc_S_opt'](self.model.disc_S)
        
        if config.get('sheduler_fn') is not None:
            self.with_sheduler = True
            self.gen_IS_sheduler = otimizers_settings['sheduler_fn'](self.gen_IS_opt)
            self.gen_SI_sheduler = otimizers_settings['sheduler_fn'](self.gen_SI_opt)
            self.disc_I_sheduler = otimizers_settings['sheduler_fn'](self.disc_I_opt)
            self.disc_S_sheduler = otimizers_settings['sheduler_fn'](self.disc_S_opt)
        else:
            self.with_sheduler = False
        
        losses = config["losses"]
        self.cycle_loss_fn = losses["cycle_loss_fn"]
        self.reconstruction_loss_fn = losses["reconstruction_loss_fn"]
        self.segmentation_loss_fn = losses["segmentation_loss_fn"]
        self.discriminator_loss_fn = losses["discriminator_loss_fn"]
        self.generator_loss_fn = losses["generator_loss_fn"]
        self.cycle_lambda = losses["cycle_lambda"]
        self.identity_lambda = losses["identity_lambda"]
        
        self.epoch = 0
        self.history = None
        
        self.metric_fn = IOU_Metric()
        
        

    def fit(self, dataset, n_epochs):
        model = self.model.to(self.device)
        if self.history is None:
            self.history = {
                'train': [],
                'val': [],
                "test": [],
            }
        
        start_epoch = self.epoch
        for epoch in range(start_epoch, start_epoch+n_epochs):
            self.epoch += 1
            print(f"Epoch {epoch + 1}/{start_epoch+n_epochs}")
            
            train_info = self.train_epoch(dataset.train_dataloader)
            print(train_info)
            self.history['train'].append(train_info)
            
            if dataset.test_dataloader is not None:
                test_info = self.test_epoch(dataset.test_dataloader)
                print(test_info)
                self.history['test'].append(test_info)
            
            if self.with_sheduler:
                self.gen_IS_sheduler.step()
                self.gen_SI_sheduler.step()
                self.disc_I_sheduler.step()
                self.disc_S_sheduler.step()
            
        return self.model.eval()

    
    def train_epoch(self, train_dataloader):
        self.model.train()
        
        gen_IS_losses = []
        gen_SI_losses = []
        disc_I_losses = []
        disc_S_losses = []
        segmentation_losses = []
        reconstruction_losses = []
        
        for patches_batch in tqdm(train_dataloader):
            real_I = patches_batch['head']['data'].float().to(self.device)  
            real_S = patches_batch['vessels']['data'].float().to(self.device) 
            
            check_None(real_I)
            check_None(real_S)
            
            #Generator outputs
            fake_S = self.model.gen_IS(real_I)
            fake_I = self.model.gen_SI(real_S)
            cycled_S = self.model.gen_IS(fake_I)
            cycled_I = self.model.gen_SI(fake_S)

            # Discriminator outputs         
            disc_real_S = self.model.disc_S(real_S)
            disc_fake_S = self.model.disc_S(fake_S)
            disc_real_I = self.model.disc_I(real_I)
            disc_fake_I = self.model.disc_I(fake_I)
            
            check_None(fake_S)
            check_None(fake_I)
            check_None(cycled_S)
            check_None(cycled_I)
            check_None(disc_real_S)
            check_None(disc_fake_S)
            check_None(disc_real_I)
            check_None(disc_fake_I)
            
            
            #Losses
            cycle_loss_I = self.cycle_loss_fn(real_S, cycled_S)
            cycle_loss_S = self.cycle_loss_fn(real_I, cycled_I)
            
            segmentation_loss = self.segmentation_loss_fn(real_S, cycled_S)
            reconstruction_loss = self.reconstruction_loss_fn(real_I, cycled_I)

            gen_IS_loss = self.generator_loss_fn(disc_fake_S)
            gen_SI_loss = self.generator_loss_fn(disc_fake_I)

            total_loss_I = gen_IS_loss + self.cycle_lambda * cycle_loss_I +\
                           self.identity_lambda * segmentation_loss
            total_loss_S = gen_SI_loss + self.cycle_lambda * cycle_loss_S +\
                           self.identity_lambda * reconstruction_loss  # + id_IS_loss
            
            
            # -----------------
            # Generators
            # -----------------
            self.gen_IS_opt.zero_grad()
            self.gen_SI_opt.zero_grad()
            
            total_loss_I.backward(retain_graph=True)
            total_loss_S.backward(retain_graph=True)
            
            self.gen_IS_opt.step()
            self.gen_SI_opt.step()
            
            # -----------------
            # Discriminators
            # -----------------
            
            self.disc_I_opt.zero_grad()
            self.disc_S_opt.zero_grad()
            
            disc_I_loss = self.discriminator_loss_fn(disc_real_I, disc_fake_I)
            disc_S_loss = self.discriminator_loss_fn(disc_real_S, disc_fake_S)
            
            disc_I_loss.backward(retain_graph=True)
            disc_S_loss.backward()
            
            self.disc_I_opt.step()
            self.disc_S_opt.step()
            
            gen_IS_losses.append(gen_IS_loss.item())
            gen_SI_losses.append(gen_SI_loss.item())
            disc_I_losses.append(disc_I_loss.item())
            disc_S_losses.append(disc_S_loss.item())
            segmentation_losses.append(segmentation_loss.item())
            reconstruction_losses.append(reconstruction_loss.item())
        
        self.model.eval()
        out = {'gen_IS_loss': sum(gen_IS_losses)/len(gen_IS_losses),
                'gen_SI_loss': sum(gen_SI_losses)/len(gen_SI_losses),
                'disc_I_loss': sum(disc_I_losses)/len(disc_I_losses),
                'disc_S_loss': sum(disc_S_losses)/len(disc_S_losses),
                'segmentation_loss': sum(segmentation_losses)/len(segmentation_losses),
                'reconstruction_loss': sum(reconstruction_losses)/len(reconstruction_losses),
               }
        return out
    
    
    def test_epoch(self, test_dataloader):
        self.model.eval()
        metrics = []
        for batch in tqdm(test_dataloader):
            patch_loader = batch["patch_loader"]
            grid_aggregator = batch["grid_aggregator"]
            GT = batch["GT"]
            sample_name = batch["sample_name"]
            head_seg = self.fast_predict(patch_loader, grid_aggregator)
            metric = self.metric_fn(GT.data, head_seg)
            metrics.append({"sample" : sample_name,
                            "seg_sum/GT_sum" : head_seg.sum()/GT.data.sum()+0.000001,
                            "metric1" : metric})
            
        return {'metrics': metrics}
    
    def fast_predict(self, patch_loader, grid_aggregator, thresh=0.5):
        for patches_batch in patch_loader:
            patch_locations = patches_batch[tio.LOCATION]
            head_patches = patches_batch['head']['data'].to(self.device)
            with torch.no_grad():
                patch_seg = self.model.gen_IS(head_patches)
                grid_aggregator.add_batch(patch_seg.cpu(), patch_locations)
        seg = grid_aggregator.get_output_tensor()
        seg[seg<thresh]=0
        seg[seg>0]=1
        return(seg)
    

In [7]:
config = {"device": "cuda",
          "otimizers_settings":{
            "gen_IS_opt" : lambda model: torch.optim.Adam(model.parameters(), lr=5e-2, betas=(0.5, 0.9)),
            "gen_SI_opt" : lambda model: torch.optim.Adam(model.parameters(), lr=5e-2, betas=(0.5, 0.9)),
            "disc_I_opt" : lambda model: torch.optim.Adam(model.parameters(), lr=2e-3, betas=(0.5, 0.9)),
            "disc_S_opt" : lambda model: torch.optim.Adam(model.parameters(), lr=2e-3, betas=(0.5, 0.9)),
            "sheduler_fn": lambda optimizer: ExponentialLR(optimizer, 0.98)
            },
          "model": model,
          "losses":{
            "cycle_loss_fn": CycleLoss(),
            "reconstruction_loss_fn": ReconstructionLoss(),
            "segmentation_loss_fn": SegmentationLoss(),
            "discriminator_loss_fn": DiscriminatorLoss(),
            "generator_loss_fn": GeneratorLoss(),
            "cycle_lambda" : 10, 
            "identity_lambda" : 5,
            }
          }
vg_controller = VG_Controller(config)

In [8]:
vg_controller.fit(dataset, 50)

Epoch 1/50


100%|███████████████████████████████████████████| 64/64 [00:36<00:00,  1.75it/s]


{'gen_IS_loss': 0.2923352918587625, 'gen_SI_loss': 0.2953451885841787, 'disc_I_loss': 0.13579580781515688, 'disc_S_loss': 0.21238530334085226, 'segmentation_loss': 0.015306337736546993, 'reconstruction_loss': 0.03378802476237297}


100%|█████████████████████████████████████████████| 2/2 [00:24<00:00, 12.15s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(194.4020), 'metric1': tensor([0.0051])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(210.8595), 'metric1': tensor([0.0047])}]}
Epoch 2/50


100%|███████████████████████████████████████████| 64/64 [00:38<00:00,  1.67it/s]


{'gen_IS_loss': 0.3649800890125334, 'gen_SI_loss': 0.5539629943668842, 'disc_I_loss': 0.04123399007949047, 'disc_S_loss': 0.14607500151032582, 'segmentation_loss': 0.0, 'reconstruction_loss': 0.017128313956163765}


100%|█████████████████████████████████████████████| 2/2 [00:23<00:00, 11.59s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(194.4020), 'metric1': tensor([0.0051])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(210.8595), 'metric1': tensor([0.0047])}]}
Epoch 3/50


100%|███████████████████████████████████████████| 64/64 [00:37<00:00,  1.70it/s]


{'gen_IS_loss': 0.5056046452373266, 'gen_SI_loss': 0.8282253611832857, 'disc_I_loss': 0.007073982458678074, 'disc_S_loss': 0.13542151456931606, 'segmentation_loss': 0.0, 'reconstruction_loss': 0.02023580162460803}


100%|█████████████████████████████████████████████| 2/2 [00:23<00:00, 11.76s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(194.4020), 'metric1': tensor([0.0051])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(210.8595), 'metric1': tensor([0.0047])}]}
Epoch 4/50


100%|███████████████████████████████████████████| 64/64 [00:37<00:00,  1.71it/s]


{'gen_IS_loss': 0.512560420203954, 'gen_SI_loss': 0.905830075033009, 'disc_I_loss': 0.0018229260113002965, 'disc_S_loss': 0.14924840751336887, 'segmentation_loss': 0.0, 'reconstruction_loss': 0.025686402682225662}


100%|█████████████████████████████████████████████| 2/2 [00:23<00:00, 11.72s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(194.4020), 'metric1': tensor([0.0051])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(210.8595), 'metric1': tensor([0.0047])}]}
Epoch 5/50


100%|███████████████████████████████████████████| 64/64 [00:37<00:00,  1.69it/s]


{'gen_IS_loss': 0.5143529488705099, 'gen_SI_loss': 0.934304446913302, 'disc_I_loss': 0.0014057661892366013, 'disc_S_loss': 0.14690564398188144, 'segmentation_loss': 0.0, 'reconstruction_loss': 0.024561628453056983}


100%|█████████████████████████████████████████████| 2/2 [00:23<00:00, 11.63s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(194.4020), 'metric1': tensor([0.0051])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(210.8595), 'metric1': tensor([0.0047])}]}
Epoch 6/50


100%|███████████████████████████████████████████| 64/64 [00:37<00:00,  1.69it/s]


{'gen_IS_loss': 0.49681749008595943, 'gen_SI_loss': 0.9650933062657714, 'disc_I_loss': 0.00023165656375567778, 'disc_S_loss': 0.13772882346529514, 'segmentation_loss': 0.0, 'reconstruction_loss': 0.016847658077779215}


100%|█████████████████████████████████████████████| 2/2 [00:22<00:00, 11.49s/it]


{'metrics': [{'sample': 'P62_CTA_0', 'seg_sum/GT_sum': tensor(194.4020), 'metric1': tensor([0.0051])}, {'sample': 'new_CTA_0', 'seg_sum/GT_sum': tensor(210.8595), 'metric1': tensor([0.0047])}]}
Epoch 7/50


100%|███████████████████████████████████████████| 64/64 [00:37<00:00,  1.69it/s]


{'gen_IS_loss': 0.5187529241666198, 'gen_SI_loss': 0.9593536015599966, 'disc_I_loss': 0.00456753137564192, 'disc_S_loss': 0.14310877659590915, 'segmentation_loss': 0.0, 'reconstruction_loss': 0.027155279101805263}


  0%|                                                     | 0/2 [00:04<?, ?it/s]


KeyboardInterrupt: 