### GAN

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import gc

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set environment variable to avoid memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Create directories
os.makedirs("MURA-v1.1/checkpoints", exist_ok=True)
os.makedirs("MURA-v1.1/synthetic", exist_ok=True)
os.makedirs("MURA-v1.1/synthetic/samples", exist_ok=True)

Using device: cuda


In [2]:
# Define initial transform
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load dataset
df_train_images = pd.read_csv("MURA-v1.1/train_image_paths.csv", header=None, names=["image_path"])
df_train_images["label"] = df_train_images["image_path"].apply(lambda x: 1 if "positive" in x else 0)
df_train_images["category"] = df_train_images["image_path"].apply(lambda x: x.split('/')[2])

# Verify dataset
print(f"Total images: {len(df_train_images)}")
print(df_train_images.head(10))
for path in df_train_images["image_path"].head():
    print(f"File {'exists' if os.path.exists(path) else 'not found'}: {path}")

# Compute category counts
categories = df_train_images["category"].unique()
category_counts = {cat: len(df_train_images[df_train_images["category"] == cat]) for cat in categories}
print("Category counts:", category_counts)
target_count = max(category_counts.values())
print(f"Target count for balancing: {target_count}")

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    dummy = torch.zeros(1, device=device)  # Warm-up
print(f"Using device: {device}")

Total images: 36808
                                          image_path  label     category
0  MURA-v1.1/train/XR_SHOULDER/patient00001/study...      1  XR_SHOULDER
1  MURA-v1.1/train/XR_SHOULDER/patient00001/study...      1  XR_SHOULDER
2  MURA-v1.1/train/XR_SHOULDER/patient00001/study...      1  XR_SHOULDER
3  MURA-v1.1/train/XR_SHOULDER/patient00002/study...      1  XR_SHOULDER
4  MURA-v1.1/train/XR_SHOULDER/patient00002/study...      1  XR_SHOULDER
5  MURA-v1.1/train/XR_SHOULDER/patient00002/study...      1  XR_SHOULDER
6  MURA-v1.1/train/XR_SHOULDER/patient00003/study...      1  XR_SHOULDER
7  MURA-v1.1/train/XR_SHOULDER/patient00003/study...      1  XR_SHOULDER
8  MURA-v1.1/train/XR_SHOULDER/patient00003/study...      1  XR_SHOULDER
9  MURA-v1.1/train/XR_SHOULDER/patient00004/study...      1  XR_SHOULDER
File exists: MURA-v1.1/train/XR_SHOULDER/patient00001/study1_positive/image1.png
File exists: MURA-v1.1/train/XR_SHOULDER/patient00001/study1_positive/image2.png
File exists: MU

  dummy = torch.zeros(1, device=device)  # Warm-up


In [3]:
# Compute category counts
categories = df_train_images["category"].unique()
category_counts = {cat: len(df_train_images[df_train_images["category"] == cat]) for cat in categories}
print("Category counts:", category_counts)

# Target count: maximum number of images in any category
target_count = max(category_counts.values())
print(f"Target count for balancing: {target_count}")

Category counts: {'XR_SHOULDER': 8379, 'XR_HUMERUS': 1272, 'XR_FINGER': 5106, 'XR_ELBOW': 4931, 'XR_WRIST': 9752, 'XR_FOREARM': 1825, 'XR_HAND': 5543}
Target count for balancing: 9752


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    dummy = torch.zeros(1, device=device)  # Warm-up
print(f"Using device: {device}")

Using device: cuda


In [5]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Device name: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.5.1+cu121
CUDA available: True
CUDA version: 12.1
Device name: NVIDIA GeForce RTX 3060


In [6]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2.0, initial_block=False):
        super(GeneratorBlock, self).__init__()
        self.initial_block = initial_block
        if initial_block:
            self.conv = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, 4, 1, 0),
                nn.BatchNorm2d(out_channels),
                nn.PReLU()
            )
        else:
            self.conv = nn.Sequential(
                nn.Upsample(scale_factor=scale_factor, mode='nearest'),
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.PReLU()
            )
        self.to_rgb = nn.Conv2d(out_channels, 1, 1)

    def forward(self, x):
        x = self.conv(x)
        rgb = self.to_rgb(x)
        return x, rgb

class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, initial_block=False):
        super(DiscriminatorBlock, self).__init__()
        self.initial_block = initial_block
        self.from_rgb = nn.Conv2d(1, in_channels, 1)
        if initial_block:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, 1, 0),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            )

    def forward(self, x, is_raw_input=False):
        if is_raw_input:
            x = self.from_rgb(x)
        x = self.conv(x)
        return x

class ProGANGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, resolutions=[4, 8, 16, 32, 64, 112, 224]):
        super(ProGANGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        scale_factors = [resolutions[i+1] / resolutions[i] for i in range(len(resolutions)-1)]
        self.blocks = nn.ModuleList([
            GeneratorBlock(2 * latent_dim, 512, initial_block=True),  # 4x4
            GeneratorBlock(512, 512, scale_factor=scale_factors[0]),  # 8x8
            GeneratorBlock(512, 256, scale_factor=scale_factors[1]),  # 16x16
            GeneratorBlock(256, 128, scale_factor=scale_factors[2]),  # 32x32
            GeneratorBlock(128, 64, scale_factor=scale_factors[3]),   # 64x64
            GeneratorBlock(64, 32, scale_factor=scale_factors[4]),    # 112x112
            GeneratorBlock(32, 16, scale_factor=scale_factors[5]),    # 224x224
        ])
        self.current_depth = 0

    def forward(self, z, labels, alpha=1.0):
        label_embed = self.label_emb(labels)
        z = torch.cat([z, label_embed], dim=1)
        z = z.view(z.size(0), z.size(1), 1, 1)

        x = z
        prev_rgb = None
        for i in range(self.current_depth + 1):
            x, rgb = self.blocks[i](x)
            if i == self.current_depth and i > 0:
                prev_rgb_up = F.interpolate(prev_rgb, size=rgb.shape[2:], mode='nearest')
                rgb = (1 - alpha) * prev_rgb_up + alpha * rgb
            prev_rgb = rgb
        return rgb

class ProGANDiscriminator(nn.Module):
    def __init__(self, num_classes, latent_dim=100, resolutions=[4, 8, 16, 32, 64, 112, 224]):
        super(ProGANDiscriminator, self).__init__()
        self.num_classes = num_classes
        self.latent_dim = latent_dim
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.blocks = nn.ModuleList([
            DiscriminatorBlock(16, 32),    # 224x224 -> 112x112
            DiscriminatorBlock(32, 64),    # 112x112 -> 56x56
            DiscriminatorBlock(64, 128),   # 56x56 -> 28x28
            DiscriminatorBlock(128, 256),  # 28x28 -> 14x14
            DiscriminatorBlock(256, 512),  # 14x14 -> 7x7
            DiscriminatorBlock(512, 512),  # 8x8 -> 4x4
            DiscriminatorBlock(512, 512, initial_block=True),  # 4x4 -> 1x1
        ])
        self.final_layer = nn.Linear(512 + latent_dim, 1)
        self.current_depth = 0
        self.resolutions = resolutions

    def forward(self, img, labels, alpha=1.0, current_res=None):
        if current_res is None:
            raise ValueError("current_res must be provided")

        x = img
        if self.current_depth > 0:
            current_block_idx = len(self.blocks) - self.current_depth - 1
            if alpha < 1.0:
                new_x = self.blocks[current_block_idx](x, is_raw_input=True)
                downsampled_x = F.interpolate(x, scale_factor=0.5, mode='nearest')
                old_x = self.blocks[current_block_idx + 1].from_rgb(downsampled_x)
                x = alpha * new_x + (1 - alpha) * old_x
            else:
                x = self.blocks[current_block_idx](x, is_raw_input=True)
        else:
            x = self.blocks[-1](x, is_raw_input=True)

        for i in range(current_block_idx + 1, len(self.blocks)):
            x = self.blocks[i](x, is_raw_input=False)

        x = x.view(x.size(0), -1)
        label_embed = self.label_emb(labels).view(labels.size(0), -1)
        x = torch.cat([x, label_embed], dim=1)
        x = self.final_layer(x)
        return x

class MURADataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('L')
        if self.transform:
            image = self.transform(image)
        return image, label

def compute_gradient_penalty(discriminator, real_samples, fake_samples, labels, device, current_res):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = discriminator(interpolates, labels, alpha=1.0, current_res=current_res)
    fake = torch.ones(real_samples.size(0), 1, device=device, requires_grad=False)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [7]:
def train_progan(generator, discriminator, dataloader, num_epochs_per_resolution, latent_dim, device, 
                 save_interval=10, save_dir="checkpoints", minority_label=1, category="unknown", start_depth=0):
    # Dynamically set hyperparameters based on depth
    resolutions = [4, 8, 16, 32, 64, 112, 224]
    
    for depth in range(start_depth, len(resolutions)):
        generator.current_depth = depth
        discriminator.current_depth = depth
        resolution = resolutions[depth]
        print(f"\nTraining at resolution {resolution}x{resolution}")

        # Set learning rate and other hyperparameters based on depth
        if depth < 3:  # 4x4, 8x8, 16x16
            lr = 0.0001
            fade_in_steps = 50
            n_critic = 5
        else:  # 32x32, 64x64, 112x112, 224x224
            lr = 0.00005
            fade_in_steps = 100
            n_critic = 10

        g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0, 0.9))
        d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0, 0.9))
        lambda_gp = 10

        current_transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((resolution, resolution)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        dataset = MURADataset(dataloader.dataset.image_paths, dataloader.dataset.labels, transform=current_transform)
        current_dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0, pin_memory=True)
        print(f"Dataset size: {len(dataset)}, Batches: {len(current_dataloader)}")

        # Fade-in phase with dynamically adjusted steps
        for alpha_step, alpha in enumerate(np.linspace(0, 1, fade_in_steps)):
            alpha = float(alpha)
            print(f"Fade-in phase, alpha={alpha:.4f}")
            running_d_loss = 0.0
            running_g_loss = 0.0
            for i, (real_imgs, labels) in enumerate(tqdm(current_dataloader, desc=f"Alpha {alpha:.4f}")):
                real_imgs = real_imgs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                batch_size = real_imgs.size(0)

                # Discriminator update
                for _ in range(n_critic):
                    d_optimizer.zero_grad()
                    z = torch.randn(batch_size, latent_dim, device=device)
                    fake_labels = torch.full((batch_size,), minority_label, dtype=torch.long, device=device)
                    fake_imgs = generator(z, fake_labels, alpha=alpha)

                    real_validity = discriminator(real_imgs, labels, alpha=alpha, current_res=resolution)
                    fake_validity = discriminator(fake_imgs.detach(), fake_labels, alpha=alpha, current_res=resolution)
                    gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs.detach(), 
                                                              labels, device, current_res=resolution)
                    d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
                    d_loss.backward()
                    d_optimizer.step()
                    running_d_loss += d_loss.item()

                # Generator update
                g_optimizer.zero_grad()
                fake_imgs = generator(z, fake_labels, alpha=alpha)
                fake_validity = discriminator(fake_imgs, fake_labels, alpha=alpha, current_res=resolution)
                g_loss = -torch.mean(fake_validity)
                g_loss.backward()
                g_optimizer.step()
                running_g_loss += g_loss.item()

            avg_d_loss = running_d_loss / (len(current_dataloader) * n_critic)
            avg_g_loss = running_g_loss / len(current_dataloader)
            print(f"Alpha {alpha:.4f}, Avg D Loss: {avg_d_loss:.4f}, Avg G Loss: {avg_g_loss:.4f}")

            # Save checkpoint
            if alpha_step % save_interval == 0:
                torch.save({
                    'depth': depth,
                    'alpha': alpha,
                    'generator_state_dict': generator.state_dict(),
                    'discriminator_state_dict': discriminator.state_dict(),
                    'g_optimizer_state_dict': g_optimizer.state_dict(),
                    'd_optimizer_state_dict': d_optimizer.state_dict(),
                }, os.path.join(save_dir, f"progan_{category}_{'positive' if minority_label == 1 else 'negative'}_depth{depth}_alpha{alpha:.4f}.pt"))

        # Stabilization phase
        for epoch in range(num_epochs_per_resolution - fade_in_steps):
            running_d_loss = 0.0
            running_g_loss = 0.0
            for i, (real_imgs, labels) in enumerate(tqdm(current_dataloader, desc=f"Epoch {epoch+1}")):
                real_imgs = real_imgs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                batch_size = real_imgs.size(0)

                # Discriminator update
                for _ in range(n_critic):
                    d_optimizer.zero_grad()
                    z = torch.randn(batch_size, latent_dim, device=device)
                    fake_labels = torch.full((batch_size,), minority_label, dtype=torch.long, device=device)
                    fake_imgs = generator(z, fake_labels, alpha=1.0)

                    real_validity = discriminator(real_imgs, labels, alpha=1.0, current_res=resolution)
                    fake_validity = discriminator(fake_imgs.detach(), fake_labels, alpha=1.0, current_res=resolution)
                    gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs.detach(), 
                                                              labels, device, current_res=resolution)
                    d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
                    d_loss.backward()
                    d_optimizer.step()
                    running_d_loss += d_loss.item()

                # Generator update
                g_optimizer.zero_grad()
                fake_imgs = generator(z, fake_labels, alpha=1.0)
                fake_validity = discriminator(fake_imgs, fake_labels, alpha=1.0, current_res=resolution)
                g_loss = -torch.mean(fake_validity)
                g_loss.backward()
                g_optimizer.step()
                running_g_loss += g_loss.item()

            avg_d_loss = running_d_loss / (len(current_dataloader) * n_critic)
            avg_g_loss = running_g_loss / len(current_dataloader)
            print(f"Depth {depth}, Epoch {epoch+1}/{num_epochs_per_resolution - fade_in_steps}, Avg D Loss: {avg_d_loss:.4f}, Avg G Loss: {avg_g_loss:.4f}")

            if (epoch + 1) % save_interval == 0:
                with torch.no_grad():
                    sample_z = torch.randn(5, latent_dim, device=device)
                    sample_labels = torch.full((5,), minority_label, dtype=torch.long, device=device)
                    sample_imgs = generator(sample_z, sample_labels, alpha=1.0)
                    sample_imgs = (sample_imgs * 0.5 + 0.5) * 255
                    sample_imgs = sample_imgs.cpu().numpy().astype(np.uint8)
                    for j in range(5):
                        Image.fromarray(sample_imgs[j, 0], mode='L').save(
                            f"MURA-v1.1/synthetic/samples/{category}_{'positive' if minority_label == 1 else 'negative'}_depth{depth}_epoch{epoch+1}_{j}.png"
                        )

        torch.save({
            'depth': depth,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
        }, os.path.join(save_dir, f"progan_{category}_{'positive' if minority_label == 1 else 'negative'}_depth{depth}.pt"))

