In [1]:
!pip install medmnist

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.0.tar.gz (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->medmnist)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->medmnist)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->medmnist)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->medmnist)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting n

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.utils as vutils
from medmnist import ChestMNIST
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import os

In [3]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Load MedMNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [5]:
from medmnist import ChestMNIST

# Load MedMNIST Dataset without downloading
train_dataset = ChestMNIST(split="train", transform=transform, download=True) # Set download=True to download the dataset
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

100%|██████████| 82.8M/82.8M [00:07<00:00, 11.0MB/s]


In [6]:
# Generator Model
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z).view(-1, 1, 28, 28)

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        return self.net(x.view(x.size(0), -1))

In [7]:
# WGAN-GP Gradient Penalty
def gradient_penalty(D, real_data, fake_data):
    alpha = torch.rand(real_data.size(0), 1, 1, 1).to(device)
    interpolates = (alpha * real_data + (1 - alpha) * fake_data).requires_grad_(True)
    d_interpolates = D(interpolates)
    grad_outputs = torch.ones_like(d_interpolates)
    gradients = torch.autograd.grad(
        outputs=d_interpolates, inputs=interpolates, grad_outputs=grad_outputs,
        create_graph=True, retain_graph=True)[0]
    return ((gradients.norm(2, dim=1) - 1) ** 2).mean()

In [8]:
# Training Function
def train_gan(gan_type, num_epochs=50):
    writer = SummaryWriter(f"runs/{gan_type}")

    z_dim = 100
    generator = Generator(z_dim).to(device)
    discriminator = Discriminator().to(device)

    optim_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optim_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    for epoch in range(num_epochs):
        for real, _ in tqdm(train_loader):
            real = real.to(device)

            # Generate fake images
            z = torch.randn(real.size(0), z_dim).to(device)
            fake = generator(z)

            # Discriminator update
            optim_D.zero_grad()
            real_loss, fake_loss = 0, 0

            if gan_type == "LS-GAN":
                real_loss = 0.5 * ((discriminator(real) - 1) ** 2).mean()
                fake_loss = 0.5 * (discriminator(fake) ** 2).mean()
            elif gan_type == "WGAN":
                real_loss = -discriminator(real).mean()
                fake_loss = discriminator(fake).mean()
            elif gan_type == "WGAN-GP":
                real_loss = -discriminator(real).mean()
                fake_loss = discriminator(fake).mean()
                gp = gradient_penalty(discriminator, real, fake)
                loss_D = real_loss + fake_loss + 10 * gp
            else:
                raise ValueError("Invalid GAN type")

            loss_D = real_loss + fake_loss
            loss_D.backward()
            optim_D.step()

            # Generator update
            if epoch % 5 == 0:
                optim_G.zero_grad()
                fake = generator(z)
                loss_G = -discriminator(fake).mean() if gan_type in ["WGAN", "WGAN-GP"] else ((discriminator(fake) - 1) ** 2).mean()
                loss_G.backward()
                optim_G.step()

                # TensorBoard Logging
                writer.add_scalar("Loss/Discriminator", loss_D.item(), epoch)
                writer.add_scalar("Loss/Generator", loss_G.item(), epoch)

        # Save generated images
        vutils.save_image(fake[:25], f"generated/{gan_type}_epoch_{epoch}.png", normalize=True)

    torch.save(generator.state_dict(), f"models/{gan_type}_generator.pth")
    writer.close()


In [9]:
# Ensure directories exist
os.makedirs("generated", exist_ok=True)
os.makedirs("models", exist_ok=True)

In [10]:
# Train all three GANs
for gan in ["LS-GAN", "WGAN", "WGAN-GP"]:
    train_gan(gan)

