In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.cuda.amp import GradScaler, autocast
import numpy as np
from mtcnn import MTCNN
import cv2
from PIL import Image
import matplotlib.pyplot as plt

In [2]:
import torch

# Check if CUDA is available
print("CUDA available:", torch.cuda.is_available())

# Get GPU count
print("Number of GPUs:", torch.cuda.device_count())

# Get GPU name
if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))
    print("Current CUDA version:", torch.version.cuda)
else:
    print("No GPU detected. Make sure your drivers and CUDA toolkit are properly installed.")

CUDA available: True
Number of GPUs: 1
GPU name: NVIDIA GeForce RTX 3050 Laptop GPU
Current CUDA version: 11.8


In [4]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [5]:
# Hyperparameters
nz = 100  # Size of z latent vector (input to generator)
ngf = 32  # Size of feature maps in generator (reduced for 4GB VRAM)
ndf = 32  # Size of feature maps in discriminator (reduced for 4GB VRAM)
num_epochs = 50  # Number of training epochs (adjust based on time)
batch_size = 16  # Small batch size for RTX 3050
accumulation_steps = 4  # Effective batch size = 16 * 4 = 64
lr = 0.0002  # Learning rate
beta1 = 0.5  # Beta1 for Adam optimizer
image_size = 64  # Image size (64x64 for memory efficiency)
nc = 3  # Number of color channels (RGB)

In [37]:
data_root = "./data"
celeba_dir = os.path.join(data_root, "celeba/img_align_celeba/img_align_celeba")
celeba_cropped_dir = os.path.join(data_root, "celeba/celeba_cropped")
anime_dir = os.path.join(data_root, "anime/images/images")
output_dir = "./outputs"
checkpoint_dir = "./checkpoints"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(celeba_cropped_dir, exist_ok=True)

In [38]:
# Preprocess CelebA images (crop faces using MTCNN)
def preprocess_celeba_images(input_dir, output_dir, size=(64, 64)):
    detector = MTCNN()
    os.makedirs(output_dir, exist_ok=True)
    for filename in os.listdir(input_dir):
        if filename.endswith(".jpg"):
            img_path = os.path.join(input_dir, filename)
            img = cv2.imread(img_path)
            if img is None:
                continue
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            faces = detector.detect_faces(img_rgb)
            if faces:
                x, y, w, h = faces[0]['box']
                face = img_rgb[max(0, y):y+h, max(0, x):x+w]
                face = cv2.resize(face, size, interpolation=cv2.INTER_AREA)
                face = Image.fromarray(face)
                face.save(os.path.join(output_dir, filename))
            else:
                print(f"No face detected in {filename}")

# Preprocess CelebA (run once, comment out after processing)
# preprocess_celeba_images(celeba_dir, celeba_cropped_dir)

# Data loading
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

