In [8]:

class ConvBlock(nn.Module):
    """Basic convolutional block with batch normalization and ReLU activation."""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class Encoder(nn.Module):
    """Encoder network for a single modality (CT or MRI)."""
    def __init__(self, in_channels=1, base_filters=64):
        super(Encoder, self).__init__()
        
        # Downsampling path
        self.enc1 = ConvBlock(in_channels, base_filters)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc2 = ConvBlock(base_filters, base_filters*2)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc3 = ConvBlock(base_filters*2, base_filters*4)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc4 = ConvBlock(base_filters*4, base_filters*8)
        
        # Store intermediate features for skip connections
        self.features = []
    
    def forward(self, x):
        self.features = []
        
        # Encoder path with feature storage
        x1 = self.enc1(x)
        self.features.append(x1)
        x = self.pool1(x1)
        
        x2 = self.enc2(x)
        self.features.append(x2)
        x = self.pool2(x2)
        
        x3 = self.enc3(x)
        self.features.append(x3)
        x = self.pool3(x3)
        
        x4 = self.enc4(x)
        self.features.append(x4)
        
        return x4, self.features

class FusionModule(nn.Module):
    """Fusion module to combine features from CT and MRI encoders."""
    def __init__(self, in_channels, out_channels):
        super(FusionModule, self).__init__()
        
        # Fusion convolution layer
        self.fusion_conv = nn.Sequential(
            nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, ct_features, mri_features):
        # Concatenate features along channel dimension
        fused_features = torch.cat([ct_features, mri_features], dim=1)
        # Apply fusion convolution
        fused_features = self.fusion_conv(fused_features)
        return fused_features

class KolmogorovArnoldModule(nn.Module):
    """
    Kolmogorov–Arnold-inspired module.
    This module approximates the idea that a multivariate function can be represented
    as a sum of univariate functions (applied via simple 1x1 convolutions here).
    """
    def __init__(self, in_channels, num_terms=3):
        super(KolmogorovArnoldModule, self).__init__()
        self.num_terms = num_terms
        # 'psi' transforms: project each channel to a scalar (via 1x1 conv)
        self.psi = nn.ModuleList([nn.Conv2d(in_channels, 1, kernel_size=1) for _ in range(num_terms)])
        # 'phi' transforms: project back from scalar to original channel dimension
        self.phi = nn.ModuleList([nn.Conv2d(1, in_channels, kernel_size=1) for _ in range(num_terms)])
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        out = 0
        # Apply each pair of psi and phi transforms and sum the results.
        for i in range(self.num_terms):
            term = self.psi[i](x)
            term = self.relu(term)
            term = self.phi[i](term)
            out = out + term
        return out

class Decoder(nn.Module):
    """Decoder network (pseudo-sensing module) to reconstruct the fused image."""
    def __init__(self, in_channels, base_filters=64):
        super(Decoder, self).__init__()
        
        # Upsampling path
        self.upconv3 = nn.ConvTranspose2d(in_channels, base_filters*4, kernel_size=2, stride=2)
        self.dec3 = ConvBlock(base_filters*4, base_filters*4)
        
        self.upconv2 = nn.ConvTranspose2d(base_filters*4, base_filters*2, kernel_size=2, stride=2)
        self.dec2 = ConvBlock(base_filters*2, base_filters*2)
        
        self.upconv1 = nn.ConvTranspose2d(base_filters*2, base_filters, kernel_size=2, stride=2)
        self.dec1 = ConvBlock(base_filters, base_filters)
        
        # Final output layer
        self.final_conv = nn.Conv2d(base_filters, 1, kernel_size=1)
    
    def forward(self, x):
        # Decoder path
        x = self.upconv3(x)
        x = self.dec3(x)
        
        x = self.upconv2(x)
        x = self.dec2(x)
        
        x = self.upconv1(x)
        x = self.dec1(x)
        
        # Final convolution to get output image
        x = self.final_conv(x)
        
        return x

