Synthetic ICH CT Generation - Generation 1

In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torchvision.transforms.functional import gaussian_blur
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.exposure import match_histograms

# Enable OpenCV multithreading for faster CPU performance
cv2.setUseOptimized(True)
cv2.setNumThreads(os.cpu_count())

# ---- Super-resolution setup (4× → 2× chain for 8×) ----
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

device = torch.device("cpu")  # force CPU

# 4× model for final super-res
rrdb4 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
sr4 = RealESRGANer(
    scale=4,
    model_path='RealESRGAN_x4plus.pth',
    model=rrdb4,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)
# 2× model for final super-res
rrdb2 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
sr2 = RealESRGANer(
    scale=2,
    model_path='RealESRGAN_x2plus.pth',
    model=rrdb2,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)

# ---- GAN setup ----
batch_size = 16
image_size = 64       # training on 64×64 slices
nc, nz, ngf, ndf = 1, 100, 64, 64
num_epochs, lr, beta1 = 100, 0.0002, 0.5

out_dir = "Synthetic Dataset/Synthetic ICH CT/Generation 1"
os.makedirs(out_dir, exist_ok=True)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.Grayscale(num_output_channels=nc),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = datasets.ImageFolder(root="Dataset (East Cyprus Hospital)/ICH Brain CT", transform=transform)
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=0)

# ---- Load a real CT slice for histogram matching ----
ref_ct_path = 'Reference CT.png'
ref_ct = cv2.imread(ref_ct_path, cv2.IMREAD_GRAYSCALE)
if ref_ct is None:
    raise FileNotFoundError(f"Reference CT not found at {ref_ct_path}")

# ----------------------------------
# Post-processing helpers
# ----------------------------------

def denoise_nl_means(tensor_img, h=3, templateWindowSize=7, searchWindowSize=21):
    np_img = tensor_img.squeeze().cpu().numpy()
    np_img8 = ((np_img + 1.0) * 127.5).astype(np.uint8)
    denoised8 = cv2.fastNlMeansDenoising(np_img8,
                                         None,
                                         h=h,
                                         templateWindowSize=templateWindowSize,
                                         searchWindowSize=searchWindowSize)
    denoised_f = denoised8.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(denoised_f).unsqueeze(0).to(tensor_img.device)

def guided_bilateral_filter(np_img,
                            d=5, sigmaColor=50, sigmaSpace=50,
                            radius=5, eps=100.0):
    """
    Light bilateral → (optional) guided filter for edge preservation.
    """
    bf = cv2.bilateralFilter(np_img, d=d,
                             sigmaColor=sigmaColor,
                             sigmaSpace=sigmaSpace)
    return bf

def sharp_kernel_filter(tensor_img):
    """3×3 high-boost sharpening kernel (strong)."""
    kernel = np.array([[-1, -1, -1],
                       [-1,  9, -1],
                       [-1, -1, -1]], dtype=np.float32)
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    sharp = cv2.filter2D(np_img, -1, kernel)
    sharp = np.clip(sharp, 0, 255).astype(np.uint8)
    sharp_f = sharp.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(sharp_f).unsqueeze(0).to(tensor_img.device)

def clahe_enhance(tensor_img, clipLimit=2.0, tileGridSize=(4,4)):
    """CLAHE with smaller tiles for stronger local contrast."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    cl = clahe.apply(np_img)
    cl_f = cl.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(cl_f).unsqueeze(0).to(tensor_img.device)

def match_histogram_tensor(tensor_img, ref_uint8):
    """Match histogram of tensor_img to ref_uint8 (both 8-bit)."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    matched = match_histograms(np_img, ref_uint8, channel_axis=None)
    matched = np.clip(matched, 0, 255).astype(np.uint8)
    matched_f = matched.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(matched_f).unsqueeze(0).to(tensor_img.device)

# ----------------------------------
# DCGAN Generator & Discriminator
# ----------------------------------
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8), nn.ReLU(True),        # 1→4
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),        # 4→8
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),        # 8→16
            nn.ConvTranspose2d(ngf*2, ngf,   4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),   nn.ReLU(True),        # 16→32
            nn.ConvTranspose2d(ngf,    nc,     4, 2, 1, bias=False),
            nn.Tanh()                                     # 32→64
        )
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(nc,   ndf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*8, 1,     4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x).view(-1)

netG = Generator().to(device)
netD = Discriminator().to(device)

criterion  = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# fixed noise for sampling
num_samples = 400
fixed_noise = torch.randn(num_samples, nz, 1, 1, device=device)

# ----------------------------------
# Training Loop
# ----------------------------------
print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real = imgs.to(device)
        b_size = real.size(0)
        label_real = torch.ones(b_size, device=device)
        label_fake = torch.zeros(b_size, device=device)

        # Discriminator real loss
        netD.zero_grad()
        errD_real = criterion(netD(real), label_real)
        errD_real.backward()

        # Discriminator fake loss
        noise    = torch.randn(b_size, nz, 1, 1, device=device)
        fake_img = netG(noise)
        errD_fake = criterion(netD(fake_img.detach()), label_fake)
        errD_fake.backward()
        optimizerD.step()

        # Generator loss
        netG.zero_grad()
        errG = criterion(netD(fake_img), label_real)
        errG.backward()
        optimizerG.step()

        if i % 50 == 0:
            print(f"[{epoch+1}/{num_epochs}][{i}/{len(dataloader)}] "
                  f"Loss_D: {(errD_real+errD_fake).item():.4f} "
                  f"Loss_G: {errG.item():.4f}")

    # ---- Sampling & saving after each epoch ----
    with torch.no_grad():
        raw = netG(fixed_noise).cpu()

        # 1) Guided bilateral filter → filtered, compute noise map
        filtered = torch.zeros_like(raw)
        for idx in range(raw.size(0)):
            np_img = (raw[idx].squeeze().numpy() * 127.5 + 127.5).astype(np.uint8)
            out = guided_bilateral_filter(np_img)
            tensor_out = torch.from_numpy((out.astype(np.float32)/127.5 - 1.0)) \
                            .unsqueeze(0).to(raw.device)
            filtered[idx] = tensor_out

        noise_map = raw - filtered   # retains high-freq detail

        # 2) Sharpen → denoise → CLAHE → hist match → fake64
        proc = torch.zeros_like(filtered)
        for idx in range(filtered.size(0)):
            img = filtered[idx]
            img = sharp_kernel_filter(img)
            img = denoise_nl_means(img, h=3)
            img = clahe_enhance(img, clipLimit=2.0, tileGridSize=(4,4))
            img = match_histogram_tensor(img, ref_ct)
            proc[idx] = img
        fake64 = proc

        # visualize & save 64×64 grid
        grid = utils.make_grid(fake64, nrow=20, padding=2, normalize=True)
        plt.figure(figsize=(10,10))
        plt.imshow(np.transpose(grid, (1,2,0)), cmap='gray')
        plt.axis('off')
        plt.title(f"Epoch {epoch+1}")
        plt.show()
        utils.save_image(fake64,
                         os.path.join(out_dir, f"epoch_{epoch+1}.png"),
                         normalize=True)

        # 3) Final super-res + reinjection + 512×512 sharpen
        if (epoch + 1) == num_epochs:
            print("Applying Super-Resolution + reinjection + final sharpen…")
            alpha = 0.9
            heavy_kernel = np.array([[-1, -1, -1],
                                     [-1,  9, -1],
                                     [-1, -1, -1]], dtype=np.float32)

            for idx in range(fake64.size(0)):
                pil64 = transforms.ToPILImage()(
                        (fake64[idx]*0.5 + 0.5).clamp(0,1)
                      ).convert('RGB')
                arr64 = np.array(pil64)

                out4, _ = sr4.enhance(arr64)
                out8, _ = sr2.enhance(out4)

                nm = noise_map[idx].squeeze().numpy()
                nm_img = Image.fromarray(((nm*127.5)+127.5).astype(np.uint8))
                nm_hr  = nm_img.resize((out8.shape[1], out8.shape[0]),
                                       Image.BICUBIC)
                nm_hr_f= (np.array(nm_hr).astype(np.float32)-127.5)/127.5

                sr8 = out8.astype(np.float32)/255.0
                combined = np.clip(sr8 + alpha * nm_hr_f[...,None], 0, 1)

                final = (combined*255).astype(np.uint8)[...,0]
                sharp_final = cv2.filter2D(final, -1, heavy_kernel)
                final_gray = Image.fromarray(sharp_final)
                final_gray.save(os.path.join(out_dir,
                                             f"sr8_epoch{epoch+1}_img{idx}.png"))

            print("Super-resolution + reinjection + final sharpen complete!")

