In [1]:
"""
This is changed from the orginal code and modified by ChatGPT from
https://github.com/thuanz123/enhancing-transformers/blob/1778fc497ea11ed2cef134404f99d4d6b921cda9/enhancing/modules/stage1/layers.py
"""

# ------------------------------------------------------------------------------------
# Enhancing Transformers
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
# Licensed under the MIT License [see LICENSE for details]
# ------------------------------------------------------------------------------------
# Modified from Taming Transformers (https://github.com/CompVis/taming-transformers)
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
# ------------------------------------------------------------------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F


import PIL
import math
import torch
import importlib
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

from functools import partial
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from typing import List, Tuple, Dict, Any, Optional, Union
from omegaconf import OmegaConf
from collections import OrderedDict
from torch.optim import lr_scheduler
from torchvision import transforms as T

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
volume_size = 64
pix_dim = 1.5

In [3]:
# verify the localhost 8000 is running
import requests
try:
    r = requests.get('http://localhost:8000')
    assert r.status_code == 200
    print('Localhost 8000 is running.')
except:
    raise Exception('Please make sure the localhost 8000 is running. Run `python -m http.server` in the root directory of this repo.')


Localhost 8000 is running.


In [4]:
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=float)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb

In [5]:
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 3 == 0

    emb_d = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0])  # (D*H*W, D/3)
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1])  # (D*H*W, D/3)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2])  # (D*H*W, D/3)

    emb = np.concatenate([emb_d, emb_h, emb_w], axis=1) # (D*H*W, D)
    return emb

In [6]:
def get_3d_sincos_pos_embed(embed_dim, grid_size):
    grid_size = (grid_size, grid_size, grid_size) if type(grid_size) != tuple else grid_size
    grid_d = np.arange(grid_size[0], dtype=np.float32)
    grid_h = np.arange(grid_size[1], dtype=np.float32)
    grid_w = np.arange(grid_size[2], dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h, grid_d)  # here w, h, d goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([3, 1, grid_size[0], grid_size[1], grid_size[2]])
    pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)

    return pos_embed


In [7]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)
    elif isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
        w = m.weight.data
        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

In [8]:
class PreNorm(nn.Module):
    def __init__(self, dim: int, fn: nn.Module) -> None:
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
        return self.fn(self.norm(x), **kwargs)

In [9]:
class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, dim)
        )

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        return self.net(x)

In [10]:
# class Attention(nn.Module):
#     def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None:
#         super().__init__()
#         inner_dim = dim_head *  heads
#         project_out = not (heads == 1 and dim_head == dim)

#         self.heads = heads
#         self.scale = dim_head ** -0.5

#         self.attend = nn.Softmax(dim = -1)
#         self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

#         self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity()

#     def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
#         qkv = self.to_qkv(x).chunk(3, dim = -1)
#         q, k, v = map(lambda t: rearrange(t, 'b (d h w) (head dim) -> b head (d h w) dim', head = self.heads), qkv)

#         attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
#         attn = self.attend(attn)

#         out = torch.matmul(attn, v)
#         out = rearrange(out, 'b head (d h w) dim -> b (d h w) (head dim)')

#         return self.to_out(out)

In [11]:
class Attention(nn.Module):
    def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None:
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Linear(inner_dim, dim) if project_out else nn.Identity()

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        attn = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')

        return self.to_out(out)

In [12]:
class Transformer(nn.Module):
    def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int) -> None:
        super().__init__()
        self.layers = nn.ModuleList([])
        for idx in range(depth):
            layer = nn.ModuleList([PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head)),
                                   PreNorm(dim, FeedForward(dim, mlp_dim))])
            self.layers.append(layer)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