class MultiModalFusionModel(nn.Module):
    """Complete multi-modal fusion model with dual-stream encoders, fusion module, a Kolmogorov–Arnold-inspired module, and decoder."""
    def __init__(self, in_channels=1, base_filters=64):
        super(MultiModalFusionModel, self).__init__()
        
        # Dual-stream encoders for CT and MRI
        self.ct_encoder = Encoder(in_channels, base_filters)
        self.mri_encoder = Encoder(in_channels, base_filters)
        
        # Fusion module
        self.fusion = FusionModule(base_filters*8, base_filters*8)
        
        # Kolmogorov–Arnold-inspired module to further process the fused features
        self.kam_module = KolmogorovArnoldModule(base_filters*8, num_terms=3)
        
        # Decoder (pseudo-sensing module)
        self.decoder = Decoder(base_filters*8, base_filters)
    
    def forward(self, ct_image, mri_image):
        # Encode CT and MRI images
        ct_features, ct_skip_features = self.ct_encoder(ct_image)
        mri_features, mri_skip_features = self.mri_encoder(mri_image)
        
        # Fuse features from both modalities
        fused_features = self.fusion(ct_features, mri_features)
        # Apply the Kolmogorov–Arnold-inspired module
        fused_features = self.kam_module(fused_features)
        
        # Decode fused features to generate the output image
        output = self.decoder(fused_features)
        
        return output

def get_loss_function(loss_type='ssim_l1'):
    """Return the specified loss function."""
    
    if loss_type.lower() == 'l1':
        return nn.L1Loss()
    
    elif loss_type.lower() in ['l2', 'mse']:
        return nn.MSELoss()
    
    elif loss_type.lower() == 'ssim_l1':
        class SSIML1Loss(nn.Module):
            def __init__(self, alpha=0.1, beta=0.9):  # More focus on SSIM
                super(SSIML1Loss, self).__init__()
                self.alpha = alpha  # L1 weight
                self.beta = beta    # SSIM weight
                self.l1_loss = nn.L1Loss()

            def forward(self, output, target):
                l1_loss = self.l1_loss(output, target)
                ssim_loss = 1 - pytorch_msssim.ssim(output, target, data_range=1.0)
                return self.alpha * l1_loss + self.beta * ssim_loss
        
        return SSIML1Loss()
    
    else:
        raise ValueError(f"Unsupported loss type: {loss_type}")




In [None]:
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
# from model import MultiModalFusionModel, get_loss_function
from dataset import MultiModalDataset


import torch
import torch.nn as nn
import torch.nn.functional as F



def train(args):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create model
    model = MultiModalFusionModel(in_channels=1, base_filters=args.base_filters)
    model = model.to(device)
    
    # Create dataset and dataloader
    train_dataset = MultiModalDataset(
        data_dir=args.data_dir,
        split='train',
        transform=True
    )
    
    val_dataset = MultiModalDataset(
        data_dir=args.data_dir,
        split='val',
        transform=False
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    # Define loss function and optimizer
    criterion = get_loss_function(args.loss_type)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    for epoch in range(args.num_epochs):
        # Training phase
        model.train()
        epoch_loss = 0.0
        
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.num_epochs} [Train]") as pbar:
            for batch_idx, (ct_images, mri_images, target_images) in enumerate(pbar):
                # Move data to device
                ct_images = ct_images.to(device)
                mri_images = mri_images.to(device)
                target_images = target_images.to(device)
                
                # Forward pass
                optimizer.zero_grad()
                outputs = model(ct_images, mri_images)
                
                # Calculate loss
                loss = criterion(outputs, target_images)
                
                # Backward pass and optimize
                loss.backward()
                optimizer.step()
                
                # Update progress bar
                epoch_loss += loss.item()
                pbar.set_postfix(loss=loss.item())
        
        avg_train_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            with tqdm(val_loader, desc=f"Epoch {epoch+1}/{args.num_epochs} [Val]") as pbar:
                for batch_idx, (ct_images, mri_images, target_images) in enumerate(pbar):
                    # Move data to device
                    ct_images = ct_images.to(device)
                    mri_images = mri_images.to(device)
                    target_images = target_images.to(device)
                    
                    # Forward pass
                    outputs = model(ct_images, mri_images)
                    
                    # Calculate loss
                    loss = criterion(outputs, target_images)
                    
                    # Update progress bar
                    val_loss += loss.item()
                    pbar.set_postfix(loss=loss.item())
        
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        # Update learning rate
        scheduler.step(avg_val_loss)
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{args.num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
            }, os.path.join(args.output_dir, 'best_model.pth'))
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % args.save_interval == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
            }, os.path.join(args.output_dir, f'checkpoint_epoch_{epoch+1}.pth'))
    
    # Plot training and validation loss
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.savefig(os.path.join(args.output_dir, 'loss_curve.png'))
    plt.close()
    
    print("Training completed!")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train Multi-Modal Fusion Model')
    
    # Dataset parameters
    parser.add_argument('--data_dir', type=str, default='./data', help='Path to dataset directory')
    
    # Model parameters
    parser.add_argument('--base_filters', type=int, default=64, help='Number of base filters in the model')
    parser.add_argument('--loss_type', type=str, default='l1', choices=['l1', 'l2', 'mse'], help='Loss function type')
    
    # Training parameters
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for training')
    parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers')
    parser.add_argument('--save_interval', type=int, default=10, help='Epoch interval to save checkpoints')
    parser.add_argument('--output_dir', type=str, default='./output', help='Directory to save outputs')
    
    args,unknown = parser.parse_known_args()
    train(args) 

