In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import os
import glob

# ===================================================================
# 1. Configuration
# ===================================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Paths
DATASET_PATH = "C:/Users/Sanjjey Arumugam/.cache/kagglehub/datasets/jeftaadriel/osteoarthritis-initiative-oai-dataset/versions/1/train"
OUTPUT_DIR = "cvae_generated_images"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 30
LEARNING_RATE = 1e-4
IMAGE_SIZE = 64
IMAGE_CHANNELS = 1
LATENT_DIM = 100
NUM_CLASSES = 5

# ===================================================================
# 2. Dataset
# ===================================================================
class OaiDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        for i in range(NUM_CLASSES):
            class_dir = os.path.join(self.root_dir, str(i))
            if os.path.isdir(class_dir):
                paths = glob.glob(os.path.join(class_dir, '*.png'))
                self.image_paths.extend(paths)
                self.labels.extend([i] * len(paths))
            else:
                print(f"Warning: Directory not found for class {i}: {class_dir}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("L")
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        if self.transform:
            image = self.transform(image)
        return image, label

transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

if not os.path.exists(DATASET_PATH):
    print(f"ERROR: Dataset path not found: {DATASET_PATH}")
else:
    dataset = OaiDataset(root_dir=DATASET_PATH, transform=transform)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    print(f"Successfully loaded {len(dataset)} images.")

# ===================================================================
# 3. CVAE Model
# ===================================================================
class CVAE(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, num_classes=NUM_CLASSES, img_channels=IMAGE_CHANNELS):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_channels = img_channels

        # Embed labels into latent_dim space
        self.embedding = nn.Embedding(num_classes, latent_dim)

        # ===== Encoder =====
        self.encoder = nn.Sequential(
            nn.Conv2d(img_channels + 1, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim)

        # ===== Decoder =====
        self.decoder_input = nn.Linear(latent_dim + latent_dim, 256 * 4 * 4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def encode(self, x, y):
        y_embed = self.embedding(y).unsqueeze(2).unsqueeze(3)
        y_map = y_embed.mean(1, keepdim=True).expand(-1, 1, x.size(2), x.size(3))
        combined = torch.cat([x, y_map], dim=1)
        result = self.encoder(combined)
        mu = self.fc_mu(result)
        logvar = self.fc_logvar(result)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, y):
        y_embed = self.embedding(y)
        combined = torch.cat([z, y_embed], dim=1)
        result = self.decoder_input(combined)
        result = result.view(-1, 256, 4, 4)
        return self.decoder(result)

    def forward(self, x, y):
        mu, logvar = self.encode(x, y)
        z = self.reparameterize(mu, logvar)
        reconstruction = self.decode(z, y)
        return reconstruction, mu, logvar

# ===================================================================
# 4. Loss
# ===================================================================
def vae_loss_function(recon_x, x, mu, logvar):
    recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

# ===================================================================
# 5. Training
# ===================================================================
model = CVAE().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# print("Starting C-VAE Training...")
# for epoch in range(NUM_EPOCHS):
#     total_loss = 0
#     for batch_idx, (real_images, labels) in enumerate(dataloader):
#         real_images = real_images.to(DEVICE)
#         labels = labels.to(DEVICE)

#         recon_images, mu, logvar = model(real_images, labels)
#         loss = vae_loss_function(recon_images, real_images, mu, logvar)
#         total_loss += loss.item()

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         if (batch_idx + 1) % 100 == 0:
#             print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Batch [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item()/len(real_images):.4f}")

#     print(f"====> Epoch: {epoch+1} Average loss: {total_loss / len(dataset):.4f}")

#     if (epoch + 1) % 10 == 0:
#         model.eval()
#         with torch.no_grad():
#             num_samples = NUM_CLASSES
#             fixed_noise = torch.randn(num_samples, LATENT_DIM).to(DEVICE)
#             fixed_labels = torch.arange(0, NUM_CLASSES).to(DEVICE)
#             generated_images = model.decode(fixed_noise, fixed_labels)
#             save_image(generated_images.view(num_samples, 1, IMAGE_SIZE, IMAGE_SIZE),
#                        os.path.join(OUTPUT_DIR, f'sample_epoch_{epoch+1}.png'),
#                        nrow=NUM_CLASSES, normalize=True)
#         model.train()

# print("✅ Training Finished.")

# ===================================================================
# 6. Final Generation
# ===================================================================
# print("Generating final images for each stage...")
# model.eval()
# with torch.no_grad():
#     for stage in range(NUM_CLASSES):
#         num_final_samples = 16
#         noise = torch.randn(num_final_samples, LATENT_DIM).to(DEVICE)
#         labels = torch.full((num_final_samples,), stage, dtype=torch.long).to(DEVICE)
#         final_images = model.decode(noise, labels)
#         save_image(final_images.view(num_final_samples, 1, IMAGE_SIZE, IMAGE_SIZE),
#                    os.path.join(OUTPUT_DIR, f'final_stage_{stage}.png'),
#                    nrow=4, normalize=True)
#         print(f"✅ Saved final images for stage {stage}.")


Using device: cuda
Successfully loaded 5778 images.


In [2]:
MODEL_PATH = os.path.join(OUTPUT_DIR, "cvae_model.pth")

torch.save({
    "epoch": NUM_EPOCHS,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss.item(),
}, MODEL_PATH)

print(f"✅ Model saved at {MODEL_PATH}")


NameError: name 'loss' is not defined

In [3]:
!pip install torchmetrics pytorch-fid


Collecting pytorch-fid
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Downloading pytorch_fid-0.3.0-py3-none-any.whl (15 kB)
Installing collected packages: pytorch-fid
Successfully installed pytorch-fid-0.3.0


In [3]:
from skimage.metrics import structural_similarity as ssim
import numpy as np

def evaluate_model(model, dataloader, device):
    model.eval()
    total_mse, total_ssim, count = 0, 0, 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            recon_images, mu, logvar = model(images, labels)

            # --- MSE ---
            mse = nn.functional.mse_loss(recon_images, images, reduction="mean").item()
            total_mse += mse * images.size(0)

            # --- SSIM (convert to numpy & denormalize to [0,1]) ---
            for i in range(images.size(0)):
                orig = images[i].cpu().numpy().squeeze()
                recon = recon_images[i].cpu().numpy().squeeze()
                orig = (orig + 1) / 2.0   # denormalize from [-1,1] → [0,1]
                recon = (recon + 1) / 2.0
                total_ssim += ssim(orig, recon, data_range=1.0)
                count += 1

    avg_mse = total_mse / count
    avg_ssim = total_ssim / count
    print(f"🔎 Evaluation Results: MSE={avg_mse:.4f}, SSIM={avg_ssim:.4f}")
    return avg_mse, avg_ssim


In [5]:
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)

model = CVAE().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
print(f"✅ Model loaded. Resuming from epoch {start_epoch}")


✅ Model loaded. Resuming from epoch 100


In [6]:
evaluate_model(model, dataloader, DEVICE)


🔎 Evaluation Results: MSE=0.0050, SSIM=0.8559


(0.005026210344297564, np.float64(0.8559373979817012))

In [11]:
!pip install pytorch-msssim

Collecting pytorch-msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Installing collected packages: pytorch-msssim
Successfully installed pytorch-msssim-1.0.0


In [22]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from pathlib import Path
from PIL import Image
from pytorch_msssim import ssim

# =========================
# Dataset
# =========================
class OAIDataset(Dataset):
    def __init__(self, root, split="train", image_size=128):
        self.root = Path(root) / split
        self.samples = []
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),  # force 128x128
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # normalize to [-1, 1]
        ])
        for stage_folder in sorted(self.root.iterdir()):
            stage = int(stage_folder.name)  # 0..4
            for img_path in stage_folder.glob("*.png"):
                side = 0
                if "_2.png" in img_path.name:
                    side = 1
                self.samples.append((img_path, stage, side))

    def __len__(self): 
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, stage, side = self.samples[idx]
        img = Image.open(img_path).convert("L")
        img = self.transform(img)
        return img, torch.tensor(stage), torch.tensor(side)


