In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.cuda.amp import GradScaler, autocast
import numpy as np
from mtcnn import MTCNN
import cv2
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
import torch

# Check if CUDA is available
print("CUDA available:", torch.cuda.is_available())

# Get GPU count
print("Number of GPUs:", torch.cuda.device_count())

# Get GPU name
if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))
    print("Current CUDA version:", torch.version.cuda)
else:
    print("No GPU detected. Make sure your drivers and CUDA toolkit are properly installed.")

CUDA available: True
Number of GPUs: 1
GPU name: NVIDIA GeForce RTX 3050 Laptop GPU
Current CUDA version: 11.8


In [3]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
# Hyperparameters
nz = 100  # Size of z latent vector (input to generator)
ngf = 16  # Size of feature maps in generator (reduced for 4GB VRAM)
ndf = 16  # Size of feature maps in discriminator (reduced for 4GB VRAM)
num_epochs = 30  # Number of training epochs (adjust based on time)
batch_size = 16  # Small batch size for RTX 3050
accumulation_steps = 4  # Effective batch size = 16 * 4 = 64
lr = 0.0002  # Learning rate
beta1 = 0.5  # Beta1 for Adam optimizer
image_size = 64  # Image size (64x64 for memory efficiency)
nc = 3  # Number of color channels (RGB)

In [5]:
data_root = "./data"
celeba_dir = os.path.join(data_root, "celeba/img_align_celeba/img_align_celeba")
celeba_cropped_dir = os.path.join(data_root, "celeba/celeba_cropped")
anime_dir = os.path.join(data_root, "anime/images/images")
output_dir = "./outputs"
checkpoint_dir = "./checkpoints"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(celeba_cropped_dir, exist_ok=True)

In [6]:
# Preprocess CelebA images (crop faces using MTCNN)
def preprocess_celeba_images(input_dir, output_dir, size=(64, 64)):
    detector = MTCNN()
    os.makedirs(output_dir, exist_ok=True)
    for filename in os.listdir(input_dir):
        if filename.endswith(".jpg"):
            img_path = os.path.join(input_dir, filename)
            img = cv2.imread(img_path)
            if img is None:
                continue
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            faces = detector.detect_faces(img_rgb)
            if faces:
                x, y, w, h = faces[0]['box']
                face = img_rgb[max(0, y):y+h, max(0, x):x+w]
                face = cv2.resize(face, size, interpolation=cv2.INTER_AREA)
                face = Image.fromarray(face)
                face.save(os.path.join(output_dir, filename))
            else:
                print(f"No face detected in {filename}")

# Preprocess CelebA (run once, comment out after processing)
# preprocess_celeba_images(celeba_dir, celeba_cropped_dir)

# Data loading
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

celeba_dataset = datasets.ImageFolder(os.path.join(data_root, "celeba/img_align_celeba"), transform=transform)
anime_dataset = datasets.ImageFolder(os.path.join(data_root, "anime/images"), transform=transform)
celeba_loader = DataLoader(celeba_dataset, batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True,drop_last=True)
anime_loader = DataLoader(anime_dataset, batch_size=batch_size, shuffle=True, num_workers=4,pin_memory=True, drop_last=True)


