In [None]:
 # Hyperparameters
    batch_size = 4
    image_size = 256
    num_epochs = 10
    learning_rate = 1e-4

    # Create dummy data for CT and MRI images (B, 1, H, W)
    ct_images = torch.randn(batch_size, 1, image_size, image_size)
    mri_images = torch.randn(batch_size, 1, image_size, image_size)
    # For demonstration, assume the target fused image is the average of CT and MRI images
    target_fused = (ct_images + mri_images) / 2.0

    # Instantiate models
    model = KaleidoFusionNet(in_channels=1, embed_dim=64, latent_dim=64, base_filters=64)
    discriminator = Discriminator(in_channels=1, base_filters=64)

    # Optimizers
    gen_optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    dis_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        discriminator.train()
        
        gen_optimizer.zero_grad()
        dis_optimizer.zero_grad()

        # Forward pass through the fusion model
        fused_output = model(ct_images, mri_images)  # (B, 1, H, W)

        # Reconstruction loss (L1 loss) between the generated output and target fused image
        rec_loss = F.l1_loss(fused_output, target_fused)
        
        # Adversarial loss for generator: aim to fool the discriminator
        pred_fake = discriminator(fused_output)
        valid_labels = torch.ones_like(pred_fake)
        adv_loss = F.binary_cross_entropy_with_logits(pred_fake, valid_labels)

        # Total generator loss (weighted sum)
        gen_loss = rec_loss + 0.001 * adv_loss
        gen_loss.backward()
        gen_optimizer.step()

        # Update discriminator: classify real fused images vs. generated ones
        dis_optimizer.zero_grad()
        # For real images (target fused), label as 1
        pred_real = discriminator(target_fused)
        loss_real = F.binary_cross_entropy_with_logits(pred_real, valid_labels)
        # For fake images, label as 0
        pred_fake = discriminator(fused_output.detach())
        fake_labels = torch.zeros_like(pred_fake)
        loss_fake = F.binary_cross_entropy_with_logits(pred_fake, fake_labels)
        dis_loss = (loss_real + loss_fake) / 2
        dis_loss.backward()
        dis_optimizer.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Gen Loss: {gen_loss.item():.4f}, Dis Loss: {dis_loss.item():.4f}")

    # Example inference
    model.eval()
    with torch.no_grad():
        fused_result = model(ct_images, mri_images)
        print("Fused result shape:", fused_result.shape)


In [None]:
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from model import MultiModalFusionModel, get_loss_function
from dataset import MultiModalDataset

def train(args):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create model
    model = MultiModalFusionModel(in_channels=1, base_filters=args.base_filters)
    model = model.to(device)
    
    # Create dataset and dataloader
    train_dataset = MultiModalDataset(
        data_dir=args.data_dir,
        split='train',
        transform=True
    )
    
    val_dataset = MultiModalDataset(
        data_dir=args.data_dir,
        split='val',
        transform=False
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    # Define loss function and optimizer
    criterion = get_loss_function(args.loss_type)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    for epoch in range(args.num_epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.num_epochs} [Train]") as pbar:
            for batch_idx, (ct_images, mri_images, target_images) in enumerate(pbar):
                # Move data to device
                ct_images = ct_images.to(device)
                mri_images = mri_images.to(device)
                target_images = target_images.to(device)
                
                # Forward pass
                optimizer.zero_grad()
                outputs = model(ct_images, mri_images)
                
                # Calculate loss
                loss = criterion(outputs, target_images)
                
                # Backward pass and optimize
                loss.backward()
                optimizer.step()
                
                # Update progress bar
                epoch_loss += loss.item()
                pbar.set_postfix(loss=loss.item())
        
        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            with tqdm(val_loader, desc=f"Epoch {epoch+1}/{args.num_epochs} [Val]") as pbar:
                for batch_idx, (ct_images, mri_images, target_images) in enumerate(pbar):
                    # Move data to device
                    ct_images = ct_images.to(device)
                    mri_images = mri_images.to(device)
                    target_images = target_images.to(device)
                    
                    # Forward pass
                    outputs = model(ct_images, mri_images)
                    
                    # Calculate loss
                    loss = criterion(outputs, target_images)
                    
                    # Update progress bar
                    val_loss += loss.item()
                    pbar.set_postfix(loss=loss.item())
        
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        # Update learning rate
        scheduler.step(avg_val_loss)
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{args.num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
            }, os.path.join(args.output_dir, 'best_model.pth'))
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % args.save_interval == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
            }, os.path.join(args.output_dir, f'checkpoint_epoch_{epoch+1}.pth'))
    
    # Plot training and validation loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.savefig(os.path.join(args.output_dir, 'loss_curve.png'))
    plt.close()
    
    print("Training completed!")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Multi-Modal Fusion Model')
    
    # Dataset parameters
    parser.add_argument('--data_dir', type=str, default='./data', help='Path to dataset directory')
    
    # Model parameters
    parser.add_argument('--base_filters', type=int, default=64, help='Number of base filters in the model')
    parser.add_argument('--loss_type', type=str, default='l1', choices=['l1', 'l2', 'mse'], help='Loss function type')
    
    # Training parameters
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training')
    parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers')
    parser.add_argument('--save_interval', type=int, default=10, help='Epoch interval to save checkpoints')
    parser.add_argument('--output_dir', type=str, default='./output', help='Directory to save outputs')
    
    args = parser.parse_args()
    train(args) 

