## 1. Import and Intiallize

# Import
Import essential libraries

In [None]:
import os, random, itertools
from glob import glob
import torch, lpips
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lpips_fn = lpips.LPIPS(net='alex').to(device)

# Path and Directories
Define file paths for training and testing CT and MRI datasets and create an output directory to save result

In [None]:
train_ct_dir = "/path/to/data/trainA"
train_mri_dir = "/path/to/data/trainB"
test_ct_dir = "/path/to/data/testA"
test_mri_dir = "/path/to/data/testB"
out_dir = "./cyclegan_v3"
os.makedirs(out_dir, exist_ok=True)

# Hyperparameter
Define hyperparameter for training Stability

In [None]:
img_size = 256
batch_size = 1
learningRate = 0.0002
beta1 = 0.5
beta2 = 0.999
total_epochs = 1000
lambda_cycle = 10.0
lambda_id = 5.0

# Dataset
Load training and test Set by "randomly" pairing MRI and CT image for training

In [None]:
class UnpairedDataset(Dataset):
    def __init__(self, ct_dir, mri_dir, tf):
        self.cts  = sorted(glob(f"{ct_dir}/*.*"))
        self.mris = sorted(glob(f"{mri_dir}/*.*"))
        self.tf   = tf

    def __len__(self):
        return max(len(self.cts), len(self.mris))

    def __getitem__(self, idx):
        img_ct  = Image.open(self.cts[idx % len(self.cts)]).convert("RGB")
        img_mri = Image.open(random.choice(self.mris)).convert("RGB")
        return self.tf(img_ct), self.tf(img_mri)

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,)*3, (0.5,)*3),
])

train_ds = UnpairedDataset(train_ct_dir, train_mri_dir, transform)
test_ds = UnpairedDataset(test_ct_dir,  test_mri_dir,  transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4)

## 2. Architecture

# ResNet Generator
Implement ResNet based generator as its residual connections help preserve image structure and stabilize training

In [None]:
class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, 1, 0), nn.InstanceNorm2d(dim), nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, 1, 0), nn.InstanceNorm2d(dim)
        )
    def forward(self, x): return x + self.block(x)

class ResNetGenerator(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, n_res=9):
        super().__init__()
        layers = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_ch, 64, 7, 1),
            nn.InstanceNorm2d(64),
            nn.ReLU(True)
        ]
        c = 64
        # Downsampling (Encoder)
        for _ in range(2):
            layers += [nn.Conv2d(c, c*2, 3, 2, 1), nn.InstanceNorm2d(c*2), nn.ReLU(True)]
            c *= 2
        # Residual
        for _ in range(n_res):
            layers += [ResnetBlock(c)]
        # Upsampling (Decoder)
        for _ in range(2):
            layers += [nn.ConvTranspose2d(c, c//2, 3, 2, 1, output_padding=1),
                       nn.InstanceNorm2d(c//2), nn.ReLU(True)]
            c //= 2
        layers += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, out_ch, 7, 1),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# PatchGAN Discriminator
Implement PatchGAN based discriminator to evaluate generated output as it evaluate image in patches instead of the whole, focusing on local high-frequency features to better detect realistic textures and details for improved adversarial training

In [None]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_ch=3, feats=[64,128,256,512]):
        super().__init__()
        layers = [nn.Conv2d(in_ch, feats[0], 4, 2, 1), nn.LeakyReLU(0.2, True)]
        prev = feats[0]
        for f in feats[1:]:
            layers += [nn.Conv2d(prev, f, 4, 2, 1, bias=False),
                       nn.InstanceNorm2d(f), nn.LeakyReLU(0.2, True)]
            prev = f
        layers += [nn.Conv2d(prev, 1, 4, 1, 1)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

## 3. Model Initialization

# Generator

In [None]:
G_AB = ResNetGenerator().to(device)  # CT to MRI
G_BA = ResNetGenerator().to(device)  # MRI to CT

# Discriminator

In [None]:
D_A  = PatchGANDiscriminator().to(device) # CT
D_B  = PatchGANDiscriminator().to(device) # MRI

# Weight
Initialize the weights of convolutional and instance normalization layers with a normal distribution to stabilize training and improve model convergence

In [None]:
def init_weights(net):
    for m in net.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight, 0.0, 0.02)
        elif isinstance(m, nn.InstanceNorm2d):
            nn.init.normal_(m.weight, 1.0, 0.02)
            nn.init.constant_(m.bias, 0)
for net in [G_AB, G_BA, D_A, D_B]:
    init_weights(net)

## 4. Training setup

# Losses

Adversarial Loss(MSE): Train generators to fool discriminators  
Cycle Consistency Loss(MAE): Ensure input images can be accurately reconstructed from generated output to maintain cyclic integrity  
Identity Loss(MAE): Ensure generator preserves color and content when the input already belongs to the target domain

In [None]:
adv_loss = nn.MSELoss().to(device)
cycle_loss = nn.L1Loss().to(device)
identity_loss = nn.L1Loss().to(device)

# Optimizer
Adam optimizer for generators and discriminators, and create fake image buffers for training stability, and define helper for real/fake target tensors

In [None]:
opt_G   = optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()),
                     lr=learningRate, betas=(0.5, 0.999))
