In [3]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.preprocessing import LabelEncoder

class HAM10000Dataset(Dataset):
    def __init__(self, csv_file, img_dirs, transform=None, device='cuda'):
        self.data = pd.read_csv(csv_file)
        self.img_dirs = img_dirs
        self.transform = transform
        self.device = device
        
        # Encode labels
        self.label_encoder = LabelEncoder()
        self.data['encoded_label'] = self.label_encoder.fit_transform(self.data['dx'])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_name = self.data.iloc[idx]['image_id'] + '.jpg'
        for img_dir in self.img_dirs:
            img_path = os.path.join(img_dir, img_name)
            if os.path.exists(img_path):
                image = Image.open(img_path).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                label = self.data.iloc[idx]['encoded_label']
                return image, label
        raise FileNotFoundError(f"Image {img_name} not found in directories {self.img_dirs}")

class SLEBlock(nn.Module):
    def __init__(self, in_channels):
        super(SLEBlock, self).__init__()
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(in_channels, in_channels // 2, 1)
        self.fc2 = nn.Conv2d(in_channels // 2, in_channels, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        x = self.global_pool(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return y * x

class FASTGANGenerator(nn.Module):
    def __init__(self, latent_dim=256, ngf=64):
        super(FASTGANGenerator, self).__init__()
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, ngf * 16, 4, 1, 0),
            nn.BatchNorm2d(ngf * 16),
            nn.ReLU(True)
        )
        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 16, ngf * 8, 4, 2, 1),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True)
        )
        self.sle1 = SLEBlock(ngf * 8)
        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True)
        )
        self.sle2 = SLEBlock(ngf * 4)
        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(ngf * 4, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.initial(z)
        x = self.layer1(x)
        x = self.sle1(x, x)
        x = self.layer2(x)
        x = self.sle2(x, x)
        x = self.layer3(x)
        return x

class FASTGANDiscriminator(nn.Module):
    def __init__(self, ndf=64):
        super(FASTGANDiscriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ndf * 2, 1, 1),
            nn.Flatten(),
            nn.Sigmoid()
        )
        self.auxiliary = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(ndf * 2, 1, 1),
            nn.Flatten(),
            nn.Sigmoid()
        )

    def forward(self, x):
        downsampled = F.interpolate(x, scale_factor=0.5)
        main_out = self.main(x)
        aux_out = self.auxiliary(downsampled)
        return main_out, aux_out

class FeatureMatchingLoss(nn.Module):
    def __init__(self):
        super(FeatureMatchingLoss, self).__init__()
        self.l1_loss = nn.L1Loss()
        
    def forward(self, real_features, fake_features):
        loss = 0
        for real_feat, fake_feat in zip(real_features, fake_features):
            loss += self.l1_loss(fake_feat, real_feat.detach())
        return loss

class MemoryBank:
    def __init__(self, max_size=1000, feature_dim=256):
        self.max_size = max_size
        self.features = []
        self.feature_dim = feature_dim
        
    def update(self, new_features):
        self.features.extend(new_features.detach().cpu())
        if len(self.features) > self.max_size:
            self.features = self.features[-self.max_size:]
            
    def sample(self, n_samples):
        if len(self.features) == 0:
            return torch.randn(n_samples, self.feature_dim)
        indices = torch.randint(0, len(self.features), (n_samples,))
        return torch.stack([self.features[i] for i in indices])

class ProgressiveGrowingManager:
    def __init__(self, start_size=16, target_size=224, n_steps=4):
        self.current_size = start_size
        self.target_size = target_size
        self.n_steps = n_steps
        self.alpha = 0.0
        
    def step(self):
        self.alpha = min(1.0, self.alpha + 0.1)
        if self.alpha >= 1.0 and self.current_size < self.target_size:
            self.current_size = min(self.current_size * 2, self.target_size)
            self.alpha = 0.0
            
    def get_size(self):
        return self.current_size

