### Complete PET-to-CT Translation Pipeline
 **Architecture**: ResNet-34 Encoder + ViT Bottleneck + CNN Decoder  
 **Features**:
 - TCIA API download
 - NPY/PNG preprocessing (7GB storage)
 - Mixed precision training
 - Multi-scale SSIM loss
 - Model checkpointing

### 0. Install Dependencies

In [None]:
%pip install pydicom numpy pillow tqdm requests torch torchvision pytorch-msssim einops kaggle scikit-learn tensorboard pyyaml --quiet

In [None]:
import os
import yaml
import numpy as np
import pydicom
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, models
from tqdm import tqdm
from multiprocessing import Pool
from pytorch_msssim import ms_ssim, SSIM
from einops import rearrange
from torch.cuda.amp import autocast, GradScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import requests
import zipfile
import io


#### 1. Configuration Management

In [None]:
class Config:
    def __init__(self):
        # Hyperparameters
        self.batch_size = 16
        self.lr_g = 2e-4
        self.lr_d = 1e-4
        self.beta1 = 0.5
        self.beta2 = 0.999
        self.lambda_l1 = 100
        self.lambda_ms_ssim = 1
        self.lambda_vgg = 0.1
        self.lambda_gp = 10
        
        # Training
        self.epochs = 100
        self.patience = 10
        
        # Paths
        self.tb_log_dir = "logs/exp1"
        self.model_dir = "saved_models"
        self.processed_dir = "/content/QIN-Breast_PROCESSED"
        
    def save(self, path):
        with open(path, 'w') as f:
            yaml.dump(self.__dict__, f)
            
    def load(self, path):
        with open(path, 'r') as f:
            self.__dict__.update(yaml.safe_load(f))

config = Config()


#### 1. Download QIN-Breast from TCIA

In [None]:
def download_qin_breast(api_key, save_dir="/content/QIN-Breast_RAW"):
    """Downloads DICOM files using TCIA API"""
    os.makedirs(save_dir, exist_ok=True)
    auth_url = f"https://services.cancerimagingarchive.net/nbia-api/services/getToken?username={api_key}"
    token = requests.get(auth_url).text.strip('"')
    headers = {"Authorization": f"Bearer {token}"}
    
    # Get list of series
    series_url = "https://services.cancerimagingarchive.net/nbia-api/services/getSeries?Collection=QIN-Breast"
    series_data = requests.get(series_url, headers=headers).json()
    
    # Download each DICOM
    for series in tqdm(series_data, desc="Downloading"):
        series_uid = series["SeriesInstanceUID"]
        images_url = f"https://services.cancerimagingarchive.net/nbia-api/services/getImage?SeriesInstanceUID={series_uid}"
        images = requests.get(images_url, headers=headers).json()
        
        for img in images:
            img_url = f"{images_url}&ImageInstanceUID={img['ImageInstanceUID']}"
            img_data = requests.get(img_url, headers=headers).content
            os.makedirs(os.path.join(save_dir, series["PatientID"]), exist_ok=True)
            with open(os.path.join(save_dir, series["PatientID"], f"{img['ImageInstanceUID']}.dcm"), "wb") as f:
                f.write(img_data)

#### 2. Preprocess to NPY/PNG

In [None]:
def process_dicom_file(args):
    """Converts DICOM to normalized numpy array"""
    dicom_path, output_dir = args
    try:
        dicom = pydicom.dcmread(dicom_path)
        img = dicom.pixel_array.astype(np.float32)
        
        # Modality-specific normalization
        if "CT" in dicom.Modality:
            img = (img - img.min()) / (img.max() - img.min())  # [0,1]
        elif "PT" in dicom.Modality:
            img = (img + 1000) / 2000  # Approximate SUV scaling
        
        # Save as NPY
        np.save(os.path.join(output_dir, f"{dicom.Modality}_{dicom.PatientID}_{dicom.SOPInstanceUID}.npy"), img)
        return True
    except Exception as e:
        print(f"Error processing {dicom_path}: {e}")
        return False

In [None]:
def preprocess_dataset(raw_dir="/content/QIN-Breast_RAW", 
                      processed_dir="/content/QIN-Breast_PROCESSED"):
    """Parallel DICOM to NPY conversion"""
    os.makedirs(processed_dir, exist_ok=True)
    dicom_files = []
    
    for root, _, files in os.walk(raw_dir):
        dicom_files.extend([os.path.join(root, f) for f in files if f.endswith(".dcm")])
    
    # Process in parallel
    with Pool(4) as pool:
        results = list(tqdm(
            pool.imap(process_dicom_file, [(f, processed_dir) for f in dicom_files]),
            total=len(dicom_files),
            desc="Preprocessing"
        ))
    
    print(f"Successfully processed {sum(results)}/{len(dicom_files)} files")