In [8]:
# Hyperparameters
latent_dim = 100
num_epochs_per_resolution = 100  # Increased for better convergence
batch_size = 16
save_interval = 10
checkpoint_dir = "MURA-v1.1/checkpoints"
generation_batch_size = 64

# Dictionary to store synthetic image paths
synthetic_data = []

# For each category, balance to target count with equal positive/negative
for category in categories:
    print(f"\nBalancing body part {category} to target count {target_count}...")
    
    category_data = df_train_images[df_train_images["category"] == category]
    positive_count = len(category_data[category_data["label"] == 1])
    negative_count = len(category_data[category_data["label"] == 0])
    print(f"Current Positive: {positive_count}, Negative: {negative_count}")

    target_positive = target_negative = target_count // 2
    current_total = positive_count + negative_count

    if current_total >= target_count:
        if positive_count > target_positive:
            num_synthetic_positive = 0
            num_synthetic_negative = target_negative - negative_count
        elif negative_count > target_negative:
            num_synthetic_positive = target_positive - positive_count
            num_synthetic_negative = 0
        else:
            num_synthetic_positive = target_positive - positive_count
            num_synthetic_negative = target_negative - negative_count
    else:
        num_synthetic_positive = target_positive - positive_count
        num_synthetic_negative = target_negative - negative_count

    for minority_label, num_synthetic in [(1, num_synthetic_positive), (0, num_synthetic_negative)]:
        if num_synthetic > 0:
            print(f"Generating {num_synthetic} synthetic {'positive' if minority_label == 1 else 'negative'} images for {category}...")
            minority_images = category_data[category_data["label"] == minority_label]["image_path"].tolist()
            minority_labels = [minority_label] * len(minority_images)
            progan_dataset = MURADataset(minority_images, minority_labels, transform=transform)
            progan_loader = DataLoader(progan_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
        
            generator = ProGANGenerator(latent_dim, num_classes=2).to(device)
            discriminator = ProGANDiscriminator(num_classes=2, latent_dim=latent_dim).to(device)

            # Define optimizers before loading checkpoint
            import torch.optim as optim
            g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
            d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

            # Dynamic checkpoint path
            checkpoint_path = r"D:\Sem 6 project\MURA-v1.1\checkpoints\progan_XR_SHOULDER_positive_depth2.pt"
            start_depth = 0
            if os.path.exists(checkpoint_path):
                checkpoint = torch.load(checkpoint_path, weights_only=False)  # Adjust weights_only as needed
                generator.load_state_dict(checkpoint['generator_state_dict'])
                discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
                g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
                d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
                start_depth = checkpoint['depth'] + 1  # Resume from next depth
                print(f"Resuming training for {category} ({'positive' if minority_label == 1 else 'negative'}) from depth {start_depth}")
            else:
                print(f"No checkpoint found for {category} ({'positive' if minority_label == 1 else 'negative'}), starting from scratch.")

            if start_depth < len([4, 8, 16, 32, 64, 112, 224]):
                train_progan(generator, discriminator, progan_loader, num_epochs_per_resolution, latent_dim, device,
                             save_interval=save_interval, save_dir=checkpoint_dir, 
                             minority_label=minority_label, category=category, start_depth=start_depth)

            synthetic_images = []
            synthetic_labels = [minority_label] * num_synthetic
            for i in range(0, num_synthetic, generation_batch_size):
                batch_size = min(generation_batch_size, num_synthetic - i)
                batch_labels = torch.tensor(synthetic_labels[i:i + batch_size], dtype=torch.long, device=device)
                with torch.no_grad():
                    z = torch.randn(batch_size, latent_dim, device=device)
                    batch_images = generator(z, batch_labels, alpha=1.0)
                    synthetic_images.append(batch_images.cpu())
                torch.cuda.empty_cache()
                gc.collect()

            synthetic_images = torch.cat(synthetic_images, dim=0)
            synthetic_dir = f"MURA-v1.1/synthetic/{category}/{'positive' if minority_label == 1 else 'negative'}"
            os.makedirs(synthetic_dir, exist_ok=True)
            for i, img in enumerate(synthetic_images):
                img = img.squeeze().numpy()
                img = (img * 0.5) + 0.5
                img = (img * 255).astype(np.uint8)
                img_path = os.path.join(synthetic_dir, f"synthetic_{i}.png")
                Image.fromarray(img, mode='L').save(img_path)
                synthetic_data.append([img_path, minority_label, category])

            del generator, discriminator, progan_dataset, progan_loader, synthetic_images
            torch.cuda.empty_cache()
            gc.collect()

print("Positive/Negative balancing complete!")


Balancing body part XR_SHOULDER to target count 9752...
Current Positive: 4168, Negative: 4211
Generating 708 synthetic positive images for XR_SHOULDER...
Resuming training for XR_SHOULDER (positive) from depth 3

Training at resolution 32x32
Dataset size: 4168, Batches: 261
Fade-in phase, alpha=0.0000


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Alpha 0.0000: 100%|██████████████████████████████████████████████████████████████████| 261/261 [02:17<00:00,  1.90it/s]


Alpha 0.0000, Avg D Loss: -23.8654, Avg G Loss: 2406.3658
Fade-in phase, alpha=0.0101


Alpha 0.0101: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.0101, Avg D Loss: -497.8861, Avg G Loss: -505.6652
Fade-in phase, alpha=0.0202


Alpha 0.0202: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.0202, Avg D Loss: -1194.9336, Avg G Loss: -704.5119
Fade-in phase, alpha=0.0303


Alpha 0.0303: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.0303, Avg D Loss: -2340.2898, Avg G Loss: 12330.0609
Fade-in phase, alpha=0.0404


Alpha 0.0404: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.0404, Avg D Loss: -2799.7801, Avg G Loss: 30395.2043
Fade-in phase, alpha=0.0505


Alpha 0.0505: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:21<00:00,  3.22it/s]


Alpha 0.0505, Avg D Loss: -1766.2016, Avg G Loss: 9989.8116
Fade-in phase, alpha=0.0606


Alpha 0.0606: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.0606, Avg D Loss: -538.7872, Avg G Loss: -11.0271
Fade-in phase, alpha=0.0707


Alpha 0.0707: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.0707, Avg D Loss: -6.1127, Avg G Loss: -13186.0964
Fade-in phase, alpha=0.0808


