### Complete PET-to-CT Translation Pipeline
 **Architecture**: ResNet-34 Encoder + ViT Bottleneck + CNN Decoder  
 **Features**:
 - TCIA API download
 - NPY/PNG preprocessing (7GB storage)
 - Mixed precision training
 - Multi-scale SSIM loss
 - Model checkpointing

### 0. Install Dependencies

In [2]:
%pip install pydicom numpy pillow tqdm requests torch torchvision pytorch-msssim einops kaggle --quiet

In [3]:
%pip install optuna




In [4]:
import os
import numpy as np
import pydicom
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from tqdm import tqdm
from multiprocessing import Pool
from pytorch_msssim import MS_SSIM
from einops import rearrange
from torch.cuda.amp import autocast, GradScaler
import requests
import zipfile
import io
import random
from torch.utils.tensorboard import SummaryWriter
import optuna
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error


#### 1. Download QIN-Breast from TCIA

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Input directories for your original DICOM files
pet_dir = '/content/drive/MyDrive/PIX/PET'
ct_dir  = '/content/drive/MyDrive/PIX/CT'

# Output directories for processed .npy files
processed_pet_dir = '/content/drive/MyDrive/PIX/PET_P'
processed_ct_dir  = '/content/drive/MyDrive/PIX/CT_P'


In [5]:
"""import os
import requests
import zipfile

def download_qin_breast(destination="/content/QIN-Breast_RAW"):

    #Downloads a sample of QIN Breast DCE-MRI dataset using the TCIA API and extracts it.
    #Requires internet access.

    os.makedirs(destination, exist_ok=True)

    # TCIA base API
    BASE_URL = "https://services.cancerimagingarchive.net/services/v4/TCIA/query"
    COLLECTION = "QIN-Breast"

    # Step 1: List all studies in the collection
    study_url = f"{BASE_URL}/getSeries?Collection={COLLECTION}"
    response = requests.get(study_url)
    if not response.ok:
        raise Exception("Failed to fetch studies from TCIA")

    series_list = response.json()

    # Step 2: Pick first N series (for demo purposes)
    series_instances = [series["SeriesInstanceUID"] for series in series_list[:20]]


    for uid in series_instances:
        print(f"Downloading series: {uid}")
        download_url = f"https://services.cancerimagingarchive.net/services/v4/TCIA/query/getImage?SeriesInstanceUID={uid}"
        out_path = os.path.join(destination, f"{uid}.zip")

        # Download the ZIP
        with requests.get(download_url, stream=True) as r:
            r.raise_for_status()
            with open(out_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)

        # Extract the ZIP
        extract_dir = os.path.join(destination, uid)
        os.makedirs(extract_dir, exist_ok=True)
        with zipfile.ZipFile(out_path, 'r') as zip_ref:
            zip_ref.extractall(extract_dir)
        os.remove(out_path)

    print("Download complete and extracted to:", destination)


#### 2. Preprocess to NPY/PNG

In [None]:
import os
import pydicom
import numpy as np

def process_dicom_file(args):
    """
    Convert a DICOM (.dcm) file to a NumPy (.npy) file.

    Args:
        args (tuple): Contains:
                      - dicom_file (str): Full path to the DICOM file.
                      - output_dir (str): Directory where the .npy file will be saved.
    """
    dicom_file, output_dir = args

    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    try:
        # Read the DICOM file
        ds = pydicom.dcmread(dicom_file)
        img_array = ds.pixel_array  # Extract pixel data

        # Generate output filename (change .dcm to .npy)
        base_name = os.path.basename(dicom_file)
        output_filename = os.path.splitext(base_name)[0] + '.npy'
        output_path = os.path.join(output_dir, output_filename)

        # Save the image array as a .npy file
        np.save(output_path, img_array)
        print(f"Converted '{dicom_file}' to '{output_path}'.")
    except Exception as e:
        print(f"Error processing {dicom_file}: {e}")

# Set your input directories (as provided) for PET and CT folders:
pet_dir = '/content/drive/MyDrive/PIX/PET'
ct_dir = '/content/drive/MyDrive/PIX/CT'

# Set your desired output directories for the processed files
output_pet_dir = '/content/drive/MyDrive/PIX/PET_P'
output_ct_dir  = '/content/drive/MyDrive/PIX/CT_P'

# List all DICOM files from each modality folder
pet_files = sorted([os.path.join(pet_dir, f) for f in os.listdir(pet_dir) if f.endswith('.dcm')])
ct_files  = sorted([os.path.join(ct_dir, f)  for f in os.listdir(ct_dir)  if f.endswith('.dcm')])

# Process PET files
print("Processing PET files...")
for dicom_file in pet_files:
    process_dicom_file((dicom_file, output_pet_dir))

# Process CT files
print("Processing CT files...")
for dicom_file in ct_files:
    process_dicom_file((dicom_file, output_ct_dir))


In [6]:
""""def process_dicom_file(args):
    Converts DICOM to normalized numpy array
    dicom_path, output_dir = args
    try:
        dicom = pydicom.dcmread(dicom_path)
        img = dicom.pixel_array.astype(np.float32)

        # Modality-specific normalization
        if "CT" in dicom.Modality:
            img = (img - img.min()) / (img.max() - img.min())  # [0,1]
        elif "PT" in dicom.Modality:
            img = (img + 1000) / 2000  # Approximate SUV scaling

        # Save as NPY
        np.save(os.path.join(output_dir, f"{dicom.Modality}_{dicom.PatientID}_{dicom.SOPInstanceUID}.npy"), img)
        return True
    except Exception as e:
        print(f"Error processing {dicom_path}: {e}")
        return False

