In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# This notebook borrows heavily from:
# https://www.kaggle.com/nachiket273/cyclegan-pytorch

## Import Dependencies
Libraries for navigating the filesystem, for constructing the models, and for image processing.

In [None]:
import numpy as np
import pandas as pd
from glob import glob
import itertools
import matplotlib.pyplot as plt
import os
import PIL
from PIL import Image
import random
import shutil
from sklearn.model_selection import GroupKFold
from sklearn.metrics import roc_curve
from sklearn import metrics
import time
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from torch.utils.data import Dataset
from torch.utils.data import random_split
from torch.utils.data import DataLoader

import torchvision.models as models
import torchvision.transforms as transforms

## Initialize Device

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

monet_directory = '../input/gan-getting-started/monet_jpg/'
photo_directory = '../input/gan-getting-started/photo_jpg/'

## Dataset and Data Loader

In [None]:
class MonetData(Dataset):
    def __init__(self, monet_directory, photo_directory, 
                 transform = transforms.ToTensor()):
        self.transform = transform
        self.monet_paths = [os.path.join(monet_directory, p) for p in os.listdir(monet_directory)]
        self.photo_paths = [os.path.join(photo_directory, p) for p in os.listdir(photo_directory)]
    
    # Return a monet paired with a random photo.
    def __get_item__(self, idx):
        monet = Image.open(monet_paths[idx])
        monet = self.transform(monet)
        
        # Since the monet and photo directories contain a different number of images
        # return a random photo from the directory
        photo_idx = int(np.random.uniform(len(photo_paths)))
        photo = Image.open(photo_paths[photo_idx])
        photo = self.transform(photo)
        
        return monet, photo

    # Return the minimum length of the list of files in the two directories.
    def __len__(self):
        return min(len(monet_paths), len(photo_paths))


## Define the Models

In [None]:
class Conv_Wrapper():
    __init__(self):
        

In [None]:
class Generator():
    __init__(self):

In [None]:
class Discriminator():
    __init__(self):

In [None]:
class GAN():
    def __init__(self):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.monet_generator = Generator()
        self.photo_generator = Generator()
        self.monet_discriminator = Discriminator()
        self.photo_discriminator = Discriminator()
        
        # Initialize weights for each of the networks.
        
        
        # Send each of the models to GPU memory, if available.
        self.monet_generator = self.monet_generator.to(self.device)
        self.photo_generator = self.photo_generator.to(self.device)
        self.monet_discriminator = self.monet_discriminator.to(self.device)
        self.photo_discriminator = self.photo_discriminator.to(self.device)
        
        
    def train(data_loader, epochs, lmbda = 10, identity_coefficient = 0.5):
        generator_parameters = itertools.chain(self.monet_generator.parameters(),
                                               self.photo_generator.parameters())
        adam_generator = torch.optim.Adam(generator_parameters)
        
        discriminator_parameters = itertools.chain(self.monet_discriminator.parameters(),
                                                   self.photo_discriminator.parameters())
        adam_discriminator = torch.optim.Adam(discriminator_parameters)
        
        l1_loss = nn.L1Loss()
        mse_loss = nn.MSELoss()
        
        for epoch in range(epochs):
            for monet, photo in data_loader:
                monet = monet.to(self.device)
                photo = photo.to(self.device)
                
                adam_generator.zero_grad()
                
                # First call the generator to create a fake monet and 
                # a fake photo.
                fake_monet = self.monet_generator(photo)
                fake_photo = self.photo_generator(monet)
                
                # Pass the fakes back through to attempt to recover
                # the originals.
                cycle_monet = self.monet_generator(fake_photo)
                cycle_photo = self.photo_generator(fake_monet)
                
                # Generate a monet and a photo from images that are
                # already monets and photos. This should ideally
                # leave the images relatively unaltered.
                identity_monet = self.monet_generator(monet)
                identity_photo = self.photo_generator(photo)
                
                # Run the fake monet and photo through the 
                # discriminator to see if it can differentiate the
                # fakes.
                for p in discriminator_parameters:
                    p.requires_grad = False
                monet_realism = self.monet_discriminator(fake_monet)
                photo_realism = self.photo_discriminator(fake_photo)
                
                # Calculate the loss for the generator.
                cycle_loss_monet = l1_loss(cycle_monet, monet) * lmbda
                cycle_loss_photo = l1_loss(cycle_photo, photo) * lmbda
                identity_loss_monet = l1_loss(identity_monet, monet) * lmbda * identity_coefficient
                identity_loss_photo = l1_loss(identity_photo, photo) * lmbda * identity_coefficient
                # TODO: review the dimensions for this.
                real = 1
                fake = 0
                adverserial_loss_monet = mse_loss(monet_realism, real)
                adverserial_loss_photo = mse_loss(photo_realism, real)
                generator_loss = cycle_loss_monet + cycle_loss_photo
                generator_loss += identity_loss_monet + identity_loss_photo
                generator_loss += adverserial_loss_monet + adverserial_loss_photo
                
                # Update the parameters for the generator.
                generator_loss.backward()
                adam_generator.step()
                
                # Perform a forward step through the discriminator.
                for p in discriminator_parameters:
                    p.requires_grad = True
                adam_discriminator.zero_grad()
                
                # TODO: check that the dimensions here are correct. The
                # pass through the generator might result in additional
                # parameters. If not, this process can be streamlined
                # somewhat.
                monet_discriminator_real = monet_discriminator(monet)
                monet_discriminator_fake = monet_discriminator(fake_monet)
                photo_discriminator_real = photo_discriminator(photo)
                photo_discriminator_fake = photo_discriminator(fake_photo)
                
                # Calculate the loss for the discriminator.
                real_loss_monet = mse_loss(monet_discriminator_real, real)
                fake_loss_monet = mse_loss(monet_discriminator_fake, fake)
                real_loss_monet = mse_loss(photo_discriminator_real, real)
                fake_loss_monet = mse_loss(photo_discriminator_fake, fake)
                discriminator_loss = real_loss_monet + fake_loss_monet
                
                discriminator_loss.backward()
                adam_discriminator.step()
                
                
                
        gen_lr = lr_sched(self.decay_epoch, self.epochs)
        desc_lr = lr_sched(self.decay_epoch, self.epochs)
        self.gen_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.adam_gen, gen_lr.step)
        self.desc_lr_sched = torch.optim.lr_scheduler.LambdaLR(self.adam_desc, desc_lr.step)