<a href="https://colab.research.google.com/github/alim98/Thesis/blob/main/VMAE_PRO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install the required packages
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install einops timm lightning wandb monai gitpython


Collecting lightning
  Downloading lightning-2.5.0.post0-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.6.1-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.0.post0-py3-none-any.whl.metadata (21 kB)
Downloading lightning-2.5.0.post0-py3-none-any.whl (815 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.2/815.2 kB[0m [31m49.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading monai-1.4.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m77.3 MB/s[0

# Download DS

In [2]:
!wget https://cloud.imi.uni-luebeck.de/s/xcZrLSQYtK68em8/download/OASIS.zip

--2025-01-10 21:18:15--  https://cloud.imi.uni-luebeck.de/s/xcZrLSQYtK68em8/download/OASIS.zip
Resolving cloud.imi.uni-luebeck.de (cloud.imi.uni-luebeck.de)... 141.83.20.118
Connecting to cloud.imi.uni-luebeck.de (cloud.imi.uni-luebeck.de)|141.83.20.118|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1444971437 (1.3G) [application/zip]
Saving to: ‘OASIS.zip’


2025-01-10 21:19:47 (15.2 MB/s) - ‘OASIS.zip’ saved [1444971437/1444971437]



In [10]:

!unzip -q OASIS.zip


replace OASIS/imagesTr/OASIS_0043_0000.nii.gz? [y]es, [n]o, [A]ll, [N]one, [r]ename: A


In [12]:
! cp /content/OASIS/OASIS_dataset.json /content/

In [11]:

import json
import os

json_path = "OASIS/OASIS_dataset.json"
with open(json_path, "r") as f:
    data = json.load(f)

# Update paths dynamically
base_path = "OASIS"
for entry in data["training"]:
    entry["image"] = os.path.join(base_path, entry["image"].lstrip("./"))
    entry["label"] = os.path.join(base_path, entry["label"].lstrip("./"))
    entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))
for entry in data["test"]:
    entry["image"] = os.path.join(base_path, entry["image"].lstrip("./"))
    entry["label"] = os.path.join(base_path, entry["label"].lstrip("./"))
    entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))
for entry in data["registration_test"]:
    entry["fixed"] = os.path.join(base_path, entry["fixed"].lstrip("./"))
    entry["moving"] = os.path.join(base_path, entry["moving"].lstrip("./"))
    # entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))
for entry in data["registration_val"]:
    entry["fixed"] = os.path.join(base_path, entry["fixed"].lstrip("./"))
    entry["moving"] = os.path.join(base_path, entry["moving"].lstrip("./"))
    # entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))

# Save updated JSON
with open(json_path, "w") as f:
    json.dump(data, f, indent=4)

## new oasis dataset test

In [None]:
from torch.utils.data import Dataset
import nibabel as nib
import torch
from monai.transforms import Resize

class OASIS_Dataset(Dataset):
    def __init__(self, json_path, mode="training", input_dim=(128, 128, 128), is_pair=False):
        self.input_dim = input_dim
        self.is_pair = is_pair

        # Load JSON data
        with open(json_path, "r") as f:
            data = json.load(f)

        if mode == "training":
            self.samples = data["training"]
        elif mode == "test":
            self.samples = data["test"]
        elif mode == "registration_val":
            self.samples = data["registration_val"]
        elif mode == "registration_test":
            self.samples = data["registration_test"]
        else:
            raise ValueError(f"Invalid mode: {mode}")

        self.transforms_image = Resize(spatial_size=input_dim)
        self.transforms_mask = Resize(spatial_size=input_dim, mode="nearest")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        if self.is_pair:
            # Load fixed and moving images for registration
            sample = self.samples[index]
            fixed_path = sample["fixed"]
            moving_path = sample["moving"]

            fixed = nib.load(fixed_path).get_fdata()
            moving = nib.load(moving_path).get_fdata()

            fixed = self.transforms_image(torch.from_numpy(fixed).unsqueeze(0).float())
            moving = self.transforms_image(torch.from_numpy(moving).unsqueeze(0).float())

            # Load segmentation masks if available
            # Assuming mask paths are provided; adjust accordingly
            fixed_mask_path = sample.get("fixed_mask", None)
            moving_mask_path = sample.get("moving_mask", None)

            if fixed_mask_path and moving_mask_path:
                fixed_mask = nib.load(fixed_mask_path).get_fdata()
                moving_mask = nib.load(moving_mask_path).get_fdata()

                fixed_mask = self.transforms_mask(torch.from_numpy(fixed_mask).unsqueeze(0).long())
                moving_mask = self.transforms_mask(torch.from_numpy(moving_mask).unsqueeze(0).long())
            else:
                # If masks are not available, return dummy masks
                fixed_mask = torch.zeros_like(fixed)
                moving_mask = torch.zeros_like(moving)

            return fixed, moving, fixed_mask, moving_mask
        else:
            # Load unpaired data (image, label, mask)
            sample = self.samples[index]
            image_path = sample["image"]
            label_path = sample["label"]
            mask_path = sample.get("mask", None)

            image = nib.load(image_path).get_fdata()
            label = nib.load(label_path).get_fdata()

            image = self.transforms_image(torch.from_numpy(image).unsqueeze(0).float())
            label = self.transforms_mask(torch.from_numpy(label).unsqueeze(0).long())

            if mask_path:
                mask = nib.load(mask_path).get_fdata()
                mask = self.transforms_mask(torch.from_numpy(mask).unsqueeze(0).long())
            else:
                mask = torch.zeros_like(label)

            return image, label, mask
from torch.utils.data import DataLoader

# For training (paired data)
train_dataset = OASIS_Dataset(
    json_path="OASIS/OASIS_dataset.json",
    mode="training",
    input_dim=(128, 128, 128),
    is_pair=True  # Set to True for paired data
)
train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=4  # Adjust based on your system
)

# For validation
val_dataset = OASIS_Dataset(
    json_path="OASIS/OASIS_dataset.json",
    mode="registration_val",
    input_dim=(128, 128, 128),
    is_pair=True  # Set to True for paired data
)
val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=4
)


In [None]:
# prompt: summerize json attributes

import json
import os

json_path = "OASIS/OASIS_dataset.json"
with open(json_path, 'r') as f:
    data = json.load(f)

def summarize_json(data):
    summary = {}
    for key, value in data.items():
        if isinstance(value, list):
            summary[key] = {
                "count": len(value),
                "example": value[0] if value else None  # Example item
            }
        elif isinstance(value, dict):
            summary[key] = summarize_json(value) # Recursive call for nested dicts
        else:
            summary[key] = value
    return summary

summary = summarize_json(data)
print(json.dumps(summary, indent=2))