print("Training complete!")


Synthetic ICH CT Generation - Generation 2

In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torchvision.transforms.functional import gaussian_blur
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.exposure import match_histograms

# Enable OpenCV multithreading for faster CPU performance
cv2.setUseOptimized(True)
cv2.setNumThreads(os.cpu_count())

# ---- Super-resolution setup (4× → 2× chain for 8×) ----
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

device = torch.device("cpu")  # force CPU

# 4× model for final super-res
rrdb4 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
sr4 = RealESRGANer(
    scale=4,
    model_path='RealESRGAN_x4plus.pth',
    model=rrdb4,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)
# 2× model for final super-res
rrdb2 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
sr2 = RealESRGANer(
    scale=2,
    model_path='RealESRGAN_x2plus.pth',
    model=rrdb2,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)

# ---- GAN setup ----
batch_size = 16
image_size = 64       # training on 64×64 slices
nc, nz, ngf, ndf = 1, 100, 64, 64
num_epochs, lr, beta1 = 100, 0.0002, 0.5

out_dir = "Dataset (Synthetic)/Synthetic ICH CT/Generation 2"
os.makedirs(out_dir, exist_ok=True)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.Grayscale(num_output_channels=nc),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = datasets.ImageFolder(root="Dataset (East Cyprus Hospital)/ICH Brain CT", transform=transform)
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=0)

# ---- Load a real CT slice for histogram matching ----
ref_ct_path = 'Reference CT - 4.jpg'
ref_ct = cv2.imread(ref_ct_path, cv2.IMREAD_GRAYSCALE)
if ref_ct is None:
    raise FileNotFoundError(f"Reference CT not found at {ref_ct_path}")

# ----------------------------------
# Post-processing helpers
# ----------------------------------

def denoise_nl_means(tensor_img, h=3, templateWindowSize=7, searchWindowSize=21):
    np_img = tensor_img.squeeze().cpu().numpy()
    np_img8 = ((np_img + 1.0) * 127.5).astype(np.uint8)
    denoised8 = cv2.fastNlMeansDenoising(np_img8,
                                         None,
                                         h=h,
                                         templateWindowSize=templateWindowSize,
                                         searchWindowSize=searchWindowSize)
    denoised_f = denoised8.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(denoised_f).unsqueeze(0).to(tensor_img.device)

def guided_bilateral_filter(np_img,
                            d=5, sigmaColor=50, sigmaSpace=50,
                            radius=5, eps=100.0):
    """
    Light bilateral → (optional) guided filter for edge preservation.
    """
    bf = cv2.bilateralFilter(np_img, d=d,
                             sigmaColor=sigmaColor,
                             sigmaSpace=sigmaSpace)
    return bf

def sharp_kernel_filter(tensor_img):
    """3×3 high-boost sharpening kernel (strong)."""
    kernel = np.array([[-1, -1, -1],
                       [-1,  9, -1],
                       [-1, -1, -1]], dtype=np.float32)
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    sharp = cv2.filter2D(np_img, -1, kernel)
    sharp = np.clip(sharp, 0, 255).astype(np.uint8)
    sharp_f = sharp.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(sharp_f).unsqueeze(0).to(tensor_img.device)

def clahe_enhance(tensor_img, clipLimit=2.0, tileGridSize=(4,4)):
    """CLAHE with smaller tiles for stronger local contrast."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    cl = clahe.apply(np_img)
    cl_f = cl.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(cl_f).unsqueeze(0).to(tensor_img.device)

def match_histogram_tensor(tensor_img, ref_uint8):
    """Match histogram of tensor_img to ref_uint8 (both 8-bit)."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    matched = match_histograms(np_img, ref_uint8, channel_axis=None)
    matched = np.clip(matched, 0, 255).astype(np.uint8)
    matched_f = matched.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(matched_f).unsqueeze(0).to(tensor_img.device)

# ----------------------------------
# DCGAN Generator & Discriminator
# ----------------------------------
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8), nn.ReLU(True),        # 1→4
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),        # 4→8
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),        # 8→16
            nn.ConvTranspose2d(ngf*2, ngf,   4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),   nn.ReLU(True),        # 16→32
            nn.ConvTranspose2d(ngf,    nc,     4, 2, 1, bias=False),
            nn.Tanh()                                     # 32→64
        )
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(nc,   ndf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*8, 1,     4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x).view(-1)

netG = Generator().to(device)
netD = Discriminator().to(device)

criterion  = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# fixed noise for sampling
num_samples = 400
fixed_noise = torch.randn(num_samples, nz, 1, 1, device=device)