In [None]:
# Check a sample
sample_npy = os.path.join(processed_dir, os.listdir(processed_dir)[0])
sample = np.load(sample_npy)
print(f"Shape: {sample.shape}, Range: [{sample.min():.2f}, {sample.max():.2f}]")
plt.imshow(sample, cmap='gray')
plt.show()

#### 3. Dataset splitting and Loader

In [None]:
def get_patient_splits(processed_dir, test_size=0.15, val_size=0.15):
    """Patient-wise splitting (prevents data leakage)"""
    # Extract unique patient IDs from filenames (format: Modality_PatientID_UID.npy)
    all_files = os.listdir(processed_dir)
    pet_files = [f for f in all_files if f.startswith("PT_")]
    patient_ids = list(set([f.split('_')[1] for f in pet_files))
    
    # Split: Train -> Val/Test
    train_ids, test_ids = train_test_split(patient_ids, test_size=test_size, random_state=42)
    train_ids, val_ids = train_test_split(train_ids, test_size=val_size/(1-test_size), random_state=42)
    
    return train_ids, val_ids, test_ids

In [None]:
class QinBreastDataset(Dataset):
    def __init__(self, root_dir, patient_ids=None, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.pairs = []
        
        # Get PET files filtered by patient IDs
        all_pet = [f for f in os.listdir(root_dir) if f.startswith("PT_")]
        if patient_ids:
            all_pet = [f for f in all_pet if f.split('_')[1] in patient_ids]
        
        # Create verified pairs
        for pet_file in all_pet:
            ct_file = pet_file.replace("PT_", "CT_")
            if os.path.exists(os.path.join(root_dir, ct_file)):
                self.pairs.append((pet_file, ct_file))
                
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pet_file, ct_file = self.pairs[idx]
        pet = np.load(os.path.join(self.root_dir, pet_file))
        ct = np.load(os.path.join(self.root_dir, ct_file))
        
        if self.transform:
            pet = self.transform(pet)
            ct = self.transform(ct)
            
        return pet, ct

#### 4. Model Architecture

In [None]:
# %% [code]
# ======================

class ViTBlock(nn.Module):
    def __init__(self, dim=512, heads=8, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim*4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim*4, dim)
        )

    def forward(self, x):
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        mlp_out = self.mlp(x)
        return self.norm2(x + mlp_out)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder (ResNet-34)
        resnet = models.resnet34(pretrained=True)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),
            *list(resnet.children())[1:-2]  # Remove original fc layer
        )
        
        # ViT Bottleneck
        self.vit = nn.Sequential(
            ViTBlock(dim=512),
            ViTBlock(dim=512),
            ViTBlock(dim=512)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        b, c, h, w = x.shape
        x = rearrange(x, 'b c h w -> (h w) b c')
        x = self.vit(x)
        x = rearrange(x, '(h w) b c -> b c h w', h=h, w=w)
        return self.decoder(x)


In [None]:
class MultiScaleDiscriminator(nn.Module):
    def __init__(self, input_channels=1):
        super().__init__()
        self.discriminators = nn.ModuleList([
            self._make_discriminator(input_channels, 64),
            self._make_discriminator(input_channels, 32),
            self._make_discriminator(input_channels, 16)
        ])
        
    def _make_discriminator(self, in_ch, base_ch):
        return nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_ch, base_ch, 4, 2, 1)),
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Conv2d(base_ch, base_ch*2, 4, 2, 1)),
            nn.InstanceNorm2d(base_ch*2),
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Conv2d(base_ch*2, 1, 4, 1, 1)),
            nn.AdaptiveAvgPool2d(1)
        )
        
    def forward(self, x):
        outputs = []
        for disc in self.discriminators:
            outputs.append(disc(x))
            x = nn.functional.interpolate(x, scale_factor=0.5, mode='bilinear')
        return torch.cat(outputs, dim=1)

####  5. Training Utilities

