In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from modulefinder import Module

In [None]:
from PIL import Image

In [None]:
!cp -r /kaggle/input/chest-xray-pneumonia /kaggle/working/chest-xray-pneumonia

In [None]:
!find /kaggle/working/chest-xray-pneumonia/ -name "._*" -delete
!find /kaggle/working/chest-xray-pneumonia/ -name "__MACOSX" -type d -exec rm -rf {} +

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.ImageFolder('/kaggle/working/chest-xray-pneumonia/chest_xray/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
class GeneratorModel(nn.Module):
    def __init__(self, z_dim=100, n_classes=2):
        super().__init__()
        self.z_dim = z_dim

        self.embedding = nn.Embedding(n_classes, 50)
        self.label_dense = nn.Linear(50, 4096)   

        self.z_dense = nn.Linear(z_dim, 4096)    
        
        self.deconv_blocks = nn.Sequential(
            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
        )

        self.out_conv = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)
        self.tanh = nn.Tanh()

    def forward(self, z, labels):
        labels = labels.long()
        label_embed = self.embedding(labels)
        x_label = self.label_dense(label_embed)
        x_label = x_label.view(-1, 64, 8, 8)

        x_z = self.z_dense(z)
        x_z = x_z.view(-1, 64, 8, 8)

        x = torch.cat([x_label, x_z], dim=1)  
        x = self.deconv_blocks(x)             
        x = self.out_conv(x)                  
        return self.tanh(x)


class DiscriminatorModel(nn.Module):
    def __init__(self, n_classes=2):
        super().__init__()

        self.embedding = nn.Embedding(n_classes, 50)
        self.label_dense = nn.Linear(50, 256 * 256)  

        self.conv_blocks = nn.Sequential(
            nn.Conv2d(2, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.flatten = nn.Flatten()
        self.drop = nn.Dropout(0.5)
        self.fc = nn.Linear(8 * 8 * 64, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, image, labels, return_features=False):
        labels = labels.long()
        label_embed = self.embedding(labels)
        y = self.label_dense(label_embed).view(-1, 1, 256, 256)

        x = torch.cat([image, y], dim=1)      

        features = self.conv_blocks(x)        
        flat_features = self.flatten(features)  
        flat_features = self.drop(flat_features)

        out = self.fc(flat_features)
        validity = self.sigmoid(out)

        if return_features:
            return validity, flat_features
        return validity

In [None]:
z_dim = 100
label_dim = 2
image_size = 256
batch_size = 32
num_epochs = 100
lr = 0.0001
lambda_gp = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.BCELoss()
generator = GeneratorModel(z_dim, label_dim).to(device)
discriminator = DiscriminatorModel().to(device)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
for epoch in range(num_epochs):
    for i, (real_images, labels) in enumerate(train_loader):
        real_images = real_images.to(device)
        labels = labels.to(device)
        batch_size = real_images.size(0)


        z = torch.randn(batch_size, z_dim, device=device)
        fake_images = generator(z, labels).detach()

        real_validity, real_features = discriminator(real_images, labels, return_features=True)
        fake_validity, _ = discriminator(fake_images, labels, return_features=True)

        real_targets = torch.ones_like(real_validity, device=device)
        fake_targets = torch.zeros_like(fake_validity, device=device)

        d_loss_real = criterion(real_validity, real_targets)
        d_loss_fake = criterion(fake_validity, fake_targets)
        d_loss = (d_loss_real + d_loss_fake) / 2

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()


        z = torch.randn(batch_size, z_dim, device=device)
        fake_images = generator(z, labels)
        fake_validity, fake_features = discriminator(fake_images, labels, return_features=True)

        feature_loss = torch.mean((real_features.detach().mean(0) - fake_features.mean(0))**2)

        g_loss = feature_loss
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        if i % 50 == 0:
            print(f"[{epoch}/{num_epochs}] Step {i} | D_loss: {d_loss.item():.4f} | G_loss: {g_loss.item():.4f}")


    if epoch == 0:
        z_fixed  = torch.randn(4, z_dim).to(device)
        labels_fixed = torch.tensor([0,1,0,1], device=device)

    if epoch % 10 == 0:
        with torch.no_grad():
            fake_samples = generator(z_fixed, labels_fixed).cpu()
            grid = torchvision.utils.make_grid(fake_samples, nrow=4, normalize=True, padding=2)
            plt.figure(figsize=(6,6))
            plt.title(f"Epoch {epoch}")
            plt.axis("off")
            plt.imshow(grid.permute(1,2,0))
            plt.show()

In [None]:
from torchvision.utils import save_image
import zipfile
save_dir = "/kaggle/working/generated_images"
os.makedirs(save_dir, exist_ok=True)

# 🔧 Tham số
num_per_label = 50       
z_dim = 100             
num_classes = 2          

def generate_images(generator, z_dim, num_per_label, num_classes, save_dir):
    generator.eval()
    idx = 0
    with torch.no_grad():
        for label in range(num_classes):
            labels = torch.full((num_per_label,), label, dtype=torch.long, device=next(generator.parameters()).device)
            z = torch.randn(num_per_label, z_dim, device=labels.device)
            fake_images = generator(z, labels)

            for img in fake_images:
                save_path = os.path.join(save_dir, f"label{label}_img{idx:04d}.png")
                save_image(img * 0.5 + 0.5, save_path)  
                idx += 1

generate_images(generator, z_dim, num_per_label, num_classes, save_dir)

zip_path = "/kaggle/working/generated_images.zip"
with zipfile.ZipFile(zip_path, 'w') as zipf:
    for root, _, files in os.walk(save_dir):
        for file in files:
            zipf.write(os.path.join(root, file), file)

print(f"Đã lưu {num_classes*num_per_label} ảnh trong {zip_path}")