Alpha 0.0808: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.0808, Avg D Loss: 390.0960, Avg G Loss: -31472.5776
Fade-in phase, alpha=0.0909


Alpha 0.0909: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.0909, Avg D Loss: 1707.0580, Avg G Loss: -5160.1198
Fade-in phase, alpha=0.1010


Alpha 0.1010: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.1010, Avg D Loss: 722.7686, Avg G Loss: 4039.7216
Fade-in phase, alpha=0.1111


Alpha 0.1111: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.1111, Avg D Loss: -382.0679, Avg G Loss: -9129.7762
Fade-in phase, alpha=0.1212


Alpha 0.1212: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.1212, Avg D Loss: -1206.5261, Avg G Loss: -7346.7566
Fade-in phase, alpha=0.1313


Alpha 0.1313: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.1313, Avg D Loss: -1415.0479, Avg G Loss: -7995.2602
Fade-in phase, alpha=0.1414


Alpha 0.1414: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.1414, Avg D Loss: -758.1689, Avg G Loss: -19890.4647
Fade-in phase, alpha=0.1515


Alpha 0.1515: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.1515, Avg D Loss: -679.8765, Avg G Loss: 10839.6419
Fade-in phase, alpha=0.1616


Alpha 0.1616: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.1616, Avg D Loss: 234.0592, Avg G Loss: 1946.7881
Fade-in phase, alpha=0.1717


Alpha 0.1717: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.1717, Avg D Loss: 150.2831, Avg G Loss: -4568.1535
Fade-in phase, alpha=0.1818


Alpha 0.1818: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.1818, Avg D Loss: 489.9748, Avg G Loss: -3117.2110
Fade-in phase, alpha=0.1919


Alpha 0.1919: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.1919, Avg D Loss: -1304.6859, Avg G Loss: -11012.7098
Fade-in phase, alpha=0.2020


Alpha 0.2020: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.26it/s]


Alpha 0.2020, Avg D Loss: -1345.0904, Avg G Loss: -9960.9634
Fade-in phase, alpha=0.2121


Alpha 0.2121: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:19<00:00,  3.27it/s]


Alpha 0.2121, Avg D Loss: -955.0456, Avg G Loss: -5005.1374
Fade-in phase, alpha=0.2222


Alpha 0.2222: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.26it/s]


Alpha 0.2222, Avg D Loss: -682.5065, Avg G Loss: -10251.2588
Fade-in phase, alpha=0.2323


Alpha 0.2323: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.2323, Avg D Loss: -1957.4077, Avg G Loss: -21634.8861
Fade-in phase, alpha=0.2424


Alpha 0.2424: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.2424, Avg D Loss: -2410.9111, Avg G Loss: -10744.2514
Fade-in phase, alpha=0.2525


Alpha 0.2525: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:19<00:00,  3.26it/s]


Alpha 0.2525, Avg D Loss: -1861.8455, Avg G Loss: -10889.1515
Fade-in phase, alpha=0.2626


Alpha 0.2626: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.26it/s]


Alpha 0.2626, Avg D Loss: -3256.0978, Avg G Loss: -4507.7050
Fade-in phase, alpha=0.2727


Alpha 0.2727: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.2727, Avg D Loss: -2667.3603, Avg G Loss: 17261.9497
Fade-in phase, alpha=0.2828


Alpha 0.2828: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:21<00:00,  3.21it/s]


Alpha 0.2828, Avg D Loss: -4288.8154, Avg G Loss: 7652.4371
Fade-in phase, alpha=0.2929


Alpha 0.2929: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.26it/s]


Alpha 0.2929, Avg D Loss: -5028.3060, Avg G Loss: 14400.0480
Fade-in phase, alpha=0.3030


Alpha 0.3030: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.26it/s]


Alpha 0.3030, Avg D Loss: -5354.9326, Avg G Loss: 21756.7783
Fade-in phase, alpha=0.3131


Alpha 0.3131: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.3131, Avg D Loss: -5616.0365, Avg G Loss: 18194.1960
Fade-in phase, alpha=0.3232


Alpha 0.3232: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.3232, Avg D Loss: -6060.8776, Avg G Loss: 14051.7336
Fade-in phase, alpha=0.3333


Alpha 0.3333: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.3333, Avg D Loss: -5709.3856, Avg G Loss: 17262.8763
Fade-in phase, alpha=0.3434


Alpha 0.3434: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.26it/s]


Alpha 0.3434, Avg D Loss: -6966.4034, Avg G Loss: 32082.0866
Fade-in phase, alpha=0.3535


Alpha 0.3535: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.26it/s]


Alpha 0.3535, Avg D Loss: -9229.4275, Avg G Loss: 27457.0577
Fade-in phase, alpha=0.3636


Alpha 0.3636: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.3636, Avg D Loss: -9640.6923, Avg G Loss: 13017.6106
Fade-in phase, alpha=0.3737


Alpha 0.3737: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.3737, Avg D Loss: -3621.4622, Avg G Loss: -3018.0652
Fade-in phase, alpha=0.3838


Alpha 0.3838: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.3838, Avg D Loss: -3326.8937, Avg G Loss: 17476.9214
Fade-in phase, alpha=0.3939


Alpha 0.3939: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.26it/s]


Alpha 0.3939, Avg D Loss: -3399.7061, Avg G Loss: 28414.9603
Fade-in phase, alpha=0.4040


Alpha 0.4040: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.26it/s]


Alpha 0.4040, Avg D Loss: -3705.7930, Avg G Loss: 30531.2496
Fade-in phase, alpha=0.4141


Alpha 0.4141: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.4141, Avg D Loss: -1003.8509, Avg G Loss: 18297.4447
Fade-in phase, alpha=0.4242


Alpha 0.4242: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.25it/s]


Alpha 0.4242, Avg D Loss: -3270.5927, Avg G Loss: 50443.2188
Fade-in phase, alpha=0.4343


Alpha 0.4343: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.4343, Avg D Loss: -753.3604, Avg G Loss: 31942.3985
Fade-in phase, alpha=0.4444


Alpha 0.4444: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.24it/s]


Alpha 0.4444, Avg D Loss: 6774.6008, Avg G Loss: 26958.5118
Fade-in phase, alpha=0.4545


Alpha 0.4545: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:21<00:00,  3.22it/s]


Alpha 0.4545, Avg D Loss: 5011.7360, Avg G Loss: 60198.1558
Fade-in phase, alpha=0.4646


Alpha 0.4646: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.4646, Avg D Loss: 3355.9344, Avg G Loss: 50288.2621
Fade-in phase, alpha=0.4747


Alpha 0.4747: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.23it/s]


Alpha 0.4747, Avg D Loss: 665.7271, Avg G Loss: 50348.8392
Fade-in phase, alpha=0.4848


Alpha 0.4848: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:20<00:00,  3.22it/s]


Alpha 0.4848, Avg D Loss: 4587.1067, Avg G Loss: 38669.4820
Fade-in phase, alpha=0.4949


Alpha 0.4949: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.35it/s]


