In [None]:
# Import Required Libraries
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from collections import Counter
from tqdm import tqdm
from glob import glob

# Check CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Define constants
IMG_SIZE = 512
LATENT_DIM = 512
NUM_CLASSES = 14
BATCH_SIZE = 32
NUM_EPOCHS = 200
DATASET_DIR = 'MURA-v1.1/train'
CHECKPOINT_DIR = 'checkpoints-BAGAN'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Define class names
body_parts = ['XR_WRIST', 'XR_SHOULDER', 'XR_HAND', 'XR_FOREARM', 'XR_FINGER', 'XR_ELBOW', 'XR_HUMERUS']
case_types = ['positive', 'negative']
class_names = [f"{bp}_{ct}" for bp in body_parts for ct in case_types]

# Dataset Class
class MURADataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        try:
            img_path = self.image_paths[idx]
            full_path = os.path.join(DATASET_DIR, '..', img_path)
            image = Image.open(full_path).convert('L')  # Grayscale for X-rays
            
            if self.transform:
                image = self.transform(image)
            
            label = self.labels[idx]
            label_onehot = torch.zeros(NUM_CLASSES)
            label_onehot[label] = 1.0
            
            return image, label_onehot
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            placeholder = torch.zeros((1, IMG_SIZE, IMG_SIZE))
            return placeholder, torch.zeros(NUM_CLASSES)

# Helper Functions for Dataset Processing
def get_class_label(image_path):
    for bp in body_parts:
        if bp in image_path:
            case = 'positive' if 'positive' in image_path else 'negative'
            class_idx = body_parts.index(bp) * 2 + (0 if case == 'positive' else 1)
            return class_idx
    return None

def scan_dataset(dataset_dir=DATASET_DIR):
    print(f"Scanning MURA dataset directory: {dataset_dir}...")
    image_paths = []
    labels = []

    for root, dirs, files in os.walk(dataset_dir):
        for file in files:
            if file.endswith('.png'):
                full_path = os.path.join(root, file)
                relative_path = os.path.relpath(full_path, os.path.join(dataset_dir, '..'))
                image_paths.append(relative_path)
                labels.append(1 if 'positive' in relative_path else 0)

    image_paths = np.array(image_paths)
    labels = np.array(labels)
    print(f"Found {len(image_paths)} images with labels")

    print("Classifying images...")
    class_labels = []
    valid_paths = []

    for path in tqdm(image_paths):
        label = get_class_label(path)
        if label is not None:
            class_labels.append(label)
            valid_paths.append(path)

    valid_paths = np.array(valid_paths)
    class_labels = np.array(class_labels)
    print(f"Successfully classified {len(class_labels)} images into classes")

    class_counts = Counter(class_labels)
    print("Class distribution:")
    for cls in sorted(class_counts.keys()):
        print(f"{class_names[cls]}: {class_counts[cls]} images")
    
    return valid_paths, class_labels, class_counts

# Autoencoder Architecture (same as BAGAN-GP + WGAN-GP)
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),  # 512 -> 256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 256 -> 128
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # 128 -> 64
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),  # 64 -> 32
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 4, stride=2, padding=1),  # 32 -> 16
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten()
        )
        self.fc = nn.Linear(1024 * 16 * 16, LATENT_DIM)
    
    def forward(self, x):
        x = self.model(x)
        latent = self.fc(x)
        return latent

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.initial_size = 16
        self.fc = nn.Linear(LATENT_DIM + NUM_CLASSES, 1024 * self.initial_size * self.initial_size)
        self.bn_initial = nn.BatchNorm2d(1024)
        self.deconv = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1),  # 16 -> 32
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=False),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),  # 32 -> 64
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=False),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),  # 64 -> 128
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=False),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),  # 128 -> 256
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=False),
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1),  # 256 -> 512
            nn.Tanh()
        )
    
    def forward(self, z, labels):
        x = torch.cat([z, labels], dim=1)
        x = self.fc(x)
        x = x.view(-1, 1024, self.initial_size, self.initial_size)
        x = self.bn_initial(x)
        x = nn.functional.relu(x, inplace=True)
        x = self.deconv(x)
        return x

