In [1]:
# '''Input Libraries'''
# import os
# import torch
# import sklearn
# import numpy as np
# import torch.nn.functional as F
# import torch.utils.data as data_utils

# from tqdm import tqdm
# from functions import mrf as mrf
# from functions import models as m
# from functions import dataset as data
# from functions import training_tools as tt
# from functions.visualization import argmax_ch
# from functions.parser import train_parser

In [2]:
# import matplotlib
# # matplotlib.use('agg')
# import matplotlib.pyplot as plt
# from functions import visualization as vis

## Load Data

In [3]:
import sae.functions.dataset
vols_path = './sae/data/vols/'
aseg_path = './sae/data/labels/'
train_set = sae.functions.dataset.load_bucker_data(vols_path,aseg_path)

In [4]:
mri, aseg, onehot, _  = train_set[0] 
print('mri :', mri.shape)
print('aseg :', aseg.shape)
print('onehot :', onehot.shape)

mri : torch.Size([1, 1, 160, 192, 224])
aseg : torch.Size([160, 192, 224])
onehot : torch.Size([1, 14, 160, 192, 224])


In [5]:
# onehot[0, :, 70, 70, 70]

In [6]:
# aseg[70, 70, 70]

In [7]:
# index = 16

# figure = matplotlib.pyplot.figure()
# matplotlib.pyplot.subplot(1, 2, 1)
# _ = matplotlib.pyplot.imshow(mri[0, 0, index])
# matplotlib.pyplot.subplot(1, 2, 2)
# _ = matplotlib.pyplot.imshow(aseg[index])
# figure.show()

## Choose template

In [8]:
atlas_path = './sae/data/prob_atlas.npz'

In [9]:
# atlas = torch.from_numpy(np.load((atlas_path))['vol_data']).float()

In [10]:
# atlas.size()

In [11]:
# atlas[70, 70, 70]

In [12]:
import sae.functions.dataset
template = sae.functions.dataset.get_prob_atlas(atlas_path)
print(template.shape)
chs = template.shape[1]
dim1 = template.shape[2]
dim2 = template.shape[3]
dim3 = template.shape[4]
print(chs, dim1, dim2, dim3)

torch.Size([1, 14, 160, 192, 224])
14 160 192 224


In [18]:
argm_ch = sae.functions.visualization.argmax_ch(template)

In [19]:
argm_ch[0, :, 0, 0, 0]

tensor([ True, False, False, False, False, False, False, False, False, False,
        False, False, False, False])

In [20]:
template[0, :, 0, 0, 0]

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [13]:
# template[0, :, 70, 70, 70]
import sys
sys.path.append('./sae')

In [21]:
import torch
import torch.nn.functional
import torch.optim

import numpy

import sae.functions.mrf
import sae.functions.visualization
import sae.functions.training_tools

import wrapper.mrf



