In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from vit_pytorch import ViT
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from scipy import linalg

# --- Config ---
image_size = 224
num_timepoints = 5
num_slices = 10
slice_strategy = "average"  # or "middle"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

# --- Utils to load 3D scan and extract representative 2D slices ---
def load_3d_scan(timepoint_folder):
    slices = sorted(f for f in os.listdir(timepoint_folder) if f.endswith('.png'))
    volume = [transform(Image.open(os.path.join(timepoint_folder, f)).convert('L')) for f in slices]
    return torch.stack(volume)  # shape: [Z, 1, H, W]

def extract_2d_from_volume(volume):
    if slice_strategy == "average":
        return volume.mean(dim=0)  # [1, H, W]
    elif slice_strategy == "middle":
        return volume[len(volume) // 2]  # [1, H, W]
    else:
        raise ValueError("Invalid slice strategy")

def load_subject_series(subject_folder):
    timepoints = sorted(os.listdir(subject_folder))
    scans = [extract_2d_from_volume(load_3d_scan(os.path.join(subject_folder, t))) for t in timepoints]
    return torch.stack(scans)  # [T, 1, H, W]

# --- Dataset for SSL (shuffled scan order prediction) ---
class MRI3DShuffledDataset(Dataset):
    def __init__(self, subject_paths):
        self.subject_paths = subject_paths

    def __len__(self):
        return len(self.subject_paths)

    def __getitem__(self, idx):
        scans = load_subject_series(self.subject_paths[idx])  # [T, 1, H, W]
        order = torch.arange(num_timepoints)
        shuffle = torch.randperm(num_timepoints)
        return scans[shuffle], shuffle

# --- Dataset for Generation Task (T0-T3 -> T4) ---
class MRI3DTemporalDataset(Dataset):
    def __init__(self, subject_paths):
        self.subject_paths = subject_paths

    def __len__(self):
        return len(self.subject_paths)

    def __getitem__(self, idx):
        scans = load_subject_series(self.subject_paths[idx])  # [T, 1, H, W]
        return scans[:4], scans[4]

# --- ViT Wrapper ---
class ScanOrderViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.vit = ViT(
            image_size=image_size,
            patch_size=16,
            num_classes=num_timepoints,
            dim=256,
            depth=4,
            heads=4,
            mlp_dim=512,
            channels=1
        )

    def forward(self, x):
        return self.vit(x)

# --- Model: ViT encoder with linear decoder ---
class TemporalScanPredictor(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.encoder.vit.mlp_head = nn.Identity()
        self.decoder = nn.Linear(256, image_size * image_size)

    def forward(self, x):  # x: [B, 4, 1, H, W]
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        features = self.encoder.vit(x)  # [B*T, 256]
        features = features.view(B, T, -1).mean(dim=1)  # [B, 256]
        out = self.decoder(features).view(B, 1, H, W)
        return out

# --- Training Loops ---
def train_ssl_vit(model, dataloader, optimizer, criterion, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for scans, shuffle in dataloader:
            scans, shuffle = scans.to(device), shuffle.to(device)
            scans = scans.view(-1, 1, image_size, image_size)
            shuffle = shuffle.view(-1)
            logits = model(scans)
            loss = criterion(logits, shuffle)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[SSL] Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")

def train_generator(model, dataloader, optimizer, criterion, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x_seq, x_target in dataloader:
            x_seq, x_target = x_seq.to(device), x_target.to(device)
            pred = model(x_seq)
            loss = criterion(pred, x_target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[GEN] Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")

# --- Visualization and Evaluation ---
def evaluate_and_visualize(model, dataloader):
    model.eval()
    with torch.no_grad():
        for x_seq, x_target in dataloader:
            x_seq, x_target = x_seq.to(device), x_target.to(device)
            pred = model(x_seq)

            real_img = x_target[0].squeeze().cpu().numpy()
            pred_img = pred[0].squeeze().cpu().numpy()

            # Visualization
            fig, axes = plt.subplots(1, 2, figsize=(8, 4))
            axes[0].imshow(real_img, cmap='gray')
            axes[0].set_title("Real T4")
            axes[1].imshow(pred_img, cmap='gray')
            axes[1].set_title("Predicted T4")
            for ax in axes:
                ax.axis('off')
            plt.tight_layout()
            plt.show()

            # Metrics
            real_tensor = x_target
            pred_tensor = pred.clamp(0, 1)
            print("MS-SSIM:", multi_scale_ssim(pred_tensor, real_tensor, data_range=1.0).item())
            print("SSIM:", ssim(pred_tensor, real_tensor, data_range=1.0).item())
            print("LPIPS:", lpips.LPIPS(net='alex')(pred_tensor, real_tensor).mean().item())
            print("MMD:", calculate_mmd(real_tensor.view(real_tensor.size(0), -1), pred_tensor.view(pred_tensor.size(0), -1)))
            print("Coverage:", calculate_coverage(real_tensor.view(real_tensor.size(0), -1), pred_tensor.view(pred_tensor.size(0), -1)))
            break  # Just show one batch

# --- Run Full Pipeline ---
def run():
    subject_paths = [os.path.join("sample_mri_3d", d) for d in os.listdir("sample_mri_3d")]

    # --- SSL Phase ---
    ssl_dataset = MRI3DShuffledDataset(subject_paths)
    ssl_loader = DataLoader(ssl_dataset, batch_size=2, shuffle=True)
    encoder = ScanOrderViT().to(device)
    ssl_optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-4)
    ssl_criterion = nn.CrossEntropyLoss()
    train_ssl_vit(encoder, ssl_loader, ssl_optimizer, ssl_criterion, epochs=5)
    torch.save(encoder.state_dict(), "ssl_vit.pth")

    # --- Generation Phase ---
    print("\nTraining generator to predict the 5th scan...")
    encoder.load_state_dict(torch.load("ssl_vit.pth", map_location=device))
    generator = TemporalScanPredictor(encoder).to(device)
    gen_dataset = MRI3DTemporalDataset(subject_paths)
    gen_loader = DataLoader(gen_dataset, batch_size=2, shuffle=True)
    gen_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)
    gen_criterion = nn.MSELoss()
    train_generator(generator, gen_loader, gen_optimizer, gen_criterion, epochs=5)

    # --- Evaluate and Visualize ---
    evaluate_and_visualize(generator, gen_loader)

run()


[SSL] Epoch 1 Loss: 1.6536
[SSL] Epoch 2 Loss: 1.6184
[SSL] Epoch 3 Loss: 1.6099
[SSL] Epoch 4 Loss: 1.6015
[SSL] Epoch 5 Loss: 1.5634

Training generator to predict the 5th scan...


### Evaluation Metrics

In [19]:
# Install required packages (only for demonstration purposes, not executable here)
# !pip install torch torchvision numpy scipy lpips piq


from torch.nn.functional import adaptive_avg_pool2d
from torchvision.models import inception_v3
from torchvision import transforms
from piq import ssim, multi_scale_ssim
import lpips 
from torchmetrics.image.kid import KernelInceptionDistance
from typing import Tuple


def calculate_fid(real_features: torch.Tensor, gen_features: torch.Tensor) -> float:
    mu1, sigma1 = real_features.mean(0), torch.cov(real_features.T)
    mu2, sigma2 = gen_features.mean(0), torch.cov(gen_features.T)

    diff = mu1 - mu2
    covmean = linalg.sqrtm(sigma1 @ sigma2)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    return float(fid)


def calculate_kid(real_images: torch.Tensor, gen_images: torch.Tensor) -> Tuple[float, float]:
    kid = KernelInceptionDistance(subset_size=50)
    kid.update(real_images, real=True)
    kid.update(gen_images, real=False)
    return kid.compute()


def calculate_ms_ssim(real_images: torch.Tensor, gen_images: torch.Tensor) -> float:
    return multi_scale_ssim(gen_images, real_images, data_range=1.0).mean().item()


def calculate_4gr_ssim(real_images: torch.Tensor, gen_images: torch.Tensor) -> float:
    # Use piq SSIM or extend it with gradient and multi-orientation filters
    return ssim(gen_images, real_images, data_range=1.0).mean().item()

def calculate_mmd(x: torch.Tensor, y: torch.Tensor, kernel: str = "rbf", sigma: float = 1.0) -> float:
    def rbf_kernel(a, b, sigma):
        a_norm = (a ** 2).sum(1).view(-1, 1)
        b_norm = (b ** 2).sum(1).view(1, -1)
        dist = a_norm + b_norm - 2 * torch.mm(a, b.T)
        return torch.exp(-dist / (2 * sigma ** 2))

    k_xx = rbf_kernel(x, x, sigma).mean()
    k_yy = rbf_kernel(y, y, sigma).mean()
    k_xy = rbf_kernel(x, y, sigma).mean()
    return float(k_xx + k_yy - 2 * k_xy)


def calculate_coverage(real_features: torch.Tensor, gen_features: torch.Tensor) -> float:
    from sklearn.neighbors import NearestNeighbors
    nn = NearestNeighbors(n_neighbors=1).fit(real_features)
    distances, _ = nn.kneighbors(gen_features)
    threshold = np.percentile(distances, 95)
    covered = np.sum(distances <= threshold)
    return covered / len(real_features)


def calculate_lpips(real_images: torch.Tensor, gen_images: torch.Tensor) -> float:
    loss_fn = lpips.LPIPS(net='alex').cuda()
    return loss_fn(gen_images, real_images).mean().item()