# ----------------------------------
# Training Loop
# ----------------------------------
print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real = imgs.to(device)
        b_size = real.size(0)
        label_real = torch.ones(b_size, device=device)
        label_fake = torch.zeros(b_size, device=device)

        # Discriminator real loss
        netD.zero_grad()
        errD_real = criterion(netD(real), label_real)
        errD_real.backward()

        # Discriminator fake loss
        noise    = torch.randn(b_size, nz, 1, 1, device=device)
        fake_img = netG(noise)
        errD_fake = criterion(netD(fake_img.detach()), label_fake)
        errD_fake.backward()
        optimizerD.step()

        # Generator loss
        netG.zero_grad()
        errG = criterion(netD(fake_img), label_real)
        errG.backward()
        optimizerG.step()

        if i % 50 == 0:
            print(f"[{epoch+1}/{num_epochs}][{i}/{len(dataloader)}] "
                  f"Loss_D: {(errD_real+errD_fake).item():.4f} "
                  f"Loss_G: {errG.item():.4f}")

    # ---- Sampling & saving after each epoch ----
    with torch.no_grad():
        raw = netG(fixed_noise).cpu()

        # 1) Guided bilateral filter → filtered, compute noise map
        filtered = torch.zeros_like(raw)
        for idx in range(raw.size(0)):
            np_img = (raw[idx].squeeze().numpy() * 127.5 + 127.5).astype(np.uint8)
            out = guided_bilateral_filter(np_img)
            tensor_out = torch.from_numpy((out.astype(np.float32)/127.5 - 1.0)) \
                            .unsqueeze(0).to(raw.device)
            filtered[idx] = tensor_out

        noise_map = raw - filtered   # retains high-freq detail

        # 2) Sharpen → denoise → CLAHE → hist match → fake64
        proc = torch.zeros_like(filtered)
        for idx in range(filtered.size(0)):
            img = filtered[idx]
            img = sharp_kernel_filter(img)
            img = denoise_nl_means(img, h=3)
            img = clahe_enhance(img, clipLimit=2.0, tileGridSize=(4,4))
            img = match_histogram_tensor(img, ref_ct)
            proc[idx] = img
        fake64 = proc

        # visualize & save 64×64 grid
        grid = utils.make_grid(fake64, nrow=20, padding=2, normalize=True)
        plt.figure(figsize=(10,10))
        plt.imshow(np.transpose(grid, (1,2,0)), cmap='gray')
        plt.axis('off')
        plt.title(f"Epoch {epoch+1}")
        plt.show()
        utils.save_image(fake64,
                         os.path.join(out_dir, f"epoch_{epoch+1}.png"),
                         normalize=True)

        # 3) Final super-res + reinjection + 512×512 sharpen
        if (epoch + 1) == num_epochs:
            print("Applying Super-Resolution + reinjection + final sharpen…")
            alpha = 0.9
            heavy_kernel = np.array([[-1, -1, -1],
                                     [-1,  9, -1],
                                     [-1, -1, -1]], dtype=np.float32)

            for idx in range(fake64.size(0)):
                pil64 = transforms.ToPILImage()(
                        (fake64[idx]*0.5 + 0.5).clamp(0,1)
                      ).convert('RGB')
                arr64 = np.array(pil64)

                out4, _ = sr4.enhance(arr64)
                out8, _ = sr2.enhance(out4)

                nm = noise_map[idx].squeeze().numpy()
                nm_img = Image.fromarray(((nm*127.5)+127.5).astype(np.uint8))
                nm_hr  = nm_img.resize((out8.shape[1], out8.shape[0]),
                                       Image.BICUBIC)
                nm_hr_f= (np.array(nm_hr).astype(np.float32)-127.5)/127.5

                sr8 = out8.astype(np.float32)/255.0
                combined = np.clip(sr8 + alpha * nm_hr_f[...,None], 0, 1)

                final = (combined*255).astype(np.uint8)[...,0]
                sharp_final = cv2.filter2D(final, -1, heavy_kernel)
                final_gray = Image.fromarray(sharp_final)
                final_gray.save(os.path.join(out_dir,
                                             f"sr8_epoch{epoch+1}_img{idx}.png"))

            print("Super-resolution + reinjection + final sharpen complete!")

print("Training complete!")


Synthetic ICH CT Generation - Generation 3

In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torchvision.transforms.functional import gaussian_blur
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.exposure import match_histograms

# Enable OpenCV multithreading for faster CPU performance
cv2.setUseOptimized(True)
cv2.setNumThreads(os.cpu_count())

# ---- Super-resolution setup (4× → 2× chain for 8×) ----
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

device = torch.device("cpu")  # force CPU

# 4× model for final super-res
rrdb4 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
sr4 = RealESRGANer(
    scale=4,
    model_path='RealESRGAN_x4plus.pth',
    model=rrdb4,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)
# 2× model for final super-res
rrdb2 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
sr2 = RealESRGANer(
    scale=2,
    model_path='RealESRGAN_x2plus.pth',
    model=rrdb2,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)

# ---- GAN setup ----
batch_size = 16
image_size = 64       # training on 64×64 slices
nc, nz, ngf, ndf = 1, 100, 64, 64
num_epochs, lr, beta1 = 100, 0.0002, 0.5

out_dir = "Dataset (Synthetic)/Synthetic ICH CT/Generation 3"
os.makedirs(out_dir, exist_ok=True)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.Grayscale(num_output_channels=nc),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = datasets.ImageFolder(root="Dataset (East Cyprus Hospital)/ICH Brain CT", transform=transform)
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=0)

# ---- Load a real CT slice for histogram matching ----
ref_ct_path = 'Reference CT - 4.jpg'
ref_ct = cv2.imread(ref_ct_path, cv2.IMREAD_GRAYSCALE)
if ref_ct is None:
    raise FileNotFoundError(f"Reference CT not found at {ref_ct_path}")

# ----------------------------------
# Post-processing helpers
# ----------------------------------

def denoise_nl_means(tensor_img, h=3, templateWindowSize=7, searchWindowSize=21):
    np_img = tensor_img.squeeze().cpu().numpy()
    np_img8 = ((np_img + 1.0) * 127.5).astype(np.uint8)
    denoised8 = cv2.fastNlMeansDenoising(np_img8,
                                         None,
                                         h=h,
                                         templateWindowSize=templateWindowSize,
                                         searchWindowSize=searchWindowSize)
    denoised_f = denoised8.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(denoised_f).unsqueeze(0).to(tensor_img.device)

def guided_bilateral_filter(np_img,
                            d=5, sigmaColor=50, sigmaSpace=50,
                            radius=5, eps=100.0):
    """
    Light bilateral → (optional) guided filter for edge preservation.
    """
    bf = cv2.bilateralFilter(np_img, d=d,
                             sigmaColor=sigmaColor,
                             sigmaSpace=sigmaSpace)
    return bf

def sharp_kernel_filter(tensor_img):
    """3×3 high-boost sharpening kernel (strong)."""
    kernel = np.array([[-1, -1, -1],
                       [-1,  9, -1],
                       [-1, -1, -1]], dtype=np.float32)
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    sharp = cv2.filter2D(np_img, -1, kernel)
    sharp = np.clip(sharp, 0, 255).astype(np.uint8)
    sharp_f = sharp.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(sharp_f).unsqueeze(0).to(tensor_img.device)

def clahe_enhance(tensor_img, clipLimit=2.0, tileGridSize=(4,4)):
    """CLAHE with smaller tiles for stronger local contrast."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    cl = clahe.apply(np_img)
    cl_f = cl.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(cl_f).unsqueeze(0).to(tensor_img.device)

def match_histogram_tensor(tensor_img, ref_uint8):
    """Match histogram of tensor_img to ref_uint8 (both 8-bit)."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    matched = match_histograms(np_img, ref_uint8, channel_axis=None)
    matched = np.clip(matched, 0, 255).astype(np.uint8)
    matched_f = matched.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(matched_f).unsqueeze(0).to(tensor_img.device)

# ----------------------------------
# DCGAN Generator & Discriminator
# ----------------------------------
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8), nn.ReLU(True),        # 1→4
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),        # 4→8
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),        # 8→16
            nn.ConvTranspose2d(ngf*2, ngf,   4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),   nn.ReLU(True),        # 16→32
            nn.ConvTranspose2d(ngf,    nc,     4, 2, 1, bias=False),
            nn.Tanh()                                     # 32→64
        )
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(nc,   ndf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*8, 1,     4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x).view(-1)

netG = Generator().to(device)
netD = Discriminator().to(device)

criterion  = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# fixed noise for sampling
num_samples = 400
fixed_noise = torch.randn(num_samples, nz, 1, 1, device=device)