In [7]:
""" def preprocess_dataset(raw_dir="/content/QIN-Breast_RAW",
                      processed_dir="/content/QIN-Breast_PROCESSED"):
    #Parallel DICOM to NPY conversion
    os.makedirs(processed_dir, exist_ok=True)
    dicom_files = []

    for root, _, files in os.walk(raw_dir):
        dicom_files.extend([os.path.join(root, f) for f in files if f.endswith(".dcm")])

    # Process in parallel
    with Pool(4) as pool:
        results = list(tqdm(
            pool.imap(process_dicom_file, [(f, processed_dir) for f in dicom_files]),
            total=len(dicom_files),
            desc="Preprocessing"
        ))

    print(f"Successfully processed {sum(results)}/{len(dicom_files)} files")"""

' def preprocess_dataset(raw_dir="/content/QIN-Breast_RAW", \n                      processed_dir="/content/QIN-Breast_PROCESSED"):\n    #Parallel DICOM to NPY conversion\n    os.makedirs(processed_dir, exist_ok=True)\n    dicom_files = []\n    \n    for root, _, files in os.walk(raw_dir):\n        dicom_files.extend([os.path.join(root, f) for f in files if f.endswith(".dcm")])\n    \n    # Process in parallel\n    with Pool(4) as pool:\n        results = list(tqdm(\n            pool.imap(process_dicom_file, [(f, processed_dir) for f in dicom_files]),\n            total=len(dicom_files),\n            desc="Preprocessing"\n        ))\n    \n    print(f"Successfully processed {sum(results)}/{len(dicom_files)} files")'

In [8]:
import os
import numpy as np
import pydicom
import shutil
from glob import glob
from tqdm import tqdm

def preprocess_dataset(raw_dir="/content/QIN-Breast_RAW", processed_dir="/content/QIN-Breast_PROCESSED"):
    """
    Converts DICOM files into normalized NumPy arrays and saves them in processed_dir.
    Filenames follow: PT_<PatientID>_<UID>.npy or CT_<PatientID>_<UID>.npy
    """
    os.makedirs(processed_dir, exist_ok=True)

    series_dirs = glob(os.path.join(raw_dir, "*"))
    for series_path in tqdm(series_dirs, desc="Preprocessing series"):
        dicom_files = glob(os.path.join(series_path, "*.dcm"))
        if len(dicom_files) == 0:
            continue

        # Try reading the first DICOM file
        try:
            sample = pydicom.dcmread(dicom_files[0], force=True)
            modality = sample.Modality.upper()
            patient_id = sample.PatientID
            series_uid = sample.SeriesInstanceUID
        except Exception as e:
            print(f"Failed to read metadata from {series_path}: {e}")
            continue

        # Only process PET or CT
        if modality not in ["PT", "CT"]:
            continue

        # Sort slices by InstanceNumber or SliceLocation
        slices = []
        for f in dicom_files:
            try:
                dcm = pydicom.dcmread(f, force=True)
                slices.append(dcm)
            except:
                continue

        slices = sorted(slices, key=lambda s: getattr(s, 'InstanceNumber', 0))

        # Stack into volume
        try:
            volume = np.stack([s.pixel_array for s in slices]).astype(np.float32)
        except Exception as e:
            print(f"Failed to stack slices for {series_path}: {e}")
            continue

        # Normalize (min-max)
        volume -= np.min(volume)
        volume /= np.max(volume) + 1e-8  # Avoid divide-by-zero

        # Save as .npy
        save_path = os.path.join(processed_dir, f"{modality}_{patient_id}_{series_uid}.npy")
        np.save(save_path, volume)

    print(f"All done! Processed files saved to {processed_dir}")