100%|██████████| 1227/1227 [00:23<00:00, 51.45it/s]
100%|██████████| 1227/1227 [00:22<00:00, 54.06it/s]
100%|██████████| 1227/1227 [00:18<00:00, 65.62it/s]
100%|██████████| 1227/1227 [00:19<00:00, 64.32it/s]
100%|██████████| 1227/1227 [00:18<00:00, 67.30it/s]
100%|██████████| 1227/1227 [00:22<00:00, 54.12it/s]
100%|██████████| 1227/1227 [00:18<00:00, 66.60it/s]
100%|██████████| 1227/1227 [00:19<00:00, 64.15it/s]
100%|██████████| 1227/1227 [00:19<00:00, 62.67it/s]
100%|██████████| 1227/1227 [00:19<00:00, 63.68it/s]
100%|██████████| 1227/1227 [00:22<00:00, 54.67it/s]
100%|██████████| 1227/1227 [00:18<00:00, 66.07it/s]
100%|██████████| 1227/1227 [00:18<00:00, 67.18it/s]
100%|██████████| 1227/1227 [00:19<00:00, 63.43it/s]
100%|██████████| 1227/1227 [00:18<00:00, 67.59it/s]
100%|██████████| 1227/1227 [00:22<00:00, 54.51it/s]
100%|██████████| 1227/1227 [00:18<00:00, 67.12it/s]
100%|██████████| 1227/1227 [00:19<00:00, 63.80it/s]
100%|██████████| 1227/1227 [00:18<00:00, 67.84it/s]
100%|███████

# **Evaluation**

In [14]:
import torch
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torchvision.models.inception import inception_v3
from scipy.linalg import sqrtm
from torch.nn.functional import softmax
from medmnist import ChestMNIST
import os

In [19]:
def calculate_inception_score(images, inception_model, batch_size=32):
    """Computes Inception Score (IS)."""
    images = images.to("cuda")
    inception_model.eval()
    with torch.no_grad():
        preds = softmax(inception_model(images), dim=1).cpu().numpy()

    scores = []
    for i in range(0, len(preds), batch_size):
        part = preds[i:i+batch_size]
        py = np.mean(part, axis=0)
        kl_div = part * (np.log(part) - np.log(py))
        scores.append(np.exp(np.mean(np.sum(kl_div, axis=1))))
    return np.mean(scores), np.std(scores)

def calculate_fid(real_images, fake_images, inception_model):
    """Computes Fréchet Inception Distance (FID)."""
    inception_model.eval()
    with torch.no_grad():
        real_features = inception_model(real_images).cpu().numpy()
        fake_features = inception_model(fake_images).cpu().numpy()

    mu_real, sigma_real = real_features.mean(axis=0), np.cov(real_features, rowvar=False)
    mu_fake, sigma_fake = fake_features.mean(axis=0), np.cov(fake_features, rowvar=False)

    diff = mu_real - mu_fake
    cov_mean, _ = sqrtm(sigma_real @ sigma_fake, disp=False)
    if np.iscomplexobj(cov_mean):
        cov_mean = cov_mean.real

    fid_score = np.sum(diff**2) + np.trace(sigma_real + sigma_fake - 2 * cov_mean)
    return fid_score

def evaluate_gan(gan_type, generator, data_loader, z_dim=100):
    """Evaluates the trained GAN model."""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    writer = SummaryWriter(f"runs/{gan_type}_evaluation")
    generator.to(device).eval()
    inception_model = inception_v3(pretrained=True, transform_input=False).to(device)

    real_images, _ = next(iter(data_loader))
    real_images = real_images.to(device)

    with torch.no_grad():
        z = torch.randn(real_images.size(0), z_dim, device=device)
        fake_images = generator(z)

    # Compute IS and FID
    inception_score, is_std = calculate_inception_score(fake_images, inception_model)
    fid_score = calculate_fid(real_images, fake_images, inception_model)

    # Log metrics
    writer.add_scalar("Metrics/Inception Score", inception_score, 0)
    writer.add_scalar("Metrics/FID", fid_score, 0)
    vutils.save_image(fake_images[:25], f"generated/{gan_type}_evaluation.png", normalize=True)
    writer.add_image("Generated Images", vutils.make_grid(fake_images[:25], normalize=True), 0)
    writer.close()

    print(f"{gan_type} - IS: {inception_score:.3f} ± {is_std:.3f}, FID: {fid_score:.3f}")

import torch
import torchvision.transforms as transforms
from medmnist import ChestMNIST
import os

# Convert grayscale to 3-channel by repeating across 3 channels
class GrayscaleToRGB:
    def __call__(self, img):
        return img.repeat(3, 1, 1)  # Repeat the grayscale channel 3 times

def main():
    os.makedirs("generated", exist_ok=True)

    dataset = ChestMNIST(root="./data", split="test", download=True, transform=transforms.Compose([
        transforms.ToTensor(),
        GrayscaleToRGB(),  # Convert grayscale to RGB by repeating channels
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for 3 channels
    ]))

    data_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

    z_dim = 100
    for gan_type in ["LS-GAN", "WGAN", "WGAN-GP"]:
        generator = Generator(z_dim)
        generator.load_state_dict(torch.load(f"models/{gan_type}_generator.pth"))
        evaluate_gan(gan_type, generator, data_loader, z_dim)