def train_step(real_imgs, generator, discriminator, g_optimizer, d_optimizer, 
               feature_matching, memory_bank, prog_manager):
    batch_size = real_imgs.size(0)
    real_imgs = F.interpolate(real_imgs, size=prog_manager.get_size())
    d_optimizer.zero_grad()
    real_main, real_aux = discriminator(real_imgs)
    z = torch.randn(batch_size, 256, 1, 1, device=real_imgs.device)
    fake_imgs = generator(z)
    fake_main, fake_aux = discriminator(fake_imgs.detach())
    d_loss = (F.binary_cross_entropy(real_main, torch.ones_like(real_main)) +
              F.binary_cross_entropy(real_aux, torch.ones_like(real_aux)) +
              F.binary_cross_entropy(fake_main, torch.zeros_like(fake_main)) +
              F.binary_cross_entropy(fake_aux, torch.zeros_like(fake_aux)))
    d_loss.backward()
    d_optimizer.step()
    g_optimizer.zero_grad()
    fake_main, fake_aux = discriminator(fake_imgs)
    g_loss = (F.binary_cross_entropy(fake_main, torch.ones_like(fake_main)) +
              F.binary_cross_entropy(fake_aux, torch.ones_like(fake_aux)) +
              feature_matching(real_main, fake_main))
    g_loss.backward()
    g_optimizer.step()
    memory_bank.update(fake_imgs)
    prog_manager.step()
    return d_loss.item(), g_loss.item()

def train_fastgan(generator, discriminator, dataloader, num_epochs, progressive_steps=[16], device='cuda'):
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    feature_matching = FeatureMatchingLoss()
    memory_bank = MemoryBank()
    prog_manager = ProgressiveGrowingManager()
    
    for step in progressive_steps:
        for epoch in range(num_epochs):
            for i, (real_imgs, _) in enumerate(dataloader):
                real_imgs = real_imgs.to(device)
                d_loss, g_loss = train_step(real_imgs, generator, discriminator, g_optimizer, d_optimizer, 
                                           feature_matching, memory_bank, prog_manager)
                if i % 100 == 0:
                    print(f'Step: {step}, Epoch [{epoch}/{num_epochs}], '
                          f'D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}')

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    torch.manual_seed(42)
    np.random.seed(42)
    
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    csv_file = '/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_metadata.csv'
    img_dirs = ['/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_1', '/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_2']
    
    dataset = HAM10000Dataset(csv_file, img_dirs, transform=transform, device=device)
    
    num_classes = len(dataset.label_encoder.classes_)
    print("Unique Classes:", dataset.label_encoder.classes_)
    print("Number of Classes:", num_classes)
    
    batch_size = 64
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    
    latent_dim = 256
    generator = FASTGANGenerator(latent_dim).to(device)
    discriminator = FASTGANDiscriminator().to(device)
    
    num_epochs = 1
    progressive_steps = [16]
    
    train_fastgan(generator, discriminator, data_loader, num_epochs, progressive_steps, device)
    
    os.makedirs('synthetic_images', exist_ok=True)
    
    synthetic_images_by_class = {}
    
    with torch.no_grad():
        for class_idx in range(num_classes):
            z = torch.randn(100, latent_dim, 1, 1).to(device)
            synthetic_images = generator(z)
            
            synthetic_images_by_class[class_idx] = synthetic_images.cpu()
            
            class_name = dataset.label_encoder.inverse_transform([class_idx])[0]
            class_dir = os.path.join('synthetic_images', class_name)
            os.makedirs(class_dir, exist_ok=True)
            
            for i, img in enumerate(synthetic_images):
                save_path = os.path.join(class_dir, f'synthetic_image_{i}.png')
                save_image((img * 0.5 + 0.5), save_path)
    
    print("Synthetic image generation and filtering complete!")

if __name__ == "__main__":
    main()

Using device: cuda
Unique Classes: ['akiec' 'bcc' 'bkl' 'df' 'mel' 'nv' 'vasc']
Number of Classes: 7
Step: 16, Epoch [0/1], D_loss: 2.7883, G_loss: 3.7204
Step: 16, Epoch [0/1], D_loss: 1.9488, G_loss: 17.9825
Synthetic image generation and filtering complete!
