In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import shutil

class HAM10000Dataset(Dataset):
    def __init__(self, csv_file, img_dirs, transform=None, device='cuda'):
        self.data = pd.read_csv(csv_file)
        self.img_dirs = img_dirs
        self.transform = transform
        self.device = device
        
        # Encode labels
        self.label_encoder = LabelEncoder()
        self.data['encoded_label'] = self.label_encoder.fit_transform(self.data['dx'])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['image_id'] + '.jpg'
        for img_dir in self.img_dirs:
            img_path = os.path.join(img_dir, img_name)
            if os.path.exists(img_path):
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                label = self.data.iloc[idx]['encoded_label']
                return image, label
        raise FileNotFoundError(f"Image {img_name} not found in directories {self.img_dirs}")

class EnhancedSLEBlock(nn.Module):
    def __init__(self, in_channels):
        super(EnhancedSLEBlock, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Content branch - adjusted channel dimensions
        self.content_fc1 = nn.Conv2d(in_channels, in_channels, 1)  # Changed from in_channels//2
        self.content_fc2 = nn.Conv2d(in_channels, in_channels, 1)  # Input/output channels match
        
        # Style branch
        self.style_modulation = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 1),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(True)
        )
        
        self.gamma = nn.Parameter(torch.zeros(1))
        self.beta = nn.Parameter(torch.zeros(1))

    def forward(self, x, skip_x):
        # Content pathway
        content = self.global_pool(x)
        content = F.relu(self.content_fc1(content))
        content = self.content_fc2(content)
        content = torch.sigmoid(content)
        
        # Style pathway
        style = self.style_modulation(skip_x)
        
        # Combine content and style
        output = skip_x * content  # Content modulation
        output = output + self.gamma * style + self.beta  # Style modulation
        return output

class EnhancedFASTGANGenerator(nn.Module):
    def __init__(self, latent_dim=256, ngf=64, output_size=64):
        super(EnhancedFASTGANGenerator, self).__init__()
        self.output_size = output_size
        
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, ngf * 16, 4, 1, 0),
            nn.BatchNorm2d(ngf * 16),
            nn.ReLU(True)
        )
        
        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 16, ngf * 8, 4, 2, 1),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True)
        )
        
        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True)
        )
        
        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True)
        )
        
        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 2, 3, 4, 2, 1),
            nn.Tanh()
        )
        
        self.sle1 = EnhancedSLEBlock(ngf * 8)
        self.sle2 = EnhancedSLEBlock(ngf * 4)

    def forward(self, z):
        x0 = self.initial(z)
        x1 = self.layer1(x0)
        x1_sle = self.sle1(x0, x1)
        x2 = self.layer2(x1_sle)
        x2_sle = self.sle2(x1_sle, x2)
        x3 = self.layer3(x2_sle)
        x4 = self.layer4(x3)
        return x4