class Autoencoder(nn.Module):
    def __init__(self, num_classes):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.num_classes = num_classes
    
    def forward(self, x, labels):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent, labels)
        return reconstructed

# Discriminator Network (modified for BGAN, no spectral normalization)
class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1),  # 512 -> 256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 256 -> 128
            nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 128 -> 64
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # 64 -> 32
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),  # 32 -> 16
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.flatten = nn.Flatten()
        self.embed_dim = 512
        self.class_embedding = nn.Linear(num_classes, self.embed_dim)
        ds_size = 16
        self.feature_size = 512 * ds_size * ds_size
        self.adv_layer = nn.Linear(self.feature_size + self.embed_dim, 1)
        self.aux_layer = nn.Linear(self.feature_size + self.embed_dim, num_classes)
    
    def forward(self, img, labels):
        features = self.main(img)
        features = self.flatten(features)
        label_embedding = self.class_embedding(labels)
        combined = torch.cat([features, label_embedding], dim=1)
        validity = self.adv_layer(combined)
        label_logits = self.aux_layer(combined)
        return validity, label_logits

# Checkpoint Functions
def save_checkpoint(epoch, generator, discriminator, gen_optimizer, disc_optimizer, d_losses, g_losses):
    checkpoint = {
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'gen_optimizer_state_dict': gen_optimizer.state_dict(),
        'disc_optimizer_state_dict': disc_optimizer.state_dict(),
        'd_losses': d_losses,
        'g_losses': g_losses
    }
    checkpoint_path = os.path.join(CHECKPOINT_DIR, f'checkpoint_epoch_{epoch}.pth')
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

def load_checkpoint(checkpoint_path, generator, discriminator, gen_optimizer, disc_optimizer):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer_state_dict'])
        disc_optimizer.load_state_dict(checkpoint['disc_optimizer_state_dict'])
        epoch = checkpoint['epoch']
        d_losses = checkpoint['d_losses']
        g_losses = checkpoint['g_losses']
        print(f"Loaded checkpoint from {checkpoint_path}, resuming from epoch {epoch + 1}")
        return epoch, d_losses, g_losses
    else:
        print(f"No checkpoint found at {checkpoint_path}")
        return 0, [], []