# ----------------------------------
# Training Loop
# ----------------------------------
print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real = imgs.to(device)
        b_size = real.size(0)
        label_real = torch.ones(b_size, device=device)
        label_fake = torch.zeros(b_size, device=device)

        # Discriminator real loss
        netD.zero_grad()
        errD_real = criterion(netD(real), label_real)
        errD_real.backward()

        # Discriminator fake loss
        noise    = torch.randn(b_size, nz, 1, 1, device=device)
        fake_img = netG(noise)
        errD_fake = criterion(netD(fake_img.detach()), label_fake)
        errD_fake.backward()
        optimizerD.step()

        # Generator loss
        netG.zero_grad()
        errG = criterion(netD(fake_img), label_real)
        errG.backward()
        optimizerG.step()

        if i % 50 == 0:
            print(f"[{epoch+1}/{num_epochs}][{i}/{len(dataloader)}] "
                  f"Loss_D: {(errD_real+errD_fake).item():.4f} "
                  f"Loss_G: {errG.item():.4f}")

    # ---- Sampling & saving after each epoch ----
    with torch.no_grad():
        raw = netG(fixed_noise).cpu()

        # 1) Guided bilateral filter → filtered, compute noise map
        filtered = torch.zeros_like(raw)
        for idx in range(raw.size(0)):
            np_img = (raw[idx].squeeze().numpy() * 127.5 + 127.5).astype(np.uint8)
            out = guided_bilateral_filter(np_img)
            tensor_out = torch.from_numpy((out.astype(np.float32)/127.5 - 1.0)) \
                            .unsqueeze(0).to(raw.device)
            filtered[idx] = tensor_out

        noise_map = raw - filtered   # retains high-freq detail

        # 2) Sharpen → denoise → CLAHE → hist match → fake64
        proc = torch.zeros_like(filtered)
        for idx in range(filtered.size(0)):
            img = filtered[idx]
            img = sharp_kernel_filter(img)
            img = denoise_nl_means(img, h=3)
            img = clahe_enhance(img, clipLimit=2.0, tileGridSize=(4,4))
            img = match_histogram_tensor(img, ref_ct)
            proc[idx] = img
        fake64 = proc

        # visualize & save 64×64 grid
        grid = utils.make_grid(fake64, nrow=20, padding=2, normalize=True)
        plt.figure(figsize=(10,10))
        plt.imshow(np.transpose(grid, (1,2,0)), cmap='gray')
        plt.axis('off')
        plt.title(f"Epoch {epoch+1}")
        plt.show()
        utils.save_image(fake64,
                         os.path.join(out_dir, f"epoch_{epoch+1}.png"),
                         normalize=True)

        # 3) Final super-res + reinjection + 512×512 sharpen
        if (epoch + 1) == num_epochs:
            print("Applying Super-Resolution + reinjection + final sharpen…")
            alpha = 0.9
            heavy_kernel = np.array([[-1, -1, -1],
                                     [-1,  9, -1],
                                     [-1, -1, -1]], dtype=np.float32)

            for idx in range(fake64.size(0)):
                pil64 = transforms.ToPILImage()(
                        (fake64[idx]*0.5 + 0.5).clamp(0,1)
                      ).convert('RGB')
                arr64 = np.array(pil64)

                out4, _ = sr4.enhance(arr64)
                out8, _ = sr2.enhance(out4)

                nm = noise_map[idx].squeeze().numpy()
                nm_img = Image.fromarray(((nm*127.5)+127.5).astype(np.uint8))
                nm_hr  = nm_img.resize((out8.shape[1], out8.shape[0]),
                                       Image.BICUBIC)
                nm_hr_f= (np.array(nm_hr).astype(np.float32)-127.5)/127.5

                sr8 = out8.astype(np.float32)/255.0
                combined = np.clip(sr8 + alpha * nm_hr_f[...,None], 0, 1)

                final = (combined*255).astype(np.uint8)[...,0]
                sharp_final = cv2.filter2D(final, -1, heavy_kernel)
                final_gray = Image.fromarray(sharp_final)
                final_gray.save(os.path.join(out_dir,
                                             f"sr8_epoch{epoch+1}_img{idx}.png"))

            print("Super-resolution + reinjection + final sharpen complete!")

print("Training complete!")


Synthetic Normal CT Generation - Generation 1

In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torchvision.transforms.functional import gaussian_blur
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.exposure import match_histograms

# Enable OpenCV multithreading for faster CPU performance
cv2.setUseOptimized(True)
cv2.setNumThreads(os.cpu_count())

# ---- Super-resolution setup (4× → 2× chain for 8×) ----
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

device = torch.device("cpu")  # force CPU

# 4× model for final super-res
rrdb4 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
sr4 = RealESRGANer(
    scale=4,
    model_path='RealESRGAN_x4plus.pth',
    model=rrdb4,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)
# 2× model for final super-res
rrdb2 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
sr2 = RealESRGANer(
    scale=2,
    model_path='RealESRGAN_x2plus.pth',
    model=rrdb2,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)

# ---- GAN setup ----
batch_size = 16
image_size = 64       # training on 64×64 slices
nc, nz, ngf, ndf = 1, 100, 64, 64
num_epochs, lr, beta1 = 100, 0.0002, 0.5

out_dir = "Dataset (Synthetic)/Synthetic Normal Brain CT/Generation 1"
os.makedirs(out_dir, exist_ok=True)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.Grayscale(num_output_channels=nc),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = datasets.ImageFolder(root="Dataset (East Cyprus Hospital)/Normal Brain CT", transform=transform)
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=0)

# ---- Load a real CT slice for histogram matching ----
ref_ct_path = 'Reference CT - 4.jpg'
ref_ct = cv2.imread(ref_ct_path, cv2.IMREAD_GRAYSCALE)
if ref_ct is None:
    raise FileNotFoundError(f"Reference CT not found at {ref_ct_path}")

# ----------------------------------
# Post-processing helpers
# ----------------------------------

def denoise_nl_means(tensor_img, h=3, templateWindowSize=7, searchWindowSize=21):
    np_img = tensor_img.squeeze().cpu().numpy()
    np_img8 = ((np_img + 1.0) * 127.5).astype(np.uint8)
    denoised8 = cv2.fastNlMeansDenoising(np_img8,
                                         None,
                                         h=h,
                                         templateWindowSize=templateWindowSize,
                                         searchWindowSize=searchWindowSize)
    denoised_f = denoised8.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(denoised_f).unsqueeze(0).to(tensor_img.device)

def guided_bilateral_filter(np_img,
                            d=5, sigmaColor=50, sigmaSpace=50,
                            radius=5, eps=100.0):
    """
    Light bilateral → (optional) guided filter for edge preservation.
    """
    bf = cv2.bilateralFilter(np_img, d=d,
                             sigmaColor=sigmaColor,
                             sigmaSpace=sigmaSpace)
    return bf

def sharp_kernel_filter(tensor_img):
    """3×3 high-boost sharpening kernel (strong)."""
    kernel = np.array([[-1, -1, -1],
                       [-1,  9, -1],
                       [-1, -1, -1]], dtype=np.float32)
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    sharp = cv2.filter2D(np_img, -1, kernel)
    sharp = np.clip(sharp, 0, 255).astype(np.uint8)
    sharp_f = sharp.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(sharp_f).unsqueeze(0).to(tensor_img.device)

def clahe_enhance(tensor_img, clipLimit=2.0, tileGridSize=(4,4)):
    """CLAHE with smaller tiles for stronger local contrast."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    cl = clahe.apply(np_img)
    cl_f = cl.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(cl_f).unsqueeze(0).to(tensor_img.device)

def match_histogram_tensor(tensor_img, ref_uint8):
    """Match histogram of tensor_img to ref_uint8 (both 8-bit)."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    matched = match_histograms(np_img, ref_uint8, channel_axis=None)
    matched = np.clip(matched, 0, 255).astype(np.uint8)
    matched_f = matched.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(matched_f).unsqueeze(0).to(tensor_img.device)

