In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from pathlib import Path

from config.config import config
from models.swin_transformer import SwinTransformerSegmentation
from models.cyclegan import Generator
from utils.visualization import (
    visualize_segmentation,
    visualize_comparison,
    visualize_attention_maps,
    plot_training_curves
)
from utils.metrics import evaluate_segmentation
from data.dataset import get_transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Load segmentation model
seg_model = SwinTransformerSegmentation(
    img_size=256,
    num_classes=4,
    embed_dim=96,
    depths=[2, 2, 6, 2],
    num_heads=[3, 6, 12, 24]
).to(device)

# Load checkpoint
checkpoint_path = '../checkpoints/best_segmentation.pth'
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    seg_model.load_state_dict(checkpoint['seg_model'])
    print(f"Loaded checkpoint from {checkpoint_path}")
    print(f"Best Dice Score: {checkpoint.get('dice_score', 'N/A')}")
else:
    print(f"Checkpoint not found: {checkpoint_path}")

seg_model.eval()

# Load CycleGAN generators
G_s2t = Generator().to(device)
G_t2s = Generator().to(device)

cyclegan_checkpoint = '../checkpoints/cyclegan_epoch_100.pth'
if os.path.exists(cyclegan_checkpoint):
    cyclegan_ckpt = torch.load(cyclegan_checkpoint, map_location=device)
    G_s2t.load_state_dict(cyclegan_ckpt['G_s2t'])
    G_t2s.load_state_dict(cyclegan_ckpt['G_t2s'])
    print(f"Loaded CycleGAN from {cyclegan_checkpoint}")

G_s2t.eval()
G_t2s.eval()

In [None]:
# Load test images
source_img_path = '../data/t1/images/sample_001.png'
source_mask_path = '../data/t1/masks/sample_001.png'
target_img_path = '../data/t2/images/sample_001.png'

# Transforms
transform = get_transforms(is_train=False)

# Load and preprocess
source_img = Image.open(source_img_path).convert('RGB')
source_mask = Image.open(source_mask_path).convert('L')
target_img = Image.open(target_img_path).convert('RGB')

source_tensor = transform(source_img).unsqueeze(0).to(device)
target_tensor = transform(target_img).unsqueeze(0).to(device)

# Get predictions
with torch.no_grad():
    pred_source = seg_model(source_tensor)
    pred_target = seg_model(target_tensor)
    
    # Generate translated images
    source_to_target = G_s2t(source_tensor)
    target_to_source = G_t2s(target_tensor)

# Visualize
class_names = ['Background', 'WT', 'TC', 'ET']

visualize_segmentation(
    source_tensor[0],
    torch.from_numpy(np.array(source_mask)),
    pred_source[0],
    class_names=class_names,
    save_path='../outputs/source_prediction.png'
)

print("Source domain prediction saved!")

In [None]:
visualize_comparison(
    source_tensor[0],
    target_tensor[0],
    source_to_target[0],
    target_to_source[0],
    pred_source[0],
    pred_target[0],
    torch.from_numpy(np.array(source_mask)),
    class_names=class_names,
    save_path='../outputs/domain_comparison.png'
)

print("Domain comparison saved!")

In [None]:
# Get attention maps from the model
with torch.no_grad():
    _ = seg_model(source_tensor)
    attention_masks = seg_model.get_attention_masks()

if len(attention_masks) > 0:
    # Visualize first layer attention
    visualize_attention_maps(
        attention_masks[0][0],  # First sample, first attention layer
        save_path='../outputs/attention_maps.png'
    )
    print("Attention maps saved!")
else:
    print("No attention masks available")

In [None]:
# Evaluate on test set
test_dir = Path('../data/t2/images')
test_images = list(test_dir.glob('*.png'))[:10]  # Evaluate on 10 images

all_dice_scores = []

for img_path in test_images:
    # Load image
    img = Image.open(img_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        pred = seg_model(img_tensor)
    
    # Note: In real evaluation, you need ground truth masks
    # For demonstration, we'll skip actual metrics computation
    print(f"Processed: {img_path.name}")

print("\nEvaluation completed!")

In [None]:
# Visualize multiple predictions in a grid
fig, axes = plt.subplots(3, 4, figsize=(16, 12))

test_images = list(Path('../data/t2/images').glob('*.png'))[:4]

for idx, img_path in enumerate(test_images):
    img = Image.open(img_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        pred = seg_model(img_tensor)
        pred_class = torch.argmax(pred[0], dim=0).cpu().numpy()
    
    # Denormalize image
    img_display = img_tensor[0].cpu().numpy().transpose(1, 2, 0)
    img_display = (img_display + 1) / 2  # Denormalize
    
    # Original image
    axes[0, idx].imshow(img_display)
    axes[0, idx].set_title(f'Image {idx+1}')
    axes[0, idx].axis('off')
    
    # Prediction
    axes[1, idx].imshow(pred_class, cmap='jet')
    axes[1, idx].set_title(f'Prediction {idx+1}')
    axes[1, idx].axis('off')
    
    # Overlay
    axes[2, idx].imshow(img_display)
    axes[2, idx].imshow(pred_class, alpha=0.5, cmap='jet')
    axes[2, idx].set_title(f'Overlay {idx+1}')
    axes[2, idx].axis('off')

plt.tight_layout()
plt.savefig('../outputs/batch_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

print("Batch predictions saved!")

In [None]:
print("="*60)
print("MA-UDA Results Summary")
print("="*60)
print(f"Model: Swin Transformer with Meta Attention")
print(f"Task: T1 -> T2 Brain Tumor Segmentation")
print(f"Classes: {class_names}")
print(f"Image Size: 256x256")
print("="*60)
print("\nVisualization outputs saved in: ../outputs/")
print("- source_prediction.png")
print("- domain_comparison.png")
print("- attention_maps.png")
print("- batch_predictions.png")
print("="*60)