# Multi-Label Conditional Diffusion for Disease-Specific Chest X-ray Synthesis using ChestX-ray14

## Import the Required Libraries

In [1]:
!pip install torch matplotlib diffusers pandas numpy kaggle gradio tqdm



In [2]:
import os
import random
import pathlib
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from kaggle.api.kaggle_api_extended import KaggleApi
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import gradio as gr
%matplotlib inline

# For reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x17c06a9cdf0>

## Data Preparation

In [3]:
# Authenticate
api = KaggleApi()
api.authenticate()

In [4]:
# Make base directories
base = pathlib.Path("./data")

In [5]:
# # Download & unzip into train folders
# api.dataset_download_files(
#     "nih-chest-xrays/data",
#     path = str(base / 'train'),
#     unzip = False
# )

In [6]:
DISEASES = ['Atelectasis','Cardiomegaly','Effusion','Infiltration',
            'Mass','Nodule','Pneumonia','Pneumothorax',
            'Consolidation','Edema','Emphysema','Fibrosis',
            'Pleural_Thickening','Hernia']

def encode_labels(disease):
    onehot=[0]*14

    for i, d in enumerate(DISEASES):
        if d in str(disease).split("|"):
            onehot[i] = 1
    
    return onehot

In [7]:
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))        
])

val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))        
])

In [8]:
class XRayDataset(Dataset):
    def __init__(self, transform, root_img_dir, csv_path):
        self.df = pd.read_csv(csv_path)
        self.root = pathlib.Path(root_img_dir)
        self.transform = transform 
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self.root / row['Image Index']
        img = Image.open(img_path).convert('L')

        if self.transform: 
            img = self.transform(img)

        onehot = torch.tensor(encode_labels(row['Finding Labels']),
                              dtype=torch.float)

        return img, onehot

In [9]:
img_dir = pathlib.Path(r"C:\Users\asust\Documents\Medical Diffusion Model\data")
img_file_paths = list(img_dir.glob("*.png"))
available_imgs = set()   # For fast look-up, and to ensure no duplicates

for img_path in img_file_paths:
    filename = img_path.name
    available_imgs.add(filename)

df = pd.read_csv(r"C:\Users\asust\Documents\Medical Diffusion Model\Data_Entry_2017.csv")

# Filter only rows where image is present in your folder
df_filtered = df[df["Image Index"].isin(available_imgs)]

# Save the filtered CSV
df_filtered.to_csv(r"C:\Users\asust\Documents\Medical Diffusion Model\sample_metadata.csv", index=False)

In [None]:
bs = 32
csv_path = r"C:\Users\asust\Documents\Medical Diffusion Model\sample_metadata.csv"
root_dir = r"C:\Users\asust\Documents\Medical Diffusion Model\data"

train_ds = XRayDataset(train_transform, root_dir, csv_path)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True,  num_workers=0)

In [None]:
imgs, labels = next(iter(train_dl))
imgs_vis = imgs[:4]*0.5+0.5

fig, axes = plt.subplots(4,1,figsize=(4,8))
for i in range(4):
    axes[i].imshow(imgs_vis[i].permute(1,2,0), cmap='gray')
    title = ' '.join([d for d,flag in zip(DISEASES, labels[i]) if flag==1])
    axes[i].set_title(title if title else 'No Finding')
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## Diffusion Model

In [None]:
class DDPM_Scheduler(nn.Module):
    def __init__(self, num_timesteps = 1000):
        super().__init__()

        self.num_timesteps=num_timesteps
        
        # Linear beta schedule from 1e-4 to 0.02
        beta = torch.linspace(1e-4, 0.02, num_timesteps)
        self.register_buffer("beta", beta)

        alpha = 1.0 - beta
        self.register_buffer("alpha", alpha)

        # Cumulative product: alpha_hat_t = prod(alpha_1 to alpha_t)
        alpha_hat = torch.cumprod(alpha, dim=0)
        self.register_buffer("alpha_hat", alpha_hat)

    def get_params(self, t):
        beta_t = self.beta[t].view(-1, 1, 1, 1)
        alpha_t = self.alpha[t].view(-1, 1, 1, 1)
        alpha_hat_t = self.alpha_hat[t].view(-1, 1, 1, 1)
        return beta_t, alpha_t, alpha_hat_t 

### Forward Function for Diffusion Model

In [None]:
def diffusion_forward(x0, t, e, scheduler: DDPM_Scheduler):
    beta_t, alpha_t, alpha_hat_t = scheduler.get_params(t)

    return ((torch.sqrt(alpha_hat_t) * x0) + (torch.sqrt(1 - alpha_hat_t) * e))

### Reverse Function for Diffusion Model (UNet Architecture)