In [1]:
!pip install torch torchvision numpy matplotlib pillow tqdm scikit-image

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [2]:
# === Imports ===
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

In [3]:
# === Dataset ===
class CTImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.lower().endswith('.jpg')]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert("L")
        if self.transform:
            image = self.transform(image)
        return image

In [4]:
# === Transform ===
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomApply([
        transforms.RandomRotation(10),
        transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
        transforms.ColorJitter(brightness=0.1, contrast=0.1)
    ], p=0.5),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [5]:
# === Sinusoidal Time Embedding ===
def sinusoidal_embedding(timesteps, dim):
    device = timesteps.device
    half_dim = dim // 2
    emb = torch.exp(torch.arange(half_dim, device=device) * -np.log(10000.0) / half_dim)
    emb = timesteps[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    return emb

In [6]:
# === UNet Generator with Time Conditioning ===
class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=None):
        super().__init__()
        self.use_emb = emb_dim is not None
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm1 = nn.InstanceNorm2d(out_channels)
        self.act1 = nn.ReLU(inplace=True)
        if self.use_emb:
            self.emb_proj = nn.Linear(emb_dim, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.norm2 = nn.InstanceNorm2d(out_channels)
        self.act2 = nn.ReLU(inplace=True)

    def forward(self, x, emb=None):
        x = self.act1(self.norm1(self.conv1(x)))
        if self.use_emb and emb is not None:
            emb_out = self.emb_proj(emb).view(emb.shape[0], -1, 1, 1)
            x = x + emb_out
        x = self.act2(self.norm2(self.conv2(x)))
        return x

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, base_channels=64, emb_dim=128):
        super().__init__()
        self.emb_dim = emb_dim
        self.time_embed = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.ReLU(),
            nn.Linear(emb_dim, emb_dim)
        )

        self.enc1 = UNetBlock(in_channels, base_channels, emb_dim)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = UNetBlock(base_channels, base_channels * 2, emb_dim)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = UNetBlock(base_channels * 2, base_channels * 4, emb_dim)
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = UNetBlock(base_channels * 4, base_channels * 8, emb_dim)

        self.up3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, 2, stride=2)
        self.dec3 = UNetBlock(base_channels * 8, base_channels * 4, emb_dim)
        self.up2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 2, stride=2)
        self.dec2 = UNetBlock(base_channels * 4, base_channels * 2, emb_dim)
        self.up1 = nn.ConvTranspose2d(base_channels * 2, base_channels, 2, stride=2)
        self.dec1 = UNetBlock(base_channels * 2, base_channels, emb_dim)

        self.final = nn.Sequential(
            nn.Conv2d(base_channels, in_channels, 1),
            nn.Tanh()
        )

    def forward(self, x, t):
        emb = sinusoidal_embedding(t, self.emb_dim)
        emb = self.time_embed(emb)

        e1 = self.enc1(x, emb)
        e2 = self.enc2(self.pool1(e1), emb)
        e3 = self.enc3(self.pool2(e2), emb)
        b = self.bottleneck(self.pool3(e3), emb)

        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1), emb)
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1), emb)
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1), emb)

        return self.final(d1) + x

In [7]:
# === Discriminator (PatchGAN) ===
from torch.nn.utils import spectral_norm

class Discriminator(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        self.model = nn.Sequential(
            spectral_norm(nn.Conv2d(in_channels, 64, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 1)
        )

    def forward(self, x):
        return self.model(x)

In [8]:
# === Perceptual Loss ===
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT).features[:8].eval()
        for p in vgg.parameters():
            p.requires_grad = False
        self.vgg = vgg

    def forward(self, x, y):
        x_rgb = x.repeat(1, 3, 1, 1)
        y_rgb = y.repeat(1, 3, 1, 1)
        return F.l1_loss(self.vgg(x_rgb), self.vgg(y_rgb))

In [9]:
# === Diffusion Scheduler ===
class DiffusionScheduler:
    def __init__(self, timesteps=1000):
        self.timesteps = timesteps
        self.betas = self._cosine_beta_schedule(timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas)
        self.sqrt_alphas_cumprod = torch.tensor(np.sqrt(self.alphas_cumprod), dtype=torch.float32)
        self.sqrt_one_minus_alphas_cumprod = torch.tensor(np.sqrt(1 - self.alphas_cumprod), dtype=torch.float32)

    def _cosine_beta_schedule(self, timesteps, s=0.008):
        steps = timesteps + 1
        x = np.linspace(0, timesteps, steps)
        alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return np.clip(betas, 1e-5, 0.1)

    def add_noise(self, x, t):
        noise = torch.randn_like(x)
        t = t.long()
        self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(x.device)
        self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(x.device)
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t][:, None, None, None]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t][:, None, None, None]
        return sqrt_alphas_cumprod_t * x + sqrt_one_minus_alphas_cumprod_t * noise, noise