#### 3. Dataset splitting and Loader

In [11]:
""" def get_patient_splits(processed_dir, test_size=0.15, val_size=0.15):
    #Patient-wise splitting (prevents data leakage)
    # Extract unique patient IDs from filenames (format: Modality_PatientID_UID.npy)
    all_files = os.listdir(processed_dir)
    pet_files = [f for f in all_files if f.startswith("PT_")]
    patient_ids = list(set([f.split('_')[1] for f in pet_files]))

    # Split: Train -> Val/Test
    train_ids, test_ids = train_test_split(patient_ids, test_size=test_size, random_state=42)
    train_ids, val_ids = train_test_split(train_ids, test_size=val_size/(1-test_size), random_state=42)

    return train_ids, val_ids, test_ids"""

' def get_patient_splits(processed_dir, test_size=0.15, val_size=0.15):\n    #Patient-wise splitting (prevents data leakage)\n    # Extract unique patient IDs from filenames (format: Modality_PatientID_UID.npy)\n    all_files = os.listdir(processed_dir)\n    pet_files = [f for f in all_files if f.startswith("PT_")]\n    patient_ids = list(set([f.split(\'_\')[1] for f in pet_files]))\n    \n    # Split: Train -> Val/Test\n    train_ids, test_ids = train_test_split(patient_ids, test_size=test_size, random_state=42)\n    train_ids, val_ids = train_test_split(train_ids, test_size=val_size/(1-test_size), random_state=42)\n    \n    return train_ids, val_ids, test_ids'

In [12]:
def get_patient_splits(processed_dir, test_size=0.15, val_size=0.15):
    """Patient-wise splitting (prevents data leakage)"""
    all_files = os.listdir(processed_dir)
    pet_files = [f for f in all_files if f.startswith("PT_")]

    if len(pet_files) == 0:
        raise ValueError("No PET files found. Check if preprocessing ran and file naming is correct.")

    patient_ids = list(set([f.split('_')[1] for f in pet_files]))

    if len(patient_ids) < 3:
        raise ValueError(f"Too few patients ({len(patient_ids)}). Need at least 3 to split into train/val/test.")

    # Split: Train -> Val/Test
    train_ids, test_ids = train_test_split(patient_ids, test_size=test_size, random_state=42)
    train_ids, val_ids = train_test_split(train_ids, test_size=val_size/(1-test_size), random_state=42)

    return train_ids, val_ids, test_ids