# ----------------------------------
# DCGAN Generator & Discriminator
# ----------------------------------
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8), nn.ReLU(True),        # 1→4
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),        # 4→8
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),        # 8→16
            nn.ConvTranspose2d(ngf*2, ngf,   4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),   nn.ReLU(True),        # 16→32
            nn.ConvTranspose2d(ngf,    nc,     4, 2, 1, bias=False),
            nn.Tanh()                                     # 32→64
        )
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(nc,   ndf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*8, 1,     4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x).view(-1)

netG = Generator().to(device)
netD = Discriminator().to(device)

criterion  = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# fixed noise for sampling
num_samples = 400
fixed_noise = torch.randn(num_samples, nz, 1, 1, device=device)

# ----------------------------------
# Training Loop
# ----------------------------------
print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real = imgs.to(device)
        b_size = real.size(0)
        label_real = torch.ones(b_size, device=device)
        label_fake = torch.zeros(b_size, device=device)

        # Discriminator real loss
        netD.zero_grad()
        errD_real = criterion(netD(real), label_real)
        errD_real.backward()

        # Discriminator fake loss
        noise    = torch.randn(b_size, nz, 1, 1, device=device)
        fake_img = netG(noise)
        errD_fake = criterion(netD(fake_img.detach()), label_fake)
        errD_fake.backward()
        optimizerD.step()

        # Generator loss
        netG.zero_grad()
        errG = criterion(netD(fake_img), label_real)
        errG.backward()
        optimizerG.step()

        if i % 50 == 0:
            print(f"[{epoch+1}/{num_epochs}][{i}/{len(dataloader)}] "
                  f"Loss_D: {(errD_real+errD_fake).item():.4f} "
                  f"Loss_G: {errG.item():.4f}")

    # ---- Sampling & saving after each epoch ----
    with torch.no_grad():
        raw = netG(fixed_noise).cpu()

        # 1) Guided bilateral filter → filtered, compute noise map
        filtered = torch.zeros_like(raw)
        for idx in range(raw.size(0)):
            np_img = (raw[idx].squeeze().numpy() * 127.5 + 127.5).astype(np.uint8)
            out = guided_bilateral_filter(np_img)
            tensor_out = torch.from_numpy((out.astype(np.float32)/127.5 - 1.0)) \
                            .unsqueeze(0).to(raw.device)
            filtered[idx] = tensor_out

        noise_map = raw - filtered   # retains high-freq detail

        # 2) Sharpen → denoise → CLAHE → hist match → fake64
        proc = torch.zeros_like(filtered)
        for idx in range(filtered.size(0)):
            img = filtered[idx]
            img = sharp_kernel_filter(img)
            img = denoise_nl_means(img, h=3)
            img = clahe_enhance(img, clipLimit=2.0, tileGridSize=(4,4))
            img = match_histogram_tensor(img, ref_ct)
            proc[idx] = img
        fake64 = proc

        # visualize & save 64×64 grid
        grid = utils.make_grid(fake64, nrow=20, padding=2, normalize=True)
        plt.figure(figsize=(10,10))
        plt.imshow(np.transpose(grid, (1,2,0)), cmap='gray')
        plt.axis('off')
        plt.title(f"Epoch {epoch+1}")
        plt.show()
        utils.save_image(fake64,
                         os.path.join(out_dir, f"epoch_{epoch+1}.png"),
                         normalize=True)

        # 3) Final super-res + reinjection + 512×512 sharpen
        if (epoch + 1) == num_epochs:
            print("Applying Super-Resolution + reinjection + final sharpen…")
            alpha = 0.9
            heavy_kernel = np.array([[-1, -1, -1],
                                     [-1,  9, -1],
                                     [-1, -1, -1]], dtype=np.float32)

            for idx in range(fake64.size(0)):
                pil64 = transforms.ToPILImage()(
                        (fake64[idx]*0.5 + 0.5).clamp(0,1)
                      ).convert('RGB')
                arr64 = np.array(pil64)

                out4, _ = sr4.enhance(arr64)
                out8, _ = sr2.enhance(out4)

                nm = noise_map[idx].squeeze().numpy()
                nm_img = Image.fromarray(((nm*127.5)+127.5).astype(np.uint8))
                nm_hr  = nm_img.resize((out8.shape[1], out8.shape[0]),
                                       Image.BICUBIC)
                nm_hr_f= (np.array(nm_hr).astype(np.float32)-127.5)/127.5

                sr8 = out8.astype(np.float32)/255.0
                combined = np.clip(sr8 + alpha * nm_hr_f[...,None], 0, 1)

                final = (combined*255).astype(np.uint8)[...,0]
                sharp_final = cv2.filter2D(final, -1, heavy_kernel)
                final_gray = Image.fromarray(sharp_final)
                final_gray.save(os.path.join(out_dir,
                                             f"sr8_epoch{epoch+1}_img{idx}.png"))

            print("Super-resolution + reinjection + final sharpen complete!")

print("Training complete!")


Synthetic Normal CT Generation - Generation 2

In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torchvision.transforms.functional import gaussian_blur
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.exposure import match_histograms

# Enable OpenCV multithreading for faster CPU performance
cv2.setUseOptimized(True)
cv2.setNumThreads(os.cpu_count())

# ---- Super-resolution setup (4× → 2× chain for 8×) ----
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

device = torch.device("cpu")  # force CPU

# 4× model for final super-res
rrdb4 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
sr4 = RealESRGANer(
    scale=4,
    model_path='RealESRGAN_x4plus.pth',
    model=rrdb4,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)
# 2× model for final super-res
rrdb2 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
sr2 = RealESRGANer(
    scale=2,
    model_path='RealESRGAN_x2plus.pth',
    model=rrdb2,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)

# ---- GAN setup ----
batch_size = 16
image_size = 64       # training on 64×64 slices
nc, nz, ngf, ndf = 1, 100, 64, 64
num_epochs, lr, beta1 = 100, 0.0002, 0.5

out_dir = "Dataset (Synthetic)/Synthetic Normal Brain CT/Generation 2"
os.makedirs(out_dir, exist_ok=True)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.Grayscale(num_output_channels=nc),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = datasets.ImageFolder(root="Dataset (East Cyprus Hospital)/Normal Brain CT", transform=transform)
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=0)

# ---- Load a real CT slice for histogram matching ----
ref_ct_path = 'Reference CT - 4.jpg'
ref_ct = cv2.imread(ref_ct_path, cv2.IMREAD_GRAYSCALE)
if ref_ct is None:
    raise FileNotFoundError(f"Reference CT not found at {ref_ct_path}")

# ----------------------------------
# Post-processing helpers
# ----------------------------------

def denoise_nl_means(tensor_img, h=3, templateWindowSize=7, searchWindowSize=21):
    np_img = tensor_img.squeeze().cpu().numpy()
    np_img8 = ((np_img + 1.0) * 127.5).astype(np.uint8)
    denoised8 = cv2.fastNlMeansDenoising(np_img8,
                                         None,
                                         h=h,
                                         templateWindowSize=templateWindowSize,
                                         searchWindowSize=searchWindowSize)
    denoised_f = denoised8.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(denoised_f).unsqueeze(0).to(tensor_img.device)

def guided_bilateral_filter(np_img,
                            d=5, sigmaColor=50, sigmaSpace=50,
                            radius=5, eps=100.0):
    """
    Light bilateral → (optional) guided filter for edge preservation.
    """
    bf = cv2.bilateralFilter(np_img, d=d,
                             sigmaColor=sigmaColor,
                             sigmaSpace=sigmaSpace)
    return bf

