In [2]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.model = nn.Sequential(
            self.block(100, 128),
            self.block(128, 256),
            self.block(256, 512),
            self.block(512, 1024),
            nn.Linear(1024, 1 * 28 * 28),
            nn.Tanh()
        )
        
    def block(self, input_dim, output_dim):
        block = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.LeakyReLU(0.2, inplace=True)
        )
        return block

    def forward(self, z):
        x = self.model(z)
        x = x.view(x.size(0), 1, 28, 28)
        return x

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(1 * 28 * 28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.model(x)

In [5]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import pandas as pd

class SkinLesionDataset(Dataset):
    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file)
        self.image_paths = self.df['image_path'].tolist()  

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = cv2.imread(image_path)
        if image is None:
            raise Exception(f"Failed to read image: {image_path}")

        image = image.astype(np.float32)
        image = (image - 127.5) / 127.5
        image = torch.from_numpy(image).permute(2, 0, 1)
        return image

dataset = SkinLesionDataset("HAM10000_metadata_paths.csv")  
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
import torch
from torchvision.utils import save_image
import os

generator = Generator().cuda()
discriminator = Discriminator().cuda()

criterion = nn.BCELoss().cuda()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
n_epochs = 20

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(1.0)
        fake = torch.cuda.FloatTensor(imgs.size(0), 1).fill_(0.0)
        real_imgs = imgs.cuda()

        optimizer_G.zero_grad()
        z = torch.normal(mean=0, std=1, size=(imgs.size(0), 100)).cuda()
        generated_imgs = generator(z)
        g_loss = criterion(discriminator(generated_imgs), real)
        g_loss.backward()
        optimizer_G.step()
        
        optimizer_D.zero_grad()
        real_loss = criterion(discriminator(real_imgs), real)
        fake_loss = criterion(discriminator(generated_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        
        if i % 400 == 0:
            save_image(generated_imgs.data[:25], f"output/{epoch}_{i}.png", nrow=5, normalize=True)

    torch.save(generator.state_dict(), f"saved_models/generator_{epoch}.pth")
    torch.save(discriminator.state_dict(), f"saved_models/discriminator_{epoch}.pth")

In [None]:
generator.load_state_dict(torch.load("saved_models/generator_epoch_n.pth"))
discriminator.load_state_dict(torch.load("saved_models/discriminator_epoch_n.pth"))

generator.eval()
discriminator.eval()

with torch.no_grad():
    test_z = torch.normal(mean=0, std=1, size=(25, 100)).cuda()
    generated_images = generator(test_z)

import matplotlib.pyplot as plt

def show_images(images, cols=5):
    plt.figure(figsize=(15, 15))
    for i, image in enumerate(images):
        plt.subplot(len(images) / cols + 1, cols, i + 1)
        img = image.permute(1, 2, 0).cpu().numpy()
        img = (img + 1) / 2  # Rescale images from [-1, 1] to [0, 1]
        plt.imshow(img)
        plt.axis('off')
    plt.show()

show_images(generated_images[:25])