In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch import nn
import monai
from monai.networks import nets as MNets
from monai.networks import blocks as MBlocks
import sys
import pytorch_lightning as pl
sys.path.append('/mnt/raid/C1_ML_Analysis/source/blender/famli-ultra-sim/dl/')

from nets.cut_D import Discriminator
from nets.cut_G import Generator
from nets.cut_P import Head
from generative.networks.nets import PatchDiscriminator

from nets.lotus import UltrasoundRendering, UltrasoundRenderingLinear, UltrasoundRenderingConv1d
import numpy as np 
from generative.losses import PatchAdversarialLoss, PerceptualLoss
from torchvision import transforms as T
class testNCE(nn.Module):
    def __init__(self):
        super().__init__()
        
        # self.E = nets.EfficientNetBN("efficientnet-b0", spatial_dims=2)
        self.G = Generator()

    def calculate_NCE_loss(self, src, tgt):
        feat_q, patch_ids_q = self.e(tgt, encode_only=True)
        feat_k, _ = self.E(src, encode_only=True, patch_ids=patch_ids_q)

        feat_k_pool = self.H(feat_k)
        feat_q_pool = self.H(feat_q)

        total_nce_loss = 0.0
        for f_q, f_k in zip(feat_q_pool, feat_k_pool):
            loss = self.patch_nce_loss(f_q, f_k)
            total_nce_loss += loss.mean()
        return total_nce_loss / 5

    def patch_nce_loss(self, feat_q, feat_k):
        feat_k = feat_k.detach()
        out = torch.mm(feat_q, feat_k.transpose(1, 0)) / 0.07
        loss = self.cross_entropy_loss(out, torch.arange(0, out.size(0), dtype=torch.long, device=self.device))
        return loss

The 'neptune-client' package has been deprecated and will be removed in the future. Install the 'neptune' package instead. For more, see https://docs.neptune.ai/setup/upgrading/
You're importing the Neptune client library via the deprecated `neptune.new` module, which will be removed in a future release. Import directly from `neptune` instead.


In [2]:
model = testNCE()

In [3]:
feat_q, patch_ids_q = model.G(torch.rand(4, 1, 256, 256), encode_only=True)
# feat_k, _ = self.G(src, encode_only=True, patch_ids=patch_ids_q)

In [4]:
for f in feat_q:
    print(f.shape)

torch.Size([1024, 1])
torch.Size([1024, 128])
torch.Size([1024, 256])
torch.Size([1024, 256])
torch.Size([1024, 256])


In [5]:
from generative.networks.nets import SPADEAutoencoderKL
from generative.networks.nets import SPADEDiffusionModelUNet

In [6]:
autoencoder = SPADEAutoencoderKL(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_res_blocks=(2, 2, 2, 2),
    num_channels=(8, 16, 32, 64),
    attention_levels=[False, False, False, False],
    latent_channels=8,
    norm_num_groups=8,
    label_nc=6,
)

In [7]:
feat = autoencoder.encode(torch.rand(2, 1, 256, 256))

In [8]:
x = torch.rand(2, 1, 256, 256)
for i, l in enumerate(autoencoder.encoder.blocks):
    print(i, x.shape)
    x = l(x)
print(x.shape)

0 torch.Size([2, 1, 256, 256])
1 torch.Size([2, 8, 256, 256])
2 torch.Size([2, 8, 256, 256])
3 torch.Size([2, 8, 256, 256])
4 torch.Size([2, 8, 128, 128])
5 torch.Size([2, 16, 128, 128])
6 torch.Size([2, 16, 128, 128])
7 torch.Size([2, 16, 64, 64])
8 torch.Size([2, 32, 64, 64])
9 torch.Size([2, 32, 64, 64])
10 torch.Size([2, 32, 32, 32])
11 torch.Size([2, 64, 32, 32])
12 torch.Size([2, 64, 32, 32])
13 torch.Size([2, 64, 32, 32])
14 torch.Size([2, 64, 32, 32])
15 torch.Size([2, 64, 32, 32])
16 torch.Size([2, 64, 32, 32])
torch.Size([2, 8, 32, 32])