{
  "name": "OASIS",
  "release": "1.1",
  "description": "OASIS task of Learn2Reg Dataset. Please see https://learn2reg.grand-challenge.org/ for more information. These data were prepared by Andrew Hoopes and Adrian V. Dalca for the following HyperMorph paper. If you use this collection please cite the following and refer to the OASIS Data Use Agreement. ",
  "licence": "Open Access Series of Imaging Studies (OASIS): Cross-Sectional MRI Data in Young, Middle Aged, Nondemented, and Demented Older Adults. Marcus DS, Wang TH, Parker J, Csernansky JG, Morris JC, Buckner RL. Journal of Cognitive Neuroscience, 19, 1498-1507.",
  "reference": "",
  "pairings": "unpaired",
  "provided_data": {
    "0": {
      "count": 3,
      "example": "image"
    }
  },
  "registration_direction": {
    "fixed": 0,
    "moving": 0
  },
  "modality": {
    "0": "MR"
  },
  "img_shift": {
    "fixed": "Patient A",
    "moving": "Patient B"
  },
  "labels": {
    "0": {}
  },
  "tensorImageSize": {
    "0": 

In [None]:

# 2. Import Necessary Modules
import logging
import os
import sys
import math
import yaml
import glob
import pickle
import random
import json
from functools import reduce
from typing import Tuple, Dict, Any, List, Set, Optional, Union, Callable, Type

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import lightning as L
from lightning import LightningModule, Trainer
from lightning.pytorch.loggers import WandbLogger

from einops import rearrange
import timm
import wandb
import monai
import nibabel as nib
import warnings
import monai.transforms as transforms
import logging

# Set random seeds for reproducibility
torch.manual_seed(42)
random.seed(42)
wandb.login()
wandb_logger = WandbLogger(project="hvit_test2")  # Replace with your project name

# 3. Define Utility Classes and Functions
# 3.1. Logger Class (Ensure only this definition exists)
import logging
import os

class Logger:
    def __init__(self, save_dir):
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)

        # Create handlers
        console_handler = logging.StreamHandler()

        # Create the directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)
        file_handler = logging.FileHandler(os.path.join(save_dir, "logfile.log"))

        # Create formatters and add to handlers
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        console_handler.setFormatter(formatter)
        file_handler.setFormatter(formatter)

        # Add handlers to the logger
        self.logger.addHandler(console_handler)
        self.logger.addHandler(file_handler)

    def info(self, message):
        self.logger.info(message)

    def warning(self, message):
        self.logger.warning(message)

    def error(self, message):
        self.logger.error(message)

    def debug(self, message):
        self.logger.debug(message)

# 3.2. Utility Functions
def read_yaml_file(file_path):
    """
    Reads a YAML file and returns the content as a dictionary.
    """
    with open(file_path, 'r') as file:
        try:
            content = yaml.safe_load(file)
            return content
        except yaml.YAMLError as e:
            print(f"Error reading YAML file: {e}")
            return None

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_one_hot(inp_seg, num_labels):
    B, C, H, W, D = inp_seg.shape
    inp_onehot = nn.functional.one_hot(inp_seg.long(), num_classes=num_labels)
    inp_onehot = inp_onehot.squeeze(dim=1)
    inp_onehot = inp_onehot.permute(0, 4, 1, 2, 3).contiguous()
    return inp_onehot

def DiceScore(y_pred, y_true, num_class):
    y_true = nn.functional.one_hot(y_true, num_classes=num_class)
    y_true = torch.squeeze(y_true, 1)
    y_true = y_true.permute(0, 4, 1, 2, 3).contiguous()
    intersection = y_pred * y_true
    intersection = intersection.sum(dim=[2, 3, 4])
    union = torch.pow(y_pred, 2).sum(dim=[2, 3, 4]) + torch.pow(y_true, 2).sum(dim=[2, 3, 4])
    dsc = (2.*intersection) / (union + 1e-5)
    return dsc

# 3.3. Loss Functions
class Grad3D(torch.nn.Module):
    """
    N-D gradient loss.
    """
    def __init__(self, penalty='l1', loss_mult=None):
        super().__init__()
        self.penalty = penalty
        self.loss_mult = loss_mult

    def forward(self, y_pred):
        dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
        dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
        dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])

        if self.penalty == 'l2':
            dy = dy * dy
            dx = dx * dx
            dz = dz * dz

        d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz)
        grad = d / 3.0

        if self.loss_mult is not None:
            grad *= self.loss_mult
        return grad

class DiceLoss(nn.Module):
    """Dice loss"""
    def __init__(self, num_class=36):
        super().__init__()
        self.num_class = num_class

    def forward(self, y_pred, y_true):
        y_true = nn.functional.one_hot(y_true, num_classes=self.num_class)
        y_true = torch.squeeze(y_true, 1)
        y_true = y_true.permute(0, 4, 1, 2, 3).contiguous()
        intersection = y_pred * y_true
        intersection = intersection.sum(dim=[2, 3, 4])
        union = torch.pow(y_pred, 2).sum(dim=[2, 3, 4]) + torch.pow(y_true, 2).sum(dim=[2, 3, 4])
        dsc = (2.*intersection) / (union + 1e-5)
        dsc_loss = (1-torch.mean(dsc))
        return dsc_loss

loss_functions = {
    "mse": nn.MSELoss(),
    "dice": DiceLoss(num_class=36),
    "grad": Grad3D(penalty='l2')
}

# 4. Define the Dataset Class
class OASIS_Dataset(Dataset):
    def __init__(self, json_path, mode="training", input_dim=(128, 128, 128), is_pair=False):
        self.input_dim = input_dim
        self.is_pair = is_pair

        # Load JSON data
        with open(json_path, "r") as f:
            data = json.load(f)

        if mode == "training":
            self.samples = data["training"]
        elif mode == "test":
            self.samples = data["test"]
        elif mode == "registration_val":
            self.samples = data["registration_val"]
        elif mode == "registration_test":
            self.samples = data["registration_test"]
        else:
            raise ValueError(f"Invalid mode: {mode}")

        self.transforms_image = transforms.Resize(spatial_size=input_dim)
        self.transforms_mask = transforms.Resize(spatial_size=input_dim, mode="nearest")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        if self.is_pair:
            # Load fixed and moving images for registration
            sample = self.samples[index]
            fixed_path = sample["fixed"]
            moving_path = sample["moving"]

            # Load images
            fixed = nib.load(fixed_path).get_fdata()
            moving = nib.load(moving_path).get_fdata()

            # Apply transforms
            fixed = self.transforms_image(torch.from_numpy(fixed).unsqueeze(0).float())
            moving = self.transforms_image(torch.from_numpy(moving).unsqueeze(0).float())

            # Load segmentation masks if available
            fixed_mask_path = sample.get("fixed_mask", None)
            moving_mask_path = sample.get("moving_mask", None)

            if fixed_mask_path and moving_mask_path:
                fixed_mask = nib.load(fixed_mask_path).get_fdata()
                moving_mask = nib.load(moving_mask_path).get_fdata()

                fixed_mask = self.transforms_mask(torch.from_numpy(fixed_mask).unsqueeze(0).long())
                moving_mask = self.transforms_mask(torch.from_numpy(moving_mask).unsqueeze(0).long())
            else:
                # If masks are not available, return dummy masks
                fixed_mask = torch.zeros_like(fixed)
                moving_mask = torch.zeros_like(moving)

            return fixed, moving, fixed_mask, moving_mask
        else:
            # Load two random images for unpaired registration
            selected_samples = random.sample(self.samples, 2)

            # Load source image
            src_sample = selected_samples[0]
            src_path = src_sample["image"]
            src = nib.load(src_path).get_fdata()
            src = self.transforms_image(torch.from_numpy(src).unsqueeze(0).float())

            # Load source mask if available
            src_mask_path = src_sample.get("mask", None)
            if src_mask_path:
                src_mask = nib.load(src_mask_path).get_fdata()
                src_mask = self.transforms_mask(torch.from_numpy(src_mask).unsqueeze(0).long())
            else:
                src_mask = torch.zeros_like(src)

            # Load target image
            tgt_sample = selected_samples[1]
            tgt_path = tgt_sample["image"]
            tgt = nib.load(tgt_path).get_fdata()
            tgt = self.transforms_image(torch.from_numpy(tgt).unsqueeze(0).float())

            # Load target mask if available
            tgt_mask_path = tgt_sample.get("mask", None)
            if tgt_mask_path:
                tgt_mask = nib.load(tgt_mask_path).get_fdata()
                tgt_mask = self.transforms_mask(torch.from_numpy(tgt_mask).unsqueeze(0).long())
            else:
                tgt_mask = torch.zeros_like(tgt)

            return src, tgt, src_mask, tgt_mask

# 5. Define the Model Components
# (Assuming all model classes are defined correctly as per your initial code)