# =========================
# Evaluation function
# =========================
def evaluate_cgan_model(model, dataloader, device, latent_dim=128):
    model.eval()
    mse_loss = nn.MSELoss()
    total_mse, total_ssim, count = 0.0, 0.0, 0

    with torch.no_grad():
        for real_images, stage, side in dataloader:
            real_images = real_images.to(device)
            stage = stage.to(device)
            side = side.to(device)

            # Random latent vector (correct shape)
            z = torch.randn(real_images.size(0), latent_dim, device=device)

            # Generate images
            fake_images = model(z, stage, side)

            # Compute metrics
            mse_val = mse_loss(fake_images, real_images).item()
            ssim_val = ssim(
                fake_images.cpu(), real_images.cpu(),
                data_range=2.0, size_average=True
            ).item()

            total_mse += mse_val
            total_ssim += ssim_val
            count += 1

    avg_mse = total_mse / count
    avg_ssim = total_ssim / count
    print(f"✅ Evaluation Results: MSE={avg_mse:.4f}, SSIM={avg_ssim:.4f}")
    return avg_mse, avg_ssim


# =========================
# Load checkpoint and model
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Update path to your checkpoint
checkpoint = torch.load(r"C:\Users\Sanjjey Arumugam\results\checkpoint_epoch_200.pth", map_location=device)

# Make sure these match the original Generator definition
netG = Generator(nz=128, n_stage=5, n_side=2, image_size=128).to(device)
netG.load_state_dict(checkpoint['netG_state_dict'])

# =========================
# Dataset & Dataloader
# =========================
dataset = OAIDataset(root="C:/Users/Sanjjey Arumugam/.cache/kagglehub/datasets/jeftaadriel/osteoarthritis-initiative-oai-dataset/versions/1", 
                     split="train", image_size=128)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

# =========================
# Evaluate
# =========================
mse, ssim_score = evaluate_cgan_model(netG, dataloader, device, latent_dim=128)
print(f"Final Evaluation -> MSE: {mse:.4f}, SSIM: {ssim_score:.4f}")


