# **Synthetic Dental X-Ray Generation and Segmentation Analysis**

## Imports

Imports

In [None]:
import os, glob, time, random, sys
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import matplotlib.pyplot as plt
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid, save_image
import torch.optim as optim

import cv2

from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

Environment report

In [None]:
print("Environment")
print("-----------")
print(f"Python        : {sys.version.split()[0]}")
print(f"NumPy         : {np.__version__}")
print(f"PyTorch       : {torch.__version__}")
print(f"OpenCV        : {cv2.__version__}")
print()

# CUDA / device check
if torch.cuda.is_available():
    print("CUDA available")
    print(f"GPU           : {torch.cuda.get_device_name(0)}")
    print(f"CUDA version  : {torch.version.cuda}")
    print(f"cuDNN         : {torch.backends.cudnn.version()}")
else:
    print("CUDA NOT available (training will be slow)")

## Globals

In [None]:
SEED = 13
PROJECT_ROOT = Path("../workspace")
DATA_ROOT_IMG = PROJECT_ROOT / "teeth_seg_dataset/d2/img"
MASK_DIR = PROJECT_ROOT / "teeth_seg_dataset/d2/masks_machine"
IMG_SIZE = 256
BATCH_SIZE = 16
NUM_WORKERS = 2
CHANNELS = 1
EPOCHS = 10
LR = 2e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

LATENT_DIM = 100
BETA1 = 0.5
BETA2 = 0.999

## Utils

In [None]:
def set_seed(seed: int=13):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)

In [None]:
def pad_to_square(img, fill=0):
    w, h = img.size
    if w == h: return img
    s = max(w, h)
    out = Image.new(img.mode, (s, s), color=fill)
    out.paste(img, ((s - w) // 2, (s - h) // 2))
    return out

In [None]:
def to_tensor_gray(img):
    if img.mode != "L": img = img.convert("L")
    arr = np.array(img, dtype=np.float32) / 255.0   # (H, W)
    arr = arr[None, ...]    # (1, H, W)
    return torch.from_numpy(arr)

In [None]:
def denorm(x):
    # [-1, 1] -> [0, 1]
    return (x.clamp(-1, 1) + 1) * 0.5

In [None]:
def show_grid(t, nrow=8, title=None):
    g = make_grid(denorm(t.detach().cpu()), nrow=nrow, padding=2)
    g_np = g.squeeze(0).permute(1, 2, 0).numpy()
    plt.figure(figsize=(22, 16), dpi=160)
    plt.axis("off")
    if title: plt.title(title)
    plt.imshow(g_np.squeeze(), cmap="gray")
    plt.show()

## Data

In [None]:
def list_image_files(root):
    return sorted(glob.glob(os.path.join(root, "**", "*.jpg"), recursive=True))

class DentalXRays(Dataset):
    def __init__(self, root, size=256):
        super().__init__()
        self.paths = list_image_files(root)
        self.size = size
        if not self.paths:
            print(f"[WARNING] No .jpg images found under: {root}")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("L")
        img = pad_to_square(img, fill=0)
        img = img.resize((self.size, self.size), resample=Image.BICUBIC)
        t = to_tensor_gray(img) # [0, 1]
        t = t * 2.0 - 1.0   # [-1, 1] for GAN
        return t

In [None]:
dataset = DentalXRays(DATA_ROOT_IMG, IMG_SIZE)
if len(dataset) == 0:
    raise RuntimeError(f"No .jpg images in {DATA_ROOT_IMG}")
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)

# Quick visual sanity check
batch = next(iter(loader))
show_grid(batch[:16], nrow=4, title="Real images (normalized, d2/img)")
print(f"Loaded {len(dataset)} .jpg images from {DATA_ROOT_IMG}")
print(batch.size())

## Network

###### Vanilla GAN

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8)),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128, momentum=0.78),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64, momentum=0.78),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # self.model = nn.Sequential(
        # nn.Conv2d(1, 256, kernel_size=3, stride=2, padding=1),
        # nn.LeakyReLU(0.2),
        # nn.Dropout(0.25),
        # nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
        # nn.ZeroPad2d((0, 1, 0, 1)),
        # nn.BatchNorm2d(512, momentum=0.82),
        # nn.LeakyReLU(0.25),
        # nn.Dropout(0.25),
        # nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
        # nn.BatchNorm2d(1024, momentum=0.82),
        # nn.LeakyReLU(0.2),
        # nn.Dropout(0.25),
        # nn.Conv2d(1024, 2048, kernel_size=3, stride=1, padding=1),
        # nn.BatchNorm2d(2048, momentum=0.8),
        # nn.LeakyReLU(0.25),
        # nn.Dropout(0.25),
        # nn.Flatten(),
        # nn.Linear(2048 * 8 * 8, 1),
        # nn.Sigmoid()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ZeroPad2d((0, 1, 0, 1)),
            nn.BatchNorm2d(64, momentum=0.82),
            nn.LeakyReLU(0.25),
            nn.Dropout(0.25),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128, momentum=0.82),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.25),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256, momentum=0.8),
            nn.LeakyReLU(0.25),
            nn.Dropout(0.25),
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 1),
            nn.Sigmoid()
    )

    def forward(self, img):
        validity = self.model(img)
        return validity

## Train

###### Vanilla GAN

In [None]:
generator = Generator(LATENT_DIM).to(DEVICE)
discriminator = Discriminator().to(DEVICE)

adversarial_loss = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr=LR, betas=(BETA1, BETA2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=LR, betas=(BETA1, BETA2))

In [None]:
for epoch in range(EPOCHS):
    for i, batch in tqdm(enumerate(loader)):

        real_images = batch.to(DEVICE)

        valid = torch.ones(real_images.size(0), 1, device=DEVICE)
        fake = torch.zeros(real_images.size(0), 1, device=DEVICE)

        real_images = real_images.to(DEVICE)

        optimizer_D.zero_grad()

        z = torch.randn(real_images.size(0), LATENT_DIM, device=DEVICE)

        fake_images = generator(z)

        real_loss = adversarial_loss(discriminator(real_images), valid)
        fake_loss = adversarial_loss(discriminator(fake_images.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()

        gen_images = generator(z)

        g_loss = adversarial_loss(discriminator(gen_images), valid)
        g_loss.backward()
        optimizer_G.step()

        if (i + 1) % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{EPOCHS}]\
                        Batch {i+1}/{len(loader)} "
                f"Discriminator Loss: {d_loss.item():.4f} "
                f"Generator Loss: {g_loss.item():.4f}"
            )
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            z = torch.randn(16, LATENT_DIM, device=DEVICE)
            generated = generator(z).detach().cpu()
            grid = make_grid(generated, nrow=4, normalize=True)
            plt.imshow(np.transpose(grid, (1, 2, 0)))
            plt.axis("off")
            plt.show()

## Evaluation