In [1]:
import torch
import torch.nn

import numpy


On suppose que l'on ait nos images super-résolue.

Tasks:

- Nous appliquons le k-means sur nos images super-résolue
    - Nous devons:
        - récupérer les clusters
        - l'image segmenté comme avant

- Nous calculons la probabilité qu'un pixel appartienne à un cluster
    - $d(x, C_{i}) = min ~~ \{ ~~ d(x, y) ~~\textbf{|}~~ \forall y \in C_{i} ~~ \}$
    - $P(x \notin C_{i}) = \frac{d(x, C_{i})}{\sum\limits_{k=0}^{N} d(x, C_{k})}$
    - $P(x \in C_{i}) = 1 - P(x \notin C_{i})$

In [2]:
def distance_cluster(
    pixel_value: float, 
    cluster: numpy.ndarray
) -> numpy.ndarray:
    return numpy.min(numpy.abs(pixel_value-cluster))
    
def probability_clusters(
    pixel_value: float, 
    clusters: numpy.ndarray
) -> numpy.ndarray:

    distances_clusters = numpy.array(
        [
            distance_cluster(pixel_value, cluster)
            for cluster in clusters
        ]
    )

    total = distances_clusters.sum()

    proba_not_in = numpy.array(
        [
            dist_cluster / total 
            for dist_cluster in distances_clusters 
        ]
    )

    proba_in = 1 - proba_not_in

    return proba_in

def probability_map(
    img: numpy.ndarray, 
    clusters: numpy.ndarray
) -> numpy.ndarray:
    
    n, m = img.shape
    nb_classes = clusters.shape[0]
    img_prob = numpy.zeros(shape=(nb_classes, n, m))

    for i in range(0, n):
        for j in range(0, m):
            img_prob[:, i, j] = probability_clusters(
                pixel_value = img[i, j],
                clusters = clusters 
            )

    return img_prob

def likely_probable(
    proba_map: numpy.ndarray
) -> numpy.ndarray:
    
    nb_classes, n, m = proba_map.shape
    idx_lp = numpy.zeros_like(proba_map)

    for i in range(0, n):
        for j in range(0, m):
            idx_lp[:, i, j] = numpy.where(
                proba_map[:, i, j] == proba_map[:, i, j].max()
            )

    return idx_lp

def index_likely_probable(
    proba_map: torch.Tensor
) -> torch.Tensor:
    
    _, nb_classes, n, m = proba_map.size()
    idx_lp = torch.zeros_like(proba_map, dtype=torch.bool)

    for i in range(0, n):
        for j in range(0, m):
            idx_lp[:, :, i, j] = torch.where(
                proba_map[:, :, i, j] == proba_map[:, :, i, j].max(),
                True,
                False
            )

    return idx_lp

In [3]:
# template[0, :, 70, 70, 70]
# import sys
# sys.path.append('./sae')

In [4]:
# import torch
# import torch.nn.functional
# import torch.optim

# import numpy

# import sae.functions.mrf
# import sae.functions.visualization
# import sae.functions.training_tools

# import wrapper.mrf
# import wrapper2D.mrf

# class SAELoss2D:

#     """
#     Warnings:
#         - Running var must be clear to each epoch's end with clear_running_var()
#     """

#     def __init__(self,
#         sigma: float,
#         alpha: float = 1.0, 
#         beta: float = 0.01, 
#         eps: float = 1e-12,
#         k: int = 3,
#         var: float = 1e8
#     ) -> None:
        
#         self.alpha = alpha
#         self.beta = beta
#         self.sigma = sigma
#         self.eps = eps
#         self.k = k

#         self.var = var

#         # self.lookup = None
#         # if self.beta != 0:
#         #     argm_ch = sae.functions.visualization.argmax_ch(self.prior)
#         #     argm_ch = argm_ch.type(torch.uint8)
#         #     self.lookup = sae.functions.mrf.get_lookup(
#         #         prior = argm_ch,
#         #         neighboor_size = self.k
#         #     )

#         self.running_var = []

#     def __call__(self,
#         x: torch.Tensor,
#         proba_map: torch.Tensor,
#         logits: torch.Tensor,
#         recon: torch.Tensor
#     ) -> torch.Tensor:
        
 
#         prior = proba_map # ie template
        
#         log_prior = torch.log(
#             sae.functions.training_tools.normalize_dim1(
#                 prior+self.eps
#             )
#         ).detach()