In [10]:
# === Training Loop ===
def train(generator, discriminator, dataloader, scheduler, device, num_epochs=100):
    gen_opt = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    disc_opt = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))
    perceptual_loss = VGGPerceptualLoss().to(device)
    bce_loss = nn.BCEWithLogitsLoss()
    l1_loss = nn.L1Loss()

    generator.train()
    discriminator.train()

    for epoch in range(num_epochs):
        if epoch < 10:
            t_range = (0, 50)
        elif epoch < 20:
            t_range = (0, 200)
        else:
            t_range = (0, scheduler.timesteps)

        print(f"\nEpoch {epoch + 1}/{num_epochs} (t_range = {t_range})")

        for real in tqdm(dataloader):
            real = real.to(device)
            bs = real.size(0)
            t = torch.randint(*t_range, (bs,), device=device)

            noised, noise = scheduler.add_noise(real, t)
            pred_noise = generator(noised, t)
            reconstructed = noised - pred_noise

            disc_real = discriminator(real)
            disc_fake = discriminator(reconstructed.detach())

            real_labels = torch.full_like(disc_real, 0.9)
            fake_labels = torch.full_like(disc_fake, 0.1)

            loss_disc = bce_loss(disc_real, real_labels) + bce_loss(disc_fake, fake_labels)
            disc_opt.zero_grad()
            loss_disc.backward()
            disc_opt.step()

            disc_fake = discriminator(reconstructed)
            loss_gan = bce_loss(disc_fake, real_labels)
            loss_recon = l1_loss(pred_noise, noise)
            loss_percep = perceptual_loss(reconstructed, real)
            loss_mse = F.mse_loss(reconstructed, real)

            loss_gen = (
                1.0 * loss_recon +
                0.1 * loss_gan +
                0.2 * loss_percep +
                0.1 * loss_mse
            )

            gen_opt.zero_grad()
            loss_gen.backward()
            gen_opt.step()

        print(f"Loss Gen: {loss_gen.item():.4f}, Loss Disc: {loss_disc.item():.4f}, Percep: {loss_percep.item():.4f}")

In [11]:
# === Image Generation and Metric Evaluation ===
def generate_images_from_files(generator, scheduler, real_image_paths, device, save_dir="generated_images", resize=(128, 128)):
    generator.eval()
    os.makedirs(save_dir, exist_ok=True)
    to_tensor = transforms.ToTensor()
    with torch.no_grad():
        for path in real_image_paths:
            base_name = os.path.splitext(os.path.basename(path))[0]
            img = Image.open(path).convert("L").resize(resize)
            img_tensor = to_tensor(img).unsqueeze(0).to(device) * 2 - 1
            t = torch.randint(0, scheduler.timesteps, (1,), device=device)
            noised, _ = scheduler.add_noise(img_tensor, t)
            pred_noise = generator(noised, t)
            reconstructed = noised - pred_noise
            out = reconstructed.squeeze().cpu().clamp(-1, 1) * 0.5 + 0.5
            out_img = transforms.ToPILImage()(out)
            out_img.save(os.path.join(save_dir, f"{base_name}_synthetic.png"))
            print(f"Saved: {base_name}_synthetic.png")


def calculate_metrics(real_dir, synthetic_dir, image_list, resize=(128, 128)):
    mae_total = 0.0
    ssim_total = 0.0
    psnr_total = 0.0
    count = 0
    for filename in image_list:
        base_name = os.path.splitext(filename)[0]
        real_path = os.path.join(real_dir, filename)
        synth_path = os.path.join(synthetic_dir, f"{base_name}_synthetic.png")
        if not os.path.exists(real_path) or not os.path.exists(synth_path):
            print(f"Skipping missing file: {filename}")
            continue
        real_img = Image.open(real_path).convert("L").resize(resize)
        synth_img = Image.open(synth_path).convert("L").resize(resize)
        real_np = np.array(real_img).astype(np.float32)
        synth_np = np.array(synth_img).astype(np.float32)
        diff = (real_np - synth_np)
        mae = np.mean(np.abs(diff))
        ssim_val = ssim(real_np, synth_np, data_range=255.0)
        psnr_val = psnr(real_np, synth_np, data_range=255.0)
        mae_total += mae
        ssim_total += ssim_val
        psnr_total += psnr_val
        count += 1
        print(f"[{filename}] MAE: {mae:.2f} HU, SSIM: {ssim_val:.4f}, PSNR: {psnr_val:.2f} dB")
    if count > 0:
        print("\n=== Average Metrics ===")
        print(f"MAE  : {mae_total / count:.2f} HU")
        print(f"SSIM : {ssim_total / count:.4f}")
        print(f"PSNR : {psnr_total / count:.2f} dB")
    else:
        print("No matching image pairs found for evaluation.")