# BAGAN with BGAN Training Function
def train_bagan_bgan(dataloader, generator, discriminator, num_epochs, device, checkpoint_dir, resume_epoch=12):
    # Optimizers
    disc_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    gen_optimizer = optim.Adam(generator.parameters(), lr=0.0006, betas=(0.5, 0.999))
    
    # Loss functions
    bce_loss = nn.BCEWithLogitsLoss()
    classification_loss = nn.CrossEntropyLoss()
    
    # Loss tracking
    d_losses = []
    g_losses = []
    
    # Resume training if specified
    start_epoch = 0
    if resume_epoch is not None:
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{resume_epoch}.pth')
        start_epoch, d_losses, g_losses = load_checkpoint(
            checkpoint_path, generator, discriminator, gen_optimizer, disc_optimizer
        )
    
    for epoch in range(start_epoch, num_epochs):
        generator.train()
        discriminator.train()
        total_d_loss = 0
        total_g_loss = 0
        
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            real_images, labels_onehot = batch
            real_images = real_images.to(device)
            labels_onehot = labels_onehot.to(device)
            labels = torch.argmax(labels_onehot, dim=1)
            batch_size = real_images.size(0)
            
            # Labels for BCE loss
            real_label = torch.full((batch_size, 1), 0.9, device=device)  # Label smoothing
            fake_label = torch.full((batch_size, 1), 0.1, device=device)
            
            # Train Discriminator
            disc_optimizer.zero_grad()
            
            # Real images
            real_validity, real_label_logits = discriminator(real_images, labels_onehot)
            d_loss_real = bce_loss(real_validity, real_label)
            class_loss_real = classification_loss(real_label_logits, labels)
            
            # Fake images
            noise = torch.randn(batch_size, LATENT_DIM, device=device)
            fake_labels_idx = torch.randint(0, NUM_CLASSES, (batch_size,), device=device)
            fake_labels_onehot = torch.zeros(batch_size, NUM_CLASSES, device=device)
            fake_labels_onehot.scatter_(1, fake_labels_idx.unsqueeze(1), 1)
            
            fake_images = generator(noise, fake_labels_onehot)
            fake_validity, _ = discriminator(fake_images.detach(), fake_labels_onehot)
            d_loss_fake = bce_loss(fake_validity, fake_label)
            
            # Total discriminator loss
            d_loss = d_loss_real + d_loss_fake + class_loss_real
            d_loss.backward()
            disc_optimizer.step()
            
            total_d_loss += d_loss.item()
            
            # Train Generator
            gen_optimizer.zero_grad()
            
            # Generate fake images
            noise = torch.randn(batch_size, LATENT_DIM, device=device)
            gen_labels_idx = torch.randint(0, NUM_CLASSES, (batch_size,), device=device)
            gen_labels_onehot = torch.zeros(batch_size, NUM_CLASSES, device=device)
            gen_labels_onehot.scatter_(1, gen_labels_idx.unsqueeze(1), 1)
            
            fake_images = generator(noise, gen_labels_onehot)
            fake_validity, fake_label_logits = discriminator(fake_images, gen_labels_onehot)
            
            # Boundary-seeking loss
            prob = torch.sigmoid(fake_validity)
            g_loss_adv = 0.5 * torch.mean(torch.log(prob / (1 - prob + 1e-8))**2)
            class_loss_fake = classification_loss(fake_label_logits, gen_labels_idx)    
            
            # Total generator loss
            g_loss = g_loss_adv + 0.5 * class_loss_fake
            g_loss.backward()
            gen_optimizer.step()
            
            total_g_loss += g_loss.item()
        
        # End of epoch
        avg_d_loss = total_d_loss / len(dataloader)
        avg_g_loss = total_g_loss / len(dataloader)
        d_losses.append(avg_d_loss)
        g_losses.append(avg_g_loss)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], D Loss: {avg_d_loss:.4f}, G Loss: {avg_g_loss:.4f}")
        
        # Save checkpoint and generate samples
        if (epoch + 1) % 1 == 0:
            save_checkpoint(epoch + 1, generator, discriminator, gen_optimizer, disc_optimizer, d_losses, g_losses)
            
            # Generate and save sample images
            with torch.no_grad():
                noise = torch.randn(NUM_CLASSES, LATENT_DIM, device=device)
                sample_labels_onehot = torch.eye(NUM_CLASSES, device=device)
                fake_images = generator(noise, sample_labels_onehot)
                fake_images = (fake_images + 1) / 2  # Scale from [-1,1] to [0,1]
                
                # Create grid of sample images
                fig, axes = plt.subplots(2, 7, figsize=(14, 4))
                axes = axes.flatten()
                
                for i in range(NUM_CLASSES):
                    img = fake_images[i].cpu().numpy().squeeze()
                    axes[i].imshow(img, cmap='gray')
                    axes[i].set_title(class_names[i])
                    axes[i].axis('off')
                
                plt.tight_layout()
                plt.savefig(os.path.join(checkpoint_dir, f'samples_epoch_{epoch+1}.png'))
                plt.close()
                
                # Save individual images
                for i in range(NUM_CLASSES):
                    img = fake_images[i].cpu().numpy().squeeze()
                    plt.imsave(os.path.join(checkpoint_dir, f'epoch_{epoch+1}_class_{i}.png'), img, cmap='gray')