class SAELoss:

    """
    Warnings:
        - Running var must be clear to each epoch's end with clear_running_var()
    """

    def __init__(self,
        sigma: float,
        prior: torch.Tensor,
        alpha: float = 1.0, 
        beta: float = 0.01, 
        eps: float = 1e-12,
        k: int = 3,
        var: float = 1e8
    ) -> None:
        
        self.prior = prior
        self.log_prior = torch.log(
            sae.functions.training_tools.normalize_dim1(prior+eps)
        ).detach()
        print('log_prior:', prior.size())
        # log_prior: torch.Size([1, 14, 160, 192, 224])
        
        self.alpha = alpha
        self.beta = beta
        self.sigma = sigma
        self.eps = eps
        self.k = k

        self.var = var

        self.lookup = None
        if self.beta != 0:
            argm_ch = sae.functions.visualization.argmax_ch(self.prior)
            argm_ch = argm_ch.type(torch.uint8)
            # argm_ch : torch.Size([1, 14, 160, 192, 224])
            self.lookup = sae.functions.mrf.get_lookup(
                prior = argm_ch,
                neighboor_size = self.k
            )

            print('argm_ch :', argm_ch.size())
            print('lookup :', self.lookup)

        self.running_var = []


    def __call__(self,
        x: torch.Tensor,
        logits: torch.Tensor,
        recon: torch.Tensor
    ) -> torch.Tensor:
        prior_loss = self.compute_prior_loss(logits)
        recon_loss = self.compute_recon_loss(x, recon)
        consistent = self.compute_consistent(logits)
        return prior_loss + recon_loss + consistent

    def compute_prior_loss(self, logits: torch.Tensor) -> torch.Tensor:

        log_pi = torch.nn.functional.log_softmax(logits, 1)
        pi = torch.exp(log_pi)
        
        cce = -1*torch.sum(pi*self.log_prior,1)      #cross entropy
        cce = torch.sum(cce,(1,2,3))            #cce over all the dims
        cce = cce.mean()               
            
        h = -1*torch.sum(pi*log_pi,1)
        h = torch.sum(h,(1,2,3))
        h = h.mean()
 
        prior_loss = cce - h

        return prior_loss
    
    def compute_consistent(self, logits: torch.Tensor) -> torch.Tensor:
        
        log_pi = torch.nn.functional.log_softmax(logits, 1)
        pi = torch.exp(log_pi)
        
        if self.beta != 0: # ie not(self.lookup is None)
            consistent = self.beta*wrapper.mrf.spatial_consistency(
                input = pi,
                table = self.lookup,
                neighboor_size = self.k
            )
        else:
            consistent = torch.zeros(1, device=logits.device)
        
        return consistent
    
    def compute_recon_loss(self, 
        x: torch.Tensor, 
        recon: torch.Tensor
    ) -> torch.Tensor:

        if self.sigma == 0:
            mse = (recon-x.detach())**2  #mse
            mse = torch.sum(mse,(1,2,3,4))    #mse over all dims
            mse = mse.mean()                  #avarage over all batches
            recon_loss = self.alpha * mse 
        elif self.sigma == 2:
            mse = (recon-x.detach())**2
            rounded_var = 10**numpy.round(numpy.log10(self.var))

            # Weight Reconstruction loss
            mse = numpy.clip(0.5*(1/(rounded_var)),0, 500) * mse
            mse = torch.sum(mse,(1,2,3,4))    #mse over all dims
            mse = mse.mean()                  #avarage over all batches

            self.running_var.append(mse.detach().mean().item())

            # Since args.var is a scalar now, we need to account for
            # the fact that we doing log det of a matrix
            # Therefore, we multiply by the dimension of the image

            c = dim1*dim2*dim3 #chs is 1 for image

            _var = torch.from_numpy(numpy.array(self.var+self.eps)).float()
            recon_loss = mse + 0.5 * c * torch.log(_var)
        else:
            raise AssertionError('sigma must be 0 or 2')
        
        return recon_loss
    
    def normalize_dim1(self, x):
        '''
        Ensure that dim1 sums up to one for proper probabilistic interpretation
        '''
        normalizer = torch.sum(x, dim=1, keepdim=True)
        return x/normalizer
    
    def update_variance(self) -> None:
        self.var = numpy.mean(self.running_var)

    def clear_running_var(self) -> None:
        """ Running var must be clear to each end epoch
        """
        self.running_var.clear()    

In [22]:
sae_loss = SAELoss(sigma=0, prior=template)