Alpha 0.4949, Avg D Loss: 4051.2152, Avg G Loss: 28107.5005
Fade-in phase, alpha=0.5051


Alpha 0.5051: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.37it/s]


Alpha 0.5051, Avg D Loss: 4661.7464, Avg G Loss: 25682.4784
Fade-in phase, alpha=0.5152


Alpha 0.5152: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.5152, Avg D Loss: 8902.9673, Avg G Loss: 16308.8151
Fade-in phase, alpha=0.5253


Alpha 0.5253: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.5253, Avg D Loss: 9105.5305, Avg G Loss: -1897.0887
Fade-in phase, alpha=0.5354


Alpha 0.5354: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.5354, Avg D Loss: 7013.1279, Avg G Loss: 15916.3803
Fade-in phase, alpha=0.5455


Alpha 0.5455: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.5455, Avg D Loss: 5091.0304, Avg G Loss: 28108.7663
Fade-in phase, alpha=0.5556


Alpha 0.5556: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.35it/s]


Alpha 0.5556, Avg D Loss: 5238.0666, Avg G Loss: -36111.7696
Fade-in phase, alpha=0.5657


Alpha 0.5657: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.5657, Avg D Loss: 7556.1629, Avg G Loss: 42336.4614
Fade-in phase, alpha=0.5758


Alpha 0.5758: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.38it/s]


Alpha 0.5758, Avg D Loss: 6333.7911, Avg G Loss: -1262.2097
Fade-in phase, alpha=0.5859


Alpha 0.5859: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.37it/s]


Alpha 0.5859, Avg D Loss: 6936.2875, Avg G Loss: 10941.0361
Fade-in phase, alpha=0.5960


Alpha 0.5960: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.35it/s]


Alpha 0.5960, Avg D Loss: 6518.6997, Avg G Loss: -21050.0833
Fade-in phase, alpha=0.6061


Alpha 0.6061: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.6061, Avg D Loss: 7371.2436, Avg G Loss: 19104.6760
Fade-in phase, alpha=0.6162


Alpha 0.6162: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.6162, Avg D Loss: 8953.5805, Avg G Loss: -73358.8367
Fade-in phase, alpha=0.6263


Alpha 0.6263: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.6263, Avg D Loss: 5312.9312, Avg G Loss: -1860.9376
Fade-in phase, alpha=0.6364


Alpha 0.6364: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.37it/s]


Alpha 0.6364, Avg D Loss: 6947.0841, Avg G Loss: -11713.0378
Fade-in phase, alpha=0.6465


Alpha 0.6465: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.6465, Avg D Loss: 8703.0777, Avg G Loss: 37790.2045
Fade-in phase, alpha=0.6566


Alpha 0.6566: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.6566, Avg D Loss: 6984.1266, Avg G Loss: 17587.5968
Fade-in phase, alpha=0.6667


Alpha 0.6667: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.6667, Avg D Loss: 8238.2782, Avg G Loss: -10361.0412
Fade-in phase, alpha=0.6768


Alpha 0.6768: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.37it/s]


Alpha 0.6768, Avg D Loss: 26159.2243, Avg G Loss: 1776.5195
Fade-in phase, alpha=0.6869


Alpha 0.6869: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.6869, Avg D Loss: 21185.3032, Avg G Loss: -66987.3914
Fade-in phase, alpha=0.6970


Alpha 0.6970: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.6970, Avg D Loss: 56216.4120, Avg G Loss: 100372.5667
Fade-in phase, alpha=0.7071


Alpha 0.7071: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.36it/s]


Alpha 0.7071, Avg D Loss: 19268.5636, Avg G Loss: 139255.4553
Fade-in phase, alpha=0.7172


Alpha 0.7172: 100%|██████████████████████████████████████████████████████████████████| 261/261 [01:17<00:00,  3.37it/s]


Alpha 0.7172, Avg D Loss: 17833.5435, Avg G Loss: 53773.2284
Fade-in phase, alpha=0.7273


Alpha 0.7273:  67%|█████████████████████████████████████████▌                    | 175/261 [2:22:38<1:10:06, 48.91s/it]


KeyboardInterrupt: 

In [None]:
# Target number of images per category (match the largest category)
target_count = max(category_counts.values())
print(f"Target number of images per category: {target_count}")