usage: ipykernel_launcher.py [-h] [--data_dir DATA_DIR]
                             [--base_filters BASE_FILTERS]
                             [--loss_type {l1,l2,mse}]
                             [--batch_size BATCH_SIZE]
                             [--num_epochs NUM_EPOCHS]
                             [--learning_rate LEARNING_RATE]
                             [--num_workers NUM_WORKERS]
                             [--save_interval SAVE_INTERVAL]
                             [--output_dir OUTPUT_DIR]
ipykernel_launcher.py: error: unrecognized arguments: --f=c:\Users\yuvra\AppData\Roaming\jupyter\runtime\kernel-v3b1cca6da8e020ee975dd0994a4f0f9d80677dcc4.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [2]:
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

# ================================
# Model Components: KaleidoFusionNet
# ================================

#############################################
# 1. Vision Transformer (ViT) Encoder Block #
#############################################

class PatchEmbed(nn.Module):
    """
    Splits the input image into patches and embeds them.
    """
    def __init__(self, in_channels, embed_dim, patch_size=16):
        super(PatchEmbed, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, H/patch_size, W/patch_size)
        B, C, H, W = x.shape
        x = x.flatten(2)  # (B, embed_dim, num_patches)
        x = x.transpose(1, 2)  # (B, num_patches, embed_dim)
        return x