opt_D_A = optim.Adam(D_A.parameters(), lr=learningRate, betas=(0.5, 0.999))
opt_D_B = optim.Adam(D_B.parameters(), lr=learningRate, betas=(0.5, 0.999))

fake_A_buffer, fake_B_buffer = [], []

def target_tensor(pred, real):
    return torch.ones_like(pred) if real else torch.zeros_like(pred)

losses = {'adv': [], 'cycle': []}

## 5. Training

In [None]:
for epoch in range(1, total_epochs+1):
    sum_adv, sum_cycle = 0.0, 0.0
    for real_CT, real_MRI in train_loader:
        real_CT, real_MRI = real_CT.to(device), real_MRI.to(device)

        # Generators
        opt_G.zero_grad()
        # identity
        id_CT  = G_BA(real_CT)
        id_MRI = G_AB(real_MRI)
        loss_id = (identity_loss(id_CT, real_CT) + identity_loss(id_MRI, real_MRI)) * lambda_id
        # adversarial
        fake_CT  = G_BA(real_MRI)
        fake_MRI = G_AB(real_CT)
        loss_GAN_BA = adv_loss(D_A(fake_CT), target_tensor(fake_CT, True))
        loss_GAN_AB = adv_loss(D_B(fake_MRI), target_tensor(fake_MRI, True))
        # cycle
        rec_CT  = G_BA(fake_MRI)
        rec_MRI = G_AB(fake_CT)
        loss_cyc = (cycle_loss(rec_CT, real_CT) + cycle_loss(rec_MRI, real_MRI)) * lambda_cycle

        loss_G = loss_id + loss_GAN_BA + loss_GAN_AB + loss_cyc
        loss_G.backward(); opt_G.step()

        sum_adv   += (loss_GAN_BA + loss_GAN_AB).item()
        sum_cycle += loss_cyc.item()

        # DiscriminatorA (CT)
        opt_D_A.zero_grad()
        real_loss_A = adv_loss(D_A(real_CT), target_tensor(real_CT, True))
        fake_A      = fake_A_buffer.append(fake_CT.detach()) or fake_CT.detach()
        fake_loss_A = adv_loss(D_A(fake_A), target_tensor(fake_A, False))
        (real_loss_A + fake_loss_A).mul_(0.5).backward(); opt_D_A.step()

        # DiscriminatorB (MRI)
        opt_D_B.zero_grad()
        real_loss_B = adv_loss(D_B(real_MRI), target_tensor(real_MRI, True))
        fake_B      = fake_B_buffer.append(fake_MRI.detach()) or fake_MRI.detach()
        fake_loss_B = adv_loss(D_B(fake_B), target_tensor(fake_B, False))
        (real_loss_B + fake_loss_B).mul_(0.5).backward(); opt_D_B.step()

    losses['adv'].append(sum_adv / len(train_loader))
    losses['cycle'].append(sum_cycle / len(train_loader))

    print(f"Epoch {epoch}/{total_epochs} | Adv: {losses['adv'][-1]:.4f} | Cycle: {losses['cycle'][-1]:.4f}")

# Checkpoint
Save model weight after every 100 epoch for reproducibility and recoverability

In [None]:
if epoch % 100 == 0:
        torch.save(G_BA.state_dict(), os.path.join(out_dir, f"G_BA_ep{epoch}.pth"))
        torch.save(G_AB.state_dict(), os.path.join(out_dir, f"G_AB_ep{epoch}.pth"))
        torch.save(D_A.state_dict(),  os.path.join(out_dir, f"D_A_ep{epoch}.pth"))
        torch.save(D_B.state_dict(),  os.path.join(out_dir, f"D_B_ep{epoch}.pth"))

# Visualization during Training
Generate and save visual results with SSIM & PSNR metrics to monitor training quality and model performance