def sharp_kernel_filter(tensor_img):
    """3×3 high-boost sharpening kernel (strong)."""
    kernel = np.array([[-1, -1, -1],
                       [-1,  9, -1],
                       [-1, -1, -1]], dtype=np.float32)
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    sharp = cv2.filter2D(np_img, -1, kernel)
    sharp = np.clip(sharp, 0, 255).astype(np.uint8)
    sharp_f = sharp.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(sharp_f).unsqueeze(0).to(tensor_img.device)

def clahe_enhance(tensor_img, clipLimit=2.0, tileGridSize=(4,4)):
    """CLAHE with smaller tiles for stronger local contrast."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    cl = clahe.apply(np_img)
    cl_f = cl.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(cl_f).unsqueeze(0).to(tensor_img.device)

def match_histogram_tensor(tensor_img, ref_uint8):
    """Match histogram of tensor_img to ref_uint8 (both 8-bit)."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    matched = match_histograms(np_img, ref_uint8, channel_axis=None)
    matched = np.clip(matched, 0, 255).astype(np.uint8)
    matched_f = matched.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(matched_f).unsqueeze(0).to(tensor_img.device)

# ----------------------------------
# DCGAN Generator & Discriminator
# ----------------------------------
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8), nn.ReLU(True),        # 1→4
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),        # 4→8
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),        # 8→16
            nn.ConvTranspose2d(ngf*2, ngf,   4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),   nn.ReLU(True),        # 16→32
            nn.ConvTranspose2d(ngf,    nc,     4, 2, 1, bias=False),
            nn.Tanh()                                     # 32→64
        )
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(nc,   ndf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*8, 1,     4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x).view(-1)

netG = Generator().to(device)
netD = Discriminator().to(device)

criterion  = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# fixed noise for sampling
num_samples = 400
fixed_noise = torch.randn(num_samples, nz, 1, 1, device=device)

# ----------------------------------
# Training Loop
# ----------------------------------
print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real = imgs.to(device)
        b_size = real.size(0)
        label_real = torch.ones(b_size, device=device)
        label_fake = torch.zeros(b_size, device=device)

        # Discriminator real loss
        netD.zero_grad()
        errD_real = criterion(netD(real), label_real)
        errD_real.backward()

        # Discriminator fake loss
        noise    = torch.randn(b_size, nz, 1, 1, device=device)
        fake_img = netG(noise)
        errD_fake = criterion(netD(fake_img.detach()), label_fake)
        errD_fake.backward()
        optimizerD.step()

        # Generator loss
        netG.zero_grad()
        errG = criterion(netD(fake_img), label_real)
        errG.backward()
        optimizerG.step()

        if i % 50 == 0:
            print(f"[{epoch+1}/{num_epochs}][{i}/{len(dataloader)}] "
                  f"Loss_D: {(errD_real+errD_fake).item():.4f} "
                  f"Loss_G: {errG.item():.4f}")

    # ---- Sampling & saving after each epoch ----
    with torch.no_grad():
        raw = netG(fixed_noise).cpu()

        # 1) Guided bilateral filter → filtered, compute noise map
        filtered = torch.zeros_like(raw)
        for idx in range(raw.size(0)):
            np_img = (raw[idx].squeeze().numpy() * 127.5 + 127.5).astype(np.uint8)
            out = guided_bilateral_filter(np_img)
            tensor_out = torch.from_numpy((out.astype(np.float32)/127.5 - 1.0)) \
                            .unsqueeze(0).to(raw.device)
            filtered[idx] = tensor_out

        noise_map = raw - filtered   # retains high-freq detail

        # 2) Sharpen → denoise → CLAHE → hist match → fake64
        proc = torch.zeros_like(filtered)
        for idx in range(filtered.size(0)):
            img = filtered[idx]
            img = sharp_kernel_filter(img)
            img = denoise_nl_means(img, h=3)
            img = clahe_enhance(img, clipLimit=2.0, tileGridSize=(4,4))
            img = match_histogram_tensor(img, ref_ct)
            proc[idx] = img
        fake64 = proc

        # visualize & save 64×64 grid
        grid = utils.make_grid(fake64, nrow=20, padding=2, normalize=True)
        plt.figure(figsize=(10,10))
        plt.imshow(np.transpose(grid, (1,2,0)), cmap='gray')
        plt.axis('off')
        plt.title(f"Epoch {epoch+1}")
        plt.show()
        utils.save_image(fake64,
                         os.path.join(out_dir, f"epoch_{epoch+1}.png"),
                         normalize=True)

        # 3) Final super-res + reinjection + 512×512 sharpen
        if (epoch + 1) == num_epochs:
            print("Applying Super-Resolution + reinjection + final sharpen…")
            alpha = 0.9
            heavy_kernel = np.array([[-1, -1, -1],
                                     [-1,  9, -1],
                                     [-1, -1, -1]], dtype=np.float32)

            for idx in range(fake64.size(0)):
                pil64 = transforms.ToPILImage()(
                        (fake64[idx]*0.5 + 0.5).clamp(0,1)
                      ).convert('RGB')
                arr64 = np.array(pil64)

                out4, _ = sr4.enhance(arr64)
                out8, _ = sr2.enhance(out4)

                nm = noise_map[idx].squeeze().numpy()
                nm_img = Image.fromarray(((nm*127.5)+127.5).astype(np.uint8))
                nm_hr  = nm_img.resize((out8.shape[1], out8.shape[0]),
                                       Image.BICUBIC)
                nm_hr_f= (np.array(nm_hr).astype(np.float32)-127.5)/127.5

                sr8 = out8.astype(np.float32)/255.0
                combined = np.clip(sr8 + alpha * nm_hr_f[...,None], 0, 1)

                final = (combined*255).astype(np.uint8)[...,0]
                sharp_final = cv2.filter2D(final, -1, heavy_kernel)
                final_gray = Image.fromarray(sharp_final)
                final_gray.save(os.path.join(out_dir,
                                             f"sr8_epoch{epoch+1}_img{idx}.png"))

            print("Super-resolution + reinjection + final sharpen complete!")

print("Training complete!")


Synthetic Normal CT Generation - Generation 3

In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torchvision.transforms.functional import gaussian_blur
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.exposure import match_histograms

# Enable OpenCV multithreading for faster CPU performance
cv2.setUseOptimized(True)
cv2.setNumThreads(os.cpu_count())

# ---- Super-resolution setup (4× → 2× chain for 8×) ----
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

device = torch.device("cpu")  # force CPU

# 4× model for final super-res
rrdb4 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
sr4 = RealESRGANer(
    scale=4,
    model_path='RealESRGAN_x4plus.pth',
    model=rrdb4,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)
# 2× model for final super-res
rrdb2 = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
sr2 = RealESRGANer(
    scale=2,
    model_path='RealESRGAN_x2plus.pth',
    model=rrdb2,
    tile=0, tile_pad=10, pre_pad=0,
    half=False, device=device
)

# ---- GAN setup ----
batch_size = 16
image_size = 64       # training on 64×64 slices
nc, nz, ngf, ndf = 1, 100, 64, 64
num_epochs, lr, beta1 = 100, 0.0002, 0.5

out_dir = "Dataset (Synthetic)/Synthetic Normal Brain CT/Generation 3"
os.makedirs(out_dir, exist_ok=True)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.Grayscale(num_output_channels=nc),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = datasets.ImageFolder(root="Dataset (East Cyprus Hospital)/Normal Brain CT", transform=transform)
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=0)

# ---- Load a real CT slice for histogram matching ----
ref_ct_path = 'Reference CT - 4.jpg'
ref_ct = cv2.imread(ref_ct_path, cv2.IMREAD_GRAYSCALE)
if ref_ct is None:
    raise FileNotFoundError(f"Reference CT not found at {ref_ct_path}")