In [12]:
# === Execution Entry ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = UNetGenerator().to(device)
discriminator = Discriminator().to(device)
scheduler = DiffusionScheduler(timesteps=1000)

data_path = "/content/sample_data/CHAOS/Train"
dataset = CTImageDataset(image_dir=data_path, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

train(generator, discriminator, dataloader, scheduler, device, num_epochs=30)

real_image_dir = "/content/sample_data/CHAOS/Train"
synthetic_image_dir = "/content/sample_data/Synthetic_CT"
all_real_files = sorted(glob.glob(os.path.join(real_image_dir, "*.jpg")))
real_image_paths = all_real_files[:25]
generate_images_from_files(generator, scheduler, real_image_paths, device, save_dir=synthetic_image_dir)
image_list = [os.path.basename(p) for p in real_image_paths]
calculate_metrics(real_image_dir, synthetic_image_dir, image_list)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:06<00:00, 86.0MB/s]



Epoch 1/30 (t_range = (0, 50))


100%|██████████| 19/19 [00:06<00:00,  3.00it/s]


Loss Gen: 1.0488, Loss Disc: 1.3376, Percep: 0.9092

Epoch 2/30 (t_range = (0, 50))


100%|██████████| 19/19 [00:04<00:00,  3.87it/s]


Loss Gen: 0.9731, Loss Disc: 1.3658, Percep: 0.6775

Epoch 3/30 (t_range = (0, 50))


100%|██████████| 19/19 [00:04<00:00,  3.86it/s]


Loss Gen: 0.9635, Loss Disc: 1.3678, Percep: 0.5541

Epoch 4/30 (t_range = (0, 50))


100%|██████████| 19/19 [00:04<00:00,  3.84it/s]


Loss Gen: 0.9584, Loss Disc: 1.3522, Percep: 0.6133

Epoch 5/30 (t_range = (0, 50))


100%|██████████| 19/19 [00:04<00:00,  3.83it/s]


Loss Gen: 0.9614, Loss Disc: 1.3287, Percep: 0.5837

Epoch 6/30 (t_range = (0, 50))


100%|██████████| 19/19 [00:05<00:00,  3.79it/s]


Loss Gen: 0.9395, Loss Disc: 1.3533, Percep: 0.4964

Epoch 7/30 (t_range = (0, 50))


100%|██████████| 19/19 [00:04<00:00,  3.80it/s]


Loss Gen: 0.9550, Loss Disc: 1.4103, Percep: 0.5704

Epoch 8/30 (t_range = (0, 50))


100%|██████████| 19/19 [00:05<00:00,  3.78it/s]


Loss Gen: 0.9441, Loss Disc: 1.3224, Percep: 0.5420

Epoch 9/30 (t_range = (0, 50))


100%|██████████| 19/19 [00:05<00:00,  3.76it/s]


Loss Gen: 0.9377, Loss Disc: 1.3653, Percep: 0.4836

Epoch 10/30 (t_range = (0, 50))


100%|██████████| 19/19 [00:05<00:00,  3.75it/s]


Loss Gen: 0.9361, Loss Disc: 1.3892, Percep: 0.4906

Epoch 11/30 (t_range = (0, 200))


100%|██████████| 19/19 [00:05<00:00,  3.73it/s]


Loss Gen: 0.9280, Loss Disc: 1.2750, Percep: 0.8545

Epoch 12/30 (t_range = (0, 200))


100%|██████████| 19/19 [00:05<00:00,  3.71it/s]


Loss Gen: 0.8898, Loss Disc: 1.3269, Percep: 0.9092

Epoch 13/30 (t_range = (0, 200))


100%|██████████| 19/19 [00:05<00:00,  3.68it/s]


Loss Gen: 0.8964, Loss Disc: 1.3602, Percep: 0.8295

Epoch 14/30 (t_range = (0, 200))


100%|██████████| 19/19 [00:05<00:00,  3.68it/s]


Loss Gen: 0.8835, Loss Disc: 1.3932, Percep: 0.8158

Epoch 15/30 (t_range = (0, 200))


100%|██████████| 19/19 [00:05<00:00,  3.66it/s]


Loss Gen: 0.8993, Loss Disc: 1.3449, Percep: 0.6957

Epoch 16/30 (t_range = (0, 200))


100%|██████████| 19/19 [00:05<00:00,  3.65it/s]


