In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
from PIL import Image


latent_dim = 128
img_size = 32
channels = 3
lr = 0.0002
b1, b2 = 0.5, 0.999
batch_size = 64
num_epochs = 200
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def list_images(basePath, contains=None):
   
    return list_files(basePath, validExts=(".jpg", ".jpeg", ".png", ".bmp"), contains=contains)

def list_files(basePath, validExts=(".jpg", ".jpeg", ".png", ".bmp"), contains=None):

    for (rootDir, dirNames, filenames) in os.walk(basePath):
       
        for filename in filenames:

            if contains is not None and filename.find(contains) == -1:
                continue


            ext = filename[filename.rfind("."):].lower()


            if ext.endswith(validExts):

                imagePath = os.path.join(rootDir, filename).replace(" ", "\\ ")
                yield imagePath

def load_images(directory='', size=(64,64)):
    images = []
    labels = []  
    label = 0
    
    imagePaths = list(list_images(directory))
    
    for path in imagePaths:
        
        if not('OSX' in path):
        
            path = path.replace('\\','/')

            image = cv2.imread(path) 
            if image is not None:  
                image = cv2.resize(image, size)
                images.append(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    
    return images


class CustomImageDataset(Dataset):
    def __init__(self, directory, img_size=32, transform=None):
        self.images = load_images(directory, size=(img_size, img_size))
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        

        image = Image.fromarray(image.astype('uint8'))
        
        if self.transform:
            image = self.transform(image)
            
        return image, 0 








In [None]:

class Generator(nn.Module):
    def __init__(self, latent_dim, channels):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

In [None]:

class Discriminator(nn.Module):
    def __init__(self, channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, 128, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img).view(-1, 1)

In [None]:
# --- Loss & Optimizers ---
adversarial_loss = nn.BCELoss()
generator = Generator(latent_dim, channels).to(device)
discriminator = Discriminator(channels).to(device)
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

data_path = "/kaggle/input/anime-faces/data"  


print("Available datasets in /kaggle/input/:")
for item in os.listdir("/kaggle/input/"):
    print(f"- {item}")

    item_path = f"/kaggle/input/{item}"
    if os.path.isdir(item_path):
        print(f"  Contents: {os.listdir(item_path)}")
print()

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])


dataset = CustomImageDataset(data_path, img_size=img_size, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

print(f"Dataset loaded successfully! Total images: {len(dataset)}")

In [None]:


for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        real_labels = torch.ones(real_imgs.size(0), 1).to(device)
        fake_labels = torch.zeros(real_imgs.size(0), 1).to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), real_labels)

        z = torch.randn(real_imgs.size(0), latent_dim, 1, 1).to(device)
        fake_imgs = generator(z)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        g_loss = adversarial_loss(discriminator(fake_imgs), real_labels)
        g_loss.backward()
        optimizer_G.step()

        if (i + 1) % 200 == 0:
            print(f"[Epoch {epoch+1}/{num_epochs}] [Batch {i+1}/{len(dataloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

    if (epoch + 1) % 10 == 0:
        generator.eval()
        with torch.no_grad():
            z = torch.randn(25, latent_dim, 1, 1).to(device)
            samples = generator(z).cpu()
            samples = 0.5 * samples + 0.5
            grid = torchvision.utils.make_grid(samples, nrow=5)
            plt.figure(figsize=(8, 8))
            plt.imshow(grid.permute(1, 2, 0))
            plt.title(f"Epoch {epoch+1}")
            plt.axis("off")
            plt.show()

            # Save output
            torchvision.utils.save_image(samples, f"epoch_{epoch+1}.png", nrow=5, normalize=True)
        generator.train()

print("Training completed!")