In [2]:
import matplotlib.pyplot as plt
import wandb
from tqdm.auto import tqdm
import os, sys
import time
import numpy as np
import collections
import torch
from torch import Tensor, nn
from torch.utils.data import Dataset, DataLoader
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler
from torchvision.transforms import Resize, CenterCrop
from typing import Iterable, Dict, Callable, Tuple
import torch.nn.functional as F
import matplotlib.pyplot as plt
from random import randrange
from copy import deepcopy

from nnunet.training.model_restore import restore_model
import batchgenerators
from batchgenerators.transforms.local_transforms import *
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.paths import preprocessing_output_dir
from nnunet.training.dataloading.dataset_loading import *
from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
from nnunet.run.load_pretrained_weights import load_pretrained_weights

sys.path.append('..')
from dataset import CalgaryCampinasDataset, ACDCDataset, MNMDataset
from utils import EarlyStopping, epoch_average, average_metrics
from model.layers import ConvBlock
from model.dae import AugResDAE
from model.unet import UNet2D
from model.wrapper import Frankenstein
from losses import MNMCriterionAE, SampleDice, UnetDice
from trainer.ae_trainer import AETrainerACDC

nnUnet_prefix = '../../../nnUNet/'

In [10]:
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Module
from typing import Iterable, Dict, Callable, Tuple, Union
from itertools import chain


class ConvBlock(nn.Module):
    """Conv block for the AE model.

    Dynamic conv block that supports both down and up
    convolutions as well as different block sizes. It's
    the main building block for the auto encoder.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        in_dim: int,
        kernel_size: int = 3,
        stride: int = 2,
        padding: int = 1,
        block_size: int = 1,
        reverse: bool = False,
        residual: bool = False
    ):
        super().__init__()
        self.block_size = block_size
        self.residual = residual

        in_channels = int(in_channels)
        out_channels = int(out_channels)
        in_dim = int(in_dim)

        if not reverse:
            if block_size > 1:
                self.block = nn.Sequential(
                    *chain.from_iterable(
                        [
                            [
                                nn.Conv2d(
                                    in_channels,
                                    in_channels,
                                    kernel_size=kernel_size,
                                    stride=1,
                                    padding=padding,
                                ),
                                nn.LayerNorm(
                                    torch.Size([in_channels, in_dim, in_dim])
                                ),
                                nn.LeakyReLU(),
                            ]
                            for _ in range(block_size - 1)
                        ]
                    )
                )

            self.sample = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                ),
                nn.LayerNorm(torch.Size([out_channels, in_dim, in_dim])),
                nn.LeakyReLU(),
            )

        else:
            if padding == 0:
                output_padding = 0
            else:
                output_padding = 1

            if block_size > 1:
                self.block = nn.Sequential(
                    *chain.from_iterable(
                        [
                            [
                                nn.ConvTranspose2d(
                                    in_channels,
                                    in_channels,
                                    kernel_size=kernel_size,
                                    stride=1,
                                    padding=padding,
                                    output_padding=0,
                                ),
                                nn.LayerNorm(
                                    torch.Size([in_channels, in_dim, in_dim])
                                ),
                                nn.LeakyReLU(),
                            ]
                            for _ in range(block_size - 1)
                        ]
                    )
                )

            self.sample = nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels,
                    out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    output_padding=output_padding,
                ),
                nn.LayerNorm(torch.Size([out_channels, in_dim, in_dim])),
                nn.LeakyReLU(),
            )

    def forward(self, x: Tensor) -> Tensor:
        if self.block_size > 1:
            x = x + self.block(x) if self.residual else self.block(x)
        return self.sample(x)

In [19]:
x_in = torch.ones((2,16,32,32))

In [20]:
layer = ConvBlock(
    in_channels=16,
    out_channels=64,
    in_dim=32,
    kernel_size=3,
    stride=1,
    padding=1,
    block_size=2,
    reverse=False,
    residual=True
)

In [21]:
tmp = layer(x_in)

In [25]:
test = nn.ModuleList(nn.Linear(1,1) for i in range(10))

In [62]:
class ChannelAE(nn.Module):
    """Autoencoder (AE) class to transform U-Net feature maps.

    Module that dynamically builds an AE. It expects a certain
    input shape (due to the fc layer). It uses conv blocks
    (see layers.py) and supports different depths, block sizes
    and latent dimensions. Currently, it only supports a dense
    bottleneck.
    """
    
    def __init__(
        self,
        in_channels: int,
        in_dim: int,
        latent_dim: int = 128,
        depth: int = 3,
        block_size: int = 1,
        residual: bool = False
    ):
        super().__init__()
        self.on = True
        self.in_channels = in_channels
        self.in_dim = in_dim
        self.depth = depth
        self.latent_dim = latent_dim
        self.block_size = block_size
        self.residual = residual
        
        
        self.init = ConvBlock(
            in_channels=self.in_channels,
            out_channels=self.in_channels,
            in_dim=self.in_dim,
            block_size=self.block_size,
            residual=self.residual,
            kernel_size=3,
            stride=1,
            padding=1,)
        self.encoder = self._build_encoder()
        self.decoder = self._build_decoder()
        self.out     = nn.Conv2d(in_channels, in_channels, 1)
        
        
        
    def _build_encoder(self):
        encoder_list = nn.ModuleList(
            ConvBlock(
                in_channels=self.in_channels // 4**i,
                out_channels=self.in_channels // 4**(i+1),
                in_dim=self.in_dim,
                block_size=self.block_size,
                residual=self.residual,
                kernel_size=3,
                stride=1,
                padding=1,
            ) for i in range(self.depth)
        )
        return encoder_list 

    
    def _build_decoder(self):
        decoder_list = nn.ModuleList(
            ConvBlock(
                in_channels=self.in_channels // 4**(i+1),
                out_channels=self.in_channels // 4**i,
                in_dim=self.in_dim,
                block_size=self.block_size,
                residual=self.residual,
                kernel_size=3,
                stride=1,
                padding=1,
            ) for i in reversed(range(self.depth))
        )
        return decoder_list
    
    
    def forward(self, x):
        encoder_outputs = []
        
        # Encoding
        x = self.init(x)
        encoder_outputs.append(x)
        for enc_layer in self.encoder:
            x = enc_layer(x)
            encoder_outputs.append(x)

        # Decoding
        for i, dec_layer in enumerate(self.decoder):
            if i > 0:
                x = dec_layer(x) + encoder_outputs[-(i + 2)]  # Skip connection
            else:
                x = dec_layer(x) # no skip connection if bottleneck
        # Output layer
        x = self.out(x)
        return x

In [63]:
# Your ChannelAE class definition here

# Instantiate the model
model = ChannelAE(in_channels=64, in_dim=32)

# Generate a dummy input
dummy_input = torch.randn(1, 64, 32, 32)  # Batch size of 1, 64 channels, 32x32 spatial dimension

# Run forward pass
output = model(dummy_input)

print("Output shape:", output.shape)

torch.Size([1, 16, 32, 32])
torch.Size([1, 4, 32, 32])
torch.Size([1, 1, 32, 32])
torch.Size([1, 4, 32, 32]) torch.Size([1, 4, 32, 32])
torch.Size([1, 16, 32, 32]) torch.Size([1, 16, 32, 32])
torch.Size([1, 64, 32, 32]) torch.Size([1, 64, 32, 32])
Output shape: torch.Size([1, 64, 32, 32])


In [31]:
for i in reversed(range(model.depth)):
    print(i)

2
1
0