In [None]:
if epoch % 100 == 0:
    with torch.no_grad():
        real_CT, real_MRI = next(iter(train_loader))
        real_CT, real_MRI = real_CT.to(device), real_MRI.to(device)
        # MRI to CT to MRI
        fake_CT = G_BA(real_MRI)
        rec_MRI = G_AB(fake_CT)
        # CT to MRI to CT
        fake_MRI = G_AB(real_CT)
        rec_CT   = G_BA(fake_MRI)
    def denorm(x):
        x = x[0].cpu().permute(1,2,0).numpy()
        return x * 0.5 + 0.5

    img_orig_MRI = denorm(real_MRI)
    img_rec_MRI  = denorm(rec_MRI)
    img_orig_CT  = denorm(real_CT)
    img_rec_CT   = denorm(rec_CT)

    ssim_MRI = ssim(img_orig_MRI, img_rec_MRI, multichannel=True)
    psnr_MRI = psnr(img_orig_MRI, img_rec_MRI)
    ssim_CT  = ssim(img_orig_CT, img_rec_CT, multichannel=True)
    psnr_CT  = psnr(img_orig_CT, img_rec_CT)

    # Visualize Output
    fig, axs = plt.subplots(2, 3, figsize=(12, 8))
    axs[0,0].imshow(img_orig_MRI); axs[0,0].set_title("Original MRI")
    axs[0,1].imshow(denorm(fake_CT));  axs[0,1].set_title("Generated CT")
    axs[0,2].imshow(img_rec_MRI);  axs[0,2].set_title(f"Reconstructed MRI\nSSIM {ssim_MRI:.3f}, PSNR {psnr_MRI:.1f}dB")

    axs[1,0].imshow(img_orig_CT); axs[1,0].set_title("Original CT")
    axs[1,1].imshow(denorm(fake_MRI)); axs[1,1].set_title("Generated MRI")
    axs[1,2].imshow(img_rec_CT);  axs[1,2].set_title(f"Reconstructed CT\nSSIM {ssim_CT:.3f}, PSNR {psnr_CT:.1f}dB")
 
    plt.suptitle(f"Epoch {epoch} Cycle-Consistency Check")
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, f"train_cycle_vis_ep{epoch}.png"))
    plt.close()

# Convergence
Plot and save training loss curve to visualize CycleGAN convergence over epoch

In [None]:
plt.figure()
plt.plot(range(1,total_epochs+1), losses['adv'], label='Adversarial Loss')
plt.plot(range(1,total_epochs+1), losses['cycle'], label='Cycle Loss')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()
plt.title('Training Convergence')
plt.savefig(os.path.join(out_dir, "convergence.png"))
plt.close()

## 6. Testing
Generate qualitative and quantitative result on random 10 test samples for both "MRI to CT" and "CT to MRI" translation. It visualizes inputs, predictions, and reconstruction, and calculates average PSNR, SSIM, and LPIPS to assess image quality

In [None]:
test_indices = random.sample(range(len(test_ds)), 10)
fig1, axs1 = plt.subplots(10,3, figsize=(9,30))
fig2, axs2 = plt.subplots(10,3, figsize=(9,30))
metrics_BA, metrics_AB = [], []

for i, idx in enumerate(test_indices):
    real_CT, real_MRI = test_ds[idx]
    real_CT  = real_CT.unsqueeze(0).to(device)
    real_MRI = real_MRI.unsqueeze(0).to(device)

    # MRI → CT → MRI cycle
    fake_CT = G_BA(real_MRI)
    rec_MRI = G_AB(fake_CT)
    orig   = real_MRI[0].cpu().permute(1,2,0).numpy()*0.5+0.5
    recon  = rec_MRI[0].cpu().permute(1,2,0).numpy()*0.5+0.5
    p, s = psnr(orig, recon), ssim(orig, recon, multichannel=True)
    l = lpips_fn(real_MRI, rec_MRI).item()
    metrics_BA.append((p,s,l))
    axs1[i,0].imshow(orig); axs1[i,0].set_title("Original MRI")
    axs1[i,1].imshow((fake_CT[0].cpu().permute(1,2,0).numpy()*0.5+0.5)); axs1[i,1].set_title("Generated CT")
    axs1[i,2].imshow(recon); axs1[i,2].set_title(f"Reconstructed MRI\nPSNR{s:.1f},SSIM{s:.3f}")

    # CT → MRI → CT cycle
    fake_MRI = G_AB(real_CT)
    rec_CT   = G_BA(fake_MRI)
    im_orig2  = real_CT[0].cpu().permute(1,2,0).numpy()*0.5+0.5
    im_rec2 = rec_CT[0].cpu().permute(1,2,0).numpy()*0.5+0.5
    p2, s2 = psnr(im_orig2, im_rec2), ssim(im_orig2, im_rec2, multichannel=True)
    l2 = lpips_fn(real_CT, rec_CT).item()
    metrics_AB.append((p2,s2,l2))
    axs2[i,0].imshow(im_orig2); axs2[i,0].set_title("Original CT")
    axs2[i,1].imshow((fake_MRI[0].cpu().permute(1,2,0).numpy()*0.5+0.5)); axs2[i,1].set_title("Generated MRI")
    axs2[i,2].imshow(im_rec2); axs2[i,2].set_title(f"Reconstructed CT\nPSNR{p2:.1f},SSIM{s2:.3f}")

# Save figures
fig1.tight_layout(); fig1.savefig(os.path.join(out_dir, "test_MRI2CT.png"))
fig2.tight_layout(); fig2.savefig(os.path.join(out_dir, "test_CT2MRI.png"))
plt.close('all')

# Print average cycle‐consistency metrics
print("Test MRI→CT avg PSNR,SSIM,LPIPS:", np.mean(metrics_BA,0))
print("Test CT→MRI avg PSNR,SSIM,LPIPS:", np.mean(metrics_AB,0))