In [9]:
def encode_nce(x, blocks, patch_ids=None, blocks_ids=[0, 4, 7, 10, 14, 17], num_patches = 256):
    feat = x
    return_ids = []
    return_feats = []
    p_id = 0
    for block_id, block in enumerate(blocks):
        feat = block(feat)
        if block_id in blocks_ids:
            B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
            feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
            if patch_ids is not None:
                patch_id = patch_ids[p_id]
                p_id += 1
            else:
                patch_id = torch.randperm(feat_reshape.shape[1]) #, device=config.DEVICE
                patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
                return_ids.append(patch_id)
            x_sample = feat_reshape[:, patch_id, :].flatten(0, 1)  # reshape(-1, x.shape[1])

            return_feats.append(x_sample)
            
    return return_feats, return_ids, feat

In [10]:
feat_q, patch_ids_q, h =  encode_nce(torch.rand(4, 1, 256, 256), autoencoder.encoder.blocks)

In [11]:
for f in feat_q:
    print(f.shape)

torch.Size([1024, 8])
torch.Size([1024, 16])
torch.Size([1024, 32])
torch.Size([1024, 64])
torch.Size([1024, 64])


In [12]:
feat_k, _, _ =  encode_nce(torch.rand(4, 1, 256, 256), autoencoder.encoder.blocks, patch_ids=patch_ids_q)

In [13]:
for f in feat_k:
    print(f.shape)

torch.Size([1024, 8])
torch.Size([1024, 16])
torch.Size([1024, 32])
torch.Size([1024, 64])
torch.Size([1024, 64])


In [14]:
MBlocks.MLPBlock(hidden_size=128, mlp_dim=3)

MLPBlock(
  (linear1): Linear(in_features=128, out_features=3, bias=True)
  (linear2): Linear(in_features=3, out_features=128, bias=True)
  (fn): GELU(approximate='none')
  (drop1): Dropout(p=0.0, inplace=False)
  (drop2): Dropout(p=0.0, inplace=False)
)

In [15]:
class ProjectionHeads(nn.Module):
    def __init__(self, blocks, blocks_ids=[0, 4, 7, 10, 14, 17], in_shape=(1, 1, 128, 128)):
        super().__init__()
        
        self.blocks_ids = blocks_ids
        
        x = torch.rand(in_shape)
        
        for i, layer in enumerate(blocks):
            x = layer(x)
            if i in blocks_ids:

                mlp = nn.Sequential(*[
                    nn.Linear(x.shape[1], 256),
                    nn.ReLU(),
                    nn.Linear(256, 256)
                ])
                
                # print("ProjectionHeads {i}, {shape}".format(i=i, shape=x.shape))

                setattr(self, 'mlp_%d' % i, mlp)
                
        
    def forward(self, feats):
        
        return_feats = []
        
        for feat_id, feat in zip(self.blocks_ids, feats):
            mlp = getattr(self, 'mlp_%d' % feat_id)
            feat = mlp(feat)
            norm = feat.pow(2).sum(1, keepdim=True).pow(1. / 2)
            feat = feat.div(norm + 1e-7)
            return_feats.append(feat)
        return return_feats

In [16]:
H = ProjectionHeads(autoencoder.encoder.blocks)

In [17]:
from torch.nn.modules.loss import _Loss
from monai.utils import LossReduction

class NCELoss(_Loss):
    """
    Calculates the NCELoss

    Args:
        reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output.
        Defaults to ``"none"``.
        - ``"none"``: no reduction will be applied.
        - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
        - ``"sum"``: the output will be summed.
    """

    def __init__(
        self,
        reduction: LossReduction,
    ) -> None:
        super().__init__()

        self.criterion = torch.nn.CrossEntropyLoss(reduction=reduction)

    def forward(
        self, features_source: list, features_target: list
    ) -> torch.Tensor:
        """

        Args:
            source: output of a projection head
            target: output of a projection head
        """

        total_nce_loss = 0.0
        for feat_s, feat_t in zip(features_source, features_target):
            loss = self.patch_nce_loss(feat_s, feat_t)
            total_nce_loss += loss.mean()
        return total_nce_loss / 5

    def patch_nce_loss(self, feat_s, feat_t):
        feat_t = feat_t.detach()
        out = torch.mm(feat_s, feat_t.transpose(1, 0)) / 0.07
        loss = self.criterion(out, torch.arange(0, out.size(0), dtype=torch.long))
        return loss
    

In [19]:
nce_l = NCELoss(reduction='none')

feat_q, patch_ids_q, _ =  encode_nce(torch.rand(4, 1, 256, 256), autoencoder.encoder.blocks)
feat_k, _, _ =  encode_nce(torch.rand(4, 1, 256, 256), autoencoder.encoder.blocks, patch_ids=patch_ids_q)

feat_q_pool = H(feat_q)
feat_k_pool = H(feat_k)

