In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import itertools
from torchvision import transforms, datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) #Rtx 2070 super

cuda


In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_channels),
        )

    def forward(self, x):
        return x + self.block(x)


In [4]:
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()
        
        # Initial convolution block
        model = [
            nn.Conv2d(input_nc, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        ]
        
        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features * 2
        
        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]
        
        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True)
            ]
            in_features = out_features
            out_features = in_features // 2
        
        # Output layer
        model += [
            nn.Conv2d(64, output_nc, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        ]
        
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [5]:
class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        model = [
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        
        model += [
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        
        model += [
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        
        model += [
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        
        model += [nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)]
        
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [6]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [7]:
class ImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.transform = transform
        self.files = sorted(os.listdir(root))
        self.root = root

    def __getitem__(self, index):
        img_path = os.path.join(self.root, self.files[index])
        img = Image.open(img_path)

        if self.transform:
            img = self.transform(img)

        return img

    def __len__(self):
        return len(self.files)

In [8]:
transform = transforms.Compose([
    transforms.Resize(int(256 * 1.12), Image.BICUBIC),
    transforms.RandomCrop(256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Crear los datasets y dataloaders
batch_size = 1

celeba_dataset = ImageDataset(root='./datasets/img_align_celeba/img_align_celeba/', transform=transform)
caricature_dataset = ImageDataset(root='./datasets/cartoonset10k/', transform=transform)

celeba_loader = DataLoader(celeba_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
caricature_loader = DataLoader(caricature_dataset, batch_size=batch_size, shuffle=True, num_workers=4)


In [9]:
# Instanciar los generadores y discriminadores
G_A2B = Generator(input_nc=3, output_nc=3).to(device)
G_B2A = Generator(input_nc=3, output_nc=3).to(device)
D_A = Discriminator(input_nc=3).to(device)
D_B = Discriminator(input_nc=3).to(device)

# Inicializar los pesos
G_A2B.apply(weights_init_normal)
G_B2A.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)

# Definir las pérdidas y optimizadores
criterion_GAN = torch.nn.MSELoss().to(device)
criterion_cycle = torch.nn.L1Loss().to(device)
criterion_identity = torch.nn.L1Loss().to(device)

optimizer_G = torch.optim.Adam(itertools.chain(G_A2B.parameters(), G_B2A.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [10]:
# Entrenamiento del modelo
n_epochs = 200
for epoch in range(n_epochs):
    for i, (real_A, real_B) in enumerate(zip(celeba_loader, caricature_loader)):
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # Generadores A2B y B2A
        optimizer_G.zero_grad()

        # Identidad
        loss_id_A = criterion_identity(G_B2A(real_A), real_A)
        loss_id_B = criterion_identity(G_A2B(real_B), real_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_A2B(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake, device=device))

        fake_A = G_B2A(real_B)
        pred_fake = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake, device=device))

        loss_GAN = (loss_GAN_A2B + loss_GAN_B2A) / 2

        # Cycle loss
        recovered_A = G_B2A(fake_B)
        loss_cycle_A = criterion_cycle(recovered_A, real_A)

        recovered_B = G_A2B(fake_A)
        loss_cycle_B = criterion_cycle(recovered_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + 10.0 * loss_cycle + 5.0 * loss_identity
        loss_G.backward()
        optimizer_G.step()

        # Discriminadores A y B
        optimizer_D_A.zero_grad()

        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real, device=device))

        pred_fake = D_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake, device=device))

        loss_D_A = (loss_D_real + loss_D_fake) / 2
        loss_D_A.backward()
        optimizer_D_A.step()

        optimizer_D_B.zero_grad()

        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real, device=device))

        pred_fake = D_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake, device=device))

        loss_D_B = (loss_D_real + loss_D_fake) / 2
        loss_D_B.backward()
        optimizer_D_B.step()

        # Imprimir las pérdidas
        print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(celeba_loader)}] "
              f"[D loss: {loss_D_A.item() + loss_D_B.item()}] "
              f"[G loss: {loss_G.item()}]")
        
        # Guardar imágenes generadas y modelos
    if epoch % 10 == 0:
        with torch.no_grad():
            fake_B = G_A2B(real_A)
            fake_A = G_B2A(real_B)
            save_image(fake_B, f"output/fake_B_{epoch}.png", normalize=True)
            save_image(fake_A, f"output/fake_A_{epoch}.png", normalize=True)

        torch.save(G_A2B.state_dict(), f"models/G_A2B_{epoch}.pth")
        torch.save(G_B2A.state_dict(), f"models/G_B2A_{epoch}.pth")
        torch.save(D_A.state_dict(), f"models/D_A_{epoch}.pth")
        torch.save(D_B.state_dict(), f"models/D_B_{epoch}.pth")

In [None]:
# Evaluación del modelo
def evaluate_model(generator_A2B, generator_B2A, test_loader_A, test_loader_B, device):
    generator_A2B.eval()
    generator_B2A.eval()

    for i, (real_A, real_B) in enumerate(zip(test_loader_A, test_loader_B)):
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        with torch.no_grad():
            fake_B = generator_A2B(real_A)
            fake_A = generator_B2A(real_B)

        save_image(fake_B, f"evaluation/fake_B_{i}.png", normalize=True)
        save_image(fake_A, f"evaluation/fake_A_{i}.png", normalize=True)

In [None]:
evaluate_model(G_A2B, G_B2A, celeba_loader, caricature_loader, device)