In [None]:
class TotalLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.L1Loss()
        self.vgg = VGGLoss()
        self.ms_ssim = MS_SSIM(data_range=1.0, channel=1)
        
    def forward(self, gen_ct, real_ct, D_real, D_fake, D):
        # Reconstruction Losses
        l1_loss = self.l1(gen_ct, real_ct)
        ms_ssim_loss = 1 - self.ms_ssim(gen_ct, real_ct)
        vgg_loss = self.vgg(gen_ct, real_ct)
        
        # Adversarial Loss
        adv_loss = -torch.mean(D_fake)
        
        # Gradient Penalty
        gp = self._gradient_penalty(D, real_ct, gen_ct.detach())
        
        return 100*l1_loss + ms_ssim_loss + 0.1*vgg_loss + 10*(adv_loss + gp)
    
    def _gradient_penalty(self, D, real, fake):
        alpha = torch.rand(real.size(0), 1, 1, 1, device=real.device)
        interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
        d_interpolates = D(interpolates)
        
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(d_interpolates),
            create_graph=True,
            retain_graph=True
        )[0]
        return ((gradients.norm(2, dim=1) - 1) ** 2).mean()

def psnr(output, target):
    """Compute PSNR between [-1,1] normalized tensors"""
    output = (output + 1) / 2  # [-1,1] → [0,1]
    target = (target + 1) / 2
    mse = torch.mean((output - target) ** 2)
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

In [None]:
class TensorBoardLogger:
    def __init__(self, log_dir):
        self.writer = SummaryWriter(log_dir)
        
    def log_scalar(self, tag, value, step):
        self.writer.add_scalar(tag, value, step)
        
    def log_images(self, tag, images, step):
        self.writer.add_images(tag, images, step)
        
    def close(self):
        self.writer.close()

def visualize_samples(generator, dataloader, device, num_samples=3):
    generator.eval()
    pet_batch, ct_batch = next(iter(dataloader))
    with torch.no_grad():
        fake_ct = generator(pet_batch.to(device)).cpu()
    
    plt.figure(figsize=(15, 5))
    for i in range(num_samples):
        plt.subplot(3, num_samples, i+1)
        plt.imshow(pet_batch[i][0], cmap='gray')
        plt.title("PET Input")
        plt.axis('off')
        
        plt.subplot(3, num_samples, i+1+num_samples)
        plt.imshow(fake_ct[i][0], cmap='gray')
        plt.title("Generated CT")
        plt.axis('off')
        
        plt.subplot(3, num_samples, i+1+2*num_samples)
        plt.imshow(ct_batch[i][0], cmap='gray')
        plt.title("Real CT")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

#### 6. Main Training Loop

In [None]:
def train():
    # Initialize
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logger = TensorBoardLogger(config.tb_log_dir)
    os.makedirs(config.model_dir, exist_ok=True)
    
    # Data
    train_ids, val_ids, test_ids = get_patient_splits(config.processed_dir)
    train_dataset = QinBreastDataset(config.processed_dir, train_ids)
    val_dataset = QinBreastDataset(config.processed_dir, val_ids)
    test_dataset = QinBreastDataset(config.processed_dir, test_ids)
    
    # Model & Optimizers
    G = Generator().to(device)
    D = MultiScaleDiscriminator().to(device)
    opt_G = torch.optim.Adam(G.parameters(), lr=config.lr_g, 
                            betas=(config.beta1, config.beta2))
    opt_D = torch.optim.Adam(D.parameters(), lr=config.lr_d, 
                            betas=(config.beta1, config.beta2))
    criterion = TotalLoss()
    scaler = GradScaler()
    
    best_val_loss = float('inf')
    no_improve = 0
    
    for epoch in range(config.epochs):
        # Training
        G.train()
        D.train()
        train_loss_g = []
        train_loss_d = []
        
        for pet, ct in tqdm(train_loader, desc=f"Train Epoch {epoch}"):
            pet, ct = pet.to(device), ct.to(device)
            
            # Discriminator Update
            opt_D.zero_grad()
            with autocast():
                fake_ct = G(pet)
                D_real = D(ct)
                D_fake = D(fake_ct.detach())
                loss_D = criterion(fake_ct, ct, D_real, D_fake, D)
            scaler.scale(loss_D).backward()
            scaler.step(opt_D)
            train_loss_d.append(loss_D.item())
            
            # Generator Update
            opt_G.zero_grad()
            with autocast():
                fake_ct = G(pet)
                D_fake = D(fake_ct)
                loss_G = criterion(fake_ct, ct, D_real, D_fake, D)
            scaler.scale(loss_G).backward()
            scaler.step(opt_G)
            scaler.update()
            train_loss_g.append(loss_G.item())
        
        # Logging
        avg_train_g = np.mean(train_loss_g)
        avg_train_d = np.mean(train_loss_d)
        logger.log_scalar('Loss/Train_G', avg_train_g, epoch)
        logger.log_scalar('Loss/Train_D', avg_train_d, epoch)
        
        # Validation
        G.eval()
        val_loss = []
        val_psnr = []
        val_ssim = []
        
        with torch.no_grad():
            for pet, ct in tqdm(val_loader, desc="Validating"):
                pet, ct = pet.to(device), ct.to(device)
                fake_ct = G(pet)
                
                # Loss
                loss = criterion(fake_ct, ct, D(ct), D(fake_ct), D)
                val_loss.append(loss.item())
                
                # Metrics
                fake_ct = (fake_ct + 1) / 2
                ct_norm = (ct + 1) / 2
                val_psnr.append(psnr(fake_ct, ct_norm).cpu().numpy())
                val_ssim.append(ms_ssim(fake_ct, ct_norm).cpu().numpy())
        
        avg_val_loss = np.mean(val_loss)
        logger.log_scalar('Loss/Val', avg_val_loss, epoch)
        logger.log_scalar('Metrics/Val_PSNR', np.mean(val_psnr), epoch)
        logger.log_scalar('Metrics/Val_SSIM', np.mean(val_ssim), epoch)
        
        # Early Stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            no_improve = 0
            torch.save(G.state_dict(), f"{config.model_dir}/best_G.pth")
            torch.save(D.state_dict(), f"{config.model_dir}/best_D.pth")
        else:
            no_improve += 1
            
        if no_improve >= config.patience:
            print(f"Early stopping at epoch {epoch}")
            break
            
        # Visualization
        if epoch % 5 == 0:
            visualize_samples(G, val_loader, device)
            
    # Final Test Evaluation
    G.load_state_dict(torch.load(f"{config.model_dir}/best_G.pth"))
    test_metrics = evaluate(G, test_loader, device)
    logger.log_scalar('Metrics/Test_PSNR', test_metrics['psnr'], epoch)
    logger.log_scalar('Metrics/Test_SSIM', test_metrics['ssim'], epoch)
    
    logger.close()
    return test_metrics