#         lookup = None
#         if self.beta != 0:
#             argm_ch = index_likely_probable(prior)
#             argm_ch = argm_ch.type(torch.uint8)
#             # print(argm_ch)
#             lookup = wrapper2D.mrf.get_lookup(
#                 prior = argm_ch,
#                 neighboor_size = self.k
#             )
        
        
#         prior_loss = self.compute_prior_loss(logits, log_prior)
#         recon_loss = self.compute_recon_loss(x, recon)
#         consistent = self.compute_consistent(logits, lookup)

#         return prior_loss + recon_loss + consistent
    
#     def compute_prior_loss(self, 
#         logits: torch.Tensor,
#         log_prior: torch.Tensor
#     ) -> torch.Tensor:

#         log_pi = torch.nn.functional.log_softmax(logits, 1)
#         pi = torch.exp(log_pi)
        
#         cce = -1*torch.sum(pi*log_prior,1)      #cross entropy
#         cce = torch.sum(cce,(1,2))            #cce over all the dims
#         cce = cce.mean()               
            
#         h = -1*torch.sum(pi*log_pi,1)
#         h = torch.sum(h,(1,2))
#         h = h.mean()

#         prior_loss = cce - h

#         return prior_loss
    
#     def compute_consistent(self, 
#         logits: torch.Tensor,
#         lookup: torch.Tensor
#     ) -> torch.Tensor:
        
#         if self.beta != 0: # ie not(self.lookup is None)
#             log_pi = torch.nn.functional.log_softmax(logits, 1)
#             pi = torch.exp(log_pi)
#             consistent = self.beta*wrapper2D.mrf.spatial_consistency(
#                 inumpyut = pi,
#                 table = lookup,
#                 neighboor_size = self.k
#             )
#         else:
#             consistent = torch.zeros(1, device=logits.device)
        
#         return consistent
    
#     def compute_recon_loss(self, 
#         x: torch.Tensor, 
#         recon: torch.Tensor
#     ) -> torch.Tensor:
        
#         _, _, dim1, dim2 = x.size()
        
#         if self.sigma == 0:
            
#             mse = (recon-x.detach())**2  #mse
#             mse = torch.sum(mse,(1,2))    #mse over all dims
#             mse = mse.mean()                  #avarage over all batches
#             recon_loss = self.alpha * mse 
        
#         elif self.sigma == 2:

#             # Estimated Variance
#             mse = (recon-x.detach())**2
#             self.running_var.append(mse.detach().mean().item())

#             rounded_var = 10**numpy.round(numpy.log10(self.var))

#             # Weight Reconstruction loss
#             mse = numpy.clip(0.5*(1/(rounded_var)),0, 500) * mse
#             mse = torch.sum(mse,(1,2))    #mse over all dims
#             mse = mse.mean()                  #avarage over all batches

#             # Since args.var is a scalar now, we need to account for
#             # the fact that we doing log det of a matrix
#             # Therefore, we multiply by the dimension of the image

#             c = dim1*dim2 #chs is 1 for image

#             _var = torch.from_numpy(numpy.array(self.var+self.eps)).float()
#             recon_loss = mse + 0.5 * c * torch.log(_var)

#         else:

#             raise AssertionError('sigma must be 0 or 2')
        
#         return recon_loss
    
#     def update_variance(self) -> None:
#         self.var = numpy.mean(self.running_var)

#     def clear_running_var(self) -> None:
#         """ Running var must be clear to each end epoch
#         """
#         self.running_var.clear()

In [5]:
NB_ROWS, NB_COLS= 512, 512
NB_CLASSES = 2 

img = torch.zeros(size=(1, 1, NB_ROWS, NB_COLS))
prior = torch.zeros(size=(1, NB_CLASSES, NB_ROWS, NB_COLS))
prior[0, 0, :, :] = torch.rand(size=(NB_ROWS, NB_COLS))
prior[0, 1, :, :] = 1 - prior[0, 0, :, :]

In [6]:
import sys
sys.path.append('./sae')
import wrapper2D.defineme

In [7]:
model = wrapper2D.defineme.SegmentationAutoEncoder(
    in_channels=1,
    out_channels=1,
    latent_dim=NB_CLASSES
)

In [8]:
out = model(img, prior, 2/3)

In [9]:
recon, logits = out
print('logit :', logits.size())
print('recon :', recon.size())

logit : torch.Size([1, 2, 512, 512])
recon : torch.Size([1, 1, 512, 512])


In [10]:
# import wrapper2D.defineme
sae_loss = wrapper2D.defineme.SAELoss2D(sigma=2)

In [11]:
loss = sae_loss(
    x=img,
    proba_map=prior,
    logits=logits,
    recon=recon
)

In [12]:
loss.item()

2510721.0

In [13]:
# def create_train_step(model: SegmentationAutoEncoder, criterion: SA)