class ViTEncoder(nn.Module):
    """
    A simplified Vision Transformer encoder that splits the image into patches,
    adds positional embeddings, and applies transformer encoder layers.
    """
    def __init__(self, in_channels=1, embed_dim=64, patch_size=16, num_layers=6, num_heads=4, dropout=0.1):
        super(ViTEncoder, self).__init__()
        self.patch_embed = PatchEmbed(in_channels, embed_dim, patch_size)
        # For simplicity, assume fixed image size 256x256 -> num_patches = (256//patch_size)^2
        num_patches = (256 // patch_size) ** 2  
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.patch_embed(x)  # (B, num_patches, embed_dim)
        B, N, C = x.shape
        # If number of patches differs from assumed, interpolate positional embeddings
        if N != self.pos_embed.shape[1]:
            pos_embed = F.interpolate(self.pos_embed.transpose(1,2), size=N, mode='linear', align_corners=False).transpose(1,2)
        else:
            pos_embed = self.pos_embed
        x = x + pos_embed
        x = self.transformer(x)  # (B, num_patches, embed_dim)
        # Reshape tokens back into a feature map (assume square layout)
        h = w = int(N ** 0.5)
        x = x.transpose(1,2).reshape(B, C, h, w)
        return x

####################################
# 2. VAE Block for Latent Mapping  #
####################################

class VAEBlock(nn.Module):
    """
    Projects encoder features into a latent space using a VAE formulation.
    """
    def __init__(self, in_channels, latent_dim):
        super(VAEBlock, self).__init__()
        self.fc_mu = nn.Conv2d(in_channels, latent_dim, kernel_size=1)
        self.fc_logvar = nn.Conv2d(in_channels, latent_dim, kernel_size=1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        # x: (B, C, H, W)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

######################################################
# 3. Cross-Modal Attention Fusion (Feature Fusion)   #
######################################################

class CrossModalAttentionFusion(nn.Module):
    """
    Fuses latent representations from CT and MRI using multi-head attention.
    """
    def __init__(self, latent_dim):
        super(CrossModalAttentionFusion, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=latent_dim, num_heads=4, batch_first=True)
    
    def forward(self, z_ct, z_mri):
        # Flatten spatial dimensions: (B, latent_dim, H, W) -> (B, N, latent_dim)
        B, C, H, W = z_ct.shape
        z_ct_flat = z_ct.flatten(2).transpose(1, 2)   # (B, N, C)
        z_mri_flat = z_mri.flatten(2).transpose(1, 2)   # (B, N, C)
        # Use CT latent as query and MRI latent as key/value (or vice versa)
        fused, _ = self.attention(z_ct_flat, z_mri_flat, z_mri_flat)
        # Reshape back to (B, latent_dim, H, W)
        fused = fused.transpose(1,2).reshape(B, C, H, W)
        return fused

#############################################
# 4. GAN-Based Decoder (Generator) for Output #
#############################################

class GANDecoder(nn.Module):
    """
    Decodes the fused latent representation into a fused image using transposed convolutions.
    """
    def __init__(self, latent_dim, base_filters=64):
        super(GANDecoder, self).__init__()
        self.up1 = nn.ConvTranspose2d(latent_dim, base_filters*8, kernel_size=2, stride=2)
        self.conv1 = nn.Sequential(
            nn.Conv2d(base_filters*8, base_filters*8, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters*8),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.ConvTranspose2d(base_filters*8, base_filters*4, kernel_size=2, stride=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(base_filters*4, base_filters*4, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters*4),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.ConvTranspose2d(base_filters*4, base_filters*2, kernel_size=2, stride=2)
        self.conv3 = nn.Sequential(
            nn.Conv2d(base_filters*2, base_filters*2, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters*2),
            nn.ReLU(inplace=True)
        )
        self.up4 = nn.ConvTranspose2d(base_filters*2, base_filters, kernel_size=2, stride=2)
        self.conv4 = nn.Sequential(
            nn.Conv2d(base_filters, base_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_filters),
            nn.ReLU(inplace=True)
        )
        self.final_conv = nn.Conv2d(base_filters, 1, kernel_size=1)
    
    def forward(self, x):
        x = self.up1(x)
        x = self.conv1(x)
        x = self.up2(x)
        x = self.conv2(x)
        x = self.up3(x)
        x = self.conv3(x)
        x = self.up4(x)
        x = self.conv4(x)
        x = self.final_conv(x)
        return x

#########################################
# 5. Diffusion Refinement Module        #
#########################################

class DiffusionRefinement(nn.Module):
    """
    Applies iterative refinement (a simplified diffusion-like process) to the generated image.
    """
    def __init__(self, channels, num_steps=3):
        super(DiffusionRefinement, self).__init__()
        self.num_steps = num_steps
        self.refinement_block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        for _ in range(self.num_steps):
            residual = self.refinement_block(x)
            x = x + residual
        return x

#########################################
# 6. KaleidoFusionNet: The Complete Model#
#########################################

class KaleidoFusionNet(nn.Module):
    """
    Complete multi-modal fusion model.
    
    Input: ct_image and mri_image (each with shape: (B, 1, H, W))
    Output: fused image with the same input/output interface.
    """
    def __init__(self, in_channels=1, embed_dim=64, latent_dim=64, base_filters=64):
        super(KaleidoFusionNet, self).__init__()
        # Dual-stream ViT encoders for CT and MRI images
        self.ct_encoder = ViTEncoder(in_channels=in_channels, embed_dim=embed_dim)
        self.mri_encoder = ViTEncoder(in_channels=in_channels, embed_dim=embed_dim)
        
        # VAE blocks to map encoder features to a shared latent space
        self.ct_vae = VAEBlock(in_channels=embed_dim, latent_dim=latent_dim)
        self.mri_vae = VAEBlock(in_channels=embed_dim, latent_dim=latent_dim)
        
        # Cross-modal attention fusion module
        self.fusion = CrossModalAttentionFusion(latent_dim=latent_dim)
        
        # GAN-based decoder (Generator)
        self.decoder = GANDecoder(latent_dim=latent_dim, base_filters=base_filters)
        
        # Diffusion refinement module to further improve output quality
        self.diffusion = DiffusionRefinement(channels=1)
    
    def forward(self, ct_image, mri_image):
        # Encode CT and MRI images
        ct_features = self.ct_encoder(ct_image)    # (B, embed_dim, H', W')
        mri_features = self.mri_encoder(mri_image)    # (B, embed_dim, H', W')
        
        # Project features into latent space using VAE blocks
        z_ct, mu_ct, logvar_ct = self.ct_vae(ct_features)   # (B, latent_dim, H', W')
        z_mri, mu_mri, logvar_mri = self.mri_vae(mri_features)  # (B, latent_dim, H', W')
        
        # Fuse the latent representations using cross-modal attention
        fused_latent = self.fusion(z_ct, z_mri)   # (B, latent_dim, H', W')
        
        # Decode fused latent representation via the GAN-based decoder
        gen_image = self.decoder(fused_latent)    # (B, 1, H_out, W_out)
        
        # Apply diffusion refinement to further enhance the output image
        refined_image = self.diffusion(gen_image)
        
        return refined_image

#########################################
# Optional: Get Loss Function           #
#########################################

def get_loss_function(loss_type='l1'):
    """
    Return the specified loss function.
    """
    if loss_type.lower() == 'l1':
        return nn.L1Loss()
    elif loss_type.lower() in ['l2', 'mse']:
        return nn.MSELoss()
    else:
        raise ValueError(f"Unsupported loss type: {loss_type}")

# ======================================
# Training Script
# ======================================

# (Assuming you have a MultiModalDataset defined in dataset.py)
from dataset import MultiModalDataset  # Ensure this file exists and implements your dataset

def train(args):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create model
    model = KaleidoFusionNet(in_channels=1, embed_dim=64, latent_dim=64, base_filters=args.base_filters)
    model = model.to(device)
    
    # Create dataset and dataloader
    train_dataset = MultiModalDataset(
        data_dir=args.data_dir,
        split='train',
        transform=True
    )
    
    val_dataset = MultiModalDataset(
        data_dir=args.data_dir,
        split='val',
        transform=False
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    # Define loss function and optimizer
    criterion = get_loss_function(args.loss_type)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    for epoch in range(args.num_epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.num_epochs} [Train]") as pbar:
            for batch_idx, (ct_images, mri_images, target_images) in enumerate(pbar):
                # Move data to device
                ct_images = ct_images.to(device)
                mri_images = mri_images.to(device)
                target_images = target_images.to(device)
                
                optimizer.zero_grad()
                outputs = model(ct_images, mri_images)
                
                # Calculate loss
                loss = criterion(outputs, target_images)
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                pbar.set_postfix(loss=loss.item())
        
        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            with tqdm(val_loader, desc=f"Epoch {epoch+1}/{args.num_epochs} [Val]") as pbar:
                for batch_idx, (ct_images, mri_images, target_images) in enumerate(pbar):
                    ct_images = ct_images.to(device)
                    mri_images = mri_images.to(device)
                    target_images = target_images.to(device)
                    
                    outputs = model(ct_images, mri_images)
                    loss = criterion(outputs, target_images)
                    val_loss += loss.item()
                    pbar.set_postfix(loss=loss.item())
        
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        scheduler.step(avg_val_loss)
        
        print(f"Epoch {epoch+1}/{args.num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
            }, os.path.join(args.output_dir, 'best_model.pth'))
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")
        
        # Save checkpoint every save_interval epochs
        if (epoch + 1) % args.save_interval == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
            }, os.path.join(args.output_dir, f'checkpoint_epoch_{epoch+1}.pth'))
    
    # Plot training and validation loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.savefig(os.path.join(args.output_dir, 'loss_curve.png'))
    plt.close()
    
    print("Training completed!")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train KaleidoFusionNet for Multi-Modal Fusion')
    
    # Dataset parameters
    parser.add_argument('--data_dir', type=str, default='./data', help='Path to dataset directory')
    
    # Model parameters
    parser.add_argument('--base_filters', type=int, default=64, help='Number of base filters in the model')
    parser.add_argument('--loss_type', type=str, default='l1', choices=['l1', 'l2', 'mse'], help='Loss function type')
    
    # Training parameters
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training')
    parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers')
    parser.add_argument('--save_interval', type=int, default=10, help='Epoch interval to save checkpoints')
    parser.add_argument('--output_dir', type=str, default='./output', help='Directory to save outputs')
    
    args, unknown = parser.parse_known_args()
    train(args)




Using device: cpu


Epoch 1/5 [Train]: 100%|██████████| 16/16 [01:27<00:00,  5.49s/it, loss=0.676]
Epoch 1/5 [Val]: 100%|██████████| 4/4 [00:14<00:00,  3.68s/it, loss=0.797]


Epoch 1/5 - Train Loss: 0.9984, Val Loss: 0.7805
Saved best model with validation loss: 0.7805


Epoch 2/5 [Train]: 100%|██████████| 16/16 [01:49<00:00,  6.82s/it, loss=0.517]
Epoch 2/5 [Val]: 100%|██████████| 4/4 [00:14<00:00,  3.67s/it, loss=0.705]


Epoch 2/5 - Train Loss: 0.5775, Val Loss: 0.6893
Saved best model with validation loss: 0.6893


Epoch 3/5 [Train]: 100%|██████████| 16/16 [01:36<00:00,  6.05s/it, loss=0.444]
Epoch 3/5 [Val]: 100%|██████████| 4/4 [00:13<00:00,  3.48s/it, loss=0.441]


Epoch 3/5 - Train Loss: 0.4755, Val Loss: 0.4620
Saved best model with validation loss: 0.4620


Epoch 4/5 [Train]: 100%|██████████| 16/16 [01:23<00:00,  5.19s/it, loss=0.429]
Epoch 4/5 [Val]: 100%|██████████| 4/4 [00:13<00:00,  3.46s/it, loss=0.421]


Epoch 4/5 - Train Loss: 0.4354, Val Loss: 0.4388
Saved best model with validation loss: 0.4388


Epoch 5/5 [Train]: 100%|██████████| 16/16 [01:23<00:00,  5.21s/it, loss=0.396]
Epoch 5/5 [Val]: 100%|██████████| 4/4 [00:17<00:00,  4.45s/it, loss=0.403]


Epoch 5/5 - Train Loss: 0.4254, Val Loss: 0.4244
Saved best model with validation loss: 0.4244
Training completed!


In [None]:
import os
import argparse
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# Import the correct model class
# from model import KaleidoFusionNet
from dataset import MultiModalDataset

def evaluate(args):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create model (KaleidoFusionNet)
    model = KaleidoFusionNet(in_channels=1, embed_dim=64, latent_dim=64, base_filters=args.base_filters)
    model = model.to(device)
    
    # Load checkpoint
    checkpoint = torch.load(args.checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']} with validation loss {checkpoint['val_loss']:.4f}")
    
    # Create dataset and dataloader
    test_dataset = MultiModalDataset(
        data_dir=args.data_dir,
        split='test',
        transform=False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Evaluation metrics
    l1_loss = nn.L1Loss()
    l2_loss = nn.MSELoss()
    
    # Evaluation loop
    model.eval()
    total_l1_loss = 0.0
    total_l2_loss = 0.0
    total_psnr = 0.0
    total_ssim = 0.0
    
    with torch.no_grad():
        with tqdm(test_loader, desc="Evaluating") as pbar:
            for batch_idx, (ct_images, mri_images, target_images) in enumerate(pbar):
                # Move data to device
                ct_images = ct_images.to(device)
                mri_images = mri_images.to(device)
                target_images = target_images.to(device)
                
                # Forward pass
                outputs = model(ct_images, mri_images)
                
                # Calculate losses
                batch_l1_loss = l1_loss(outputs, target_images).item()
                batch_l2_loss = l2_loss(outputs, target_images).item()
                
                # Calculate PSNR and SSIM for each sample in the batch
                for i in range(outputs.size(0)):
                    # Convert to numpy arrays for PSNR and SSIM calculation
                    output_np = outputs[i, 0].cpu().numpy()
                    target_np = target_images[i, 0].cpu().numpy()
                    
                    # Normalize to [0, 1]
                    output_np = (output_np - output_np.min()) / (output_np.max() - output_np.min() + 1e-8)
                    target_np = (target_np - target_np.min()) / (target_np.max() - target_np.min() + 1e-8)
                    
                    # Calculate PSNR and SSIM
                    batch_psnr = psnr(target_np, output_np, data_range=1.0)
                    batch_ssim = ssim(target_np, output_np, data_range=1.0)
                    
                    total_psnr += batch_psnr
                    total_ssim += batch_ssim
                
                total_l1_loss += batch_l1_loss
                total_l2_loss += batch_l2_loss
                pbar.set_postfix(L1=batch_l1_loss, L2=batch_l2_loss)
                
                # Save sample images
                if batch_idx < args.num_samples_to_save:
                    for i in range(min(outputs.size(0), 4)):  # Save up to 4 images per batch
                        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
                        
                        ct_img = ct_images[i, 0].cpu().numpy()
                        mri_img = mri_images[i, 0].cpu().numpy()
                        target_img = target_images[i, 0].cpu().numpy()
                        output_img = outputs[i, 0].cpu().numpy()
                        
                        axes[0].imshow(ct_img, cmap='gray')
                        axes[0].set_title('CT Image')
                        axes[0].axis('off')
                        
                        axes[1].imshow(mri_img, cmap='gray')
                        axes[1].set_title('MRI Image')
                        axes[1].axis('off')
                        
                        axes[2].imshow(target_img, cmap='gray')
                        axes[2].set_title('Target Image')
                        axes[2].axis('off')
                        
                        axes[3].imshow(output_img, cmap='gray')
                        axes[3].set_title('Fused Image (Output)')
                        axes[3].axis('off')
                        
                        plt.tight_layout()
                        plt.savefig(os.path.join(args.output_dir, f'sample_{batch_idx}_{i}.png'))
                        plt.close()
    
    num_samples = len(test_dataset)
    avg_l1_loss = total_l1_loss / len(test_loader)
    avg_l2_loss = total_l2_loss / len(test_loader)
    avg_psnr = total_psnr / num_samples
    avg_ssim = total_ssim / num_samples
    
    print(f"Evaluation Results:")
    print(f"L1 Loss: {avg_l1_loss:.4f}")
    print(f"L2 Loss: {avg_l2_loss:.4f}")
    print(f"PSNR: {avg_psnr:.4f} dB")
    print(f"SSIM: {avg_ssim:.4f}")
    
    with open(os.path.join(args.output_dir, 'evaluation_results.txt'), 'w') as f:
        f.write(f"Evaluation Results:\n")
        f.write(f"L1 Loss: {avg_l1_loss:.4f}\n")
        f.write(f"L2 Loss: {avg_l2_loss:.4f}\n")
        f.write(f"PSNR: {avg_psnr:.4f} dB\n")
        f.write(f"SSIM: {avg_ssim:.4f}\n")
    
    print(f"Evaluation completed! Results saved to {args.output_dir}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate Multi-Modal Fusion Model')
    
    # Dataset parameters
    parser.add_argument('--data_dir', type=str, default='./data', help='Path to dataset directory')
    
    # Model parameters
    parser.add_argument('--base_filters', type=int, default=64, help='Number of base filters in the model')
    # parser.add_argument('--checkpoint_path', type=str, required=True, help='Path to model checkpoint')

    parser.add_argument('--checkpoint_path', type=str, default=r"C:\Users\aggar\Downloads\Telegram Desktop\DL_updated_11march\DL\output\best_model.pth", help='Path to model checkpoint')


    
    # Evaluation parameters
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers')
    parser.add_argument('--output_dir', type=str, default='./evaluation', help='Directory to save outputs')
    parser.add_argument('--num_samples_to_save', type=int, default=5, help='Number of sample batches to save')
    
    # Use parse_known_args() to ignore unknown arguments passed by Jupyter
    args, unknown = parser.parse_known_args()
    evaluate(args)


Using device: cpu




Loaded checkpoint from epoch 4 with validation loss 0.4244


Evaluating: 100%|██████████| 4/4 [00:29<00:00,  7.47s/it, L1=0.436, L2=0.451]

Evaluation Results:
L1 Loss: 0.4414
L2 Loss: 0.4477
PSNR: 11.7819 dB
SSIM: 0.0757
Evaluation completed! Results saved to ./evaluation