# 6. Define the LightningModule
class LiTHViT(LightningModule):
    def __init__(self, args, config, wandb_logger=None, save_model_every_n_epochs=10):
        super().__init__()
        self.automatic_optimization = False
        self.args = args
        self.config = config
        self.best_val_loss = 1e8
        self.save_model_every_n_epochs = save_model_every_n_epochs
        self.lr = args.lr
        self.last_epoch = 0
        self.tgt2src_reg = args.tgt2src_reg
        self.hvit_light = args.hvit_light
        self.precision = args.precision

        # Initialize logger
        self.custom_logger = Logger(save_dir="./logs")

        self.hvit = HierarchicalViT_Light(config) if self.hvit_light else HierarchicalViT(config)

        self.loss_weights = {
            "mse": self.args.mse_weights,
            "dice": self.args.dice_weights,
            "grad": self.args.grad_weights
        }
        self.wandb_logger = wandb_logger
        self.test_step_outputs = []

    def _forward(self, batch, calc_score: bool = False, tgt2src_reg: bool = False):
        _loss = {}
        _score = 0.

        dtype_map = {
            'bf16': torch.bfloat16,
            'fp32': torch.float32,
            'fp16': torch.float16
        }
        dtype_ = dtype_map.get(self.precision, torch.float32)

        with torch.autocast(device_type="cuda", dtype=dtype_):
            if tgt2src_reg:
                target, source = batch[0].to(dtype=dtype_), batch[1].to(dtype=dtype_)
                tgt_seg, src_seg = batch[2], batch[3]
            else:
                source, target = batch[0].to(dtype=dtype_), batch[1].to(dtype=dtype_)
                src_seg, tgt_seg = batch[2], batch[3]

            moved, flow = self.hvit(source, target)

            if calc_score:
                moved_seg = self._get_one_hot_from_src(src_seg, flow, self.args.num_labels)
                _score = DiceScore(moved_seg, tgt_seg.long(), self.args.num_labels)

            _loss = {}
            for key, weight in self.loss_weights.items():
                if key == "mse":
                    _loss[key] = weight * loss_functions[key](moved, target)
                elif key == "dice":
                    moved_seg = self._get_one_hot_from_src(src_seg, flow, self.args.num_labels)
                    _loss[key] = weight * loss_functions[key](moved_seg, tgt_seg.long())
                elif key == "grad":
                    _loss[key] = weight * loss_functions[key](flow)

            _loss["avg_loss"] = sum(_loss.values()) / len(_loss)
        return _loss, _score

    def training_step(self, batch, batch_idx):
        self.hvit.train()
        opt = self.optimizers()

        loss1, _ = self._forward(batch, calc_score=False)
        self.manual_backward(loss1["avg_loss"])
        opt.step()
        opt.zero_grad()

        if self.tgt2src_reg:
            loss2, _ = self._forward(batch, tgt2src_reg=True, calc_score=False)
            self.manual_backward(loss2["avg_loss"])
            opt.step()
            opt.zero_grad()

        total_loss = {
            key: (loss1[key].item() + loss2[key].item()) / 2 if self.tgt2src_reg and key in loss2 else loss1[key].item()
            for key in loss1.keys()
        }

        if self.wandb_logger:
            self.wandb_logger.log_metrics(total_loss, step=self.global_step)
        self.custom_logger.info(f"Batch {batch_idx} - Loss: {total_loss}")
        return total_loss

    def on_train_epoch_end(self):
        if self.current_epoch % self.save_model_every_n_epochs == 0:
            checkpoints_dir = f"./checkpoints/{self.current_epoch}"
            os.makedirs(checkpoints_dir, exist_ok=True)
            checkpoint_path = f"{checkpoints_dir}/model_epoch_{self.current_epoch}.ckpt"
            self.trainer.save_checkpoint(checkpoint_path)
            self.custom_logger.info(f"Saved model at epoch {self.current_epoch}")  # Use custom_logger

        current_lr = self.optimizers().param_groups[0]['lr']
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"learning_rate": current_lr}, step=self.global_step)

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            self.hvit.eval()
            _loss, _score = self._forward(batch, calc_score=True)

        # Log each component of the validation loss
        for loss_name, loss_value in _loss.items():
            self.log(f"val_{loss_name}", loss_value, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)

        # Log the mean validation score if available
        if _score is not None:
            self.log("val_score", _score.mean(), on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)

        # Log to wandb
        if self.wandb_logger:
            log_dict = {f"val_{k}": v.item() for k, v in _loss.items()}
            log_dict.update({
                "val_score_mean": _score.mean().item() if _score is not None else None,
            })
            self.wandb_logger.log_metrics({k: v for k, v in log_dict.items() if v is not None}, step=self.global_step)

        return {"val_loss": _loss["avg_loss"], "val_score": _score.mean().item()}

    def on_validation_epoch_end(self):
        """
        Callback method called at the end of the validation epoch.
        Saves the best model based on validation loss and logs metrics.
        """
        val_loss = self.trainer.callback_metrics.get("val_avg_loss")
        if val_loss is not None and self.current_epoch > 0:
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                checkpoints_dir = f"./checkpoints/{self.current_epoch}"
                os.makedirs(checkpoints_dir, exist_ok=True)
                best_model_path = f"{checkpoints_dir}/best_model.ckpt"
                self.trainer.save_checkpoint(best_model_path)
                if self.wandb_logger:
                    self.wandb_logger.experiment.log({
                        "best_model_saved": best_model_path,
                        "best_val_loss": self.best_val_loss.item()
                    })
                self.custom_logger.info(f"New best model saved with validation loss: {self.best_val_loss:.4f}")

    def test_step(self, batch, batch_idx):
        """
        Performs a single test step on a batch of data.
        """
        with torch.no_grad():
            self.hvit.eval()
            _, _score = self._forward(batch, calc_score=True)

        _score = _score.mean() if isinstance(_score, torch.Tensor) else torch.tensor(_score).mean()

        self.test_step_outputs.append(_score)

        # Log to wandb only if the logger is available
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"test_dice": _score.item()}, step=self.global_step)

        # Return as a dict with tensor values
        return {"test_dice": _score}

    def on_test_epoch_end(self):
        """
        Callback method called at the end of the test epoch.
        Computes and logs the average test Dice score.
        """
        # Calculate the average Dice score across all test steps
        avg_test_dice = torch.stack(self.test_step_outputs).mean()

        # Log the average test Dice score
        self.log("avg_test_dice", avg_test_dice, prog_bar=True)

        # Log to wandb if available
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"total_test_dice_avg": avg_test_dice.item()})

        # Clear the test step outputs list for the next test epoch
        self.test_step_outputs.clear()

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler for the model.
        """
        optimizer = torch.optim.Adam(self.hvit.parameters(), lr=self.lr, weight_decay=0, amsgrad=True)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=self.lr_lambda)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }

    def lr_lambda(self, epoch):
        """
        Defines the learning rate schedule.
        """
        return math.pow(1 - epoch / self.trainer.max_epochs, 0.9)

    @classmethod
    def load_from_checkpoint(cls, checkpoint_path, args=None, wandb_logger=None):
        """
        Loads a model from a checkpoint file.
        """
        checkpoint = torch.load(checkpoint_path, map_location='cpu')

        args = args or checkpoint.get('hyper_parameters', {}).get('args')
        config = checkpoint.get('hyper_parameters', {}).get('config')

        model = cls(args, config, wandb_logger)
        model.load_state_dict(checkpoint['state_dict'])

        if 'hyper_parameters' in checkpoint:
            hyper_params = checkpoint['hyper_parameters']
            for attr in ['lr', 'best_val_loss', 'last_epoch']:
                setattr(model, attr, hyper_params.get(attr, getattr(model, attr)))

        return model

    def on_save_checkpoint(self, checkpoint):
        """
        Callback to save additional information in the checkpoint.
        """
        checkpoint['hyper_parameters'] = {
            'config': self.config,
            'lr': self.lr,
            'best_val_loss': self.best_val_loss,
            'last_epoch': self.current_epoch
        }

    def _get_one_hot_from_src(self, src_seg, flow, num_labels):
        """
        Converts source segmentation to one-hot encoding and applies deformation.
        """
        src_seg_onehot = get_one_hot(src_seg, self.args.num_labels)
        deformed_segs = [
            self.hvit.spatial_trans(src_seg_onehot[:, i:i+1, ...].float(), flow.float())
            for i in range(num_labels)
        ]
        return torch.cat(deformed_segs, dim=1)
from torch.utils.data import DataLoader

# Initialize the dataset and dataloaders

# For training (unpaired data, paired on-the-fly)
train_dataset = OASIS_Dataset(
    json_path="OASIS/OASIS_dataset.json",
    mode="training",
    input_dim=(128, 128, 128),
    is_pair=False  # Set to False to enable on-the-fly pairing
)
train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=1  # Adjust based on your system
)

# For validation
val_dataset = OASIS_Dataset(
    json_path="OASIS/OASIS_dataset.json",
    mode="registration_val",
    input_dim=(128, 128, 128),
    is_pair=True  # Set to True for validation paired data
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=1
)
# Define training arguments
class Args:
    def __init__(self):
        self.lr = 0.001
        self.mse_weights = 1.0
        self.dice_weights = 1.0
        self.grad_weights = 1.0
        self.tgt2src_reg = False
        self.hvit_light = True
        self.precision = 'fp32'  # Training precision (e.g., 'bf16', 'fp16', 'fp32')
        self.num_labels = 36  # Update based on your dataset

args = Args()

# 3. Define configuration
config = {
    'WO_SELF_ATT': False,
    '_NUM_CROSS_ATT': -1,
    'out_fmaps': ['P4', 'P3', 'P2', 'P1'],  # Number of levels = 4
    'scale_level_df': 'P1',
    'upsample_df': True,
    'upsample_scale_factor': 2,
    'fpn_channels': 64,
    'start_channels': 32,
    'patch_size': [2, 2, 2, 2],  # Matches number of levels
    'backbone_net': 'fpn',
    'in_channels': 1,

    # **Debugged Lines: Update 'data_size' and add 'img_size' to match input_dim**
    'data_size': [128, 128, 128],  # Updated from [40, 48, 56]
    'img_size': [128, 128, 128],   # Added to align with input_dim

    'bias': True,
    'norm_type': 'instance',
    'kernel_size': 3,
    'depths': [1, 1, 1, 1],  # Matches number of levels
    'mlp_ratio': 2,
    'num_heads': [4, 8, 16, 32],  # Matches number of levels
    'drop_path_rate': 0.,
    'qkv_bias': True,
    'drop_rate': 0.,
    'attn_drop_rate': 0.,
    'use_seg_loss': False,
    'use_seg_proxy_loss': False,
    'num_organs': 36,  # Updated to match DiceLoss
}
# Initialize WandB logger (optional, replace with None if not using WandB)
wandb_logger = WandbLogger(project="hvit_test")  # Replace with your project name

# Instantiate the Lightning module
lit_model = LiTHViT(args, config, wandb_logger=wandb_logger)

# Define the PyTorch Lightning Trainer
trainer = Trainer(
    max_epochs=5,  # Number of epochs
    logger=wandb_logger,  # Log training metrics
    enable_checkpointing=False,  # Disable checkpointing for testing
    devices=1,  # Number of GPUs (set to 0 for CPU)
    accelerator="gpu" if torch.cuda.is_available() else "cpu",  # Use GPU if available
    precision=16 if args.precision == 'fp16' else 32  # Set precision
)

# Start training
trainer.fit(lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)


[34m[1mwandb[0m: Currently logged in as: [33malim98barnet[0m ([33malim98barnet-university-of-tehran[0m). Use [1m`wandb login --relogin`[0m to force relogin
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name | Type                  | Params | Mode 
-------------------------------------------------------
0 | hvit | HierarchicalViT_Light | 17.9 M | train
-------------------------------------------------------
17.9 M    Trainable params
0         Non-trainable params
17.9 M    Total params
71.731    Total estimated model params size (MB)
244       Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name | Type                  | Params | Mode 
-------------------------------------------------------
0 | hvit | HierarchicalViT_Light | 17.9 M | train
-------------------------------------------------------
17.9 M    Trainable params
0         Non-trainable params
17.9 M    Total params
71.731    Total estimated model params size (MB)
244       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 6

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])


2025-01-10 13:06:59,021 - __main__ - INFO - Batch 0 - Loss: {'mse': 0.003875497728586197, 'dice': 0.9468063712120056, 'grad': 1.6333691732484112e-08, 'avg_loss': 0.31689396500587463}
INFO:__main__:Batch 0 - Loss: {'mse': 0.003875497728586197, 'dice': 0.9468063712120056, 'grad': 1.6333691732484112e-08, 'avg_loss': 0.31689396500587463}


MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])


2025-01-10 13:07:01,549 - __main__ - INFO - Batch 1 - Loss: {'mse': 0.0029408661648631096, 'dice': 0.9466022253036499, 'grad': 0.0011872303439304233, 'avg_loss': 0.31691011786460876}
INFO:__main__:Batch 1 - Loss: {'mse': 0.0029408661648631096, 'dice': 0.9466022253036499, 'grad': 0.0011872303439304233, 'avg_loss': 0.31691011786460876}


MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])


2025-01-10 13:07:04,097 - __main__ - INFO - Batch 2 - Loss: {'mse': 0.004398035351186991, 'dice': 0.9476045370101929, 'grad': 0.00596154248341918, 'avg_loss': 0.3193213939666748}
INFO:__main__:Batch 2 - Loss: {'mse': 0.004398035351186991, 'dice': 0.9476045370101929, 'grad': 0.00596154248341918, 'avg_loss': 0.3193213939666748}


MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])


2025-01-10 13:07:06,634 - __main__ - INFO - Batch 3 - Loss: {'mse': 0.004532233811914921, 'dice': 0.9477837681770325, 'grad': 0.006855541840195656, 'avg_loss': 0.31972384452819824}
INFO:__main__:Batch 3 - Loss: {'mse': 0.004532233811914921, 'dice': 0.9477837681770325, 'grad': 0.006855541840195656, 'avg_loss': 0.31972384452819824}


MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 16, 16, 16])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 32, 32, 32])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])
MLP input shape after permute: torch.Size([1, 64, 64, 64, 64])


INFO: 
Detected KeyboardInterrupt, attempting graceful shutdown ...
INFO:lightning.pytorch.utilities.rank_zero:
Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

# Full

In [8]:

# import json
# import os

# json_path = "OASIS/OASIS_dataset.json"
# with open(json_path, "r") as f:
#     data = json.load(f)

# # Update paths dynamically
# base_path = "OASIS"
# for entry in data["training"]:
#     entry["image"] = os.path.join(base_path, entry["image"].lstrip("./"))
#     entry["label"] = os.path.join(base_path, entry["label"].lstrip("./"))
#     entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))
# for entry in data["test"]:
#     entry["image"] = os.path.join(base_path, entry["image"].lstrip("./"))
#     entry["label"] = os.path.join(base_path, entry["label"].lstrip("./"))
#     entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))
# for entry in data["registration_test"]:
#     entry["fixed"] = os.path.join(base_path, entry["fixed"].lstrip("./"))
#     entry["moving"] = os.path.join(base_path, entry["moving"].lstrip("./"))
#     # entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))
# for entry in data["registration_val"]:
#     entry["fixed"] = os.path.join(base_path, entry["fixed"].lstrip("./"))
#     entry["moving"] = os.path.join(base_path, entry["moving"].lstrip("./"))
#     # entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))

# # Save updated JSON
# with open(json_path, "w") as f:
#     json.dump(data, f, indent=4)

## new oasis dataset test

In [14]:
from torch.utils.data import Dataset
import nibabel as nib
import torch
from monai.transforms import Resize

class OASIS_Dataset(Dataset):
    def __init__(self, json_path, mode="training", input_dim=(128, 128, 128), is_pair=False):
        self.input_dim = input_dim
        self.is_pair = is_pair

        # Load JSON data
        with open(json_path, "r") as f:
            data = json.load(f)

        if mode == "training":
            self.samples = data["training"]
        elif mode == "test":
            self.samples = data["test"]
        elif mode == "registration_val":
            self.samples = data["registration_val"]
        elif mode == "registration_test":
            self.samples = data["registration_test"]
        else:
            raise ValueError(f"Invalid mode: {mode}")

        self.transforms_image = Resize(spatial_size=input_dim)
        self.transforms_mask = Resize(spatial_size=input_dim, mode="nearest")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        if self.is_pair:
            # Load fixed and moving images for registration
            sample = self.samples[index]
            fixed_path = sample["fixed"]
            moving_path = sample["moving"]

            fixed = nib.load(fixed_path).get_fdata()
            moving = nib.load(moving_path).get_fdata()

            fixed = self.transforms_image(torch.from_numpy(fixed).unsqueeze(0).float())
            moving = self.transforms_image(torch.from_numpy(moving).unsqueeze(0).float())

            # Load segmentation masks if available
            # Assuming mask paths are provided; adjust accordingly
            fixed_mask_path = sample.get("fixed_mask", None)
            moving_mask_path = sample.get("moving_mask", None)

            if fixed_mask_path and moving_mask_path:
                fixed_mask = nib.load(fixed_mask_path).get_fdata()
                moving_mask = nib.load(moving_mask_path).get_fdata()

                fixed_mask = self.transforms_mask(torch.from_numpy(fixed_mask).unsqueeze(0).long())
                moving_mask = self.transforms_mask(torch.from_numpy(moving_mask).unsqueeze(0).long())
            else:
                # If masks are not available, return dummy masks
                fixed_mask = torch.zeros_like(fixed)
                moving_mask = torch.zeros_like(moving)

            return fixed, moving, fixed_mask, moving_mask
        else:
            # Load unpaired data (image, label, mask)
            sample = self.samples[index]
            image_path = sample["image"]
            label_path = sample["label"]
            mask_path = sample.get("mask", None)

            image = nib.load(image_path).get_fdata()
            label = nib.load(label_path).get_fdata()

            image = self.transforms_image(torch.from_numpy(image).unsqueeze(0).float())
            label = self.transforms_mask(torch.from_numpy(label).unsqueeze(0).long())

            if mask_path:
                mask = nib.load(mask_path).get_fdata()
                mask = self.transforms_mask(torch.from_numpy(mask).unsqueeze(0).long())
            else:
                mask = torch.zeros_like(label)

            return image, label, mask
from torch.utils.data import DataLoader

# For training (paired data)
train_dataset = OASIS_Dataset(
    json_path="OASIS/OASIS_dataset.json",
    mode="training",
    input_dim=(128, 128, 128),
    is_pair=True  # Set to True for paired data
)
train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=4  # Adjust based on your system
)

# For validation
val_dataset = OASIS_Dataset(
    json_path="OASIS/OASIS_dataset.json",
    mode="registration_val",
    input_dim=(128, 128, 128),
    is_pair=True  # Set to True for paired data
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4
)


In [3]:


class OASIS_Dataset(Dataset):
    """
    Custom Dataset for OASIS Registration Tasks.
    """
    def __init__(self, json_path, mode="training", input_dim=(128, 128, 128), is_pair=False):
        self.input_dim = input_dim
        self.is_pair = is_pair

        # Load JSON data
        with open(json_path, "r") as f:
            data = json.load(f)

        if mode == "training":
            self.samples = data["training"]
        elif mode == "test":
            self.samples = data["test"]
        elif mode == "registration_val":
            self.samples = data["registration_val"]
        elif mode == "registration_test":
            self.samples = data["registration_test"]
        else:
            raise ValueError(f"Invalid mode: {mode}")

        self.transforms_image = Resize(spatial_size=input_dim)
        self.transforms_mask = Resize(spatial_size=input_dim, mode="nearest")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        if self.is_pair:
            # Load fixed and moving images for registration
            sample = self.samples[index]
            fixed_path = sample["fixed"]
            moving_path = sample["moving"]

            fixed = nib.load(fixed_path).get_fdata()
            moving = nib.load(moving_path).get_fdata()

            fixed = self.transforms_image(torch.from_numpy(fixed).unsqueeze(0).float())
            moving = self.transforms_image(torch.from_numpy(moving).unsqueeze(0).float())

            # Load segmentation masks if available
            fixed_mask_path = sample.get("fixed_mask", None)
            moving_mask_path = sample.get("moving_mask", None)

            if fixed_mask_path and moving_mask_path:
                fixed_mask = nib.load(fixed_mask_path).get_fdata()
                moving_mask = nib.load(moving_mask_path).get_fdata()

                fixed_mask = self.transforms_mask(torch.from_numpy(fixed_mask).unsqueeze(0).long())
                moving_mask = self.transforms_mask(torch.from_numpy(moving_mask).unsqueeze(0).long())
            else:
                # If masks are not available, return dummy masks
                fixed_mask = torch.zeros_like(fixed)
                moving_mask = torch.zeros_like(moving)

            return fixed, moving, fixed_mask, moving_mask
        else:
            # Load unpaired data (image, label, mask)
            sample = self.samples[index]
            image_path = sample["image"]
            label_path = sample["label"]
            mask_path = sample.get("mask", None)

            image = nib.load(image_path).get_fdata()
            label = nib.load(label_path).get_fdata()

            image = self.transforms_image(torch.from_numpy(image).unsqueeze(0).float())
            label = self.transforms_mask(torch.from_numpy(label).unsqueeze(0).long())

            if mask_path:
                mask = nib.load(mask_path).get_fdata()
                mask = self.transforms_mask(torch.from_numpy(mask).unsqueeze(0).long())
            else:
                mask = torch.zeros_like(label)

            return image, label, mask

# =======================
# 8. Define Training Arguments and Configuration
# =======================

class Args:
    """
    Training Arguments.
    """
    def __init__(self):
        self.lr = 0.001
        self.mse_weights = 1.0
        self.dice_weights = 1.0
        self.grad_weights = 1.0
        self.tgt2src_reg = False
        self.hvit_light = True
        self.precision = 'fp32'  # Options: 'bf16', 'fp16', 'fp32'
        self.num_labels = 36  # Update based on your dataset

args = Args()


# =======================
# 9. Initialize the Dataset and DataLoaders
# =======================

# For training (paired data)
train_dataset = OASIS_Dataset(
    json_path="OASIS/OASIS_dataset.json",
    mode="training",
    input_dim=(128, 128, 128),
    is_pair=False  # Set to True for paired data
)
train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=4  # Adjust based on your system
)

# For validation
val_dataset = OASIS_Dataset(
    json_path="OASIS/OASIS_dataset.json",
    mode="registration_val",
    input_dim=(128, 128, 128),
    is_pair=True  # Set to True for validation paired data
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4
)

# For testing (optional)
test_dataset = OASIS_Dataset(
    json_path="OASIS/OASIS_dataset.json",
    mode="registration_test",
    input_dim=(128, 128, 128),
    is_pair=True  # Set to True for test paired data
)
test_loader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=4
)


In [17]:
# =======================
# 1. Import Necessary Libraries
# =======================

import json
import os
import math
import random
import logging
from typing import Dict, Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import monai
from monai.transforms import Resize

import nibabel as nib

import lightning as L
from lightning import LightningModule, Trainer
from lightning.pytorch.loggers import WandbLogger

from einops import rearrange
import timm
import wandb

# =======================
# 2. Setup Logging and WandB
# =======================

# Initialize WandB
wandb.login()
wandb_logger = WandbLogger(project="hvit_mae_integration")  # Replace with your project name

import json
import os

# json_path = "OASIS/OASIS_dataset.json"
# with open(json_path, "r") as f:
#     data = json.load(f)

# # Update paths dynamically
# base_path = "OASIS"
# for entry in data["training"]:
#     entry["image"] = os.path.join(base_path, entry["image"].lstrip("./"))
#     entry["label"] = os.path.join(base_path, entry["label"].lstrip("./"))
#     entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))
# for entry in data["test"]:
#     entry["image"] = os.path.join(base_path, entry["image"].lstrip("./"))
#     entry["label"] = os.path.join(base_path, entry["label"].lstrip("./"))
#     entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))
# for entry in data["registration_test"]:
#     entry["fixed"] = os.path.join(base_path, entry["fixed"].lstrip("./"))
#     entry["moving"] = os.path.join(base_path, entry["moving"].lstrip("./"))
#     # entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))
# for entry in data["registration_val"]:
#     entry["fixed"] = os.path.join(base_path, entry["fixed"].lstrip("./"))
#     entry["moving"] = os.path.join(base_path, entry["moving"].lstrip("./"))
#     # entry["mask"] = os.path.join(base_path, entry["mask"].lstrip("./"))

# # Save updated JSON
# with open(json_path, "w") as f:
#     json.dump(data, f, indent=4)
# Custom Logger Class
class Logger:
    def __init__(self, save_dir):
        self.logger = logging.getLogger(__name__)
        self.logger.setLevel(logging.INFO)

        # Create handlers
        console_handler = logging.StreamHandler()

        # Create the directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)
        file_handler = logging.FileHandler(os.path.join(save_dir, "logfile.log"))

        # Create formatters and add to handlers
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        console_handler.setFormatter(formatter)
        file_handler.setFormatter(formatter)

        # Add handlers to the logger
        self.logger.addHandler(console_handler)
        self.logger.addHandler(file_handler)

    def info(self, message):
        self.logger.info(message)

    def warning(self, message):
        self.logger.warning(message)

    def error(self, message):
        self.logger.error(message)

    def debug(self, message):
        self.logger.debug(message)

# Initialize Custom Logger
custom_logger = Logger(save_dir="./logs")

# =======================
# 3. Define Utility Functions and Losses
# =======================

def get_one_hot(inp_seg, num_labels):
    """
    Converts segmentation labels to one-hot encoding.
    """
    inp_onehot = nn.functional.one_hot(inp_seg.long(), num_classes=num_labels)  # (B, C, H, W, D, num_labels)
    inp_onehot = torch.squeeze(inp_onehot, 1)  # Remove channel dimension if present
    inp_onehot = inp_onehot.permute(0, 5, 1, 2, 3).contiguous()  # (B, num_labels, H, W, D)
    return inp_onehot

def DiceScore(y_pred, y_true, num_class):
    """
    Computes the Dice Score for multi-class segmentation.
    """
    y_true = get_one_hot(y_true, num_class)  # (B, num_labels, H, W, D)
    intersection = y_pred * y_true
    intersection = intersection.sum(dim=[2, 3, 4])
    union = torch.pow(y_pred, 2).sum(dim=[2, 3, 4]) + torch.pow(y_true, 2).sum(dim=[2, 3, 4])
    dsc = (2. * intersection) / (union + 1e-5)
    return dsc

# Gradient Loss for Regularization
class Grad3D(nn.Module):
    """
    N-D gradient loss for smoothness regularization.
    """
    def __init__(self, penalty='l1', loss_mult=None):
        super().__init__()
        self.penalty = penalty
        self.loss_mult = loss_mult

    def forward(self, y_pred):
        dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
        dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
        dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])

        if self.penalty == 'l2':
            dy = dy ** 2
            dx = dx ** 2
            dz = dz ** 2

        grad = (dx.mean() + dy.mean() + dz.mean()) / 3.0

        if self.loss_mult is not None:
            grad *= self.loss_mult
        return grad

# Dice Loss for Segmentation
class DiceLoss(nn.Module):
    """
    Dice loss for multi-class segmentation.
    """
    def __init__(self, num_class=36):
        super().__init__()
        self.num_class = num_class

    def forward(self, y_pred, y_true):
        y_true = get_one_hot(y_true, self.num_class)  # (B, num_labels, H, W, D)
        intersection = y_pred * y_true
        intersection = intersection.sum(dim=[2, 3, 4])
        union = torch.pow(y_pred, 2).sum(dim=[2, 3, 4]) + torch.pow(y_true, 2).sum(dim=[2, 3, 4])
        dsc = (2. * intersection) / (union + 1e-5)
        dsc_loss = 1 - torch.mean(dsc)
        return dsc_loss

# Define Loss Functions
loss_functions = {
    "mse": nn.MSELoss(),
    "dice": DiceLoss(num_class=36),
    "grad": Grad3D(penalty='l2')
}

# =======================
# 4. Define the Vision Transformer (ViT) MAE Encoder for 3D Data
# =======================

# Patch Embedding for 3D ViT
class PatchEmbed3D(nn.Module):
    """
    3D Patch Embedding Layer for Vision Transformer.
    """
    def __init__(self, in_channels=1, embed_dim=768, patch_size=(16, 16, 16)):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x: (B, C, D, H, W)
        x = self.proj(x)  # (B, embed_dim, D', H', W')
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        x = self.norm(x)
        return x

# 3D Vision Transformer Encoder
class ViTMAEEncoder3D(nn.Module):
    """
    3D Vision Transformer Masked Autoencoder Encoder.
    """
    def __init__(self, img_size=(128, 128, 128), patch_size=(16, 16, 16), in_channels=2, embed_dim=768,
                 depth=12, num_heads=12, mlp_ratio=4.0, drop_rate=0.0, attn_drop_rate=0.0):
        super().__init__()
        self.patch_embed = PatchEmbed3D(in_channels=in_channels, embed_dim=embed_dim, patch_size=patch_size)
        num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2])
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,
                                       dim_feedforward=int(embed_dim * mlp_ratio),
                                       dropout=drop_rate, activation='gelu')
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # x: (B, C, D, H, W)
        x = self.patch_embed(x)  # (B, num_patches, embed_dim)
        x = x + self.pos_embed  # Add positional embedding
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)  # (B, num_patches, embed_dim)
        return x  # Feature representation

# =======================
# 5. Define Registration Head and Spatial Transformer
# =======================

# Registration Head
class RegistrationHead(nn.Sequential):
    """
    Registration head for generating displacement fields.
    """
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super().__init__()
        conv3d = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2
        )
        # Initialize weights with small random values
        nn.init.normal_(conv3d.weight, mean=0.0, std=1e-5)
        nn.init.constant_(conv3d.bias, 0)
        self.add_module('conv3d', conv3d)

# Spatial Transformer
class SpatialTransformer(nn.Module):
    """
    N-D Spatial Transformer for applying displacement fields.
    """
    def __init__(self, size, mode='bilinear'):
        super().__init__()
        self.mode = mode

        # Create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(*vectors, indexing='ij')  # For 3D, use 'ij' indexing
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor)

        # Register the grid as a buffer
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        # src: (B, C, D, H, W)
        # flow: (B, 3, D, H, W)
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # Normalize grid values to [-1, 1]
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # Permute to (B, D, H, W, C) as required by grid_sample
        new_locs = new_locs.permute(0, 2, 3, 4, 1).contiguous()

        return F.grid_sample(src, new_locs, align_corners=False, mode=self.mode)

# =======================
# 6. Define the HierarchicalViT_Module
# =======================

class HierarchicalViT_Module(nn.Module):
    """
    Hierarchical Vision Transformer (HViT) with MAE Backbone for Image Registration.
    """
    def __init__(self, config: Dict[str, Any]):
        super(HierarchicalViT_Module, self).__init__()
        self.config = config  # Initialize the config attribute

        # Initialize the 3D ViT MAE encoder
        self.encoder = ViTMAEEncoder3D(
            img_size=config['data_size'],
            patch_size=(16, 16, 16),
            in_channels=2,  # source + target
            embed_dim=config['fpn_channels'],  # Match FPN channels
            depth=config['vit_depth'],
            num_heads=config['vit_num_heads'],
            mlp_ratio=config['vit_mlp_ratio'],
            drop_rate=config['vit_drop_rate'],
            attn_drop_rate=config['vit_attn_drop_rate']
        )

        # Decoder: Simple linear layers to map transformer features to desired output
        self.decoder = nn.Sequential(
            nn.Linear(config['fpn_channels'], config['fpn_channels'] * 2),
            nn.ReLU(),
            nn.Linear(config['fpn_channels'] * 2, config['fpn_channels']),
            nn.ReLU()
        )

        # Reshape and Upsample to match spatial dimensions
        self.upsample = nn.Sequential(
            nn.ConvTranspose3d(config['fpn_channels'], config['fpn_channels'], kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose3d(config['fpn_channels'], config['fpn_channels'], kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose3d(config['fpn_channels'], config['fpn_channels'], kernel_size=2, stride=2),
            nn.ReLU(),
            nn.ConvTranspose3d(config['fpn_channels'], config['fpn_channels'], kernel_size=2, stride=2),
            nn.ReLU()
        )
        # This upsampling sequence doubles the spatial dimensions each step:
        # From 8x8x8 -> 16x16x16 -> 32x32x32 -> 64x64x64 -> 128x128x128

        # Registration head to generate displacement fields
        self.reg_head = RegistrationHead(
            in_channels=config['fpn_channels'],
            out_channels=config['ndims'],
            kernel_size=config['kernel_size']
        )

        # Spatial Transformer for applying the flow
        self.spatial_trans = SpatialTransformer(size=config['data_size'], mode='bilinear')

    def forward(self, x):
        """
        Forward pass of the HierarchicalViT with MAE backbone.
        Args:
            x: Concatenated source and target images (B, 2, D, H, W)
        Returns:
            moved: Source image moved according to the predicted flow
            flow: Predicted displacement field
        """
        features = self.encoder(x)  # (B, num_patches, embed_dim)
        decoded = self.decoder(features)  # (B, num_patches, embed_dim)

        # Reshape decoded features to spatial dimensions
        # num_patches = (128/16)^3 = 8^3 = 512
        # decoded: (B,512,768) -> (B,768,8,8,8)
        decoded_spatial = decoded.permute(0, 2, 1).reshape(-1, self.config['fpn_channels'], 8, 8, 8)

        # Upsample to match original spatial dimensions
        upsampled = self.upsample(decoded_spatial)  # (B,768,128,128,128)

        # Generate displacement field
        flow = self.reg_head(upsampled)  # (B,3,128,128,128)

        # Apply displacement field to source image
        moved = self.spatial_trans(x[:, :1, ...], flow)  # Apply flow to source image

        return moved, flow

# =======================
# 7. Define the OASIS Dataset Class
# =======================
# Define Configuration Dictionary
config = {
    'WO_SELF_ATT': False,
    '_NUM_CROSS_ATT': -1,
    'out_fmaps': ['P4', 'P3', 'P2', 'P1'],  # Number of levels = 4
    'scale_level_df': 'P1',
    'upsample_df': True,
    'upsample_scale_factor': 2,
    'fpn_channels': 768,  # Match ViT MAE embed_dim
    'start_channels': 32,
    'patch_size': [16, 16, 16, 16],  # Matches number of levels
    'backbone_net': 'mae_vit',
    'in_channels': 1,
    'data_size': [128, 128, 128],  # 3D data size
    'img_size': [128, 128, 128],   # Align with input_dim
    'bias': True,
    'norm_type': 'instance',
    'kernel_size': 3,
    'vit_depth': 12,
    'vit_num_heads': 12,
    'vit_mlp_ratio': 4.0,
    'vit_drop_rate': 0.0,
    'vit_attn_drop_rate': 0.0,
    'num_organs': 36,  # Updated to match DiceLoss
    'ndims': 3,  # Number of spatial dimensions
}
# =======================
# 10. Define the LightningModule
# =======================

class LiTHViT(LightningModule):
    """
    LightningModule for Hierarchical Vision Transformer with MAE Backbone.
    """
    def __init__(self, args, config, wandb_logger=None, save_model_every_n_epochs=10):
        super(LiTHViT, self).__init__()
        self.automatic_optimization = False  # Manual optimization
        self.args = args
        self.config = config
        self.best_val_loss = 1e8
        self.save_model_every_n_epochs = save_model_every_n_epochs
        self.lr = args.lr
        self.tgt2src_reg = args.tgt2src_reg
        self.hvit_light = args.hvit_light
        self.precision = args.precision

        # Initialize the HierarchicalViT_Module
        self.hvit = HierarchicalViT_Module(config)

        # Define loss weights
        self.loss_weights = {
            "mse": self.args.mse_weights,
            "dice": self.args.dice_weights,
            "grad": self.args.grad_weights
        }

        # Initialize loggers
        self.wandb_logger = wandb_logger
        self.custom_logger = custom_logger
        self.test_step_outputs = []

    def _forward(self, batch, calc_score: bool = False, tgt2src_reg: bool = False):
        """
        Forward pass to compute losses and optionally scores.
        """
        _loss = {}
        _score = 0.

        dtype_map = {
            'bf16': torch.bfloat16,
            'fp32': torch.float32,
            'fp16': torch.float16
        }
        dtype_ = dtype_map.get(self.precision, torch.float32)

        with torch.autocast(device_type="cuda", dtype=dtype_):
            if tgt2src_reg:
                target, source = batch[1].to(dtype=dtype_), batch[0].to(dtype=dtype_)
                tgt_seg, src_seg = batch[3], batch[2]
            else:
                source, target = batch[0].to(dtype=dtype_), batch[1].to(dtype=dtype_)
                src_seg, tgt_seg = batch[2], batch[3]

            # Concatenate source and target along channel dimension
            concatenated_input = torch.cat([source, target], dim=1)  # (B, 2, D, H, W)

            moved, flow = self.hvit(concatenated_input)

            if calc_score:
                moved_seg = self._get_one_hot_from_src(src_seg, flow, self.args.num_labels)
                _score = DiceScore(moved_seg, tgt_seg.long(), self.args.num_labels)

            _loss = {}
            for key, weight in self.loss_weights.items():
                if key == "mse":
                    _loss[key] = weight * loss_functions[key](moved, target)
                elif key == "dice":
                    moved_seg = self._get_one_hot_from_src(src_seg, flow, self.args.num_labels)
                    _loss[key] = weight * loss_functions[key](moved_seg, tgt_seg.long())
                elif key == "grad":
                    _loss[key] = weight * loss_functions[key](flow)

            _loss["avg_loss"] = sum(_loss.values()) / len(_loss)
        return _loss, _score

    def training_step(self, batch, batch_idx):
        """
        Training step.
        """
        opt = self.optimizers()

        # Forward pass and compute losses
        loss_dict, _ = self._forward(batch, calc_score=False)

        # Backward pass
        self.manual_backward(loss_dict["avg_loss"])

        # Optimizer step
        opt.step()
        opt.zero_grad()

        # Logging
        if self.wandb_logger:
            self.wandb_logger.log_metrics(loss_dict, step=self.global_step)
        self.custom_logger.info(f"Batch {batch_idx} - Loss: {loss_dict}")
        return loss_dict

    def on_train_epoch_end(self):
        """
        Actions to perform at the end of each training epoch.
        """
        # Save model checkpoint every N epochs
        if self.current_epoch % self.save_model_every_n_epochs == 0:
            checkpoints_dir = f"./checkpoints/epoch_{self.current_epoch}"
            os.makedirs(checkpoints_dir, exist_ok=True)
            checkpoint_path = f"{checkpoints_dir}/model_epoch_{self.current_epoch}.ckpt"
            self.trainer.save_checkpoint(checkpoint_path)
            self.custom_logger.info(f"Saved model at epoch {self.current_epoch}")

        # Log learning rate
        current_lr = self.optimizers().param_groups[0]['lr']
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"learning_rate": current_lr}, step=self.global_step)

    def validation_step(self, batch, batch_idx):
        """
        Validation step.
        """
        with torch.no_grad():
            _loss, _score = self._forward(batch, calc_score=True)

        # Log each component of the validation loss
        for loss_name, loss_value in _loss.items():
            self.log(f"val_{loss_name}", loss_value, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        # Log the mean validation score if available
        if _score is not None:
            self.log("val_score", _score.mean(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        # Log to WandB
        if self.wandb_logger:
            log_dict = {f"val_{k}": v.item() for k, v in _loss.items()}
            if _score is not None:
                log_dict["val_score_mean"] = _score.mean().item()
            self.wandb_logger.log_metrics(log_dict, step=self.global_step)

        return {"val_loss": _loss["avg_loss"], "val_score": _score.mean().item() if _score is not None else 0.0}

    def on_validation_epoch_end(self):
        """
        Actions to perform at the end of the validation epoch.
        """
        val_loss = self.trainer.callback_metrics.get("val_avg_loss")
        if val_loss is not None and self.current_epoch > 0:
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                checkpoints_dir = f"./checkpoints/best"
                os.makedirs(checkpoints_dir, exist_ok=True)
                best_model_path = f"{checkpoints_dir}/best_model.ckpt"
                self.trainer.save_checkpoint(best_model_path)
                if self.wandb_logger:
                    self.wandb_logger.experiment.log({
                        "best_model_saved": best_model_path,
                        "best_val_loss": self.best_val_loss.item()
                    })
                self.custom_logger.info(f"New best model saved with validation loss: {self.best_val_loss:.4f}")

    def test_step(self, batch, batch_idx):
        """
        Test step.
        """
        with torch.no_grad():
            _loss, _score = self._forward(batch, calc_score=True)

        _score_mean = _score.mean().item() if _score is not None else 0.0
        self.test_step_outputs.append(_score_mean)

        # Log to WandB
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"test_dice": _score_mean}, step=self.global_step)

        return {"test_dice": _score_mean}

    def on_test_epoch_end(self):
        """
        Actions to perform at the end of the test epoch.
        """
        # Calculate the average Dice score across all test steps
        avg_test_dice = sum(self.test_step_outputs) / len(self.test_step_outputs) if self.test_step_outputs else 0.0

        # Log the average test Dice score
        self.log("avg_test_dice", avg_test_dice, prog_bar=True)

        # Log to WandB
        if self.wandb_logger:
            self.wandb_logger.log_metrics({"total_test_dice_avg": avg_test_dice})

        # Clear the test step outputs list for the next test epoch
        self.test_step_outputs.clear()

    def configure_optimizers(self):
        """
        Configures the optimizer and learning rate scheduler for the model.
        """
        optimizer = torch.optim.Adam(self.hvit.parameters(), lr=self.lr, weight_decay=0, amsgrad=True)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=self.lr_lambda)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }

    def lr_lambda(self, epoch):
        """
        Defines the learning rate schedule.
        """
        return math.pow(1 - epoch / self.trainer.max_epochs, 0.9)

    @classmethod
    def load_from_checkpoint(cls, checkpoint_path, args=None, wandb_logger=None):
        """
        Loads a model from a checkpoint file.
        """
        checkpoint = torch.load(checkpoint_path, map_location='cpu')

        args = args or checkpoint.get('hyper_parameters', {}).get('args')
        config = checkpoint.get('hyper_parameters', {}).get('config')

        model = cls(args, config, wandb_logger)
        model.load_state_dict(checkpoint['state_dict'])

        if 'hyper_parameters' in checkpoint:
            hyper_params = checkpoint['hyper_parameters']
            for attr in ['lr', 'best_val_loss', 'last_epoch']:
                setattr(model, attr, hyper_params.get(attr, getattr(model, attr)))

        return model

    def on_save_checkpoint(self, checkpoint):
        """
        Callback to save additional information in the checkpoint.
        """
        checkpoint['hyper_parameters'] = {
            'config': self.config,
            'lr': self.lr,
            'best_val_loss': self.best_val_loss,
            'last_epoch': self.current_epoch
        }
    def get_one_hot(inp_seg, num_labels):
        """
        Converts segmentation labels to one-hot encoding.
        Args:
            inp_seg: Input segmentation tensor of shape (B, D, H, W)
            num_labels: Number of segmentation classes
        Returns:
            One-hot encoded tensor of shape (B, num_labels, D, H, W)
        """
        # Debugging: Print input shape
        print(f"Input segmentation shape: {inp_seg.shape}")

        # Apply one-hot encoding
        inp_onehot = nn.functional.one_hot(inp_seg.long(), num_classes=num_labels)  # Expected output shape: (B, D, H, W, num_labels)
        print(f"One-hot encoded shape before permute: {inp_onehot.shape}")

        # Check if the shape matches expectations
        if inp_onehot.dim() == 5:  # (B, D, H, W, num_labels)
            inp_onehot = inp_onehot.permute(0, 4, 1, 2, 3).contiguous()  # Reorder to (B, num_labels, D, H, W)
        else:
            raise ValueError(f"Unexpected one-hot encoded shape: {inp_onehot.shape}")

        print(f"One-hot encoded shape after permute: {inp_onehot.shape}")
        return inp_onehot

    def _get_one_hot_from_src(self, src_seg, flow, num_labels):
        """
        Converts source segmentation to one-hot encoding and applies deformation.
        Args:
            src_seg: Source segmentation tensor
            flow: Deformation field
            num_labels: Number of segmentation classes
        Returns:
            Deformed one-hot encoded segmentation
        """
        src_seg_onehot = get_one_hot(src_seg, num_labels)  # (B, num_labels, D, H, W)
        deformed_segs = []

        for i in range(num_labels):
            # Extract one class at a time and apply deformation
            class_seg = src_seg_onehot[:, i:i+1, ...]  # Keep the channel dimension
            deformed_class = self.hvit.spatial_trans(class_seg.float(), flow.float())
            deformed_segs.append(deformed_class)

        # Concatenate along the channel (class) dimension
        return torch.cat(deformed_segs, dim=1)  # (B, num_labels, D, H, W)

# =======================
# 11. Initialize the LightningModule and Trainer
# =======================

# Initialize the LightningModule
lit_model = LiTHViT(args, config, wandb_logger=wandb_logger)

# Define the PyTorch Lightning Trainer
trainer = Trainer(
    max_epochs=50,  # Number of epochs
    logger=wandb_logger,  # Log training metrics
    enable_checkpointing=False,  # Disable automatic checkpointing
    devices=1,  # Number of GPUs (set to 0 for CPU)
    accelerator="gpu" if torch.cuda.is_available() else "cpu",  # Use GPU if available
    precision=16 if args.precision == 'fp16' else 32  # Set precision
)

# =======================
# 12. Start Training and Testing
# =======================

# Start Training
trainer.fit(lit_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# Start Testing (Optional)
trainer.test(lit_model, dataloaders=test_loader)


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name | Type                   | Params | Mode 
--------------------------------------------------------
0 | hvit | HierarchicalViT_Module | 113 M  | train
--------------------------------------------------------
113 M     Trainable params
0         Non-trainable params
113 M     Total params
452.177   Total estimated model params size (MB)
144       Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name | Type                   | Params | Mode 
--------------------------------------------------------
0 | hvit | HierarchicalViT_Module | 113 M  | train
--------------------------------------------------------
113 M     Trainable params
0         Non-trainable params
113 M     Total params
452.177   Total estimated model params size (MB)
144       Modules in train mode
0         Modules in eval m

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 6.00 GiB. GPU 0 has a total capacity of 14.75 GiB of which 2.31 GiB is free. Process 182254 has 12.44 GiB memory in use. Of the allocated memory 12.25 GiB is allocated by PyTorch, and 66.84 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [23]:
! pip install torchinfo
from torchinfo import summary

# Initialize the model
hvit_module = HierarchicalViT_Module(config)

# Print summary
summary(hvit_module, input_size=(1, 2, 128, 128, 128), col_names=["input_size", "output_size", "num_params"])


Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [ViTMAEEncoder3D: 1, PatchEmbed3D: 2, Conv3d: 3, LayerNorm: 3, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, TransformerEncoderLayer: 3, MultiheadAttention: 4, Dropout: 4, LayerNorm: 4, Linear: 4, Dropout: 4, Linear: 4, Dropout: 4, LayerNorm: 4, LayerNorm: 2, Sequential: 1, Linear: 2, ReLU: 2, Linear: 2, ReLU: 2, ConvTranspose3d: 2, ReLU: 2, ConvTranspose3d: 2, ReLU: 2, ConvTranspose3d: 2, ReLU: 2]

In [11]:
from torch.utils.tensorboard import SummaryWriter

# Initialize the model and TensorBoard writer
hvit_module = HierarchicalViT_Module(config)
writer = SummaryWriter("runs/HierarchicalViT")

# Add the model to TensorBoard
dummy_input = torch.randn(1, 2, 128, 128, 128)  # Batch size 1, 2 channels (source+target)
writer.add_graph(hvit_module, dummy_input)
writer.close()

# Launch TensorBoard
# In terminal:
# tensorboard --logdir=runs/HierarchicalViT


In [14]:
!pip install torchviz

Collecting torchviz
  Downloading torchviz-0.0.3-py3-none-any.whl.metadata (2.1 kB)
Downloading torchviz-0.0.3-py3-none-any.whl (5.7 kB)
Installing collected packages: torchviz
Successfully installed torchviz-0.0.3


In [16]:
from torchviz import make_dot

# Initialize the model
hvit_module = HierarchicalViT_Module(config)

# Generate dummy input
dummy_input = torch.randn(1, 2, 128, 128, 128)  # Batch size 1, 2 channels (source+target)

# Perform a forward pass
moved, flow = hvit_module(dummy_input)

# Visualize the model
dot = make_dot((moved, flow), params=dict(hvit_module.named_parameters()))
dot.render("HierarchicalViT_Model", format="pdf")  # Saves to PDF


'HierarchicalViT_Model.pdf'

In [4]:
cp /content/OASIS_dataset.json /content/OASIS