In [1]:
import torch.nn as nn
from generator import Generator
import torch
from discriminator import Discriminator
import torchvision.utils as vutils
from torch import optim
import fiftyone.zoo as foz
from torchvision import transforms, utils as vutils
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
import pretrain

In [2]:
beta1 = 0.5
batch_size = 32
img_size = 128 #64
device = "cuda:0"

In [None]:
from torch.utils.data import DataLoader
from preprocess_dataset import  SelectiveEdgeSmoothing, FiftyOnePyTorchDataset, \
                        ShuffledDatasetSampler, visualize_batch

generic_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(), 
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

gaussian_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    SelectiveEdgeSmoothing( img_size = img_size, radius=5, alpha=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load FiftyOne dataset
fiftyone_dataset = foz.load_zoo_dataset("coco-2017", split="validation", max_samples=776)
fo_dataset = FiftyOnePyTorchDataset(fiftyone_dataset, transform=generic_transform)
real_images_loader = DataLoader(fo_dataset, batch_size=batch_size, shuffle=True)

# Load Cartoon dataset
cartoon_path = '../../GhibliDataset' 
cartoon_dataset = ImageFolder(root=cartoon_path, transform=generic_transform)
sampler = ShuffledDatasetSampler(cartoon_dataset, seed=42)
cartoon_images_loader = DataLoader(cartoon_dataset, batch_size=batch_size, sampler=sampler)

# Load Cartoon Edge dataset
cartoon_edge_dataset = ImageFolder(root=cartoon_path, transform=gaussian_transform)
cartoon_edge_images_loader = DataLoader(cartoon_edge_dataset, batch_size=batch_size, sampler=sampler)


# Visualize images
batch_img = next(iter(real_images_loader))
visualize_batch(batch_img, "Some COCO Dataset Images")

batch_img, _ = next(iter(cartoon_images_loader)) 
visualize_batch(batch_img, "Some Cartoon Dataset Images")

batch_img, _ = next(iter(cartoon_edge_images_loader))
visualize_batch(batch_img, "Some Cartoon Edge Dataset Images")

In [4]:
epochs = 5
lr = 0.0002
#pretrained_generator = pretrain.fit(real_images_loader,  epochs, lr, beta1, device = device)
pretrained_generator = torch.load('pretrained_generator.pth')

In [None]:
from loss import ContentLoss

epochs = 200 
lr = 0.0001

generator = Generator().to(device) #pretrained_generator
discriminator = Discriminator().to(device)
adversarial_loss = nn.BCELoss().to(device)
co_loss = ContentLoss().to(device)
bce_loss = nn.BCEWithLogitsLoss().to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
real_label = 1.
fake_label = 0.

for epoch in range(epochs):
    for i, (real_images, (cartoon_images, _), (cartoon_edge_images, _)) in enumerate(zip(real_images_loader, cartoon_images_loader, cartoon_edge_images_loader)):
        
        real_images = real_images.to(device)
        cartoon_images = cartoon_images.to(device)
        cartoon_edge_images = cartoon_edge_images.to(device)
        batch_size = cartoon_images.size(0)
        
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        #D
        discriminator.zero_grad()
        generated_images = generator(real_images)
        cartoon_pred = discriminator(cartoon_images)
        cartoon_edge_pred = discriminator(cartoon_edge_images)
        generated_pred = discriminator(generated_images.detach())
        
        
        loss_d = adversarial_loss(cartoon_pred, real_labels) + \
                adversarial_loss(generated_pred, fake_labels) + \
                adversarial_loss(cartoon_edge_pred, fake_labels)
        loss_d.backward()
        optimizer_D.step()
        
        
        #G
        generator.zero_grad()
        generated_pred = discriminator(generated_images)
        loss_g = bce_loss(generated_pred, real_labels) + co_loss(generated_images, real_images)
        loss_g.backward()
        optimizer_G.step()
        
        if i % 1 == 0:
            print(f"Epoch [{epoch}/{epochs}] Batch {i}/{len(cartoon_images_loader)}")

    if epoch % 1 == 0: 
        print(f"Epoch [{epoch}/{epochs}] Batch [{i}/{len(real_images_loader)}]")
        fake_img_grid = vutils.make_grid(generated_images.detach().cpu(), padding=2, normalize=True)
        real_img_grid = vutils.make_grid(real_images.detach().cpu(), padding=2, normalize=True)
        

        plt.figure(figsize=(16,30))
        plt.axis(False)
        plt.title("Original images")
        plt.imshow(real_img_grid.permute(1, 2, 0).squeeze())
        plt.show()
        plt.figure(figsize=(16,30))
        plt.axis(False)
        plt.title("CartoonGAN images")
        plt.imshow(fake_img_grid.permute(1, 2, 0).squeeze())
        plt.show()

In [None]:
# generator = Generator().to(device)
# discriminator = Discriminator().to(device)

# criterion = nn.BCELoss()
# real_label = 1.
# fake_label = 0.

# optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
# for epoch in range(epochs):
#     for i, (images, _) in enumerate(cartoon_images_loader):

#         real_images = images.to(device)
#         current_batch_size = real_images.size(0)
#         real_labels = torch.full((current_batch_size,), real_label, device=device)
#         fake_labels = torch.full((current_batch_size,), fake_label, device=device)

#         discriminator.zero_grad()
#         outputs_real = discriminator(real_images).view(-1)
#         loss_real = criterion(outputs_real, real_labels)
#         loss_real.backward()

#         noise = torch.randn(current_batch_size, 3, 64, 64, device=device)  
#         fake_images = generator(noise)
#         outputs_fake = discriminator(fake_images.detach()).view(-1)
#         loss_fake = criterion(outputs_fake, fake_labels)
#         loss_fake.backward()

#         optimizer_D.step()


#         generator.zero_grad()
#         outputs_fake = discriminator(fake_images).view(-1)
#         loss_generator = criterion(outputs_fake, real_labels) 
#         loss_generator.backward()

#         optimizer_G.step()

#        
#         if i % 1 == 0:
#             print(f"Epoch [{epoch}/{epochs}] Batch {i}/{len(cartoon_images_loader)} \
#                   Loss D: {loss_real + loss_fake}, Loss G: {loss_generator}")

#    
#     torch.save(generator.state_dict(), 'generator.pth')
#     torch.save(discriminator.state_dict(), 'discriminator.pth')

#     with torch.no_grad():
#         fake_images = generator(noise).detach().cpu()
#     img_grid = vutils.make_grid(fake_images, padding=2, normalize=True)
#     # Convert to plt
#     plt.imshow(img_grid.permute(1, 2, 0).squeeze())
#     plt.show()