In [13]:
class QinBreastDataset(Dataset):
    def __init__(self, root_dir, patient_ids=None, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.pairs = []

        # Get PET files filtered by patient IDs
        all_pet = [f for f in os.listdir(root_dir) if f.startswith("PT_")]
        if patient_ids:
            all_pet = [f for f in all_pet if f.split('_')[1] in patient_ids]

        # Create verified pairs
        for pet_file in all_pet:
            ct_file = pet_file.replace("PT_", "CT_")
            if os.path.exists(os.path.join(root_dir, ct_file)):
                self.pairs.append((pet_file, ct_file))

    def __len__(self):
        return len(self.pairs)  #returns no. of PET-CT pairs avialble

    def __getitem__(self, idx):
        pet_file, ct_file = self.pairs[idx]
        #pet = np.load(os.path.join(self.root_dir, pet_file))
        #ct = np.load(os.path.join(self.root_dir, ct_file))
        try:
            pet = np.load(os.path.join(self.root_dir, pet_file))
            ct = np.load(os.path.join(self.root_dir, ct_file))
        except Exception as e:
            raise RuntimeError(f"Error loading files: {pet_file}, {ct_file}. {e}")

        if self.transform:
            pet = self.transform(pet)
            ct = self.transform(ct)

        return pet, ct

#### 4. Model Architecture

In [14]:
# %% [code]
# ======================

class ViTBlock(nn.Module):
    def __init__(self, dim=512, heads=8, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim*4),
            nn.GELU(),
            nn.Dropout(dropout),  #GEL nad Dropout for better stability
            nn.Linear(dim*4, dim)
        )

    def forward(self, x):
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        mlp_out = self.mlp(x)
        return self.norm2(x + mlp_out)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder (ResNet-34)
        #resnet = models.resnet34(pretrained=True)
        resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),
            *list(resnet.children())[1:-2]  # Remove original fc layer
        )

        # ViT Bottleneck
        self.vit = nn.Sequential(
            ViTBlock(dim=512),
            #ViTBlock(dim=512),
           # ViTBlock(dim=512)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        b, c, h, w = x.shape
        x = rearrange(x, 'b c h w -> (h w) b c')
        x = self.vit(x)
        x = rearrange(x, '(h w) b c -> b c h w', h=h, w=w)
        #return self.decoder(x)
        return self.decoder(x.to(device))  # # Move output tensor back to GPU



Multi-Scale Discriminator is designed to assess images at different resolutions, improving adversarial learning stability

In [15]:
class MultiScaleDiscriminator(nn.Module):
    def __init__(self, input_channels=1):
        super().__init__()
        self.discriminators = nn.ModuleList([
            self._make_discriminator(input_channels, 64),
            self._make_discriminator(input_channels, 32),
            self._make_discriminator(input_channels, 16)
        ])

    def _make_discriminator(self, in_ch, base_ch):
        return nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_ch, base_ch, 4, 2, 1)),  #Improves training stability by constraining weight norms.
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Conv2d(base_ch, base_ch*2, 4, 2, 1)),
            nn.InstanceNorm2d(base_ch*2),         #Helps normalize features, preventing vanishing or exploding gradients.
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Conv2d(base_ch*2, 1, 4, 1, 1)),
            nn.AdaptiveAvgPool2d(1)
        )

    def forward(self, x):
        outputs = []
        x = x.to(device)  # Ensure tensor is on GPU
        for disc in self.discriminators:
            outputs.append(disc(x))
            #x = nn.functional.interpolate(x, scale_factor=0.5, mode='bilinear')
            x = nn.functional.interpolate(x, scale_factor=0.5, mode='nearest')
         #return torch.cat(outputs, dim=1)
        return torch.cat(outputs, dim=1).to(device)  # Keep output on GPU

####  5. Training Utilities

In [16]:
#from torchvision.models import vgg19
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19, VGG19_Weights

class VGGLoss(nn.Module):
    def __init__(self, requires_grad=False):
        super(VGGLoss, self).__init__()
        # Load a pre-trained VGG19 model
        vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features

        # Break the network into slices corresponding to different layers' outputs.
        self.slice1 = nn.Sequential(*[vgg[x] for x in range(2)])
        self.slice2 = nn.Sequential(*[vgg[x] for x in range(2, 7)])
        self.slice3 = nn.Sequential(*[vgg[x] for x in range(7, 12)])
        self.slice4 = nn.Sequential(*[vgg[x] for x in range(12, 21)])
        self.slice5 = nn.Sequential(*[vgg[x] for x in range(21, 30)])

        # Freeze the VGG parameters if not training them.
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x, y):
        # Compute feature maps at various depths
        loss = 0
        x1, y1 = self.slice1(x), self.slice1(y)
        loss += F.l1_loss(x1, y1)

        x2, y2 = self.slice2(x), self.slice2(y)
        loss += F.l1_loss(x2, y2)

        x3, y3 = self.slice3(x), self.slice3(y)
        loss += F.l1_loss(x3, y3)

        x4, y4 = self.slice4(x), self.slice4(y)
        loss += F.l1_loss(x4, y4)

        x5, y5 = self.slice5(x), self.slice5(y)
        loss += F.l1_loss(x5, y5)

        return loss


In [17]:
class TotalLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.L1Loss()
        self.vgg = VGGLoss()
        self.ms_ssim = MS_SSIM(data_range=1.0, channel=1)

    def forward(self, gen_ct, real_ct, D_real, D_fake, D):
        # Reconstruction Losses
        l1_loss = self.l1(gen_ct, real_ct)
        ms_ssim_loss = 1 - self.ms_ssim(gen_ct, real_ct)
        vgg_loss = self.vgg(gen_ct, real_ct)

        # Adversarial Loss
        adv_loss = -torch.mean(D_fake)

        # Gradient Penalty
        gp = self._gradient_penalty(D, real_ct, gen_ct.detach())

        return 100*l1_loss + ms_ssim_loss + 0.1*vgg_loss + 10*(adv_loss + gp)

    def _gradient_penalty(self, D, real, fake):
        alpha = torch.rand(real.size(0), 1, 1, 1, device=real.device)
        interpolates = (alpha * real + ((1 - alpha) * fake)).requires_grad_(True)
        #d_interpolates = D(interpolates)
        d_interpolates = D(interpolates).view(-1)

        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=torch.ones_like(d_interpolates),
            create_graph=True,
            retain_graph=True
        )[0]
        return ((gradients.norm(2, dim=1) - 1) ** 2).mean()