In [23]:
import os
import argparse
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

# Import the correct model class
# from model import KaleidoFusionNet
from dataset import MultiModalDataset

def evaluate(args):
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create model (KaleidoFusionNet)
    # model = KaleidoFusionNet(in_channels=1, embed_dim=64, latent_dim=64, base_filters=args.base_filters)

    model = MultiModalFusionModel(in_channels=1, base_filters=args.base_filters)
    model = model.to(device)

    # model = model.to(device)
    
    # Load checkpoint
    checkpoint = torch.load(args.checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']} with validation loss {checkpoint['val_loss']:.4f}")
    
    # Create dataset and dataloader
    test_dataset = MultiModalDataset(
        data_dir=args.data_dir,
        split='test',
        transform=False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True
    )
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Evaluation metrics
    l1_loss = nn.L1Loss()
    l2_loss = nn.MSELoss()
    
    # Evaluation loop
    model.eval()
    total_l1_loss = 0.0
    total_l2_loss = 0.0
    total_psnr = 0.0
    total_ssim = 0.0
    
    with torch.no_grad():
        with tqdm(test_loader, desc="Evaluating") as pbar:
            for batch_idx, (ct_images, mri_images, target_images) in enumerate(pbar):
                # Move data to device
                ct_images = ct_images.to(device)
                mri_images = mri_images.to(device)
                target_images = target_images.to(device)
                
                # Forward pass
                outputs = model(ct_images, mri_images)
                
                # Calculate losses
                batch_l1_loss = l1_loss(outputs, target_images).item()
                batch_l2_loss = l2_loss(outputs, target_images).item()
                
                # Calculate PSNR and SSIM for each sample in the batch
                for i in range(outputs.size(0)):
                    # Convert to numpy arrays for PSNR and SSIM calculation
                    output_np = outputs[i, 0].cpu().numpy()
                    target_np = target_images[i, 0].cpu().numpy()
                    
                    # Normalize to [0, 1]
                    output_np = (output_np - output_np.min()) / (output_np.max() - output_np.min() + 1e-8)
                    target_np = (target_np - target_np.min()) / (target_np.max() - target_np.min() + 1e-8)
                    
                    # Calculate PSNR and SSIM
                    batch_psnr = psnr(target_np, output_np, data_range=1.0)
                    batch_ssim = ssim(target_np, output_np, data_range=1.0)
                    
                    total_psnr += batch_psnr
                    total_ssim += batch_ssim
                
                total_l1_loss += batch_l1_loss
                total_l2_loss += batch_l2_loss
                pbar.set_postfix(L1=batch_l1_loss, L2=batch_l2_loss)
                
                # Save sample images
                if batch_idx < args.num_samples_to_save:
                    for i in range(min(outputs.size(0), 4)):  # Save up to 4 images per batch
                        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
                        
                        ct_img = ct_images[i, 0].cpu().numpy()
                        mri_img = mri_images[i, 0].cpu().numpy()
                        target_img = target_images[i, 0].cpu().numpy()
                        output_img = outputs[i, 0].cpu().numpy()
                        
                        axes[0].imshow(ct_img, cmap='gray')
                        axes[0].set_title('CT Image')
                        axes[0].axis('off')
                        
                        axes[1].imshow(mri_img, cmap='gray')
                        axes[1].set_title('MRI Image')
                        axes[1].axis('off')
                        
                        axes[2].imshow(target_img, cmap='gray')
                        axes[2].set_title('Target Image')
                        axes[2].axis('off')
                        
                        axes[3].imshow(output_img, cmap='gray')
                        axes[3].set_title('Fused Image (Output)')
                        axes[3].axis('off')
                        
                        plt.tight_layout()
                        plt.savefig(os.path.join(args.output_dir, f'sample_{batch_idx}_{i}.png'))
                        plt.close()
    
    num_samples = len(test_dataset)
    avg_l1_loss = total_l1_loss / len(test_loader)
    avg_l2_loss = total_l2_loss / len(test_loader)
    avg_psnr = total_psnr / num_samples
    avg_ssim = total_ssim / num_samples
    
    print(f"Evaluation Results:")
    print(f"L1 Loss: {avg_l1_loss:.4f}")
    print(f"L2 Loss: {avg_l2_loss:.4f}")
    print(f"PSNR: {avg_psnr:.4f} dB")
    print(f"SSIM: {avg_ssim:.4f}")
    
    with open(os.path.join(args.output_dir, 'evaluation_results.txt'), 'w') as f:
        f.write(f"Evaluation Results:\n")
        f.write(f"L1 Loss: {avg_l1_loss:.4f}\n")
        f.write(f"L2 Loss: {avg_l2_loss:.4f}\n")
        f.write(f"PSNR: {avg_psnr:.4f} dB\n")
        f.write(f"SSIM: {avg_ssim:.4f}\n")
    
    print(f"Evaluation completed! Results saved to {args.output_dir}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Evaluate Multi-Modal Fusion Model')
    
    # Dataset parameters
    parser.add_argument('--data_dir', type=str, default='./data', help='Path to dataset directory')
    
    # Model parameters
    parser.add_argument('--base_filters', type=int, default=64, help='Number of base filters in the model')
    # parser.add_argument('--checkpoint_path', type=str, required=True, help='Path to model checkpoint')

    parser.add_argument('--checkpoint_path', type=str, default=r"C:\Users\aggar\Downloads\Telegram Desktop\DL_updated_11march-with2new\DL\output\best_model.pth", help='Path to model checkpoint')


    
    # Evaluation parameters
    parser.add_argument('--batch_size', type=int, default=8, help='Batch size for evaluation')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loading workers')
    parser.add_argument('--output_dir', type=str, default='./evaluation', help='Directory to save outputs')
    parser.add_argument('--num_samples_to_save', type=int, default=5, help='Number of sample batches to save')
    
    # Use parse_known_args() to ignore unknown arguments passed by Jupyter
    args, unknown = parser.parse_known_args()
    evaluate(args)


Using device: cpu
Loaded checkpoint from epoch 64 with validation loss 0.0560


Evaluating: 100%|██████████| 4/4 [01:07<00:00, 16.95s/it, L1=0.0572, L2=0.0164]

Evaluation Results:
L1 Loss: 0.0590
L2 Loss: 0.0178
PSNR: 20.8359 dB
SSIM: 0.3959
Evaluation completed! Results saved to ./evaluation





In [None]:
#0.1990
#0.0937
# 1.055
#0.1998
#0.2609
#0.3249
#0.2775
#0.3084
#0.3306
#0.3503 
#0.3680

#0.3959