log_prior: torch.Size([1, 14, 160, 192, 224])
argm_ch : torch.Size([1, 14, 160, 192, 224])
lookup : [[9.93031455e-01 0.00000000e+00 1.36885061e-07 0.00000000e+00
  1.54816243e-03 5.32330791e-08 9.12567071e-08 0.00000000e+00
  1.36124588e-06 3.38904596e-04 0.00000000e+00 1.52094512e-08
  2.40081187e-05 5.05581169e-03]
 [0.00000000e+00 8.29217037e-01 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 6.20877635e-02
  1.08695200e-01 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00]
 [1.86254424e-04 0.00000000e+00 8.33550630e-01 0.00000000e+00
  4.40905610e-02 3.53262557e-02 0.00000000e+00 0.00000000e+00
  4.16278637e-02 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 4.52184350e-02]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 8.32721867e-01
  0.00000000e+00 0.00000000e+00 1.03891362e-03 1.36927945e-03
  8.17003408e-02 0.00000000e+00 7.29456461e-02 0.00000000e+00
  0.00000000e+00 1.02239533e-02]
 [1.05162970e-02 0.00000

In [37]:

import wrapper.models
import sae.functions.models

class SegmentationAutoEncoder(torch.torch.nn.Module):

    def __init__(self, 
        in_channels: int,
        out_channels: int, 
        latent_dim: int
    ) -> None:

        """
        Params:
            - in_channels : nb_channels of image to segmentation
            - out_channels : nb_channels of segmented image
            - latent_dim : ch
        """
        
        super(SegmentationAutoEncoder, self).__init__()
        
        # Encoder
        enc_nf = [4, 8, 16, 32]
        dec_nf = [32, 16, 8, 4]
        # self.encoder = sae.functions.models.Simple_Unet(
        self.encoder = sae.functions.models.Simple_Unet(
            input_ch = in_channels,
            out_ch = latent_dim,
            use_bn = False,
            enc_nf = enc_nf,
            dec_nf = dec_nf
        )

        # summary = torch.load(
        #     f = './weights/pretrained_encoder.pth.tar',
        #     map_location=torch.device('cpu')
        # )                        
        # _ = self.encoder.load_state_dict(
        #     summary['u1']
        # ) 

        # Decoder
        # self.decoder = sae.functions.models.Simple_Decoder(
        self.decoder = sae.functions.models.Simple_Decoder(
            input_ch = latent_dim,
            out_ch = out_channels 
        )


    def forward(self, 
        x: torch.Tensor, 
        prior, 
        tau: float, 
    ) -> torch.Tensor:

        out = self.encoder(x)
        # out = functions.models.enforcer(prior, out)
        out = wrapper.models.enforcer(prior, out)
        n_batch, chs, dim1, dim2, dim3 = out.size()
        logits = out
        out = out.permute(0, 2, 3, 4, 1)
        out = out.view(n_batch, dim1*dim2*dim3, chs)
        # pred = functions.training_tools.gumbel_softmax(out, tau)
        pred = wrapper.training_tools.gumbel_softmax(out, tau)
        pred = pred.view(n_batch, dim1, dim2, dim3, chs)
        pred = pred.permute(0, 4, 1, 2, 3)

        recon = self.decoder(pred)

        return logits, recon if self.training else recon
    
    
    # To add image in string doc
    ## ![Alt text](https://img.freepik.com/vecteurs-libre/vecteur-vintage-pack-ornements-arrondis_23-2147505286.jpg?size=626&ext=jpg "a title")
    def load_encoder_from(self, pth_file_loaded) -> None:
        """
        You can load a file.pth with:
        ```py
        torch.load(
            f = path/of/your/file.pth,
            map_location = torch.device(...)
        )
        ```

        Encoder weights must be trained in first time to map
        a template (too called prior). To make a template, you can generate your
        image n times and compute number times that a class is
        predict.

        ..
            \int_{0}^{1} f(x) dx

        """
                      
        _ = self.encoder.load_state_dict(pth_file_loaded)

In [38]:
sae_model = SegmentationAutoEncoder(
    in_channels = 1, 
    out_channels = 1, 
    latent_dim = chs
)

In [39]:
x, _, _, _ = train_set[0]
x.shape

torch.Size([1, 1, 160, 192, 224])

In [40]:
out = sae_model(x, template, 2/3)

In [41]:
logits, recon = out
print('logit :', logits.size())
print('recon :', recon.size())

logit : torch.Size([1, 14, 160, 192, 224])
recon : torch.Size([1, 1, 160, 192, 224])


In [None]:
x_loss = sae_loss(x, logits, recon)

In [None]:
x_loss.item()

In [None]:
atlas_path = './data/prob_atlas.npz'
template = data.get_prob_atlas(atlas_path)
sae_loss = SAELoss(sigma=0, prior=template)
# torch.uint8

In [None]:
def create_train_step_for_segmentation_auto_encoder_model(
    model: SegmentationAutoEncoder,
    optimizer: torch.optim.Optimizer,
    criterion: SAELoss
):
    """
    Params:
        - model : Segmentation Auto Encodeur (SAE)
        - optimizer : optimizer
        - criterion : SAEloss
    """  

    def train_step(engine, batch) -> None:
        
        model.train()
        optimizer.zero_grad()
        
        # Batch processing
        batch_loss = 0
        size_of_batch = batch.size()[0]
        # predictions = []
        acc_logits = []
        acc_recons = []
        for i in range(0, size_of_batch):
            x, _, _, _ = batch[i]
            logits, recon = model(x)
            loss = criterion(x, logits, recon)
            batch_loss += loss.item()
            loss.backward()
            acc_logits.append(logits)
            acc_recons.append(recon)

            # predictions.append(prediction.unsqueeze(0))
            # loss: torch.Tensor = criterion(prediction, results[i])
            # batch_loss += loss.item()
            # loss.backward()

        optimizer.step()

        batch_loss /= size_of_batch

        output = {
            'loss' : batch_loss,
            'logits' : acc_logits,
            'recons' : acc_recons
        }

        return output
    
    return train_step

In [None]:
def create_train_step_for_unfolding_model(
    model: SegmentationAutoEncoder,
    optimizer: torch.optim.Optimizer,
    criterion: SAELoss
) -> tuple:
    pass

In [None]:
import pathlib

import functions.dataset


def from_sae_config(config: dict) -> tuple:

    template = functions.dataset.get_prob_atlas(
        path = pathlib.Path(config['dataset']['template'])
    )

    model = SegmentationAutoEncoder(
        in_channels = config['model']['in_channels'],
        out_channels = config['model']['out_channels'],
        latent_dim = template.shape[1]
        # latent_dim = config['model']['latent_dim']

    )

    optimizer = torch.optim.Adam(
        params = model.parameters(),
        lr = config['train']['learning_rate']
    )

    

    criterion = SAELoss(
       sigma = config['train']['loss']['sigma'],
       prior = template,
       alpha = config['train']['loss'].get('alpha', 1.0),
       beta = config['train']['loss'].get('beta', 0.01),
       eps = config['train']['loss'].get('eps', 1e-12),
       k = config['train']['loss'].get('eps', 1e-12),
       var = config['train']['loss'].get('var', 1e8)
    )

    return model, optimizer, criterion


def from_unfolding_config(config: dict) -> tuple:
    pass


unfolding_config: dict = {}
sae_config : dict = {}

unfold_model, unfold_optim, unfold_criterion = \
        from_unfolding_config(unfolding_config)
    
sae_model, sae_optim, sae_criterion = \
    from_sae_config(sae_config)

unfold_train_step = \
    create_train_step_for_unfolding_model(
        model=unfold_model,
        optimizer=unfold_optim,
        criterion=unfold_criterion
    )
            

sae_train_step = \
    create_train_step_for_segmentation_auto_encoder_model(
        model=sae_model,
        optimizer=sae_optim,
        criterion=sae_criterion
    )



In [None]:
import typing

def create_train_step(
   unfolding_train_step: typing.Callable,
   sae_train_step: typing.Callable
) -> typing.Callable :
    
    def train_step(engine, batch) -> dict:

        res_unfolding = unfolding_train_step(engine, batch)
        res_sae = sae_train_step(engine, res_unfolding['predictions'])

        output = {
            'unfolding' : res_unfolding,
            'sae' : res_sae
        }

        return output
    
    return train_step

In [2]:
import os
import torch
import sklearn
import numpy as np
import torch.nn.functional as F
import torch.utils.data as data_utils

from tqdm import tqdm
from functions import mrf as mrf
from functions import models as m
from functions import dataset as data
from functions import training_tools as tt
from functions.visualization import argmax_ch
from functions.parser import train_parser

In [3]:
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from functions import visualization as vis


        
"""Load Data"""
vols_path = './data/vols/'
aseg_path = './data/labels/'
train_set = data.load_bucker_data(vols_path,
                                    aseg_path)

"""Choose template"""    
atlas_path = './data/prob_atlas.npz'
template = data.get_prob_atlas(atlas_path)
chs = template.shape[1]
dim1 = template.shape[2]
dim2 = template.shape[3]
dim3 = template.shape[4]


        


"""Making Model"""
enc_nf = [4, 8, 16, 32]
dec_nf = [32, 16, 8, 4]

# Encoder
u1 = m.Simple_Unet(input_ch=1,
                    out_ch=chs,
                    use_bn= False,
                    enc_nf= enc_nf,
                    dec_nf= dec_nf)
# u1 = torch.nn.DataParallel(u1)
# u1.cuda()

# Decoder
u2 = m.Simple_Decoder(chs, 1)
# u2 = torch.nn.DataParallel(u2)
# u2.cuda()


"""Pretrained Model"""
# In order obtain good initialization, the encoder was pretrained by
# mapping the training data to the probabilistic template
print('============ Loading pretrained weight for enc and dec ============')
summary = torch.load(
    './weights/pretrained_encoder.pth.tar',
    map_location=torch.device('cpu')
)                        
u1.load_state_dict(summary['u1']) 



RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.