Loss Gen: 0.8905, Loss Disc: 1.2603, Percep: 0.7715

Epoch 17/30 (t_range = (0, 200))


100%|██████████| 19/19 [00:05<00:00,  3.63it/s]


Loss Gen: 0.8856, Loss Disc: 1.3581, Percep: 0.6787

Epoch 18/30 (t_range = (0, 200))


100%|██████████| 19/19 [00:05<00:00,  3.62it/s]


Loss Gen: 0.8759, Loss Disc: 1.3560, Percep: 0.6930

Epoch 19/30 (t_range = (0, 200))


100%|██████████| 19/19 [00:05<00:00,  3.60it/s]


Loss Gen: 0.8639, Loss Disc: 1.3318, Percep: 0.7413

Epoch 20/30 (t_range = (0, 200))


100%|██████████| 19/19 [00:05<00:00,  3.58it/s]


Loss Gen: 0.8833, Loss Disc: 1.3390, Percep: 0.7094

Epoch 21/30 (t_range = (0, 1000))


100%|██████████| 19/19 [00:05<00:00,  3.57it/s]


Loss Gen: 0.6980, Loss Disc: 1.1940, Percep: 1.8108

Epoch 22/30 (t_range = (0, 1000))


100%|██████████| 19/19 [00:05<00:00,  3.54it/s]


Loss Gen: 0.7077, Loss Disc: 1.1754, Percep: 1.4766

Epoch 23/30 (t_range = (0, 1000))


100%|██████████| 19/19 [00:05<00:00,  3.52it/s]


Loss Gen: 0.7143, Loss Disc: 1.1638, Percep: 1.3002

Epoch 24/30 (t_range = (0, 1000))


100%|██████████| 19/19 [00:05<00:00,  3.48it/s]


Loss Gen: 0.6884, Loss Disc: 1.1037, Percep: 1.4825

Epoch 25/30 (t_range = (0, 1000))


100%|██████████| 19/19 [00:05<00:00,  3.49it/s]


Loss Gen: 0.7101, Loss Disc: 1.1743, Percep: 1.3693

Epoch 26/30 (t_range = (0, 1000))


100%|██████████| 19/19 [00:05<00:00,  3.47it/s]


Loss Gen: 0.7986, Loss Disc: 1.2689, Percep: 1.1180

Epoch 27/30 (t_range = (0, 1000))


100%|██████████| 19/19 [00:05<00:00,  3.52it/s]


Loss Gen: 0.6863, Loss Disc: 1.0962, Percep: 1.2625

Epoch 28/30 (t_range = (0, 1000))


100%|██████████| 19/19 [00:05<00:00,  3.52it/s]


Loss Gen: 0.6696, Loss Disc: 1.0204, Percep: 1.4396

Epoch 29/30 (t_range = (0, 1000))


100%|██████████| 19/19 [00:05<00:00,  3.55it/s]


Loss Gen: 0.7410, Loss Disc: 1.3096, Percep: 1.0695

Epoch 30/30 (t_range = (0, 1000))


100%|██████████| 19/19 [00:05<00:00,  3.56it/s]


Loss Gen: 0.6453, Loss Disc: 1.1356, Percep: 1.3340
Saved: 1_i0000,0000b_synthetic.png
Saved: 1_i0001,0000b_synthetic.png
Saved: 1_i0002,0000b_synthetic.png
Saved: 1_i0003,0000b_synthetic.png
Saved: 1_i0004,0000b_synthetic.png
Saved: 1_i0005,0000b_synthetic.png
Saved: 1_i0006,0000b_synthetic.png
Saved: 1_i0007,0000b_synthetic.png
Saved: 1_i0008,0000b_synthetic.png
Saved: 1_i0009,0000b_synthetic.png
Saved: 1_i0010,0000b_synthetic.png
Saved: 1_i0011,0000b_synthetic.png
Saved: 1_i0012,0000b_synthetic.png
Saved: 1_i0013,0000b_synthetic.png
Saved: 1_i0014,0000b_synthetic.png
Saved: 1_i0015,0000b_synthetic.png
Saved: 1_i0016,0000b_synthetic.png
Saved: 1_i0017,0000b_synthetic.png
Saved: 1_i0018,0000b_synthetic.png
Saved: 1_i0019,0000b_synthetic.png
Saved: 1_i0020,0000b_synthetic.png
Saved: 1_i0021,0000b_synthetic.png
Saved: 1_i0022,0000b_synthetic.png
Saved: 1_i0023,0000b_synthetic.png
Saved: 1_i0024,0000b_synthetic.png
[1_i0000,0000b.jpg] MAE: 8.42 HU, SSIM: 0.7962, PSNR: 27.22 dB
[1_i0001,0