# ----------------------------------
# Post-processing helpers
# ----------------------------------

def denoise_nl_means(tensor_img, h=3, templateWindowSize=7, searchWindowSize=21):
    np_img = tensor_img.squeeze().cpu().numpy()
    np_img8 = ((np_img + 1.0) * 127.5).astype(np.uint8)
    denoised8 = cv2.fastNlMeansDenoising(np_img8,
                                         None,
                                         h=h,
                                         templateWindowSize=templateWindowSize,
                                         searchWindowSize=searchWindowSize)
    denoised_f = denoised8.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(denoised_f).unsqueeze(0).to(tensor_img.device)

def guided_bilateral_filter(np_img,
                            d=5, sigmaColor=50, sigmaSpace=50,
                            radius=5, eps=100.0):
    """
    Light bilateral → (optional) guided filter for edge preservation.
    """
    bf = cv2.bilateralFilter(np_img, d=d,
                             sigmaColor=sigmaColor,
                             sigmaSpace=sigmaSpace)
    return bf

def sharp_kernel_filter(tensor_img):
    """3×3 high-boost sharpening kernel (strong)."""
    kernel = np.array([[-1, -1, -1],
                       [-1,  9, -1],
                       [-1, -1, -1]], dtype=np.float32)
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    sharp = cv2.filter2D(np_img, -1, kernel)
    sharp = np.clip(sharp, 0, 255).astype(np.uint8)
    sharp_f = sharp.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(sharp_f).unsqueeze(0).to(tensor_img.device)