def psnr(output, target):
    """Compute PSNR between [-1,1] normalized tensors"""
    output = (output + 1) / 2  # [-1,1] → [0,1]
    target = (target + 1) / 2
    mse = torch.mean((output - target) ** 2)
    mse = torch.clamp(mse, min=1e-8)  # Avoid division by zero
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

#### 6. Main Training Loop

In [18]:
def train():
    # Initialize
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    G = Generator().to(device)
    D = MultiScaleDiscriminator().to(device)
    opt_G = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_D = torch.optim.Adam(D.parameters(), lr=1e-4, betas=(0.5, 0.999))
    criterion = TotalLoss()
    scaler = GradScaler()

    # Initialize SummaryWriter for TensorBoard logging
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter()

    # Data
    train_ids, val_ids, test_ids = get_patient_splits("/content/QIN-Breast_PROCESSED")

    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    eval_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    train_dataset = QinBreastDataset("/content/QIN-Breast_PROCESSED", train_ids, train_transform)
    val_dataset = QinBreastDataset("/content/QIN-Breast_PROCESSED", val_ids, eval_transform)
    test_dataset = QinBreastDataset("/content/QIN-Breast_PROCESSED", test_ids, eval_transform)

    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0) # batch_size=16. num_workers=4
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

    # Training
    best_val_loss = float('inf')
    global_step = 0  # Counter for logging iterations
    for epoch in range(15):
        # Training
        G.train()
        D.train()
        train_losses = []
        #for pet, ct in tqdm(train_loader, desc=f"Train Epoch {epoch}"):
        for i, (pet, ct) in enumerate(train_loader):
            if i >= 10:  # Limit to 10 batches for demonstration
                break

            pet, ct = pet.to(device), ct.to(device)

            # Train Discriminator
            opt_D.zero_grad()
            with autocast():
                fake_ct = G(pet)
                D_real = D(ct)
                D_fake = D(fake_ct.detach())
                loss_D = criterion(fake_ct, ct, D_real, D_fake, D)
            scaler.scale(loss_D).backward()
            scaler.step(opt_D)

            # Train Generator
            opt_G.zero_grad()
            with autocast():
                fake_ct = G(pet)
                D_fake = D(fake_ct)
                loss_G = criterion(fake_ct, ct, D_real, D_fake, D)
            scaler.scale(loss_G).backward()
            scaler.step(opt_G)
            scaler.update()

            train_losses.append(loss_G.item())
             # Log the losses for this batch
            writer.add_scalar("Train/Generator Loss", loss_G.item(), global_step)
            writer.add_scalar("Train/Discriminator Loss", loss_D.item(), global_step)
            global_step += 1  # Increment global step for TensorBoard logging

        # Validation
        G.eval()
        val_losses = []
        with torch.no_grad():
            #for pet, ct in tqdm(val_loader, desc="Validating"):
            for i, (pet, ct) in enumerate(val_loader):
                if i >= 10:  # Process only 2 batches from validation set
                   break
                pet, ct = pet.to(device), ct.to(device)
                fake_ct = G(pet)
                loss = criterion(fake_ct, ct, D(ct), D(fake_ct), D)
                #change by copilot # Adjust TotalLoss to allow this, or use an alternative loss function
                #loss = criterion(fake_ct, ct)
                val_losses.append(loss.item())

        avg_val_loss = np.mean(val_losses)
        writer.add_scalar("Validation/Loss", avg_val_loss, epoch)
        print(f"Epoch {epoch} | Train Loss: {np.mean(train_losses):.4f} | Val Loss: {avg_val_loss:.4f}")

        # Visualization & Evaluation Logging
        # --------------------
        # Log a few sample images from the validation set once every epoch
        # (or every few epochs if desired)
        sample_pet, sample_ct = next(iter(val_loader))
        sample_pet = sample_pet.to(device)
        with torch.no_grad():
            sample_fake_ct = G(sample_pet)
        # Convert images to [0,1] range for viewing if they were normalized to [-1,1]
        sample_fake_ct_img = (sample_fake_ct + 1) / 2.0
        sample_ct_img = (sample_ct + 1) / 2.0

        # Log images to TensorBoard under "Evaluation/Real_CT" and "Evaluation/Fake_CT"
        writer.add_images("Evaluation/Real_CT", sample_ct_img, epoch)
        writer.add_images("Evaluation/Fake_CT", sample_fake_ct_img, epoch)

        # You can also log evaluation metrics like PSNR or SSIM if computed per epoch
        # For example:
        # epoch_psnr = ...  # Compute PSNR value across the validation set
        # writer.add_scalar("Evaluation/PSNR", epoch_psnr, epoch)

        # Save best model
        #if avg_val_loss < best_val_loss:
           # best_val_loss = avg_val_loss
           # torch.save(G.state_dict(), "best_generator.pth")
          #  print("Saved new best model!")
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
            "epoch": epoch,
            "G_state_dict": G.state_dict(),
            "D_state_dict": D.state_dict(),
            "G_optimizer": opt_G.state_dict(),
            "D_optimizer": opt_D.state_dict()
                }, "best_model.pth")
            print("Saved new best model!")


    # Final Test
    G.load_state_dict(torch.load("best_generator.pth"))
    G.eval()
    test_psnr = []
    test_ssim = []

    with torch.no_grad():
        #for pet, ct in tqdm(test_loader, desc="Testing"):
        for i, (pet, ct) in enumerate(test_loader):
              if i >= 10:  # Process only 2 batches from test set
                   break
              pet, ct = pet.to(device), ct.to(device)
              fake_ct = G(pet)

              # Convert to [0,1] for metrics
              fake_ct = (fake_ct + 1) / 2
              ct = (ct + 1) / 2

              test_psnr.append(psnr(fake_ct, ct).cpu().numpy())
              test_ssim.append(ms_ssim(fake_ct, ct, data_range=1.0).cpu().numpy())

    print(f"\nFinal Test Results:")
    print(f"PSNR: {np.mean(test_psnr):.2f} ± {np.std(test_psnr):.2f} dB")
    print(f"SSIM: {np.mean(test_ssim):.4f} ± {np.std(test_ssim):.4f}")

    # Make sure to close the writer after training:
    writer.close()