In [7]:
# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Downsample condition from [sub_batch_size, 3, 64, 64] to [sub_batch_size, 3, 1, 1]
        self.downsample = nn.Sequential(
            nn.Conv2d(nc, nc, 4, 2, 1, bias=False),  # [sub_batch_size, 3, 64, 64] → [sub_batch_size, 3, 32, 32]
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(nc, nc, 4, 2, 1, bias=False),  # [sub_batch_size, 3, 32, 32] → [sub_batch_size, 3, 16, 16]
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(nc, nc, 4, 2, 1, bias=False),  # [sub_batch_size, 3, 16, 16] → [sub_batch_size, 3, 8, 8]
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(nc, nc, 4, 2, 1, bias=False),  # [sub_batch_size, 3, 8, 8] → [sub_batch_size, 3, 4, 4]
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(nc, nc, 4, 4, 0, bias=False),  # [sub_batch_size, 3, 4, 4] → [sub_batch_size, 3, 1, 1]
        )
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz + nc, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input, condition):
        # Downsample condition to [sub_batch_size, 3, 1, 1]
        condition = self.downsample(condition)  # [sub_batch_size, 3, 64, 64] → [sub_batch_size, 3, 1, 1]
        # Concatenate noise and condition
        input = torch.cat([input, condition], dim=1)  # [sub_batch_size, nz+nc, 1, 1]
        output = self.main(input)
        return output
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),        # Input: [4, 3, 64, 64] → [4, 32, 32, 32]
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),   # [4, 32, 32, 32] → [4, 64, 16, 16]
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),  # [4, 64, 16, 16] → [4, 128, 8, 8]
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),  # [4, 128, 8, 8] → [4, 256, 4, 4]
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),        # [4, 256, 4, 4] → [4, 1, 1, 1]
        )

    def forward(self, input):
        return self.main(input).view(-1)  # [sub_batch_size, 1, 1, 1] → [sub_batch_size]

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss, optimizers, and mixed precision scaler
criterion = nn.BCEWithLogitsLoss()
optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
scaler = torch.amp.GradScaler()

# Labels
real_label = 1.0
fake_label = 0.0

In [8]:
print("Testing data loaders...")
for batch in anime_loader:
    anime_imgs, _ = batch
    print(f"Anime batch shape: {anime_imgs.shape}")
    break
for batch in celeba_loader:
    celeba_imgs, _ = batch
    print(f"CelebA batch shape: {celeba_imgs.shape}")
    break

Testing data loaders...
Anime batch shape: torch.Size([16, 3, 64, 64])
CelebA batch shape: torch.Size([16, 3, 64, 64])


In [9]:
print(f"Anime dataset size: {len(anime_dataset)}")
print(f"CelebA dataset size: {len(celeba_dataset)}")

Anime dataset size: 63565
CelebA dataset size: 202599


In [10]:
import time

In [11]:
from tqdm import tqdm