class EnhancedFASTGANDiscriminator(nn.Module):
    def __init__(self, ndf=64, input_size=64):
        super(EnhancedFASTGANDiscriminator, self).__init__()
        self.input_size = input_size
        
        # Shared feature extractor
        self.features = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2)
        )
        
        # Discriminator head
        self.discriminator = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ndf * 8, 1, 1),
            nn.Flatten(),
            nn.Sigmoid()
        )
        
        # Decoder for self-supervision
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(ndf * 8, ndf * 4, 4, 2, 1),
            nn.BatchNorm2d(ndf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ndf * 4, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ndf * 2, ndf, 4, 2, 1),
            nn.BatchNorm2d(ndf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ndf, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        features = self.features(x)
        validity = self.discriminator(features)
        reconstruction = self.decoder(features)
        return validity, reconstruction

class SyntheticImageClassifier:
    def __init__(self, num_classes, device='cuda'):
        self.device = device
        
        # EfficientNetV2
        self.efficientnet = models.efficientnet_v2_s(pretrained=True)
        self.efficientnet.classifier[1] = nn.Linear(self.efficientnet.classifier[1].in_features, num_classes)
        self.efficientnet = self.efficientnet.to(device)
        
        # ShuffleNetV2
        self.shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
        self.shufflenet.fc = nn.Linear(self.shufflenet.fc.in_features, num_classes)
        self.shufflenet = self.shufflenet.to(device)
        
        # Transformation for input images
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def classify_synthetic_images(self, synthetic_images):
        resized_images = F.interpolate(synthetic_images, size=(224, 224), mode='bilinear', align_corners=False)
        normalized_images = (resized_images - resized_images.min()) / (resized_images.max() - resized_images.min())
        normalized_images = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(normalized_images)
        
        with torch.no_grad():
            efficientnet_preds = self.efficientnet(normalized_images)
            shufflenet_preds = self.shufflenet(normalized_images)
        
        efficientnet_classes = torch.argmax(efficientnet_preds, dim=1)
        shufflenet_classes = torch.argmax(shufflenet_preds, dim=1)
        
        agreed_classification_mask = (efficientnet_classes == shufflenet_classes)
        
        return agreed_classification_mask

def enhanced_train_step(real_imgs, generator, discriminator, g_optimizer, d_optimizer, 
                       device, lambda_rec=10.0):
    batch_size = real_imgs.size(0)
    
    # Train Discriminator
    d_optimizer.zero_grad()
    
    # Real images
    real_validity, real_reconstruction = discriminator(real_imgs)
    
    # Generate fake images
    z = torch.randn(batch_size, 256, 1, 1, device=device)
    fake_imgs = generator(z)
    fake_validity, _ = discriminator(fake_imgs.detach())
    
    # Hinge loss
    d_loss_real = torch.mean(F.relu(1.0 - real_validity))
    d_loss_fake = torch.mean(F.relu(1.0 + fake_validity))
    d_loss_adv = d_loss_real + d_loss_fake
    
    # Reconstruction loss for self-supervision
    d_loss_rec = F.mse_loss(real_reconstruction, real_imgs)
    
    # Total discriminator loss
    d_loss = d_loss_adv + lambda_rec * d_loss_rec
    
    d_loss.backward()
    d_optimizer.step()
    
    # Train Generator
    g_optimizer.zero_grad()
    
    fake_validity, _ = discriminator(fake_imgs)
    g_loss = -torch.mean(fake_validity)  # Hinge loss for generator
    
    g_loss.backward()
    g_optimizer.step()
    
    return {
        'd_loss': d_loss.item(),
        'd_loss_adv': d_loss_adv.item(),
        'd_loss_rec': d_loss_rec.item(),
        'g_loss': g_loss.item()
    }, fake_imgs

def plot_data_distribution_comparison(original_csv, synthetic_images_dir):
    metadata = pd.read_csv(original_csv)
    original_class_counts = metadata['dx'].value_counts()
    synthetic_class_counts = {}
    label_encoder = LabelEncoder()
    label_encoder.fit(metadata['dx'])
    
    for class_name in label_encoder.classes_:
        class_dir = os.path.join(synthetic_images_dir, class_name)
        if os.path.exists(class_dir):
            synthetic_class_counts[class_name] = len([f for f in os.listdir(class_dir) 
                                                    if f.endswith(('.png', '.jpg'))])
        else:
            synthetic_class_counts[class_name] = 0
    
    synthetic_class_counts = pd.Series(synthetic_class_counts)
    
    plt.figure(figsize=(15, 6))
    x = np.arange(len(original_class_counts))
    width = 0.4
    
    plt.bar(x - width/2, original_class_counts.values, width, label='Original Dataset', color='blue', alpha=0.7)
    plt.bar(x + width/2, synthetic_class_counts.values, width, label='Synthetic Images', color='orange', alpha=0.7)
    
    plt.title('Comparison of Original HAM10000 Dataset and Synthetic Images', fontsize=16)
    plt.xlabel('Skin Lesion Type', fontsize=14)
    plt.ylabel('Number of Samples', fontsize=14)
    plt.xticks(x, original_class_counts.index, rotation=90, ha='right')
    plt.legend()
    
    for i, (orig, synth) in enumerate(zip(original_class_counts.values, synthetic_class_counts.values)):
        plt.text(i - width/2, orig + 50, str(int(orig)), ha='center', va='bottom', fontsize=8)
        plt.text(i + width/2, synth + 50, str(int(synth)), ha='center', va='bottom', fontsize=8)
    
    plt.tight_layout()
    plt.savefig('dataset_distribution_comparison.png')
    plt.close()

def copy_original_images_by_class(csv_file, img_dirs, output_base_dir='synthetic_images'):
    metadata = pd.read_csv(csv_file)
    os.makedirs(output_base_dir, exist_ok=True)
    copied_images = set()
    
    for class_name in metadata['dx'].unique():
        class_output_dir = os.path.join(output_base_dir, class_name)
        os.makedirs(class_output_dir, exist_ok=True)
        
        class_metadata = metadata[metadata['dx'] == class_name]
        
        for _, row in class_metadata.iterrows():
            img_filename = row['image_id'] + '.jpg'
            
            for img_dir in img_dirs:
                img_path = os.path.join(img_dir, img_filename)
                
                if os.path.exists(img_path):
                    dest_path = os.path.join(class_output_dir, img_filename)
                    
                    if img_path not in copied_images:
                        shutil.copy2(img_path, dest_path)
                        copied_images.add(img_path)
                    break
    
    print(f"Original images copied to {output_base_dir}")
    print(f"Total unique images copied: {len(copied_images)}")

def train_enhanced_fastgan(generator, discriminator, dataloader, num_epochs, device='cuda', 
                          lambda_rec=10.0, save_interval=100):
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    os.makedirs('training_progress', exist_ok=True)
    
    for epoch in range(num_epochs):
        for i, (real_imgs, _) in enumerate(dataloader):
            real_imgs = real_imgs.to(device)
            
            losses, fake_imgs = enhanced_train_step(
                real_imgs, generator, discriminator,
                g_optimizer, d_optimizer, device, lambda_rec
            )
            
            if i % 100 == 0:
                print(f'Epoch [{epoch}/{num_epochs}], Batch [{i}], '
                      f'D_loss: {losses["d_loss"]:.4f}, '
                      f'D_adv: {losses["d_loss_adv"]:.4f}, '
                      f'D_rec: {losses["d_loss_rec"]:.4f}, '
                      f'G_loss: {losses["g_loss"]:.4f}')
                
                # Save sample images
                if i % save_interval == 0:
                    save_image(fake_imgs[:16] * 0.5 + 0.5,
                             f'training_progress/epoch_{epoch}_batch_{i}.png',
                             nrow=4, normalize=False)
    
    return generator, discriminator

def generate_synthetic_images(generator, classifier, num_classes, num_images_per_class,
                            device='cuda', batch_size=64, output_dir='synthetic_images'):
    os.makedirs(output_dir, exist_ok=True)
    generator.eval()
    
    with torch.no_grad():
        for class_idx in range(num_classes):
            class_dir = os.path.join(output_dir, f'class_{class_idx}')
            os.makedirs(class_dir, exist_ok=True)
            
            num_generated = 0
            while num_generated < num_images_per_class:
                # Generate images
                z = torch.randn(batch_size, 256, 1, 1, device=device)
                fake_imgs = generator(z)
                
                # Filter images using classifier
                valid_mask = classifier.classify_synthetic_images(fake_imgs)
                valid_images = fake_imgs[valid_mask]
                
                # Save valid images
                for idx, img in enumerate(valid_images):
                    if num_generated >= num_images_per_class:
                        break
                    save_image(img * 0.5 + 0.5,
                             os.path.join(class_dir, f'synthetic_{num_generated}.png'))
                    num_generated += 1
                
                print(f'Class {class_idx}: Generated {num_generated}/{num_images_per_class} images')

def main():
    # Set device and random seeds
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Dataset parameters
    csv_file = '/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_metadata.csv'
    img_dirs = ['/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_1', '/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_2']
    
    # Data preprocessing
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Initialize dataset and dataloader
    dataset = HAM10000Dataset(csv_file, img_dirs, transform=transform, device=device)
    num_classes = len(dataset.label_encoder.classes_)
    print(f"Number of classes: {num_classes}")
    print("Class labels:", dataset.label_encoder.classes_)
    
    batch_size = 64
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    # Initialize models
    generator = EnhancedFASTGANGenerator(latent_dim=256, output_size=64).to(device)
    discriminator = EnhancedFASTGANDiscriminator(input_size=64).to(device)
    classifier = SyntheticImageClassifier(num_classes=num_classes, device=device)
    
    # Training parameters
    num_epochs = 100
    print("Starting training...")
    
    # Train the model
    generator, discriminator = train_enhanced_fastgan(
        generator, discriminator, dataloader, 
        num_epochs=num_epochs, device=device
    )
    
    # Generate synthetic images for each class
    print("Generating synthetic images...")
    generate_synthetic_images(
        generator, classifier, 
        num_classes=num_classes,
        num_images_per_class=1000,  # Adjust as needed
        device=device,
        output_dir='synthetic_images'
    )
    
    # Plot distribution comparison
    plot_data_distribution_comparison(csv_file, 'synthetic_images')
    
    print("Training and generation complete!")

if __name__ == "__main__":
    main()

Using device: cuda
Number of classes: 7
Class labels: ['akiec' 'bcc' 'bkl' 'df' 'mel' 'nv' 'vasc']


Downloading: "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_v2_s-dd5fe13b.pth
100%|██████████| 82.7M/82.7M [00:00<00:00, 94.6MB/s]
Downloading: "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth" to /root/.cache/torch/hub/checkpoints/shufflenetv2_x1-5666bf0f80.pth
100%|██████████| 8.79M/8.79M [00:00<00:00, 61.5MB/s]

Starting training...





RuntimeError: Given groups=1, weight of size [512, 512, 1, 1], expected input[64, 1024, 1, 1] to have 512 channels, but got 1024 channels instead