In [19]:
%pip install beautifulsoup4




####  7. Execute Pipeline

In [21]:
if __name__ == "__main__":
    # Step 1-2: Download and preprocess (one-time)
    if not os.path.exists("/content/QIN-Breast_PROCESSED"):
        download_qin_breast()
        preprocess_dataset()

    # Step 3-6: Train
    train()

  scaler = GradScaler()


ValueError: No PET files found. Check if preprocessing ran and file naming is correct.

In [22]:
import os

processed_dir = "/content/QIN-Breast_PROCESSED"
files = os.listdir(processed_dir)

print("Sample files:", files[:10])
print("Total files:", len(files))


Sample files: ['CT_QIN-BREAST-01-0005_1.3.6.1.4.1.14519.5.2.1.8162.7003.184534544545899155746390970758.npy', 'CT_QIN-BREAST-01-0005_1.3.6.1.4.1.14519.5.2.1.8162.7003.194451068311289985577103378471.npy', 'CT_QIN-BREAST-01-0007_1.3.6.1.4.1.14519.5.2.1.8162.7003.843576143914773065152505268737.npy', 'CT_QIN-BREAST-01-0003_1.3.6.1.4.1.14519.5.2.1.8162.7003.154745635425029481935829340289.npy', 'CT_QIN-BREAST-01-0003_1.3.6.1.4.1.14519.5.2.1.8162.7003.304251602254400173448218579444.npy', 'CT_QIN-BREAST-01-0002_1.3.6.1.4.1.14519.5.2.1.8162.7003.135793651241397654453131409563.npy', 'CT_QIN-BREAST-01-0003_1.3.6.1.4.1.14519.5.2.1.8162.7003.187597513862753106512454405500.npy', 'CT_QIN-BREAST-01-0002_1.3.6.1.4.1.14519.5.2.1.8162.7003.129443967611089824437584755973.npy', 'CT_QIN-BREAST-01-0007_1.3.6.1.4.1.14519.5.2.1.8162.7003.838123888394889102410256573858.npy', 'CT_QIN-BREAST-01-0005_1.3.6.1.4.1.14519.5.2.1.8162.7003.287873048257814352778755034563.npy']
Total files: 830