def train():
    for epoch in range(num_epochs):
        start_epoch = time.time()
        epoch_iterator = tqdm(enumerate(celeba_loader), total=len(celeba_loader), desc=f"Epoch {epoch+1}/{num_epochs}")

        for i, (celeba_imgs, _) in epoch_iterator:
            start_batch = time.time()
            # Data loading time
            anime_iter = iter(anime_loader)
            try:
                anime_imgs, _ = next(anime_iter)
            except StopIteration:
                anime_iter = iter(anime_loader)
                anime_imgs, _ = next(anime_iter)
            data_load_time = time.time()
            
            celeba_imgs = celeba_imgs.to(device)
            anime_imgs = anime_imgs.to(device)
            batch_size = celeba_imgs.size(0)
            
            if batch_size < accumulation_steps:
                epoch_iterator.set_postfix({"status": f"Skipping batch {i} with size {batch_size}"})
                continue
            
            # Discriminator training
            start_d_train = time.time()
            discriminator.zero_grad(set_to_none=True)
            d_loss_total = 0
            for _ in range(accumulation_steps):
                sub_batch_size = batch_size // accumulation_steps
                if sub_batch_size == 0:
                    break
                sub_celeba_imgs = celeba_imgs[:sub_batch_size]
                sub_anime_imgs = anime_imgs[:sub_batch_size]
                
                with torch.amp.autocast('cuda'):
                    real_output = discriminator(sub_anime_imgs)
                    real_loss = criterion(real_output, torch.full((sub_batch_size,), real_label, device=device, dtype=torch.float))
                    noise = torch.randn(sub_batch_size, nz, 1, 1, device=device)
                    fake_imgs = generator(noise, sub_celeba_imgs)
                    fake_output = discriminator(fake_imgs.detach())
                    fake_loss = criterion(fake_output, torch.full((sub_batch_size,), fake_label, device=device, dtype=torch.float))
                    d_loss = (real_loss + fake_loss) / accumulation_steps
                scaler.scale(d_loss).backward()
                d_loss_total += d_loss.item()
            
            scaler.step(optimizerD)
            scaler.update()
            end_d_train = time.time()
            
            # Generator training
            start_g_train = time.time()
            for _ in range(2):
                generator.zero_grad(set_to_none=True)
                g_loss_total = 0
                for _ in range(accumulation_steps):
                    sub_batch_size = batch_size // accumulation_steps
                    if sub_batch_size == 0:
                        break
                    sub_celeba_imgs = celeba_imgs[:sub_batch_size]
                    
                    with torch.amp.autocast('cuda'):
                        noise = torch.randn(sub_batch_size, nz, 1, 1, device=device)
                        fake_imgs = generator(noise, sub_celeba_imgs)
                        fake_output = discriminator(fake_imgs)
                        g_loss = criterion(fake_output, torch.full((sub_batch_size,), real_label, device=device, dtype=torch.float)) / accumulation_steps
                    scaler.scale(g_loss).backward()
                    g_loss_total += g_loss.item()
                
                scaler.step(optimizerG)
                scaler.update()
            end_g_train = time.time()
            
            end_batch = time.time()
            epoch_iterator.set_postfix({
                "D Loss": f"{d_loss_total:.4f}", 
                "G Loss": f"{g_loss_total:.4f}", 
                "Batch Time": f"{end_batch - start_batch:.2f} sec"
            })
            
            torch.cuda.empty_cache()
        
        end_epoch = time.time()
        print(f"Epoch {epoch+1} took {(end_epoch - start_epoch)/60:.2f} minutes")
        
        if (epoch + 1) % 10 == 0:
            torch.save(generator.state_dict(), f"{checkpoint_dir}/generator_epoch_{epoch+1}.pth")
            torch.save(discriminator.state_dict(), f"{checkpoint_dir}/discriminator_epoch_{epoch+1}.pth")
            print(f"Saved checkpoint for epoch {epoch+1}")

In [12]:
def generate_anime_image(human_img_path, output_path, model_path=None):
    if model_path:
        generator.load_state_dict(torch.load(model_path))
    generator.eval()
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Load and preprocess human image
    img = Image.open(human_img_path).convert('RGB')
    img = transform(img).unsqueeze(0).to(device)
    
    # Generate anime image
    noise = torch.randn(1, nz, 1, 1, device=device)
    with torch.no_grad():
        with torch.amp.autocast('cuda'):  # Updated API
            anime_img = generator(noise, img)
    
    # Save output
    save_image(anime_img, output_path, normalize=True)
    
    # Display result
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title("Human Image")
    plt.imshow(Image.open(human_img_path))
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.title("Anime Image")
    plt.imshow(Image.open(output_path))
    plt.axis('off')
    plt.show()

In [13]:
# Main execution
if __name__ == "__main__":
    # Uncomment to preprocess CelebA (run once)
    # preprocess_celeba_images(celeba_dir, celeba_cropped_dir)
    
    # Train the model
    train()
    
    # Example inference
    human_img_path = "./data/celeba/img_align_celeba/img_align_celeba/000001.jpg"
    output_path = "./outputs/test_anime.png"
    model_path = f"{checkpoint_dir}/generator_epoch_{num_epochs}.pth" if os.path.exists(f"{checkpoint_dir}/generator_epoch_{num_epochs}.pth") else None
    generate_anime_image(human_img_path, output_path, model_path)

Epoch 1/30:   0%|          | 6/12662 [01:54<67:03:55, 19.08s/it, D Loss=1.1755, G Loss=0.7678, Batch Time=16.01 sec]


KeyboardInterrupt: 

In [45]:
# Test data loading
for batch in anime_loader:
    anime_imgs, _ = batch
    print(f"Anime batch shape: {anime_imgs.shape}")
    break

Anime batch shape: torch.Size([16, 3, 64, 64])