# For each category, balance the total number of images
for category in categories:
    print(f"\nBalancing body part {category}...")
    
    current_count = category_counts[category]
    num_synthetic_total = target_count - current_count

    if num_synthetic_cing needed for {category}")
        continue

    # Get current positive/negative ratio
    category_data = df_train_images[df_train_images["category"] == category]
    positive_count = len(category_data[category_data["label"] == 1])
    negative_count = len(category_data[category_dtotal <= 0:
        print(f"No balanata["label"] == 0])
    total = positive_count + negative_count
    positive_ratio = positive_count / total

    # Calculate how many positive/negative synthetic images to generate
    num_positive_synthetic = int(num_synthetic_total * positive_ratio)
    num_negative_synthetic = num_synthetic_total - num_positive_synthetic

    # Generate synthetic positive images
    if num_positive_synthetic > 0:
        positive_images = category_data[category_data["label"] == 1][0].tolist()
        positive_labels = [1] * len(positive_images)
        bagan_dataset = MURADataset(positive_images, positive_labels, transform=transform)
        bagan_loader = DataLoader(bagan_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

        generator = Generator(latent_dim, num_classes=2).to(device)
        discriminator = Discriminator(num_classes=2).to(device)
        train_bagan(generator, discriminator, bagan_loader, num_epochs, latent_dim, device)

        synthetic_positive = generate_synthetic_images(generator, num_positive_synthetic, [1] * num_positive_synthetic, latent_dim, device)
        synthetic_dir = f"MURA-v1.1/synthetic/{category}/positive"
        os.makedirs(synthetic_dir, exist_ok=True)
        for i, img in enumerate(synthetic_positive):
            img = img.cpu().permute(1, 2, 0).numpy()
            img = (img * 0.5) + 0.5
            img = (img * 255).astype(np.uint8)
            img_path = os.path.join(synthetic_dir, f"synthetic_{len(positive_images) + i}.png")
            Image.fromarray(img).save(img_path)
            synthetic_data.append([img_path, 1, category])

    # Generate synthetic negative images
    if num_negative_synthetic > 0:
        negative_images = category_data[category_data["label"] == 0][0].tolist()
        negative_labels = [0] * len(negative_images)
        bagan_dataset = MURADataset(negative_images, negative_labels, transform=transform)
        bagan_loader = DataLoader(bagan_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

        generator = Generator(latent_dim, num_classes=2).to(device)
        discriminator = Discriminator(num_classes=2).to(device)
        train_bagan(generator, discriminator, bagan_loader, num_epochs, latent_dim, device)

        synthetic_negative = generate_synthetic_images(generator, num_negative_synthetic, [0] * num_negative_synthetic, latent_dim, device)
        synthetic_dir = f"MURA-v1.1/synthetic/{category}/negative"
        os.makedirs(synthetic_dir, exist_ok=True)
        for i, img in enumerate(synthetic_negative):
            img = img.cpu().permute(1, 2, 0).numpy()
            img = (img * 0.5) + 0.5
            img = (img * 255).astype(np.uint8)
            img_path = os.path.join(synthetic_dir, f"synthetic_{len(negative_images) + i}.png")
            Image.fromarray(img).save(img_path)
            synthetic_data.append([img_path, 0, category])

print("Body part balancing complete!")

In [None]:
    # Generate synthetic images
    synthetic_labels = [minority_label] * num_synthetic
    synthetic_images = generate_synthetic_images(generator, num_synthetic, synthetic_labels, latent_dim, device)

    # Save synthetic images
    synthetic_dir = f"MURA-v1.1/synthetic/{category}/{'positive' if minority_label == 1 else 'negative'}"
    os.makedirs(synthetic_dir, exist_ok=True)
    for i, img in enumerate(synthetic_images):
        img = img.cpu().squeeze().numpy()  # Remove channel dimension for grayscale
        img = (img * 0.5) + 0.5  # Denormalize
        img = (img * 255).astype(np.uint8)
        img_path = os.path.join(synthetic_dir, f"synthetic_{i}.png")
        Image.fromarray(img, mode='L').save(img_path)  # Save as grayscale
        synthetic_data.append([img_path, minority_label, category])

print("Positive/Negative balancing complete!")

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm
import gc

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set environment variable to avoid memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Create directories for checkpoints and synthetic images
os.makedirs("MURA-v1.1/checkpoints", exist_ok=True)
os.makedirs("MURA-v1.1/synthetic", exist_ok=True)
os.makedirs("MURA-v1.1/synthetic/samples", exist_ok=True)

In [None]:
# Define transforms
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Scale to [-1, 1]
])

# Load the dataset (assuming train.csv is provided)
df_train_images = pd.read_csv("MURA-v1.1/train_image_paths.csv", header=None, names=["image_path"])
df_train_images["label"] = df_train_images["image_path"].apply(lambda x: 1 if "positive" in x else 0)
df_train_images["category"] = df_train_images["image_path"].apply(lambda x: x.split('/')[2])

# Verify the dataset
print(f"Total images: {len(df_train_images)}")
print(df_train_images.head(10))

# Verify file existence for the first few paths
for path in df_train_images["image_path"].head():
    if not os.path.exists(path):
        print(f"File not found: {path}")
    else:
        print(f"File exists: {path}")

# Compute category counts
categories = df_train_images["category"].unique()
category_counts = {cat: len(df_train_images[df_train_images["category"] == cat]) for cat in categories}
print("Category counts:", category_counts)

# Target count: maximum number of images in any category
target_count = max(category_counts.values())
print(f"Target count for balancing: {target_count}")

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    dummy = torch.zeros(1, device=device)  # Warm-up
print(f"Using device: {device}")

In [None]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2.0, initial_block=False):
        super(GeneratorBlock, self).__init__()
        self.initial_block = initial_block
        if initial_block:
            self.conv = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, 4, 1, 0),
                nn.BatchNorm2d(out_channels),
                nn.PReLU()
            )
        else:
            self.conv = nn.Sequential(
                nn.Upsample(scale_factor=scale_factor, mode='nearest'),
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.PReLU()
            )
        self.to_rgb = nn.Conv2d(out_channels, 1, 1)  # 1 channel for grayscale

    def forward(self, x):
        x = self.conv(x)
        rgb = self.to_rgb(x)
        return x, rgb

class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels, out_channels, initial_block=False):
        super(DiscriminatorBlock, self).__init__()
        self.initial_block = initial_block
        self.from_rgb = nn.Conv2d(1, in_channels, 1)
        if initial_block:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, 1, 0),
                nn.LeakyReLU(0.2, inplace=True)
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1),
                nn.LeakyReLU(0.2, inplace=True)
            )

    def forward(self, x, is_raw_input=False):
        if is_raw_input:
            x = self.from_rgb(x)
        x = self.conv(x)
        return x