In [20]:
if __name__ == "__main__":
    main()

RuntimeError: Failed to setup the default `root` directory. Please specify and create the `root` directory manually.

In [None]:
import torch
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance
from medmnist import ChestMNIST
import torchvision.transforms as transforms

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load real images without downloading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
real_dataset = ChestMNIST(split="test", transform=transform, download=True) # Set download=True to download the dataset
real_loader = DataLoader(real_dataset, batch_size=64, shuffle=True)

In [None]:
import torch
# Import from their specific submodules
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance

In [None]:
# Compute Metrics
def compute_metrics():
    inception = InceptionScore(feature=2048).to(device)  # Specify feature=2048
    fid = FrechetInceptionDistance(feature=2048).to(device)

    # Get real images
    real_images = next(iter(real_loader))[0].to(device)
    real_images = real_images.float()  # Ensure float before scaling
    real_images = (real_images * 255).byte()  # Convert to uint8
    real_images = real_images.repeat(1, 3, 1, 1)  # Convert grayscale to RGB if needed
    fid.update(real_images, real=True)

    for gan in ["LS-GAN", "WGAN", "WGAN-GP"]:
        generator = Generator().to(device)  # Ensure Generator is defined
        generator.load_state_dict(torch.load(f"models/{gan}_generator.pth"))
        generator.eval()

        # Generate fake images
        latent_dim = 100  # Adjust this if needed
        fake_images = torch.cat([generator(torch.randn(64, latent_dim, device=device)) for _ in range(10)], dim=0)

        fake_images = fake_images.float()  # Ensure float before scaling
        fake_images = (fake_images + 1) / 2  # Rescale to [0,1]
        fake_images = (fake_images * 255).byte()
        fake_images = fake_images.repeat(1, 3, 1, 1)

        inception.update(fake_images)
        fid.update(fake_images, real=False)

        score, _ = inception.compute()
        fid_value = fid.compute()

        print(f"{gan} - Inception Score: {score.item()}, FID: {fid_value.item()}")


In [None]:
compute_metrics()



ModuleNotFoundError: InceptionScore metric requires that `Torch-fidelity` is installed. Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`.

In [None]:
pip install torchmetrics[image]



In [None]:
pip install torch-fidelity



In [None]:
compute_metrics()

ModuleNotFoundError: InceptionScore metric requires that `Torch-fidelity` is installed. Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`.

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance

# Remove the InceptionScore-related parts from compute_metrics()
def compute_metrics():
    fid = FrechetInceptionDistance().to(device)

    # Get real images
    real_images = next(iter(real_loader))[0].to(device)
    real_images = (real_images * 255).byte()  # Convert to uint8
    real_images = real_images.repeat(1, 3, 1, 1)  # Convert grayscale to RGB if needed
    fid.update(real_images, real=True)

    for gan in ["LS-GAN", "WGAN", "WGAN-GP"]:
        generator = Generator().to(device)
        generator.load_state_dict(torch.load(f"models/{gan}_generator.pth"))
        generator.eval()

        # Generate fake images
        fake_images = torch.cat([generator(torch.randn(64, 100).to(device)) for _ in range(10)], dim=0)
        fake_images = (fake_images + 1) / 2  # Rescale to [0,1]
        fake_images = (fake_images * 255).byte()
        fake_images = fake_images.repeat(1, 3, 1, 1)

        fid.update(fake_images, real=False)

        fid_value = fid.compute()
        print(f"{gan} - FID: {fid_value.item()}")

compute_metrics()


ModuleNotFoundError: FrechetInceptionDistance metric requires that `Torch-fidelity` is installed. Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`.

In [None]:
pip show torch-fidelity

Name: torch-fidelity
Version: 0.3.0
Summary: High-fidelity performance metrics for generative models in PyTorch
Home-page: https://www.github.com/toshas/torch-fidelity
Author: Anton Obukhov
Author-email: 
License: Apache License 2.0
Location: /usr/local/lib/python3.11/dist-packages
Requires: numpy, Pillow, scipy, torch, torchvision, tqdm
Required-by: 


In [None]:
import torch_fidelity
print("Torch-fidelity is working correctly!")

Torch-fidelity is working correctly!