In [13]:
class ViTEncoder3D(nn.Module):
        # "dim": 240, "depth": 6, "heads": 8, "mlp_dim": 512, "channels": 1, "dim_head": 64
    def __init__(self, volume_size: Union[Tuple[int, int, int], int], patch_size: Union[Tuple[int, int, int], int],
                 dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 1, dim_head: int = 64) -> None:
        super().__init__()
        volume_depth, volume_height, volume_width = volume_size if isinstance(volume_size, tuple) else (volume_size, volume_size, volume_size)
        patch_depth, patch_height, patch_width = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size, patch_size)

        assert volume_depth % patch_depth == 0 and volume_height % patch_height == 0 and volume_width % patch_width == 0, 'Volume dimensions must be divisible by the patch size.'
        en_pos_embedding = get_3d_sincos_pos_embed(dim, (volume_depth // patch_depth, volume_height // patch_height, volume_width // patch_width))
        self.num_patches = (volume_depth // patch_depth) * (volume_height // patch_height) * (volume_width // patch_width)
        self.patch_dim = channels * patch_depth * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            nn.Conv3d(channels, dim, kernel_size=patch_size, stride=patch_size),
            Rearrange('b c d h w -> b (d h w) c'),
        )
        self.en_pos_embedding = nn.Parameter(torch.from_numpy(en_pos_embedding).float().unsqueeze(0), requires_grad=False)
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        self.apply(init_weights)

    def forward(self, volume: torch.FloatTensor) -> torch.FloatTensor:
        x = self.to_patch_embedding(volume)
        x = x + self.en_pos_embedding
        x = self.transformer(x)

        return x

In [14]:
class ViTDecoder3D(nn.Module):
    def __init__(self, volume_size: Union[Tuple[int, int, int], int], patch_size: Union[Tuple[int, int, int], int],
                 dim: int, depth: int, heads: int, mlp_dim: int, channels: int = 1, dim_head: int = 64) -> None:
        super().__init__()
        volume_depth, volume_height, volume_width = volume_size if isinstance(volume_size, tuple) else (volume_size, volume_size, volume_size)
        patch_depth, patch_height, patch_width = patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size, patch_size)

        assert volume_depth % patch_depth == 0 and volume_height % patch_height == 0 and volume_width % patch_width == 0, 'Volume dimensions must be divisible by the patch size.'
        de_pos_embedding = get_3d_sincos_pos_embed(dim, (volume_depth // patch_depth, volume_height // patch_height, volume_width // patch_width))

        self.num_patches = (volume_depth // patch_depth) * (volume_height // patch_height) * (volume_width // patch_width)
        self.patch_dim = channels * patch_depth * patch_height * patch_width

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
        self.de_pos_embedding = nn.Parameter(torch.from_numpy(de_pos_embedding).float().unsqueeze(0), requires_grad=False)
        self.to_voxel = nn.Sequential(
            Rearrange('b (d h w) c -> b c d h w', d=volume_depth // patch_depth, h=volume_height // patch_height, w=volume_width // patch_width),
            nn.ConvTranspose3d(dim, channels, kernel_size=patch_size, stride=patch_size)
        )

        self.apply(init_weights)

    def forward(self, token: torch.FloatTensor) -> torch.FloatTensor:
        x = token + self.de_pos_embedding
        x = self.transformer(x)
        x = self.to_voxel(x)

        return x

    def get_last_layer(self) -> nn.Parameter:
        return self.to_voxel[-1].weight


In [15]:
class BaseQuantizer(nn.Module):
    def __init__(self, embed_dim: int, n_embed: int, straight_through: bool = True, use_norm: bool = True,
                 use_residual: bool = False, num_quantizers: Optional[int] = None) -> None:
        super().__init__()
        self.straight_through = straight_through
        self.norm = lambda x: F.normalize(x, dim=-1) if use_norm else x

        self.use_residual = use_residual
        self.num_quantizers = num_quantizers

        self.embed_dim = embed_dim
        self.n_embed = n_embed

        self.embedding = nn.Embedding(self.n_embed, self.embed_dim)
        self.embedding.weight.data.normal_()
        
    def quantize(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
        pass
    
    def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
        if not self.use_residual:
            z_q, loss, encoding_indices = self.quantize(z)
        else:
            z_q = torch.zeros_like(z)
            residual = z.detach().clone()

            losses = []
            encoding_indices = []

            for _ in range(self.num_quantizers):
                z_qi, loss, indices = self.quantize(residual.clone())
                residual.sub_(z_qi)
                z_q.add_(z_qi)

                encoding_indices.append(indices)
                losses.append(loss)

            losses, encoding_indices = map(partial(torch.stack, dim = -1), (losses, encoding_indices))
            loss = losses.mean()

        # preserve gradients with straight-through estimator
        if self.straight_through:
            z_q = z + (z_q - z).detach()

        return z_q, loss, encoding_indices

In [16]:
class VectorQuantizer(BaseQuantizer):
    def __init__(self, embed_dim: int, n_embed: int, beta: float = 0.25, use_norm: bool = True,
                 use_residual: bool = False, num_quantizers: Optional[int] = None, **kwargs) -> None:
        super().__init__(embed_dim, n_embed, True,
                         use_norm, use_residual, num_quantizers)
        
        self.beta = beta

    def quantize(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
        z_reshaped_norm = self.norm(z.view(-1, self.embed_dim))
        embedding_norm = self.norm(self.embedding.weight)
        
        d = torch.sum(z_reshaped_norm ** 2, dim=1, keepdim=True) + \
            torch.sum(embedding_norm ** 2, dim=1) - 2 * \
            torch.einsum('b d, n d -> b n', z_reshaped_norm, embedding_norm)

        encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        encoding_indices = encoding_indices.view(*z.shape[:-1])
        
        z_q = self.embedding(encoding_indices).view(z.shape)
        z_qnorm, z_norm = self.norm(z_q), self.norm(z)
        
        # compute loss for embedding
        loss = self.beta * torch.mean((z_qnorm.detach() - z_norm)**2) +  \
               torch.mean((z_qnorm - z_norm.detach())**2)

        return z_qnorm, loss, encoding_indices

In [17]:
class GumbelQuantizer(BaseQuantizer):
    def __init__(self, embed_dim: int, n_embed: int, temp_init: float = 1.0,
                 use_norm: bool = True, use_residual: bool = False, num_quantizers: Optional[int] = None, **kwargs) -> None:
        super().__init__(embed_dim, n_embed, False,
                         use_norm, use_residual, num_quantizers)
        
        self.temperature = temp_init
        
    def quantize(self, z: torch.FloatTensor, temp: Optional[float] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor]:
        # force hard = True when we are in eval mode, as we must quantize
        hard = not self.training
        temp = self.temperature if temp is None else temp
        
        z_reshaped_norm = self.norm(z.view(-1, self.embed_dim))
        embedding_norm = self.norm(self.embedding.weight)

        logits = - torch.sum(z_reshaped_norm ** 2, dim=1, keepdim=True) - \
                 torch.sum(embedding_norm ** 2, dim=1) + 2 * \
                 torch.einsum('b d, n d -> b n', z_reshaped_norm, embedding_norm)
        logits =  logits.view(*z.shape[:-1], -1)
        
        soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=-1, hard=hard)
        z_qnorm = torch.matmul(soft_one_hot, embedding_norm)
        
        # kl divergence to the prior loss
        logits =  F.log_softmax(logits, dim=-1) # use log_softmax because it is more numerically stable
        loss = torch.sum(logits.exp() * (logits+math.log(self.n_embed)), dim=-1).mean()
               
        # get encoding via argmax
        encoding_indices = soft_one_hot.argmax(dim=-1)
        
        return z_qnorm, loss, encoding_indices

In [18]:
class ViTVQ3D(pl.LightningModule):
    def __init__(self, volume_key: str, volume_size: int, patch_size: int, encoder: OmegaConf, decoder: OmegaConf, quantizer: OmegaConf,
                 path: Optional[str] = None, ignore_keys: List[str] = list(), scheduler: Optional[OmegaConf] = None) -> None:
        # loss: OmegaConf, 
        super().__init__()
        self.path = path
        self.ignore_keys = ignore_keys 
        self.volume_key = volume_key
        self.scheduler = scheduler 
        
        # self.loss = initialize_from_config(loss)
        # self.loss = VQVAELoss()
        self.encoder = ViTEncoder3D(volume_size=volume_size, patch_size=patch_size, **encoder)
        self.decoder = ViTDecoder3D(volume_size=volume_size, patch_size=patch_size, **decoder)
        self.quantizer = VectorQuantizer(**quantizer)
        # self.pre_quant = nn.Linear(encoder.dim, quantizer.embed_dim)
        # self.post_quant = nn.Linear(quantizer.embed_dim, decoder.dim)
        self.pre_quant = nn.Linear(encoder["dim"], quantizer["embed_dim"])
        self.post_quant = nn.Linear(quantizer["embed_dim"], decoder["dim"])

        if path is not None:
            self.init_from_ckpt(path, ignore_keys)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:    
        quant, diff = self.encode(x)
        dec = self.decode(quant)
        
        return dec, diff

    def init_from_ckpt(self, path: str, ignore_keys: List[str] = list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        self.load_state_dict(sd, strict=False)
        print(f"Restored from {path}")
        
    def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        h = self.encoder(x)
        h = self.pre_quant(h)
        quant, emb_loss, _ = self.quantizer(h)
        
        return quant, emb_loss

    def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
        quant = self.post_quant(quant)
        dec = self.decoder(quant)
        
        return dec

    def encode_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
        h = self.encoder(x)
        h = self.pre_quant(h)
        _, _, codes = self.quantizer(h)
        
        return codes

    def decode_codes(self, code: torch.LongTensor) -> torch.FloatTensor:
        quant = self.quantizer.embedding(code)
        quant = self.quantizer.norm(quant)
        
        if self.quantizer.use_residual:
            quant = quant.sum(-2)  
            
        dec = self.decode(quant)
        
        return dec

    def get_input(self, batch: Tuple[Any, Any], key: str = 'volume') -> Any:
        x = batch[key]
        if len(x.shape) == 4:
            x = x[..., None]
        if x.dtype == torch.double:
            x = x.float()

        return x.contiguous()

    def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
        x = self.get_input(batch, self.volume_key)
        xrec, qloss = self(x)

        if optimizer_idx == 0:
            # autoencoder
            aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, batch_idx,
                                            last_layer=self.decoder.get_last_layer(), split="train")

            self.log("train/total_loss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            del log_dict_ae["train/total_loss"]
            
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)

            return aeloss

        if optimizer_idx == 1:
            # discriminator
            discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, batch_idx,
                                                last_layer=self.decoder.get_last_layer(), split="train")
            
            self.log("train/disc_loss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
            del log_dict_disc["train/disc_loss"]
            
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            
            return discloss

    def validation_step(self, batch: Tuple[Any, Any], batch_idx: int) -> Dict:
        x = self.get_input(batch, self.volume_key)
        xrec, qloss = self(x)
        aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, batch_idx,
                                        last_layer=self.decoder.get_last_layer(), split="val")

        rec_loss = log_dict_ae["val/rec_loss"]

        self.log("val/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        self.log("val/total_loss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        del log_dict_ae["val/rec_loss"]
        del log_dict_ae["val/total_loss"]

        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)

        if hasattr(self.loss, 'discriminator'):
            discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, batch_idx,
                                                last_layer=self.decoder.get_last_layer(), split="val")
            
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
        
        return self.log_dict

    def configure_optimizers(self) -> Tuple[List, List]:
        lr = self.learning_rate
        optim_groups = list(self.encoder.parameters()) + \
                       list(self.decoder.parameters()) + \
                       list(self.pre_quant.parameters()) + \
                       list(self.post_quant.parameters()) + \
                       list(self.quantizer.parameters())
        
        optimizers = [torch.optim.AdamW(optim_groups, lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)]
        schedulers = []
        
        if hasattr(self.loss, 'discriminator'):
            optimizers.append(torch.optim.AdamW(self.loss.discriminator.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4))

        if self.scheduler is not None:
            self.scheduler.params.start = lr
            scheduler = initialize_from_config(self.scheduler)
            
            schedulers = [
                {
                    'scheduler': lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                } for optimizer in optimizers
            ]
   
        return optimizers, schedulers
        
    def log_images(self, batch: Tuple[Any, Any], *args, **kwargs) -> Dict:
        log = dict()
        x = self.get_input(batch, self.volume_key).to(self.device)
        quant, _ = self.encode(x)
        
        log["originals"] = x
        log["reconstructions"] = self.decode(quant)
        
        return log

In [19]:
# model = ViTVQ3D(
#     volume_key="volume", volume_size=96, patch_size=8,
#     encoder={
#         "dim": 240, "depth": 6, "heads": 8, "mlp_dim": 512, "channels": 1, "dim_head": 64
#     },
#     decoder={
#         "dim": 240, "depth": 6, "heads": 8, "mlp_dim": 512, "channels": 1, "dim_head": 64
#     },
#     quantizer={
#         "embed_dim": 64, "n_embed": 512, "beta": 0.25, "use_norm": True, "use_residual": False
#     }
# )

In [20]:
class ViTVQGumbel3D(ViTVQ3D):
    def __init__(self, volume_key: str, volume_size: int, patch_size: int, encoder: OmegaConf, decoder: OmegaConf, quantizer: OmegaConf, loss: OmegaConf,
                 path: Optional[str] = None, ignore_keys: List[str] = list(), temperature_scheduler: OmegaConf = None, scheduler: Optional[OmegaConf] = None) -> None:
        super().__init__(volume_key, volume_size, patch_size, encoder, decoder, quantizer, loss, None, None, scheduler)

        self.temperature_scheduler = initialize_from_config(temperature_scheduler) \
                                     if temperature_scheduler else None
        self.quantizer = GumbelQuantizer(**quantizer)

        if path is not None:
            self.init_from_ckpt(path, ignore_keys)

    def training_step(self, batch: Tuple[Any, Any], batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
        if self.temperature_scheduler:
            self.quantizer.temperature = self.temperature_scheduler(self.global_step)

        loss = super().training_step(batch, batch_idx, optimizer_idx)
        
        if optimizer_idx == 0:
            self.log("temperature", self.quantizer.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)

        return loss


In [21]:
# # create a 64*64*64 tensor float32
# dx, dy, dz = 96, 96, 96
# x = torch.randn(1, 1, dx, dy, dz).float()
# # forward
# out, diff = model(x)
# # print(out.shape, diff)
# # recon loss is between x and out
# # diff is the embedding loss
# # show the recon loss in MSE loss
# recon_loss = F.mse_loss(x, out)
# print("Reconstruction loss is ", recon_loss)

In [22]:
# import nibabel as nib

# nii_path = "s0001.nii.gz"
# nii_file = nib.load(nii_path)
# nii_data = nii_file.get_fdata()
# print("nii_data.shape is ", nii_data.shape)
# # show max and min value
# print("max value is ", nii_data.max())
# print("min value is ", nii_data.min())


In [23]:
# # clip the value from -1024 to 2976
# # and normalize the value to 0-1
# nii_data_clip = np.clip(nii_data, -1024, 2976)
# nii_data_norm = (nii_data_clip - (-1024)) / (2976 - (-1024))

In [24]:
# # cut the center
# dx, dy, dz = (96, 96, 96)
# nii_data_norm_cut = nii_data_norm[32:32+dx, 32:32+dy, 32:32+dz]
# print("nii_data_norm_cut.shape is ", nii_data_norm_cut.shape)
# # input the nii_data_norm_cut to the model
# nii_data_norm_cut_tensor = torch.tensor(nii_data_norm_cut).float()
# nii_data_norm_cut_tensor = nii_data_norm_cut_tensor.unsqueeze(0).unsqueeze(0)
# print("nii_data_norm_cut_tensor.shape is ", nii_data_norm_cut_tensor.shape)
# out, cb_loss = model(nii_data_norm_cut_tensor)
# print("Recon loss is ", F.mse_loss(nii_data_norm_cut_tensor, out))

In [25]:
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from typing import Optional, Sequence, Tuple, Union
from monai.networks.layers.factories import Act, Norm

# apply the channel wise norm for all feature maps
# Normalize activations
def normalize_activations(features):
    normalized_features = []
    for feature in features:
        # Compute the norm along the channel dimension
        norm = torch.norm(feature, p=2, dim=1, keepdim=True)
        # Normalize the feature map
        normalized_feature = feature / (norm + 1e-8)  # Add a small value to avoid division by zero
        normalized_features.append(normalized_feature)
    return normalized_features

# Compute l2 distance for each pair of feature maps
def l2_distance(features1, features2):
    l2_distances = []
    for f1, f2 in zip(features1, features2):
        # Compute the l2 distance
        l2_dist = torch.norm(f1 - f2, p=2, dim=1)  # Sum across channel dimension
        l2_distances.append(l2_dist)
    return l2_distances

# Average the l2 distances across spatial dimensions
def average_spatial(distances):
    averaged_distances = []
    for dist in distances:
        # Average across spatial dimensions (dimensions 1, 2, 3)
        avg_dist = torch.mean(dist, dim=[1, 2, 3])
        averaged_distances.append(avg_dist)
    return averaged_distances

# Overall average across all layers
def average_all_layers(distances):
    total_distance = torch.stack(distances).mean()
    return total_distance

class UNet3D_encoder(nn.Module):
    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        channels: Sequence[int],
        strides: Sequence[int],
        kernel_size: Union[Sequence[int], int] = 3,
        up_kernel_size: Union[Sequence[int], int] = 3,
        num_res_units: int = 0,
        act: Union[Tuple, str] = Act.PRELU,
        norm: Union[Tuple, str] = Norm.INSTANCE,
        dropout: float = 0.0,
        bias: bool = True,
        adn_ordering: str = "NDA",
        dimensions: Optional[int] = None,
        pretrained_path = None,
    ) -> None:
        super().__init__()

        if len(channels) < 2:
            raise ValueError("the length of `channels` should be no less than 2.")
        delta = len(strides) - (len(channels) - 1)
        if delta < 0:
            raise ValueError("the length of `strides` should equal to `len(channels) - 1`.")
        if delta > 0:
            warnings.warn(f"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.")
        if dimensions is not None:
            spatial_dims = dimensions
        if isinstance(kernel_size, Sequence):
            if len(kernel_size) != spatial_dims:
                raise ValueError("the length of `kernel_size` should equal to `dimensions`.")
        if isinstance(up_kernel_size, Sequence):
            if len(up_kernel_size) != spatial_dims:
                raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.")

        self.dimensions = spatial_dims
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.channels = channels
        self.strides = strides
        self.kernel_size = kernel_size
        self.up_kernel_size = up_kernel_size
        self.num_res_units = num_res_units
        self.act = act
        self.norm = norm
        self.dropout = dropout
        self.bias = bias
        self.adn_ordering = adn_ordering
        self.pretrained_path = pretrained_path


        # UNet( 
        # spatial_dims=unet_dict["spatial_dims"],
        # in_channels=unet_dict["in_channels"],
        # out_channels=unet_dict["out_channels"],
        # channels=unet_dict["channels"],
        # strides=unet_dict["strides"],
        # num_res_units=unet_dict["num_res_units"],
        # act=unet_dict["act"],
        # norm=unet_dict["normunet"],
        # dropout=unet_dict["dropout"],
        # bias=unet_dict["bias"],
        # )

        # input - down1 ------------- up1 -- output
        #         |                   |
        #         down2 ------------- up2
        #         |                   |
        #         down3 ------------- up3
        #         |                   |
        #         down4 -- bottom --  up4
        # 1 -> (32, 64, 128, 256) -> 1

        self.down1 = ResidualUnit(3, self.in_channels, self.channels[0], self.strides[0],
                kernel_size=self.kernel_size, subunits=self.num_res_units,
                act=self.act, norm=self.norm, dropout=self.dropout,
                bias=self.bias, adn_ordering=self.adn_ordering)
        self.down2 = ResidualUnit(3, self.channels[0], self.channels[1], self.strides[1],
                kernel_size=self.kernel_size, subunits=self.num_res_units,
                act=self.act, norm=self.norm, dropout=self.dropout,
                bias=self.bias, adn_ordering=self.adn_ordering)
        self.down3 = ResidualUnit(3, self.channels[1], self.channels[2], self.strides[2],
                kernel_size=self.kernel_size, subunits=self.num_res_units,
                act=self.act, norm=self.norm, dropout=self.dropout,
                bias=self.bias, adn_ordering=self.adn_ordering)
        self.bottom = ResidualUnit(3, self.channels[2], self.channels[3], 1,
                kernel_size=self.kernel_size, subunits=self.num_res_units,
                act=self.act, norm=self.norm, dropout=self.dropout,
                bias=self.bias, adn_ordering=self.adn_ordering)

        self.load_weights(self.pretrained_path)
        self.set_all_parameters_ignore_grad()
    
    def load_weights(self, path: str):
        # the path is to the whole model, where we only need the encoder part
        # iterate the current model weight names and load the weights from the pre-trained model
        pretrain_dict = torch.load(path)
        current_dict = self.state_dict()
        # 1. filter out unnecessary keys
        pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in current_dict}
        # 2. overwrite entries in the existing state dict
        current_dict.update(pretrain_dict)
        # 3. load the new state dict
        self.load_state_dict(current_dict)

    def set_all_parameters_ignore_grad(self):
        for param in self.parameters():
            param.requires_grad = False


    # def forward(self, x: torch.Tensor) -> torch.Tensor:
    #     x1 = self.down1(x)
    #     x2 = self.down2(x1)
    #     x3 = self.down3(x2)
    #     x4 = self.bottom(x3)
    #     return x1, x2, x3, x4

    def forward(self, y: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
        y1 = self.down1(y)
        y2 = self.down2(y1)
        y3 = self.down3(y2)
        y4 = self.bottom(y3)

        z1 = self.down1(z)
        z2 = self.down2(z1)
        z3 = self.down3(z2)
        z4 = self.bottom(z3)

        y_features = [y1, y2, y3, y4]
        z_features = [z1, z2, z3, z4]

        y_fea_norm = normalize_activations(y_features)
        z_fea_norm = normalize_activations(z_features)

        l2_dist = l2_distance(y_fea_norm, z_fea_norm)
        avg_dist = average_spatial(l2_dist)
        total_dist = average_all_layers(avg_dist)

        return total_dist

In [26]:
# from monai.networks.nets.unet import UNet as UNet

# model = torch.load("model_best_181.pth")
# # save the state_dict
# torch.save(model.state_dict(), "model_best_181_state_dict.pth")

In [27]:
# # here we set the encoder model parameters

# preceptual_loss = dict()
# preceptual_loss["spatial_dims"] = 3
# preceptual_loss["in_channels"] = 1
# preceptual_loss["out_channels"] = 1
# preceptual_loss["channels"] = [32, 64, 128, 256]
# preceptual_loss["strides"] = [2, 2, 2]
# preceptual_loss["num_res_units"] = 4
# preceptual_loss["pretrained_path"] = "model_best_181_state_dict.pth"

# preceptual_model = UNet3D_encoder(**preceptual_loss)
# print("input size is ", nii_data_norm_cut_tensor.size(), "output size is ", out.size())
# total_dist = preceptual_model(nii_data_norm_cut_tensor, out)
# print("total_dist is ", total_dist)


In [28]:
# reconL2_loss = F.mse_loss(nii_data_norm_cut_tensor, out)
# reconL1_loss = F.l1_loss(nii_data_norm_cut_tensor, out)
# perceptual_loss = total_dist
# codebook_loss = cb_loss
# print("Recon L2 loss is ", reconL2_loss)
# print("Recon L1 loss is ", reconL1_loss)
# print("Perceptual loss is ", perceptual_loss)
# print("Codebook loss is ", codebook_loss)

In [29]:
# # show the plot for max and min for both xn_y and xn_z for x1 to x4

# def describe_data(name, data):
#     # print(name, "MAX: ", data.max(), "MIN: ", data.min(), "MIN: ", data.mean(), "STD: ", data.std())
#     print(f"{name} MAX: {data.max():4f} MIN: {data.min():4f} MEAN: {data.mean():4f} STD: {data.std():4f}")
    
# def check_max_and_min(features, names):
#     for idx, feature in enumerate(features):
#         feature_numpy = feature.detach().numpy()
#         feature_numpy_channel_max = feature_numpy.max(axis=(0, 2, 3, 4))
#         feature_numpy_channel_min = feature_numpy.min(axis=(0, 2, 3, 4))
#         feature_numpy_channel_mean = feature_numpy.mean(axis=(0, 2, 3, 4))
#         feature_numpy_channel_std = feature_numpy.std(axis=(0, 2, 3, 4))
#         print("-"*20)
#         describe_data(names[idx], feature_numpy_channel_max)
#         describe_data(names[idx], feature_numpy_channel_min)
#         describe_data(names[idx], feature_numpy_channel_mean)
#         describe_data(names[idx], feature_numpy_channel_std)

# check_max_and_min([x1_y, x2_y, x3_y, x4_y], ["x1_y", "x2_y", "x3_y", "x4_y"])
# check_max_and_min([x1_z, x2_z, x3_z, x4_z], ["x1_z", "x2_z", "x3_z", "x4_z"])

In [30]:
# # apply the channel wise norm for all feature maps
# # Normalize activations
# def normalize_activations(features):
#     normalized_features = []
#     for feature in features:
#         # Compute the norm along the channel dimension
#         norm = torch.norm(feature, p=2, dim=1, keepdim=True)
#         # Normalize the feature map
#         normalized_feature = feature / (norm + 1e-8)  # Add a small value to avoid division by zero
#         normalized_features.append(normalized_feature)
#     return normalized_features

# # Compute l2 distance for each pair of feature maps
# def l2_distance(features1, features2):
#     l2_distances = []
#     for f1, f2 in zip(features1, features2):
#         # Compute the l2 distance
#         l2_dist = torch.norm(f1 - f2, p=2, dim=1)  # Sum across channel dimension
#         l2_distances.append(l2_dist)
#     return l2_distances

# # Average the l2 distances across spatial dimensions
# def average_spatial(distances):
#     averaged_distances = []
#     for dist in distances:
#         # Average across spatial dimensions (dimensions 1, 2, 3)
#         avg_dist = torch.mean(dist, dim=[1, 2, 3])
#         averaged_distances.append(avg_dist)
#     return averaged_distances

# # Overall average across all layers
# def average_all_layers(distances):
#     total_distance = torch.stack(distances).mean()
#     return total_distance

# y_features = normalize_activations([x1_y, x2_y, x3_y, x4_y])
# z_features = normalize_activations([x1_z, x2_z, x3_z, x4_z])

# check_max_and_min(y_features, ["x1_y_norm", "x2_y_norm", "x3_y_norm", "x4_y_norm"])
# check_max_and_min(z_features, ["x1_z_norm", "x2_z_norm", "x3_z_norm", "x4_z_norm"])

In [31]:
# # Calculate l2 distances
# l2_distances = l2_distance(y_features, z_features)

# # Average across spatial dimensions
# averaged_spatial_distances = average_spatial(l2_distances)

# # Final perceptual loss
# final_distance = average_all_layers(averaged_spatial_distances)

# print(f'Final Distance: {final_distance.item()}')

In [32]:
from monai.transforms import (
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    ScaleIntensityRanged,
    Spacingd,
    RandSpatialCropd,
    # random flip and rotate
    RandFlipd,
    RandRotated,
)

train_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(
            keys=["image"],
            pixdim=(pix_dim, pix_dim, pix_dim),
            mode=("bilinear"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-1024,
            a_max=2976,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
        # random crop to the target size
        RandSpatialCropd(keys=["image"], roi_size=(volume_size, volume_size, volume_size), random_center=True, random_size=False),
        # add random flip and rotate
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=2),
        RandRotated(keys=["image"], prob=0.5, range_x=15, range_y=15, range_z=15),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(
            keys=["image"],
            pixdim=(pix_dim, pix_dim, pix_dim),
            mode=("bilinear"),
        ),
        ScaleIntensityRanged(keys=["image"], a_min=-1024, a_max=2976, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["image"], source_key="image"),
        RandSpatialCropd(keys=["image"], roi_size=(volume_size, volume_size, volume_size), random_center=True, random_size=False),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image"], prob=0.5, spatial_axis=2),
    ]
)

In [33]:
# check CUDA device
device = torch.device("cuda:0")
print(device)
# show Memory for device
!nvidia-smi

cuda:0
Tue Jun 25 22:33:19 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A5000    Off  | 00000000:17:00.0 Off |                  Off |
| 30%   34C    P8    18W / 230W |      1MiB / 24564MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A5000    Off  | 00000000:65:00.0 Off |                  Off |
| 30%   29C    P8    17W / 230W |      1MiB / 24564MiB |      0%      Default |
|

In [34]:
# # divide the training, validation and validation set
# import glob

# ct_file_list = sorted(glob.glob("tsv1_ct/*.nii.gz"))
# print("total ct files are ", len(ct_file_list))

# # divide the training, validation and test set by 5:3:2

# import random
# import json

# random.shuffle(ct_file_list)
# # divide the whole dataset into 10 pieces, and save them into chunk_0 to chunk_9

# chunk_size = len(ct_file_list) // 10

# # save them into data_chunks.json
# # data_chunks.json is a dictionary, key is the chunk number, value is the list of file names
# data_chunk = dict()
# for i in range(10):
#     start = i * chunk_size
#     end = (i + 1) * chunk_size
#     if i == 9:
#         end = len(ct_file_list)
#     # data_chunk[f"chunk_{i}"] = ct_file_list[start:end]
#     # in each item, it will be {"image": "tsv1_ct/xxx.nii.gz"}
#     data_chunk[f"chunk_{i}"] = [{"image": ct_file} for ct_file in ct_file_list[start:end]]

# with open("data_chunks.json", "w") as f:
#     json.dump(data_chunk, f)

In [35]:
import json

# load data_chunks.json and specif chunk_0 to chunk_4 for training, chunk_5 to chunk_7 for validation, chunk_8 and chunk_9 for testing
with open("data_chunks.json", "r") as f:
    data_chunk = json.load(f)

train_files = []
val_files = []
test_files = []

for i in range(5):
    train_files.extend(data_chunk[f"chunk_{i}"])
for i in range(5, 8):
    val_files.extend(data_chunk[f"chunk_{i}"])
for i in range(8, 10):
    test_files.extend(data_chunk[f"chunk_{i}"])

num_train_files = len(train_files)
num_val_files = len(val_files)
num_test_files = len(test_files)

print("Train files are ", len(train_files))
print("Val files are ", len(val_files))
print("Test files are ", len(test_files))

Train files are  575
Val files are  345
Test files are  233


In [36]:
from monai.data import (
    DataLoader,
    CacheDataset,
)

In [37]:
import torch
from torch.utils.data import DataLoader, Dataset, default_collate
import numpy as np
import random

class RobustCacheDataset(CacheDataset):
    def __init__(self, data, transform=None, cache_num=0, cache_rate=1.0, num_workers=0):
        super().__init__(data, transform=transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)

    def __getitem__(self, idx):
        try:
            # Fetch the cached data
            data = super().__getitem__(idx)
            return data
        except Exception as e:
            print(f"Error loading data at index {idx}: {e}")
            # Handle the error appropriately, for example, return None or a default value
            return None

def worker_init_fn(worker_id):
    np.random.seed(worker_id)
    random.seed(worker_id)

def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    return default_collate(batch)

In [38]:
train_ds = RobustCacheDataset(
    data=train_files,
    transform=train_transforms,
    cache_num=num_train_files,
    cache_rate=0.2, # 600 * 0.1 = 60
    num_workers=4,
)

val_ds = RobustCacheDataset(
    data=val_files,
    transform=val_transforms, 
    cache_num=num_val_files,
    cache_rate=0.1, # 360 * 0.05 = 18
    num_workers=2)


Loading dataset:   0%|                                                                         | 0/115 [00:00<?, ?it/s]

Loading dataset: 100%|███████████████████████████████████████████████████████████████| 115/115 [00:36<00:00,  3.15it/s]
Loading dataset: 100%|█████████████████████████████████████████████████████████████████| 34/34 [00:12<00:00,  2.74it/s]


In [39]:
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn, collate_fn=collate_fn, timeout=60)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn, collate_fn=collate_fn, timeout=60)

In [40]:
model = ViTVQ3D(
    volume_key="volume", volume_size=volume_size, patch_size=8,
    encoder={
        "dim": 360, "depth": 6, "heads": 16, "mlp_dim": 1024, "channels": 1, "dim_head": 128
    },
    decoder={
        "dim": 360, "depth": 6, "heads": 16, "mlp_dim": 1024, "channels": 1, "dim_head": 128
    },
    quantizer={
        "embed_dim": 128, "n_embed": 1024, "beta": 0.25, "use_norm": True, "use_residual": False
    }
).to(device)

2024-06-25 22:34:09,711 - Created a temporary directory at /tmp/tmpienn0c3f
2024-06-25 22:34:09,712 - Writing /tmp/tmpienn0c3f/_remote_module_non_scriptable.py


In [41]:
learning_rate = 5e-4
# use AdamW optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)

In [42]:
# here we set the encoder model parameters

preceptual_loss = dict()
preceptual_loss["spatial_dims"] = 3
preceptual_loss["in_channels"] = 1
preceptual_loss["out_channels"] = 1
preceptual_loss["channels"] = [32, 64, 128, 256]
preceptual_loss["strides"] = [2, 2, 2]
preceptual_loss["num_res_units"] = 4
preceptual_loss["pretrained_path"] = "model_best_181_state_dict.pth"

preceptual_model = UNet3D_encoder(**preceptual_loss).to(device)
# total_dist = preceptual_model(nii_data_norm_cut_tensor, out)
# print("total_dist is ", total_dist)


In [43]:
# create a logger for the training
# every time called logger.log(), it will save the log into the file
import time

class simple_logger():
    def __init__(self, log_file_path):
        self.log_file_path = log_file_path
        self.log_dict = dict()
    
    def log(self, global_epoch, key, msg):
        if key not in self.log_dict.keys():
            self.log_dict[key] = dict()
        current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
        self.log_dict[key] = {
            "time": current_time,
            "global_epoch": global_epoch,
            "msg": msg
        }
        log_str = f"{current_time} Global epoch: {global_epoch}, {key}, {msg}\n"
        with open(self.log_file_path, "a") as f:
            f.write(log_str)
        print(log_str)

In [44]:
import time
current_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
log_file_path = f"train_log_{current_time}.json"
logger = simple_logger(log_file_path)

num_epoch = 1000
loss_weights = {
    "reconL2": 1.0, 
    "reconL1": 0.1, 
    "perceptual": 0.05, 
    "codebook": 0.1}
val_per_epoch = 20
num_train_batch = len(train_loader)
num_val_batch = len(val_loader)

In [45]:
for idx_epoch in range(num_epoch):
    model.train()
    epoch_loss_train = {
        "reconL2": [],
        "reconL1": [],
        "perceptual": [],
        "codebook": [],
        "total": [],
    }

    for idx_batch, batch in enumerate(train_loader):
        x = batch["image"].to(device)
        # print x size
        # print("x size is ", x.size())
        xrec, cb_loss = model(x)
        optimizer.zero_grad()
        
        total_dist = preceptual_model(x, xrec)
        reconL2_loss = F.mse_loss(x, xrec)
        reconL1_loss = F.l1_loss(x, xrec)
        perceptual_loss = total_dist
        codebook_loss = cb_loss
        total_loss = loss_weights["reconL2"] * reconL2_loss + \
                        loss_weights["reconL1"] * reconL1_loss + \
                        loss_weights["perceptual"] * perceptual_loss + \
                        loss_weights["codebook"] * codebook_loss
        epoch_loss_train["reconL2"].append(reconL2_loss.item())
        epoch_loss_train["reconL1"].append(reconL1_loss.item())
        epoch_loss_train["perceptual"].append(perceptual_loss.item())
        epoch_loss_train["codebook"].append(codebook_loss.item())
        epoch_loss_train["total"].append(total_loss.item())
        print(f"<{idx_epoch}> [{idx_batch}/{num_train_batch}] Total loss: {total_loss.item()}")
        total_loss.backward()
        optimizer.step()
    
    for key in epoch_loss_train.keys():
        epoch_loss_train[key] = np.asanyarray(epoch_loss_train[key])
        logger.log(idx_epoch, f"train_{key}_mean", epoch_loss_train[key].mean())
        logger.log(idx_epoch, f"train_{key}_std", epoch_loss_train[key].std())

    # validation
    if idx_epoch % val_per_epoch == 0:
        model.eval()
        epoch_loss_val = {
            "reconL2": [],
            "reconL1": [],
            "perceptual": [],
            "codebook": [],
            "total": [],
        }
        with torch.no_grad():
            for idx_batch, batch in enumerate(val_loader):
                x = batch["image"].to(device)
                xrec, cb_loss = model(x)
                total_dist = preceptual_model(x, xrec)
                reconL2_loss = F.mse_loss(x, xrec)
                reconL1_loss = F.l1_loss(x, xrec)
                perceptual_loss = total_dist
                codebook_loss = cb_loss
                total_loss = loss_weights["reconL2"] * reconL2_loss + \
                                loss_weights["reconL1"] * reconL1_loss + \
                                loss_weights["perceptual"] * perceptual_loss + \
                                loss_weights["codebook"] * codebook_loss
                epoch_loss_val["reconL2"].append(reconL2_loss.item())
                epoch_loss_val["reconL1"].append(reconL1_loss.item())
                epoch_loss_val["perceptual"].append(perceptual_loss.item())
                epoch_loss_val["codebook"].append(codebook_loss.item())
                epoch_loss_val["total"].append(total_loss.item())
                print(f"<{idx_epoch}> [{idx_batch}/{num_val_batch}] Total loss: {total_loss.item()}")
        
        for key in epoch_loss_val.keys():
            epoch_loss_val[key] = np.asanyarray(epoch_loss_val[key])
            logger.log(idx_epoch, f"val_{key}_mean", epoch_loss_val[key].mean())
            logger.log(idx_epoch, f"val_{key}_std", epoch_loss_val[key].std())


<0> [0/18] Total loss: 0.9360347986221313
<0> [1/18] Total loss: 1.0179909467697144
<0> [2/18] Total loss: 0.4376298785209656
<0> [3/18] Total loss: 0.2809537351131439
<0> [4/18] Total loss: 0.20731933414936066
<0> [5/18] Total loss: 0.1707051396369934
<0> [6/18] Total loss: 0.1455327719449997
<0> [7/18] Total loss: 0.12855277955532074
<0> [8/18] Total loss: 0.11861488223075867
<0> [9/18] Total loss: 0.11713769286870956
<0> [10/18] Total loss: 0.1187293529510498
<0> [11/18] Total loss: 0.11842599511146545
<0> [12/18] Total loss: 0.11711177974939346
<0> [13/18] Total loss: 0.12094827741384506
<0> [14/18] Total loss: 0.11760570108890533
<0> [15/18] Total loss: 0.11644783616065979
<0> [16/18] Total loss: 0.11109942197799683
<0> [17/18] Total loss: 0.10906561464071274
2024-06-25-22-35-58 Global epoch: 0, train_reconL2_mean, 0.16187313116259044

2024-06-25-22-35-58 Global epoch: 0, train_reconL2_std, 0.25110155183205246

2024-06-25-22-35-58 Global epoch: 0, train_reconL1_mean, 0.26168151282

OMP: Error #179: Function Can't open SHM failed:
OMP: System error #0: Success
OMP: Error #179: Function Can't open SHM failed:
OMP: System error #0: Success
OMP: Error #179: Function Can't open SHM failed:
OMP: System error #0: Success
OMP: Error #179: Function Can't open SHM failed:
OMP: System error #0: Success


RuntimeError: DataLoader worker (pid(s) 17472, 17473, 17474) exited unexpectedly

In [None]:
# # show all the data shape in tsv1_ct

# import os
# import glob
# import nibabel as nib
# ct_file_list = sorted(glob.glob("tsv1_ct/*.nii.gz"))
# print("total ct files are ", len(ct_file_list))
# len_ct = len(ct_file_list)

# for idx_ct, ct_path in enumerate(ct_file_list):
#     ct_file = nib.load(ct_path)
#     ct_filename = ct_path.split("/")[-1]
#     try:
#         ct_data = ct_file.get_fdata()
#         shapes = ct_data.shape
#         # if any shapes is less than 64, we need to move it to tsv1_ct_small
#         if shapes[0] < 64 or shapes[1] < 64 or shapes[2] < 64:
#             print(f"[{idx_ct+1}]/[{len_ct}] {ct_path} shape is {ct_data.shape} the pixel dimension is {shapes}")
#             cmd = f"mv {ct_path} tsv1_ct_small/{ct_filename}"
#             print(cmd)
#             os.system(cmd)
#     except Exception as e:
#         print(f"[{idx_ct+1}]/[{len_ct}] {ct_path} has error: {e}")
#     # print(f"{ct_path} shape is {ct_data.shape} the pixel dimension is {}")
#     # cmd = f"mv {ct_file} tsv1_ct_small/{ct_filename}"
#     # print(cmd)

total ct files are  1172
[713]/[1172] tsv1_ct/s0864.nii.gz has error: CRC check failed 0x5ab25329 != 0x84f031a9
[787]/[1172] tsv1_ct/s0955.nii.gz shape is (149, 149, 29) the pixel dimension is (149, 149, 29)
mv tsv1_ct/s0955.nii.gz tsv1_ct_small/s0955.nii.gz
[797]/[1172] tsv1_ct/s0968.nii.gz shape is (193, 193, 59) the pixel dimension is (193, 193, 59)
mv tsv1_ct/s0968.nii.gz tsv1_ct_small/s0968.nii.gz
[870]/[1172] tsv1_ct/s1047.nii.gz shape is (129, 129, 62) the pixel dimension is (129, 129, 62)
mv tsv1_ct/s1047.nii.gz tsv1_ct_small/s1047.nii.gz
[871]/[1172] tsv1_ct/s1048.nii.gz shape is (235, 235, 33) the pixel dimension is (235, 235, 33)
mv tsv1_ct/s1048.nii.gz tsv1_ct_small/s1048.nii.gz
[886]/[1172] tsv1_ct/s1065.nii.gz shape is (319, 319, 54) the pixel dimension is (319, 319, 54)
mv tsv1_ct/s1065.nii.gz tsv1_ct_small/s1065.nii.gz
[920]/[1172] tsv1_ct/s1109.nii.gz shape is (267, 267, 60) the pixel dimension is (267, 267, 60)
mv tsv1_ct/s1109.nii.gz tsv1_ct_small/s1109.nii.gz
[953]/