celeba_dataset = datasets.ImageFolder(os.path.join(data_root, "celeba/img_align_celeba"), transform=transform)
anime_dataset = datasets.ImageFolder(os.path.join(data_root, "anime/images"), transform=transform)
celeba_loader = DataLoader(celeba_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
anime_loader = DataLoader(anime_dataset, batch_size=batch_size, shuffle=True, num_workers=1)


In [28]:
# Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz + nc, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input, condition):
        input = torch.cat([input, condition], dim=1)
        return self.main(input)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss, optimizers, and mixed precision scaler
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerD = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
scaler = torch.amp.GradScaler(device="cuda")

# Labels
real_label = 1.0
fake_label = 0.0


In [42]:
print(len(os.listdir(anime_dataset)))

TypeError: listdir: path should be string, bytes, os.PathLike or None, not ImageFolder

In [46]:
def train():
    for epoch in range(num_epochs):
        for i, (celeba_imgs, _) in enumerate(celeba_loader):
            anime_iter = iter(anime_loader)
            try:
                anime_imgs, _ = next(anime_iter)
            except StopIteration:
                anime_iter = iter(anime_loader)
                anime_imgs, _ = next(anime_iter)
            
            # Ensure shapes are correct
            if anime_imgs.size(1) != 3:
                raise ValueError(f"Anime images have incorrect channels: {anime_imgs.shape}")
            if celeba_imgs.size(1) != 3:
                raise ValueError(f"CelebA images have incorrect channels: {celeba_imgs.shape}")
            
            celeba_imgs = celeba_imgs.to(device)
            anime_imgs = anime_imgs.to(device)
            batch_size = celeba_imgs.size(0)
            
            # Train Discriminator
            discriminator.zero_grad(set_to_none=True)
            d_loss_total = 0
            for _ in range(accumulation_steps):
                sub_batch_size = batch_size // accumulation_steps
                sub_celeba_imgs = celeba_imgs[:sub_batch_size]
                sub_anime_imgs = anime_imgs[:sub_batch_size]
                
                with autocast():
                    real_output = discriminator(sub_anime_imgs)
                    real_loss = criterion(real_output, torch.full((sub_batch_size,), real_label, device=device, dtype=torch.float))
                    noise = torch.randn(sub_batch_size, nz, 1, 1, device=device)
                    fake_imgs = generator(noise, sub_celeba_imgs)
                    fake_output = discriminator(fake_imgs.detach())
                    fake_loss = criterion(fake_output, torch.full((sub_batch_size,), fake_label, device=device, dtype=torch.float))
                    d_loss = (real_loss + fake_loss) / accumulation_steps
                scaler.scale(d_loss).backward()
                d_loss_total += d_loss.item()
            
            scaler.step(optimizerD)
            scaler.update()
            
            # Train Generator
            generator.zero_grad(set_to_none=True)
            g_loss_total = 0
            for _ in range(accumulation_steps):
                sub_batch_size = batch_size // accumulation_steps
                sub_celeba_imgs = celeba_imgs[:sub_batch_size]
                
                with autocast():
                    noise = torch.randn(sub_batch_size, nz, 1, 1, device=device)
                    fake_imgs = generator(noise, sub_celeba_imgs)
                    fake_output = discriminator(fake_imgs)
                    g_loss = criterion(fake_output, torch.full((sub_batch_size,), real_label, device=device, dtype=torch.float)) / accumulation_steps
                scaler.scale(g_loss).backward()
                g_loss_total += g_loss.item()
            
            scaler.step(optimizerG)
            scaler.update()
            
            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i}/{len(celeba_loader)}] "
                      f"D Loss: {d_loss_total:.4f} G Loss: {g_loss_total:.4f}")
                
                # Save generated images
                with torch.no_grad():
                    noise = torch.randn(batch_size, nz, 1, 1, device=device)
                    fake_imgs = generator(noise, celeba_imgs)
                save_image(fake_imgs, f"{output_dir}/epoch_{epoch+1}_batch_{i}.png", normalize=True)
            
            # Clear GPU memory
            torch.cuda.empty_cache()
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save(generator.state_dict(), f"{checkpoint_dir}/generator_epoch_{epoch+1}.pth")
            torch.save(discriminator.state_dict(), f"{checkpoint_dir}/discriminator_epoch_{epoch+1}.pth")
            print(f"Saved checkpoint for epoch {epoch+1}")

In [31]:
# Inference function
def generate_anime_image(human_img_path, output_path, model_path=None):
    if model_path:
        generator.load_state_dict(torch.load(model_path))
    generator.eval()
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Load and preprocess human image
    img = Image.open(human_img_path).convert('RGB')
    img = transform(img).unsqueeze(0).to(device)
    
    # Generate anime image
    noise = torch.randn(1, nz, 1, 1, device=device)
    with torch.no_grad():
        with autocast():
            anime_img = generator(noise, img)
    
    # Save output
    save_image(anime_img, output_path, normalize=True)
    
    # Display result
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title("Human Image")
    plt.imshow(Image.open(human_img_path))
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.title("Anime Image")
    plt.imshow(Image.open(output_path))
    plt.axis('off')
    plt.show()


In [47]:
# Main execution
if __name__ == "__main__":
    # Uncomment to preprocess CelebA (run once)
    # preprocess_celeba_images(celeba_dir, celeba_cropped_dir)
    
    # Train the model
    train()
    
    # Example inference
    human_img_path = "./data/celeba/img_align_celeba/img_align_celeba/000001.jpg"
    output_path = "./outputs/test_anime.png"
    model_path = f"{checkpoint_dir}/generator_epoch_{num_epochs}.pth" if os.path.exists(f"{checkpoint_dir}/generator_epoch_{num_epochs}.pth") else None
    generate_anime_image(human_img_path, output_path, model_path)

  with autocast():


RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
or torch.nn.BCEWithLogitsLoss.  binary_cross_entropy_with_logits and BCEWithLogits are
safe to autocast.

In [45]:
# Test data loading
for batch in anime_loader:
    anime_imgs, _ = batch
    print(f"Anime batch shape: {anime_imgs.shape}")
    break

Anime batch shape: torch.Size([16, 3, 64, 64])