# Function to Generate Balanced Dataset
def generate_balanced_dataset(generator, class_counts, class_names, num_classes, latent_dim, device, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    
    # Calculate target counts for balanced dataset
    total_images = sum(class_counts.values())
    target_per_body_part = total_images / 7  # Equal distribution across 7 body parts
    target_per_class = target_per_body_part / 2  # 50/50 split between positive/negative
    
    print(f"Total images: {total_images}")
    print(f"Target per body part: {target_per_body_part:.0f}")
    print(f"Target per class: {target_per_class:.0f}")
    
    with torch.no_grad():
        for body_part_idx in range(0, num_classes, 2):
            body_part_name = class_names[body_part_idx].split('_')[0]
            positive_class = body_part_idx
            negative_class = body_part_idx + 1
            
            current_positive = class_counts.get(positive_class, 0)
            current_negative = class_counts.get(negative_class, 0)
            
            images_needed_positive = max(0, int(target_per_class) - current_positive)
            images_needed_negative = max(0, int(target_per_class) - current_negative)
            
            print(f"Balancing {body_part_name}:")
            print(f"  Positive: {current_positive} -> need {images_needed_positive} more")  
            print(f"  Negative: {current_negative} -> need {images_needed_negative} more")
            
            # Create directories
            pos_dir = os.path.join(output_dir, class_names[positive_class])
            neg_dir = os.path.join(output_dir, class_names[negative_class])
            os.makedirs(pos_dir, exist_ok=True)
            os.makedirs(neg_dir, exist_ok=True)
            
            # Generate positive class images
            if images_needed_positive > 0:
                batch_size = 16
                num_batches = (images_needed_positive + batch_size - 1) // batch_size
                
                for batch in tqdm(range(num_batches), desc=f"Generating {body_part_name} positive"):
                    current_batch_size = min(batch_size, images_needed_positive - batch * batch_size)
                    if current_batch_size <= 0:
                        break
                        
                    noise = torch.randn(current_batch_size, latent_dim, device=device)
                    labels_onehot = torch.zeros(current_batch_size, num_classes, device=device)
                    labels_onehot[:, positive_class] = 1.0
                    
                    fake_images = generator(noise, labels_onehot)
                    fake_images = (fake_images + 1) / 2  # Scale from [-1,1] to [0,1]
                    
                    for i in range(current_batch_size):
                        img = fake_images[i].cpu().numpy().squeeze()
                        img = (img * 255).astype(np.uint8)
                        img_pil = Image.fromarray(img, mode='L')
                        img_pil.save(os.path.join(pos_dir, f"synthetic_{batch * batch_size + i}.png"))
            
            # Generate negative class images
            if images_needed_negative > 0:
                batch_size = 16
                num_batches = (images_needed_negative + batch_size - 1) // batch_size
                
                for batch in tqdm(range(num_batches), desc=f"Generating {body_part_name} negative"):
                    current_batch_size = min(batch_size, images_needed_negative - batch * batch_size)
                    if current_batch_size <= 0:
                        break
                        
                    noise = torch.randn(current_batch_size, latent_dim, device=device)
                    labels_onehot = torch.zeros(current_batch_size, num_classes, device=device)
                    labels_onehot[:, negative_class] = 1.0
                    
                    fake_images = generator(noise, labels_onehot)
                    fake_images = (fake_images + 1) / 2  # Scale from [-1,1] to [0,1]
                    
                    for i in range(current_batch_size):
                        img = fake_images[i].cpu().numpy().squeeze()
                        img = (img * 255).astype(np.uint8)
                        img_pil = Image.fromarray(img, mode='L')
                        img_pil.save(os.path.join(neg_dir, f"synthetic_{batch * batch_size + i}.png"))

# Load and prepare training data
image_paths, labels, class_counts = scan_dataset(DATASET_DIR)
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1] for tanh activation
])
dataset = MURADataset(image_paths, labels, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

# Load and prepare validation data
val_image_paths, val_labels, val_class_counts = scan_dataset('MURA-v1.1/valid')
val_dataset = MURADataset(val_image_paths, val_labels, transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Load pretrained autoencoder
autoencoder = Autoencoder(NUM_CLASSES).to(device)
checkpoint_path = 'checkpoints-BAGAN-GP-WGAN-GP_Old-2/best_autoencoder.pth'  # Use best checkpoint
if os.path.exists(checkpoint_path):
    autoencoder.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print(f"Loaded autoencoder checkpoint from {checkpoint_path}")
else:
    print(f"Checkpoint not found at {checkpoint_path}. Please train the autoencoder first.")
    exit()

# Extract the pretrained decoder (generator) and initialize discriminator
generator = autoencoder.decoder.to(device)
discriminator = Discriminator(NUM_CLASSES).to(device)

# Train BAGAN with BGAN
train_bagan_bgan(dataloader, generator, discriminator, NUM_EPOCHS, device, CHECKPOINT_DIR, resume_epoch=41)

# Generate balanced dataset
output_dir = 'synthetic_mura'
generate_balanced_dataset(generator, class_counts, class_names, NUM_CLASSES, LATENT_DIM, device, output_dir)

Using device: cuda
Scanning MURA dataset directory: MURA-v1.1/train...
Found 36812 images with labels
Classifying images...


100%|███████████████████████████████████████████████████████████████████████| 36812/36812 [00:00<00:00, 1166547.43it/s]

Successfully classified 36812 images into classes
Class distribution:
XR_WRIST_positive: 3987 images
XR_WRIST_negative: 5765 images
XR_SHOULDER_positive: 4169 images
XR_SHOULDER_negative: 4211 images
XR_HAND_positive: 1484 images
XR_HAND_negative: 4059 images
XR_FOREARM_positive: 661 images
XR_FOREARM_negative: 1164 images
XR_FINGER_positive: 1970 images
XR_FINGER_negative: 3138 images
XR_ELBOW_positive: 2007 images
XR_ELBOW_negative: 2925 images
XR_HUMERUS_positive: 599 images
XR_HUMERUS_negative: 673 images
Scanning MURA dataset directory: MURA-v1.1/valid...





Found 3197 images with labels
Classifying images...


100%|█████████████████████████████████████████████████████████████████████████| 3197/3197 [00:00<00:00, 1305157.67it/s]

Successfully classified 3197 images into classes
Class distribution:
XR_WRIST_positive: 295 images
XR_WRIST_negative: 364 images
XR_SHOULDER_positive: 278 images
XR_SHOULDER_negative: 285 images
XR_HAND_positive: 189 images
XR_HAND_negative: 271 images
XR_FOREARM_positive: 151 images
XR_FOREARM_negative: 150 images
XR_FINGER_positive: 247 images
XR_FINGER_negative: 214 images
XR_ELBOW_positive: 230 images
XR_ELBOW_negative: 235 images
XR_HUMERUS_positive: 140 images
XR_HUMERUS_negative: 148 images



  autoencoder.load_state_dict(torch.load(checkpoint_path, map_location=device))


Loaded autoencoder checkpoint from checkpoints-BAGAN-GP-WGAN-GP_Old-2/best_autoencoder.pth


  checkpoint = torch.load(checkpoint_path)


Loaded checkpoint from checkpoints-BAGAN\checkpoint_epoch_41.pth, resuming from epoch 42


Epoch 42/200: 100%|████████████████████████████████████████████████████████████████| 1151/1151 [21:08<00:00,  1.10s/it]


Epoch [42/200], D Loss: 0.6802, G Loss: 2.5148
Checkpoint saved at checkpoints-BAGAN\checkpoint_epoch_42.pth


Epoch 43/200:   0%|▏                                                                  | 4/1151 [00:02<13:42,  1.40it/s]