####  6. Comprehensive Evaluation

In [None]:
def evaluate(model, loader, device):
    model.eval()
    metrics = {
        'psnr': [],
        'ssim': [],
        'mae': [],
        'lpips': []
    }
    
    # Initialize LPIPS (Perceptual Metric)
    lpips_model = LPIPS(net='vgg').to(device)
    
    with torch.no_grad():
        for pet, ct in tqdm(loader, desc="Evaluating"):
            pet, ct = pet.to(device), ct.to(device)
            fake_ct = model(pet)
            
            # Convert to [0,1]
            fake_ct_norm = (fake_ct + 1) / 2
            ct_norm = (ct + 1) / 2
            
            # Calculate metrics
            metrics['psnr'].append(psnr(fake_ct_norm, ct_norm).cpu().numpy())
            metrics['ssim'].append(ms_ssim(fake_ct_norm, ct_norm).cpu().numpy())
            metrics['mae'].append(torch.abs(fake_ct_norm - ct_norm).mean().item())
            metrics['lpips'].append(lpips_model(fake_ct_norm, ct_norm).item())
    
    # Aggregate results
    return {k: (np.mean(v), np.std(v)) for k, v in metrics.items()}


#### 7. Hyperparameter Tuning

In [None]:
def run_experiment(config_path):
    config = Config()
    config.load(config_path)
    
    # Run training with loaded config
    results = train()
    
    # Save results
    with open(f"{config.tb_log_dir}/results.yaml", 'w') as f:
        yaml.dump(results, f)
        
    return results

# Example hyperparameter config file
hyperparams = """
batch_size: 16
lr_g: 0.0002
lr_d: 0.0001
lambda_l1: 100
lambda_ms_ssim: 1
lambda_vgg: 0.1
"""

with open("hp_config.yaml", 'w') as f:
    f.write(hyperparams)

# Run multiple experiments
# for hp in hyperparam_grid:
#     run_experiment(hp)


#### 8. Execution & Visualization

In [None]:
if __name__ == "__main__":
    # Start TensorBoard: !tensorboard --logdir=logs/
    
    # Train with default config
    base_metrics = train()
    
    # Compare with different hyperparams
    # tuned_metrics = run_experiment("hp_config.yaml")
    
    # Visualize comparisons
    print("\nModel Comparison:")
    print(f"Base Model PSNR: {base_metrics['psnr'][0]:.2f} ± {base_metrics['psnr'][1]:.2f}")
    # print(f"Tuned Model PSNR: {tuned_metrics['psnr'][0]:.2f} ± {tuned_metrics['psnr'][1]:.2f}")