def clahe_enhance(tensor_img, clipLimit=2.0, tileGridSize=(4,4)):
    """CLAHE with smaller tiles for stronger local contrast."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=tileGridSize)
    cl = clahe.apply(np_img)
    cl_f = cl.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(cl_f).unsqueeze(0).to(tensor_img.device)

def match_histogram_tensor(tensor_img, ref_uint8):
    """Match histogram of tensor_img to ref_uint8 (both 8-bit)."""
    np_img = (tensor_img.squeeze().cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
    matched = match_histograms(np_img, ref_uint8, channel_axis=None)
    matched = np.clip(matched, 0, 255).astype(np.uint8)
    matched_f = matched.astype(np.float32) / 127.5 - 1.0
    return torch.from_numpy(matched_f).unsqueeze(0).to(tensor_img.device)

# ----------------------------------
# DCGAN Generator & Discriminator
# ----------------------------------
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8), nn.ReLU(True),        # 1→4
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),        # 4→8
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),        # 8→16
            nn.ConvTranspose2d(ngf*2, ngf,   4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),   nn.ReLU(True),        # 16→32
            nn.ConvTranspose2d(ngf,    nc,     4, 2, 1, bias=False),
            nn.Tanh()                                     # 32→64
        )
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(nc,   ndf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf*8, 1,     4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x).view(-1)

netG = Generator().to(device)
netD = Discriminator().to(device)

criterion  = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# fixed noise for sampling
num_samples = 400
fixed_noise = torch.randn(num_samples, nz, 1, 1, device=device)

# ----------------------------------
# Training Loop
# ----------------------------------
print("Starting Training Loop...")
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real = imgs.to(device)
        b_size = real.size(0)
        label_real = torch.ones(b_size, device=device)
        label_fake = torch.zeros(b_size, device=device)

        # Discriminator real loss
        netD.zero_grad()
        errD_real = criterion(netD(real), label_real)
        errD_real.backward()

        # Discriminator fake loss
        noise    = torch.randn(b_size, nz, 1, 1, device=device)
        fake_img = netG(noise)
        errD_fake = criterion(netD(fake_img.detach()), label_fake)
        errD_fake.backward()
        optimizerD.step()

        # Generator loss
        netG.zero_grad()
        errG = criterion(netD(fake_img), label_real)
        errG.backward()
        optimizerG.step()

        if i % 50 == 0:
            print(f"[{epoch+1}/{num_epochs}][{i}/{len(dataloader)}] "
                  f"Loss_D: {(errD_real+errD_fake).item():.4f} "
                  f"Loss_G: {errG.item():.4f}")

    # ---- Sampling & saving after each epoch ----
    with torch.no_grad():
        raw = netG(fixed_noise).cpu()

        # 1) Guided bilateral filter → filtered, compute noise map
        filtered = torch.zeros_like(raw)
        for idx in range(raw.size(0)):
            np_img = (raw[idx].squeeze().numpy() * 127.5 + 127.5).astype(np.uint8)
            out = guided_bilateral_filter(np_img)
            tensor_out = torch.from_numpy((out.astype(np.float32)/127.5 - 1.0)) \
                            .unsqueeze(0).to(raw.device)
            filtered[idx] = tensor_out

        noise_map = raw - filtered   # retains high-freq detail

        # 2) Sharpen → denoise → CLAHE → hist match → fake64
        proc = torch.zeros_like(filtered)
        for idx in range(filtered.size(0)):
            img = filtered[idx]
            img = sharp_kernel_filter(img)
            img = denoise_nl_means(img, h=3)
            img = clahe_enhance(img, clipLimit=2.0, tileGridSize=(4,4))
            img = match_histogram_tensor(img, ref_ct)
            proc[idx] = img
        fake64 = proc

        # visualize & save 64×64 grid
        grid = utils.make_grid(fake64, nrow=20, padding=2, normalize=True)
        plt.figure(figsize=(10,10))
        plt.imshow(np.transpose(grid, (1,2,0)), cmap='gray')
        plt.axis('off')
        plt.title(f"Epoch {epoch+1}")
        plt.show()
        utils.save_image(fake64,
                         os.path.join(out_dir, f"epoch_{epoch+1}.png"),
                         normalize=True)

        # 3) Final super-res + reinjection + 512×512 sharpen
        if (epoch + 1) == num_epochs:
            print("Applying Super-Resolution + reinjection + final sharpen…")
            alpha = 0.9
            heavy_kernel = np.array([[-1, -1, -1],
                                     [-1,  9, -1],
                                     [-1, -1, -1]], dtype=np.float32)

            for idx in range(fake64.size(0)):
                pil64 = transforms.ToPILImage()(
                        (fake64[idx]*0.5 + 0.5).clamp(0,1)
                      ).convert('RGB')
                arr64 = np.array(pil64)

                out4, _ = sr4.enhance(arr64)
                out8, _ = sr2.enhance(out4)

                nm = noise_map[idx].squeeze().numpy()
                nm_img = Image.fromarray(((nm*127.5)+127.5).astype(np.uint8))
                nm_hr  = nm_img.resize((out8.shape[1], out8.shape[0]),
                                       Image.BICUBIC)
                nm_hr_f= (np.array(nm_hr).astype(np.float32)-127.5)/127.5

                sr8 = out8.astype(np.float32)/255.0
                combined = np.clip(sr8 + alpha * nm_hr_f[...,None], 0, 1)

                final = (combined*255).astype(np.uint8)[...,0]
                sharp_final = cv2.filter2D(final, -1, heavy_kernel)
                final_gray = Image.fromarray(sharp_final)
                final_gray.save(os.path.join(out_dir,
                                             f"sr8_epoch{epoch+1}_img{idx}.png"))

            print("Super-resolution + reinjection + final sharpen complete!")

print("Training complete!")


In [None]:
pip install streamlit torch torchvision numpy Pillow scikit-image torchmetrics scikit-learn matplotlib

In [None]:
import streamlit as st
import os
import io
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from torchvision import transforms, datasets
from torchvision.utils import make_grid, save_image
from skimage.metrics import structural_similarity as ssim
from skimage.exposure import match_histograms
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt

# -- Device --
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -- DCGAN Generator Definition --
nz, ngf, nc = 100, 64, 1
nrows = 8
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf), nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        return self.model(x)

# -- Postprocessing Helpers --
def match_histogram_tensor(tensor_img, ref_img_uint8):
    img = tensor_img.squeeze().cpu().numpy()
    img_uint8 = ((img + 1.0) * 127.5).astype(np.uint8)
    matched = match_histograms(img_uint8, ref_img_uint8, channel_axis=None)
    matched = np.clip(matched, 0, 255).astype(np.uint8)
    f = matched.astype(np.float32)/127.5 - 1.0
    return torch.from_numpy(f).unsqueeze(0)

# -- Metric Functions --
def compute_inception_score(images, resize=True):
    is_metric = InceptionScore(feature=None, dims=2048)
    return is_metric(images.to(device), resize=resize)


def compute_fid(real_images, fake_images, resize=True):
    fid_metric = FrechetInceptionDistance(feature=None, dims=2048)
    fid_metric(real_images.to(device), real=True, resize=resize)
    fid_metric(fake_images.to(device), real=False, resize=resize)
    return fid_metric.compute()


def compute_ssim_batch(real_imgs, fake_imgs):
    scores = []
    real = real_imgs.squeeze().cpu().numpy()
    fake = fake_imgs.squeeze().cpu().numpy()
    for r, f in zip(real, fake):
        scores.append(ssim(r, f, data_range=2.0))
    return float(np.mean(scores))

# -- Classification AUC Evaluation --
def evaluate_auc(real_dir, fake_images):
    # Load real images and labels
    dataset = datasets.ImageFolder(root=real_dir, transform=transforms.Compose([
        transforms.Resize(64), transforms.Grayscale(), transforms.ToTensor()]))
    loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
    X_real, y_real = [], []
    for imgs, labels in loader:
        X_real.append(imgs.view(imgs.size(0), -1).numpy())
        y_real.append(labels.numpy())
    X_real = np.vstack(X_real); y_real = np.hstack(y_real)
    # Fake images flatten
    X_fake = fake_images.view(fake_images.size(0), -1).cpu().numpy()
    y_fake = np.ones(X_fake.shape[0]) * -1  # label synthetic as class -1
    # binary labels: real=1, fake=0
    X = np.vstack([X_real, X_fake]); y = np.hstack([np.ones_like(y_real), np.zeros_like(y_fake)])
    # Train on real only
    clf_real = LogisticRegression(max_iter=1000).fit(X_real, np.ones_like(y_real))
    # Predict on real test
    probs_real = clf_real.predict_proba(X_real)[:,1]
    auc_real = roc_auc_score(np.ones_like(y_real), probs_real)
    # Train on real+fake
    clf_mix = LogisticRegression(max_iter=1000).fit(X, y)
    probs_mix = clf_mix.predict_proba(X_real)[:,1]
    auc_mix = roc_auc_score(np.ones_like(y_real), probs_mix)
    fpr_real, tpr_real, _ = roc_curve(np.ones_like(y_real), probs_real)
    fpr_mix, tpr_mix, _ = roc_curve(np.ones_like(y_real), probs_mix)
    return {'AUC_real_only': auc_real, 'AUC_real+synthetic': auc_mix}, (fpr_real, tpr_real, fpr_mix, tpr_mix)

# -- App UI --
st.title("Synthetic CT Generator & Metrics Dashboard")
st.sidebar.header("Configuration")

# Input params
ckpt = st.sidebar.file_uploader("Upload DCGAN Generator Checkpoint", type=['pth','pt'])
num_images = st.sidebar.slider("Number of Synthetic CTs", 1, 400, 100)
real_dir = st.sidebar.text_input("Path to Real CT ImageFolder", "./Dataset_Normal_CT")
ref_ct_file = st.sidebar.file_uploader("Reference CT for Hist Matching", type=['png','jpg','jpeg'])

# Load Generator
netG = Generator().to(device)
if ckpt:
    state = torch.load(ckpt, map_location=device)
    netG.load_state_dict(state)
    st.sidebar.success("Generator loaded.")

# Read reference CT
ref_uint8 = None
if ref_ct_file:
    ref_img = Image.open(ref_ct_file).convert('L')
    ref_uint8 = np.array(ref_img)

# Buttons
generate = st.sidebar.button("Generate Synthetic CTs")
compute_metrics_btn = st.sidebar.button("Compute Metrics")
evaluate_auc_btn = st.sidebar.button("Evaluate AUC")

# Storage
if 'synthetic' not in st.session_state:
    st.session_state.synthetic = None

# Generate synthetic CTs
if generate:
    st.info("Generating images...")
    noise = torch.randn(num_images, nz, 1, 1, device=device)
    with torch.no_grad():
        raw = netG(noise).cpu()
    # Postprocessing: hist match if provided
    if ref_uint8 is not None:
        proc = torch.stack([match_histogram_tensor(img, ref_uint8) for img in raw])
        fake = proc
    else:
        fake = raw
    st.session_state.synthetic = fake
    grid = make_grid(fake, nrow=nrows, normalize=True)
    st.image(grid.permute(1,2,0), caption="Synthetic CTs", use_column_width=True)

# Compute and display metrics
if compute_metrics_btn and st.session_state.synthetic is not None:
    st.info("Computing metrics...")
    # Load real images
    real_dataset = datasets.ImageFolder(root=real_dir, transform=transforms.Compose([
        transforms.Resize(64), transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]))
    real_loader = torch.utils.data.DataLoader(real_dataset, batch_size=num_images, shuffle=False)
    real_imgs, _ = next(iter(real_loader))
    fake_imgs = st.session_state.synthetic
    # Normalize to [-1,1]
    is_score, is_std = compute_inception_score(fake_imgs)
    fid_score = compute_fid(real_imgs, fake_imgs)
    ssim_score = compute_ssim_batch((real_imgs+1)/1, (fake_imgs+1)/1)
    st.metric("Inception Score", f"{is_score:.3f} ± {is_std:.3f}")
    st.metric("FID", f"{fid_score:.3f}")
    st.metric("SSIM", f"{ssim_score:.3f}")

# Evaluate AUC
if evaluate_auc_btn and st.session_state.synthetic is not None:
    st.info("Evaluating AUC for classification task...")
    aucs, curves = evaluate_auc(real_dir, st.session_state.synthetic)
    st.metric("AUC (Real Only)", f"{aucs['AUC_real_only']:.3f}")
    st.metric("AUC (Real + Synthetic)", f"{aucs['AUC_real+synthetic']:.3f}")
    fpr_r, tpr_r, fpr_m, tpr_m = curves
    fig, ax = plt.subplots()
    ax.plot(fpr_r, tpr_r, label='Real Only')
    ax.plot(fpr_m, tpr_m, label='Real+Synthetic')
    ax.plot([0,1], [0,1], 'k--')
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('ROC Curves')
    ax.legend()
    st.pyplot(fig)

# Download synthetic
if st.session_state.synthetic is not None:
    buffer = io.BytesIO()
    save_image(st.session_state.synthetic, buffer, nrow=nrows, normalize=True)
    st.download_button(label="Download Grid of Synthetic CTs",
                       data=buffer.getvalue(), file_name="synthetic_cts.png",
                       mime="image/png")