nce_l(feat_q_pool, feat_k_pool)

tensor(7.3205, grad_fn=<DivBackward0>)

In [23]:
class SPADELotus(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()

        self.save_hyperparameters()
        
        # self.D = MNets.Discriminator(in_shape=(1, 128, 128))
        self.D = PatchDiscriminator(spatial_dims=2, num_layers_d=3, num_channels=16, in_channels=1, out_channels=1)

        self.USR = UltrasoundRendering(**kwargs)

        self.G = SPADEAutoencoderKL(
                spatial_dims=2,
                in_channels=1,
                out_channels=1,
                num_res_blocks=(2, 2, 2, 2),
                num_channels=(8, 16, 32, 64),
                attention_levels=[False, False, False, False],
                latent_channels=8,
                norm_num_groups=8,
                label_nc=self.hparams.num_labels,
            )
        
        self.H = ProjectionHeads(self.G.encoder.blocks)

        self.perceptual_loss = PerceptualLoss(spatial_dims=2, network_type="alex")
        self.adv_loss = PatchAdversarialLoss(criterion="least_squares")
        self.nce_loss = NCELoss(reduction='none')
        
        self.l1 = nn.L1Loss()
        # self.mse = nn.MSELoss()

        self.automatic_optimization = False

        self.transform_us = T.Compose([T.Pad((0, 80, 0, 0)), T.CenterCrop(256)])
        
        self.scaler_g = torch.cuda.amp.GradScaler()
        self.scaler_d = torch.cuda.amp.GradScaler()

    def configure_optimizers(self):
        opt_gen = optim.AdamW(
            list(self.USR.parameters()) + list(self.G.parameters()),
            # self.USR.parameters(),
            lr=self.hparams.lr,
            betas=self.hparams.betas,
            weight_decay=self.hparams.weight_decay            
        )
        opt_disc = optim.AdamW(
            self.D_Y.parameters(),
            lr=self.hparams.lr,
            betas=self.hparams.betas,
            weight_decay=self.hparams.weight_decay
        )        
        opt_head = optim.AdamW(
            self.H.parameters(),
            lr=self.hparams.lr,
            betas=self.hparams.betas,
            weight_decay=self.hparams.weight_decay
        )

        return [opt_gen, opt_disc, opt_head]
    
    # This is only called during inference time to set a custom grid
    def init_grid(self, w, h, center_x, center_y, r1, r2, theta):
        grid = self.USR.compute_grid(w, h, center_x, center_y, r1, r2, theta)
        inverse_grid, mask = self.USR.compute_grid_inverse(grid)
        
        self.USR.grid = self.USR.normalize_grid(grid)
        self.USR.inverse_grid = self.USR.normalize_grid(inverse_grid)
        self.USR.mask_fan = mask

    def forward(self, X):
        M = self.transform_us(self.USR.mask_fan)
        S = self.one_hot(self.transform_us(X))*M
        X = self.transform_us(self.USR(X))
        return self.G(X, S)*M

    # def scheduler_step(self):
    #     self.scheduler_disc.step()
    #     self.scheduler_gen.step()
    #     self.scheduler_mlp.step()

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def on_train_start(self):

        # Define the file names directly without using out_dir
        grid_t_file = 'grid_t.pt'
        inverse_grid_t_file = 'inverse_grid_t.pt'
        mask_fan_t_file = 'mask_fan_t.pt'

        if self.hparams.create_grids or not os.path.exists(grid_t_file):
            grid_tensor = []
            inverse_grid_t = []
            mask_fan_t = []

            for i in range(self.hparams.n_grids):

                grid_w, grid_h = self.hparams.grid_w, self.hparams.grid_h
                center_x = self.hparams.center_x
                r1 = self.hparams.r1

                center_y = self.hparams.center_y_start + (self.hparams.center_y_end - self.hparams.center_y_start) * (torch.rand(1))
                r2 = self.hparams.r2_start + ((self.hparams.r2_end - self.hparams.r2_start) * torch.rand(1)).item()
                theta = self.hparams.theta_start + ((self.hparams.theta_end - self.hparams.theta_start) * torch.rand(1)).item()
                
                grid, inverse_grid, mask = self.USR.init_grids(grid_w, grid_h, center_x, center_y, r1, r2, theta)

                grid_tensor.append(grid.unsqueeze(dim=0))
                inverse_grid_t.append(inverse_grid.unsqueeze(dim=0))
                mask_fan_t.append(mask.unsqueeze(dim=0))

            self.grid_t = torch.cat(grid_tensor).to(self.device)
            self.inverse_grid_t = torch.cat(inverse_grid_t).to(self.device)
            self.mask_fan_t = torch.cat(mask_fan_t).to(self.device)

            # Save tensors directly to the current directory
            
            torch.save(self.grid_t, grid_t_file)
            torch.save(self.inverse_grid_t, inverse_grid_t_file)
            torch.save(self.mask_fan_t, mask_fan_t_file)

            # print("Grids SAVED!")
            # print(self.grid_t.shape, self.inverse_grid_t.shape, self.mask_fan_t.shape)
        
        else:
            # Load tensors directly from the current directory
            self.grid_t = torch.load(grid_t_file).to(self.device)
            self.inverse_grid_t = torch.load(inverse_grid_t_file).to(self.device)
            self.mask_fan_t = torch.load(mask_fan_t_file).to(self.device)

    def training_step(self, train_batch, batch_idx):

        # Y is the real ultrasound
        labeled, Y = train_batch
        X_x = labeled['img']
        X_s = labeled['seg']

        opt_gen, opt_disc, opt_head = self.optimizers()

        grid_idx = torch.randint(low=0, high=self.hparams.n_grids - 1, size=(X_s.shape[0],))
        
        grid = self.grid_t[grid_idx]
        inverse_grid = self.inverse_grid_t[grid_idx]
        mask_fan = self.mask_fan_t[grid_idx]

        X_ = self.USR(X_s, grid=grid, inverse_grid=inverse_grid, mask_fan=mask_fan)
        X_s = X_s*mask_fan
        X_ = X_.detach()
        X = self.transform_us(X_)
        X_s = self.transform_us(X_s)
        
        labels = one_hot(X_s, self.hparams.num_labels).to(self.device)
        
        with autocast(enabled=True):
            
            feat_q, patch_ids_q, h = self.encode_nce(X, self.G.encoder.blocks)
            z_mu, z_sigma = self.quant_conv(h)
            z = self.sampling(z_mu, z_sigma)
            Y_fake = self.G.decode(z, labels)
            
            recons_loss = self.l1(Y_fake, X_)
            
            feat_k, _, _ = self.encode_nce(Y, self.G.encoder.blocks, patch_ids=patch_ids_q)
            
            feat_q_pool = self.H(feat_q)
            feat_k_pool = self.H(feat_k)
            
            # Y_fake, z_mu, z_sigma = self.G(X, labels)
            
            p_loss = self.perceptual_loss(Y_fake.float(), Y.float())
            
            kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])
            kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]

            logits_fake = self.D(Y_fake.contiguous().float())[-1]
            a_loss = self.adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
            
            nce_loss = self.nce_loss(feat_q_pool, feat_k_pool)
            
            loss_g = self.hparams.recons_weight*recons_loss + self.hparams.nce_weight*nce_loss + self.hparams.adv_weight * a_loss + (self.hparams.kl_weight * kl_loss) + (self.hparams.perceptual_weight * p_loss)

            
        # update D
        if self.current_epoch < self.hparams.warm_up_epochs:
            with autocast(enabled=True):
                self.set_requires_grad(self.D, True)
                opt_disc.zero_grad(set_to_none=True)
                
                loss_d = self.compute_D_loss(Y, Y_fake)
                
                scaler_d.scale(loss_d).backward()
                scaler_d.step(opt_disc)
                scaler_d.update()
                self.set_requires_grad(self.D, False)

        # update G
        opt_gen.zero_grad()
        opt_head.zero_grad()
        
        opt_gen.zero_grad(set_to_none=True)
        scaler_g.scale(loss_g).backward()
        scaler_g.step(opt_gen)
        scaler_g.update()

        opt_gen.step()
        opt_head.step()
        
        self.log("train_loss_g", loss_g)
        self.log("train_loss_d", loss_d)

    def validation_step(self, val_batch, batch_idx):

        labeled, Y = val_batch
        X_x = labeled['img']
        X_s = labeled['seg']
        
        grid_idx = torch.randint(low=0, high=self.hparams.n_grids - 1, size=(X_s.shape[0],))
        
        grid = self.grid_t[grid_idx]
        inverse_grid = self.inverse_grid_t[grid_idx]
        mask_fan = self.mask_fan_t[grid_idx]

        X_ = self.USR(X_s, grid=grid, inverse_grid=inverse_grid, mask_fan=mask_fan)
        X_s = X_s*mask_fan
        X = self.transform_us(X_)
        X_s = self.transform_us(X_s)
        
        
        labels = one_hot(X_s, self.hparams.num_labels).to(self.device)
        
        with autocast(enabled=True):
            
            feat_q, patch_ids_q, h = self.encode_nce(X, self.G.encoder.blocks)
            z_mu, z_sigma = self.quant_conv(h)
            z = self.sampling(z_mu, z_sigma)
            Y_fake = self.G.decode(z, labels)
            
            recons_loss = self.l1(Y_fake, X_)
            
            feat_k, _, _ = self.encode_nce(Y, self.G.encoder.blocks, patch_ids=patch_ids_q)
            
            feat_q_pool = self.H(feat_q)
            feat_k_pool = self.H(feat_k)
            
            p_loss = self.perceptual_loss(Y_fake.float(), Y.float())
            
            kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])
            kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]

            logits_fake = self.D(Y_fake.contiguous().float())[-1]
            a_loss = self.adv_loss(logits_fake, target_is_real=True, for_discriminator=False)
            
            nce_loss = self.nce_loss(feat_q_pool, feat_k_pool)
            
            loss_g = self.hparams.recons_weight*recons_loss + self.hparams.nce_weight*nce_loss + self.hparams.adv_weight * a_loss + (self.hparams.kl_weight * kl_loss) + (self.hparams.perceptual_weight * p_loss)

        self.log("val_loss", loss_g, sync_dist=True)
        
    def quant_conv(self, h):
        z_mu = self.G.quant_conv_mu(h)
        z_log_var = self.G.quant_conv_log_sigma(h)
        z_log_var = torch.clamp(z_log_var, -30.0, 20.0)
        z_sigma = torch.exp(z_log_var / 2)

        return z_mu, z_sigma
    
    def compute_D_loss(self, Y, Y_fake):
        logits_fake = self.D(Y_fake.contiguous().detach())[-1]
        loss_d_fake = self.adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
        logits_real = self.D(Y.contiguous().detach())[-1]
        loss_d_real = self.adv_loss(logits_real, target_is_real=True, for_discriminator=True)
        discriminator_loss = (loss_d_fake + loss_d_real) * 0.5

        loss_d = self.hparams.adv_weight * discriminator_loss
        
        return loss_d
    
    def encode_nce(self, x, blocks, patch_ids=None, blocks_ids=[0, 4, 7, 10, 14, 17], num_patches = 256):
        feat = x
        return_ids = []
        return_feats = []
        p_id = 0

        for block_id, block in enumerate(blocks):
            feat = block(feat)
            if block_id in blocks_ids:
                feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
                if patch_ids is not None:
                    patch_id = patch_ids[p_id]
                    p_id += 1
                else:
                    patch_id = torch.randperm(feat_reshape.shape[1]) #, device=config.DEVICE
                    patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
                    return_ids.append(patch_id)
                x_sample = feat_reshape[:, patch_id, :].flatten(0, 1)  # reshape(-1, x.shape[1])

                return_feats.append(x_sample)

        return return_feats, return_ids, feat
    
    def one_hot(self, input_label):
        # One hot encoding function for the labels
        shape_ = list(input_label.shape)
        shape_[1] = label_nc
        label_out = torch.zeros(shape_)
        for channel in range(self.hparams.num_labels):
            label_out[:, channel, ...] = input_label[:, 0, ...] == channel
        return label_out

In [24]:
sl_model = SPADELotus(num_labels=330, grid_w=128, grid_h=128, center_x=64.0, center_y=-15, r1=20.0, r2=125.0, theta=np.pi/6, alpha_coeff_boundary_map=0.1, beta_coeff_scattering=10, tgc=8)

SPADELotus(
  (D): PatchDiscriminator(
    (initial_conv): Convolution(
      (conv): Conv2d(1, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (adn): ADN(
        (D): Dropout(p=0.0, inplace=False)
        (A): LeakyReLU(negative_slope=0.2)
      )
    )
    (0): Convolution(
      (conv): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (adn): ADN(
        (N): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (D): Dropout(p=0.0, inplace=False)
        (A): LeakyReLU(negative_slope=0.2)
      )
    )
    (1): Convolution(
      (conv): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (adn): ADN(
        (N): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (D): Dropout(p=0.0, inplace=False)
        (A): LeakyReLU(negative_slope=0.2)
      )
    )
    (2): Convolution(
      (conv): Conv2d(64, 128, kernel_size=(4, 4), stride=(1