class ProGANGenerator(nn.Module):
    def __init__(self, latent_dim, num_classes, resolutions=[4, 8, 16, 32, 64, 112, 224]):
        super(ProGANGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        scale_factors = [resolutions[i+1] / resolutions[i] for i in range(len(resolutions)-1)]
        self.blocks = nn.ModuleList([
            GeneratorBlock(2 * latent_dim, 512, initial_block=True),  # 4x4
            GeneratorBlock(512, 512, scale_factor=scale_factors[0]),  # 8x8
            GeneratorBlock(512, 256, scale_factor=scale_factors[1]),  # 16x16
            GeneratorBlock(256, 128, scale_factor=scale_factors[2]),  # 32x32
            GeneratorBlock(128, 64, scale_factor=scale_factors[3]),   # 64x64
            GeneratorBlock(64, 32, scale_factor=scale_factors[4]),    # 112x112
            GeneratorBlock(32, 16, scale_factor=scale_factors[5]),    # 224x224
        ])
        self.current_depth = 0

    def forward(self, z, labels, alpha=1.0):
        label_embed = self.label_emb(labels)
        z = torch.cat([z, label_embed], dim=1)
        z = z.view(z.size(0), z.size(1), 1, 1)

        x = z
        prev_rgb = None
        for i in range(self.current_depth + 1):
            x, rgb = self.blocks[i](x)
            if i == self.current_depth and i > 0:
                prev_rgb_up = nn.functional.interpolate(prev_rgb, size=rgb.shape[2:], mode='nearest')
                rgb = (1 - alpha) * prev_rgb_up + alpha * rgb
            prev_rgb = rgb

        return rgb

class ProGANDiscriminator(nn.Module):
    def __init__(self, num_classes, latent_dim=100, resolutions=[4, 8, 16, 32, 64, 112, 224]):
        super(ProGANDiscriminator, self).__init__()
        self.num_classes = num_classes
        self.latent_dim = latent_dim
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.blocks = nn.ModuleList([
            DiscriminatorBlock(16, 32),    # 224x224 -> 112x112
            DiscriminatorBlock(32, 64),    # 112x112 -> 56x56
            DiscriminatorBlock(64, 128),   # 56x56 -> 28x28
            DiscriminatorBlock(128, 256),  # 28x28 -> 14x14
            DiscriminatorBlock(256, 512),  # 14x14 -> 7x7
            DiscriminatorBlock(512, 512),  # 8x8 -> 4x4
            DiscriminatorBlock(512, 512, initial_block=True),  # 4x4 -> 1x1
        ])
        self.final_layer = nn.Linear(512 + latent_dim, 1)
        self.current_depth = 0
        self.resolutions = resolutions

    def forward(self, img, labels, alpha=1.0, current_res=None):
        if current_res is None:
            raise ValueError("current_res must be provided")

        x = img
        if self.current_depth > 0:
            current_block_idx = len(self.blocks) - self.current_depth - 1
            if alpha < 1.0:
                # Fade-in phase
                new_x = self.blocks[current_block_idx](x, is_raw_input=True)
                downsampled_x = F.interpolate(x, scale_factor=0.5, mode='nearest')
                old_x = self.blocks[current_block_idx + 1].from_rgb(downsampled_x)
                x = alpha * new_x + (1 - alpha) * old_x
            else:
                # Stabilization phase
                x = self.blocks[current_block_idx](x, is_raw_input=True)
        else:
            # Depth 0: Apply only the initial block
            x = self.blocks[-1](x, is_raw_input=True)

        # Apply remaining blocks
        for i in range(len(self.blocks) - self.current_depth, len(self.blocks)):
            x = self.blocks[i](x, is_raw_input=False)

        # Final processing
        x = x.view(x.size(0), -1)
        label_embed = self.label_emb(labels).view(labels.size(0), -1)
        x = torch.cat([x, label_embed], dim=1)
        x = self.final_layer(x)
        return x

class MURADataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert('L')
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
def compute_gradient_penalty(discriminator, real_samples, fake_samples, labels, device, current_res):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = discriminator(interpolates, labels, alpha=1.0, current_res=current_res)
    fake = torch.ones(real_samples.size(0), 1, device=device, requires_grad=False)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [None]:
def train_progan(generator, discriminator, dataloader, num_epochs_per_resolution, latent_dim, device, 
                 save_interval=10, save_dir="checkpoints", minority_label=1, category="unknown", start_depth=0):
    g_optimizer = optim.Adam(generator.parameters(), lr=0.002, betas=(0, 0.99))  # Reduced LR
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.002, betas=(0, 0.99))  # Reduced LR
    lambda_gp = 25
    n_critic = 5  # Number of discriminator updates per generator update
    resolutions = [4, 8, 16, 32, 64, 112, 224]

    for depth in range(start_depth, len(resolutions)):
        generator.current_depth = depth
        discriminator.current_depth = depth
        resolution = resolutions[depth]
        print(f"\nTraining at resolution {resolution}x{resolution}")

        current_transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((resolution, resolution)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        dataset = MURADataset(dataloader.dataset.image_paths, dataloader.dataset.labels, transform=current_transform)
        current_dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0)
        print(f"Dataset size: {len(dataset)}, Batches: {len(current_dataloader)}")

        # Fade-in phase
        for alpha_step, alpha in enumerate(np.linspace(0, 1, num_epochs_per_resolution // 2)):
            alpha = float(alpha)
            print(f"Fade-in phase, alpha={alpha:.4f}")
            running_d_loss = 0.0
            running_g_loss = 0.0
            for i, (real_imgs, labels) in enumerate(current_dataloader):
                real_imgs = real_imgs.to(device)
                labels = labels.to(device)
                batch_size = real_imgs.size(0)

                # Discriminator update
                for _ in range(n_critic):
                    d_optimizer.zero_grad()
                    z = torch.randn(batch_size, latent_dim, device=device)
                    fake_labels = torch.full((batch_size,), minority_label, dtype=torch.long, device=device)
                    fake_imgs = generator(z, fake_labels, alpha=alpha)

                    real_validity = discriminator(real_imgs, labels, alpha=alpha, current_res=resolution)
                    fake_validity = discriminator(fake_imgs.detach(), fake_labels, alpha=alpha, current_res=resolution)
                    gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs.detach(), 
                                                              labels, device, current_res=resolution)
                    d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
                    d_loss = torch.clamp(d_loss, -10, 10)  # Clip to prevent extreme values
                    d_loss.backward()
                    d_optimizer.step()
                    running_d_loss += d_loss.item()

                # Generator update
                g_optimizer.zero_grad()
                fake_imgs = generator(z, fake_labels, alpha=alpha)
                fake_validity = discriminator(fake_imgs, fake_labels, alpha=alpha, current_res=resolution)
                g_loss = -torch.mean(fake_validity)
                g_loss = torch.clamp(g_loss, -100, 100)  # Clip to prevent extreme values
                g_loss.backward()
                g_optimizer.step()
                running_g_loss += g_loss.item()

            avg_d_loss = running_d_loss / (len(current_dataloader) * n_critic)
            avg_g_loss = running_g_loss / len(current_dataloader)
            print(f"Alpha {alpha:.4f}, Avg D Loss: {avg_d_loss:.4f}, Avg G Loss: {avg_g_loss:.4f}")

        # Stabilization phase
        for epoch in range(num_epochs_per_resolution // 2):
            running_d_loss = 0.0
            running_g_loss = 0.0
            for i, (real_imgs, labels) in enumerate(current_dataloader):
                real_imgs = real_imgs.to(device)
                labels = labels.to(device)
                batch_size = real_imgs.size(0)

                # Discriminator update
                for _ in range(n_critic):
                    d_optimizer.zero_grad()
                    z = torch.randn(batch_size, latent_dim, device=device)
                    fake_labels = torch.full((batch_size,), minority_label, dtype=torch.long, device=device)
                    fake_imgs = generator(z, fake_labels, alpha=1.0)

                    real_validity = discriminator(real_imgs, labels, alpha=1.0, current_res=resolution)
                    fake_validity = discriminator(fake_imgs.detach(), fake_labels, alpha=1.0, current_res=resolution)
                    gradient_penalty = compute_gradient_penalty(discriminator, real_imgs, fake_imgs.detach(), 
                                                              labels, device, current_res=resolution)
                    d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
                    d_loss = torch.clamp(d_loss, -10, 10)  # Clip to prevent extreme values
                    d_loss.backward()
                    d_optimizer.step()
                    running_d_loss += d_loss.item()

                # Generator update
                g_optimizer.zero_grad()
                fake_imgs = generator(z, fake_labels, alpha=1.0)
                fake_validity = discriminator(fake_imgs, fake_labels, alpha=1.0, current_res=resolution)
                g_loss = -torch.mean(fake_validity)
                g_loss = torch.clamp(g_loss, -100, 100)  # Clip to prevent extreme values
                g_loss.backward()
                g_optimizer.step()
                running_g_loss += g_loss.item()

            avg_d_loss = running_d_loss / (len(current_dataloader) * n_critic)
            avg_g_loss = running_g_loss / len(current_dataloader)
            print(f"Depth {depth}, Epoch {epoch+1}/{num_epochs_per_resolution//2}, Avg D Loss: {avg_d_loss:.4f}, Avg G Loss: {avg_g_loss:.4f}")

            if (epoch + 1) % save_interval == 0:
                with torch.no_grad():
                    sample_z = torch.randn(5, latent_dim, device=device)
                    sample_labels = torch.full((5,), minority_label, dtype=torch.long, device=device)
                    sample_imgs = generator(sample_z, sample_labels, alpha=1.0)
                    sample_imgs = (sample_imgs * 0.5 + 0.5) * 255
                    sample_imgs = sample_imgs.cpu().numpy().astype(np.uint8)
                    for j in range(5):
                        Image.fromarray(sample_imgs[j, 0], mode='L').save(
                            f"MURA-v1.1/synthetic/samples/{category}_{'positive' if minority_label == 1 else 'negative'}_depth{depth}_epoch{epoch+1}_{j}.png"
                        )

        torch.save({
            'depth': depth,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
        }, os.path.join(save_dir, f"progan_{category}_{'positive' if minority_label == 1 else 'negative'}_depth{depth}.pt"))

In [None]:
# Hyperparameters
latent_dim = 100
num_epochs_per_resolution = 100
batch_size = 16
save_interval = 10
checkpoint_dir = "MURA-v1.1/checkpoints"
generation_batch_size = 64

# Dictionary to store synthetic image paths
synthetic_data = []

# For each category, balance to target count with equal positive/negative
for category in categories:
    print(f"\nBalancing body part {category} to target count {target_count}...")
    
    category_data = df_train_images[df_train_images["category"] == category]
    positive_count = len(category_data[category_data["label"] == 1])
    negative_count = len(category_data[category_data["label"] == 0])
    print(f"Current Positive: {positive_count}, Negative: {negative_count}")

    target_positive = target_negative = target_count // 2
    current_total = positive_count + negative_count

    if current_total >= target_count:
        if positive_count > target_positive:
            num_synthetic_positive = 0
            num_synthetic_negative = target_negative - negative_count
        elif negative_count > target_negative:
            num_synthetic_positive = target_positive - positive_count
            num_synthetic_negative = 0
        else:
            num_synthetic_positive = target_positive - positive_count
            num_synthetic_negative = target_negative - negative_count
    else:
        num_synthetic_positive = target_positive - positive_count
        num_synthetic_negative = target_negative - negative_count

    if num_synthetic_positive > 0:
        print(f"Generating {num_synthetic_positive} synthetic positive images for {category}...")
        minority_label = 1
        minority_images = category_data[category_data["label"] == minority_label]["image_path"].tolist()
        minority_labels = [minority_label] * len(minority_images)
        progan_dataset = MURADataset(minority_images, minority_labels, transform=transform)
        progan_loader = DataLoader(progan_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

        generator = ProGANGenerator(latent_dim, num_classes=2).to(device)
        discriminator = ProGANDiscriminator(num_classes=2, latent_dim=latent_dim).to(device)
        g_optimizer = optim.Adam(generator.parameters(), lr=0.001, betas=(0, 0.99))
        d_optimizer = optim.Adam(discriminator.parameters(), lr=0.001, betas=(0, 0.99))

        checkpoint_path = os.path.join(checkpoint_dir, f"progan_{category}_positive_depth{len([4, 8, 16, 32, 64, 112, 224])-1}.pt")
        start_depth = 0
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path)
            generator.load_state_dict(checkpoint['generator_state_dict'])
            discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
            g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
            d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
            start_depth = checkpoint['depth']
            print(f"Resuming training for {category} (positive) from depth {start_depth}")
        else:
            print(f"No checkpoint found for {category} (positive), starting from scratch.")

        if start_depth < len([4, 8, 16, 32, 64, 112, 224]):
            train_progan(generator, discriminator, progan_loader, num_epochs_per_resolution, latent_dim, device,
                         save_interval=save_interval, save_dir=checkpoint_dir, 
                         minority_label=minority_label, category=category, start_depth=start_depth)

        synthetic_images = []
        synthetic_labels = [minority_label] * num_synthetic_positive
        for i in range(0, num_synthetic_positive, generation_batch_size):
            batch_size = min(generation_batch_size, num_synthetic_positive - i)
            batch_labels = torch.tensor(synthetic_labels[i:i + batch_size], dtype=torch.long, device=device)
            with torch.no_grad():
                z = torch.randn(batch_size, latent_dim, device=device)
                batch_images = generator(z, batch_labels, alpha=1.0)
                synthetic_images.append(batch_images.cpu())
            torch.cuda.empty_cache()
            gc.collect()

        synthetic_images = torch.cat(synthetic_images, dim=0)
        synthetic_dir = f"MURA-v1.1/synthetic/{category}/positive"
        os.makedirs(synthetic_dir, exist_ok=True)
        for i, img in enumerate(synthetic_images):
            img = img.squeeze().numpy()
            img = (img * 0.5) + 0.5
            img = (img * 255).astype(np.uint8)
            img_path = os.path.join(synthetic_dir, f"synthetic_{i}.png")
            Image.fromarray(img, mode='L').save(img_path)
            synthetic_data.append([img_path, minority_label, category])

        del generator, discriminator, progan_dataset, progan_loader, synthetic_images
        torch.cuda.empty_cache()
        gc.collect()

    if num_synthetic_negative > 0:
        print(f"Generating {num_synthetic_negative} synthetic negative images for {category}...")
        minority_label = 0
        minority_images = category_data[category_data["label"] == minority_label]["image_path"].tolist()
        minority_labels = [minority_label] * len(minority_images)
        progan_dataset = MURADataset(minority_images, minority_labels, transform=transform)
        progan_loader = DataLoader(progan_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

        generator = ProGANGenerator(latent_dim, num_classes=2).to(device)
        discriminator = ProGANDiscriminator(num_classes=2, latent_dim=latent_dim).to(device)
        g_optimizer = optim.Adam(generator.parameters(), lr=0.001, betas=(0, 0.99))
        d_optimizer = optim.Adam(discriminator.parameters(), lr=0.001, betas=(0, 0.99))

        checkpoint_path = os.path.join(checkpoint_dir, f"progan_{category}_negative_depth{len([4, 8, 16, 32, 64, 112, 224])-1}.pt")
        start_depth = 0
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path)
            generator.load_state_dict(checkpoint['generator_state_dict'])
            discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
            g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
            d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
            start_depth = checkpoint['depth']
            print(f"Resuming training for {category} (negative) from depth {start_depth}")
        else:
            print(f"No checkpoint found for {category} (negative), starting from scratch.")

        if start_depth < len([4, 8, 16, 32, 64, 112, 224]):
            train_progan(generator, discriminator, progan_loader, num_epochs_per_resolution, latent_dim, device,
                         save_interval=save_interval, save_dir=checkpoint_dir, 
                         minority_label=minority_label, category=category, start_depth=start_depth)

        synthetic_images = []
        synthetic_labels = [minority_label] * num_synthetic_negative
        for i in range(0, num_synthetic_negative, generation_batch_size):
            batch_size = min(generation_batch_size, num_synthetic_negative - i)
            batch_labels = torch.tensor(synthetic_labels[i:i + batch_size], dtype=torch.long, device=device)
            with torch.no_grad():
                z = torch.randn(batch_size, latent_dim, device=device)
                batch_images = generator(z, batch_labels, alpha=1.0)
                synthetic_images.append(batch_images.cpu())
            torch.cuda.empty_cache()
            gc.collect()

        synthetic_images = torch.cat(synthetic_images, dim=0)
        synthetic_dir = f"MURA-v1.1/synthetic/{category}/negative"
        os.makedirs(synthetic_dir, exist_ok=True)
        for i, img in enumerate(synthetic_images):
            img = img.squeeze().numpy()
            img = (img * 0.5) + 0.5
            img = (img * 255).astype(np.uint8)
            img_path = os.path.join(synthetic_dir, f"synthetic_{i}.png")
            Image.fromarray(img, mode='L').save(img_path)
            synthetic_data.append([img_path, minority_label, category])

        del generator, discriminator, progan_dataset, progan_loader, synthetic_images
        torch.cuda.empty_cache()
        gc.collect()

print("Positive/Negative balancing complete!")