✅ Evaluation Results: MSE=0.1603, SSIM=0.3317
Final Evaluation -> MSE: 0.1603, SSIM: 0.3317


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.utils import save_image
from pathlib import Path
from PIL import Image
import os
from pytorch_fid import fid_score

# =========================
# Dataset
# =========================
class OAIDataset(Dataset):
    def __init__(self, root, split="train", image_size=128):
        self.root = Path(root) / split
        self.samples = []
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # [-1,1]
        ])
        for stage_folder in sorted(self.root.iterdir()):
            stage = int(stage_folder.name)  # 0..4
            for img_path in stage_folder.glob("*.png"):
                side = 0
                if "_2.png" in img_path.name:
                    side = 1
                self.samples.append((img_path, stage, side))

    def __len__(self): 
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, stage, side = self.samples[idx]
        img = Image.open(img_path).convert("L")
        img = self.transform(img)
        return img, torch.tensor(stage), torch.tensor(side)

# =========================
# GAN Generator
# =========================
class Generator(nn.Module):
    def __init__(self, nz=128, ngf=64, nc=1, n_stage=5, n_side=2, image_size=128):
        super().__init__()
        self.stage_emb = nn.Embedding(n_stage, n_stage)
        self.side_emb  = nn.Embedding(n_side, n_side)
        self.init_size = image_size // 16
        cond_dim = n_stage + n_side
        self.fc = nn.Linear(nz + cond_dim, ngf*8*self.init_size*self.init_size)

        def block(in_feat, out_feat):
            return nn.Sequential(
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_feat, out_feat, 3, padding=1),
                nn.BatchNorm2d(out_feat),
                nn.ReLU(True)
            )
        self.conv_blocks = nn.Sequential(
            block(ngf*8, ngf*4),
            block(ngf*4, ngf*2),
            block(ngf*2, ngf),
            block(ngf, ngf//2),
            nn.Conv2d(ngf//2, nc, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, z, stage, side):
        cond = torch.cat([self.stage_emb(stage), self.side_emb(side)], dim=1)
        x = torch.cat([z, cond], dim=1)
        out = self.fc(x).view(x.size(0), -1, self.init_size, self.init_size)
        return self.conv_blocks(out)

# =========================
# Paths and Device
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = r"C:\Users\Sanjjey Arumugam\results\checkpoint_epoch_200.pth"
dataset_root = "C:/Users/Sanjjey Arumugam/.cache/kagglehub/datasets/jeftaadriel/osteoarthritis-initiative-oai-dataset/versions/1"

generated_root = Path("cgan_generated")
generated_root.mkdir(exist_ok=True)

# =========================
# Load checkpoint
# =========================
checkpoint = torch.load(checkpoint_path, map_location=device)
netG = Generator(nz=128, n_stage=5, n_side=2, image_size=128).to(device)
netG.load_state_dict(checkpoint['netG_state_dict'])
netG.eval()

# =========================
# Dataset & Dataloader
# =========================
dataset = OAIDataset(root=dataset_root, split="train", image_size=128)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

latent_dim = 128

# =========================
# Generate images per stage
# =========================
print("🚀 Generating images for FID evaluation...")
stage_dirs = []
with torch.no_grad():
    for stage in range(5):
        stage_dir = generated_root / f"stage_{stage}"
        stage_dir.mkdir(exist_ok=True)
        stage_dirs.append(stage_dir)

        # Filter dataset for this stage
        stage_samples = [(img, s, side) for img, s, side in dataset.samples if s == stage]
        for i in range(0, len(stage_samples), 64):
            batch_samples = stage_samples[i:i+64]
            imgs = []
            stages = []
            sides = []
            for img_path, s, side in batch_samples:
                img = Image.open(img_path).convert("L")
                img = transforms.Compose([
                    transforms.Resize((128,128)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5],[0.5])
                ])(img)
                imgs.append(img)
                stages.append(torch.tensor(s))
                sides.append(torch.tensor(side))
            imgs = torch.stack(imgs).to(device)
            stages = torch.stack(stages).to(device)
            sides = torch.stack(sides).to(device)
            
            z = torch.randn(imgs.size(0), latent_dim, device=device)
            fake_imgs = netG(z, stages, sides)
            for j in range(fake_imgs.size(0)):
                save_image(fake_imgs[j], stage_dir / f"img_{i+j}.png", normalize=True)

# =========================
# Compute FID per stage
# =========================
print("🚀 Computing FID per stage...")
for stage in range(5):
    real_stage_dir = Path(dataset_root) / "train" / str(stage)
    fake_stage_dir = generated_root / f"stage_{stage}"
    fid_val = fid_score.calculate_fid_given_paths([str(real_stage_dir), str(fake_stage_dir)],
                                                  batch_size=64,
                                                  device=device,
                                                  dims=2048)
    print(f"Stage {stage} -> FID: {fid_val:.4f}")


🚀 Generating images for FID evaluation...
🚀 Computing FID per stage...


100%|██████████| 36/36 [00:11<00:00,  3.07it/s]
