In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
!pip install scipy==1.2.1
import os
import argparse
import warnings
import numpy as np
import matplotlib.pyplot as plt
import json
warnings.filterwarnings("ignore")

# Torch imports
import torch
import torch.nn as nn
import torch.optim as optim

# to the directory of workspace_4471
import sys
sys.path.append('/content/drive/MyDrive/workspace_4471')
sys.path.append('/workspace_4471') # for other packages import

Collecting scipy==1.2.1
  Downloading scipy-1.2.1-cp37-cp37m-manylinux1_x86_64.whl (24.8 MB)
[K     |████████████████████████████████| 24.8 MB 1.2 MB/s 
Installing collected packages: scipy
  Attempting uninstall: scipy
    Found existing installation: scipy 1.4.1
    Uninstalling scipy-1.4.1:
      Successfully uninstalled scipy-1.4.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.[0m
Successfully installed scipy-1.2.1


In [None]:
# Local imports
from data_loader import get_image_loader
from models_v2 import CycleGenerator_v2, DCDiscriminator
from cycle_utils import create_dir, create_model, checkpoint, save_samples

In [None]:
SEED = 11
# Set the random seed manually for reproducibility.
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f7e5cee6e70>

In [None]:
class Opts():
    def __init__(self):
        self.image_size = 256
        self.g_conv_dim = 32
        self.init_zero_weights = True # Choose whether to initialize the generator conv weights to 0 (implements the identity function)

        # Training hyper-parameters
        self.train_iters = 100000
        self.batch_size = 1
        self.num_workers = 1
        self.lr = 5e-5 # original 2e-4
        self.beta1 = 0.5
        self.beta2 = 0.999
        self.lbd_cyclegan = 10
        self.lbd_identity = 5
        self.lbd_feature = 10

        self.X = 'day' #Choose the type of images for domain X
        self.Y = 'night' #Choose the type of images for domain Y.

        # Saving directories and checkpoint/sample iterations
        self.checkpoint_dir = 'drive/MyDrive/workspace_4471/checkpoints_cyclegan_w_featloss'
        self.sample_dir = 'drive/MyDrive/workspace_4471/samples_cyclegan_w_featloss'
        self.log_step = 10
        self.sample_every = 400
        self.checkpoint_every = 1600
        self.losslog_dir = 'drive/MyDrive/workspace_4471/losslog_dir_w_featloss'
        
    
def print_opts(opts):
    """Prints the values of all command-line arguments.
    """
    print('=' * 80)
    print('Opts'.center(80))
    print('-' * 80)
    for key in opts.__dict__:
        if opts.__dict__[key]:
            print('{:>30}: {:<30}'.format(key, opts.__dict__[key]).center(80))
    print('=' * 80)


In [None]:
opts = Opts()
print_opts(opts)

                                      Opts                                      
--------------------------------------------------------------------------------
                             image_size: 256                                    
                             g_conv_dim: 32                                     
                      init_zero_weights: 1                                      
                            train_iters: 100000                                 
                             batch_size: 1                                      
                            num_workers: 1                                      
                                     lr: 5e-05                                  
                                  beta1: 0.5                                    
                                  beta2: 0.999                                  
                           lbd_cyclegan: 10                                     
                           l

In [None]:
def load_checkpoint(G_XtoY, G_YtoX, D_X, D_Y, g_optimizer, d_optimizer, iter, filename):
    # Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        G_XtoY.load_state_dict(checkpoint['G_XtoY'])
        G_YtoX.load_state_dict(checkpoint['G_YtoX'])
        D_X.load_state_dict(checkpoint['D_X'])
        D_Y.load_state_dict(checkpoint['D_Y'])
        g_optimizer.load_state_dict(checkpoint['g_optimizer'])
        d_optimizer.load_state_dict(checkpoint['d_optimizer'])
        iter = checkpoint['iter']
        
        print("=> loaded checkpoint '{}' (iter {})"
                  .format(filename, checkpoint['iter']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))
        iter = 0

    return G_XtoY, G_YtoX, D_X, D_Y, g_optimizer, d_optimizer, iter

In [None]:
"""
Loads the data, creates checkpoint and sample directories, and starts the training loop.
"""

# Create train and test dataloaders for images from the two domains X and Y
dataloader_X, test_dataloader_X = get_image_loader(img_type=opts.X, opts=opts )
dataloader_Y, test_dataloader_Y = get_image_loader(img_type=opts.Y, opts=opts )

# Create checkpoint and sample directories
create_dir(opts.checkpoint_dir)
create_dir(opts.sample_dir)
create_dir(opts.losslog_dir)

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

len(dataloader_X), len(dataloader_Y)

(1626, 1626)

In [None]:
# Create generators and discriminators
G_XtoY, G_YtoX, D_X, D_Y = create_model(opts)
G_XtoY.to(device)
G_YtoX.to(device)
D_X.to(device)
D_Y.to(device)

g_params = list(G_XtoY.parameters()) + list(G_YtoX.parameters())  # Get generator parameters
d_params = list(D_X.parameters()) + list(D_Y.parameters())  # Get discriminator parameters

# Create optimizers for the generators and discriminators
g_optimizer = optim.Adam(g_params, opts.lr, [opts.beta1, opts.beta2])
d_optimizer = optim.Adam(d_params, opts.lr, [opts.beta1, opts.beta2])


In [None]:
import torchvision.models as models
import torchvision.transforms as transforms

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

# Load the pretrained model
model = models.resnet18(pretrained=True)
model.to(device)
# Use the model object to select the desired layer
layer = model._modules.get('avgpool')

for param in model.parameters():
    param.requires_grad = False

def get_vector(image):
    img = normalize(image)
    my_embedding = torch.zeros(512)

    def copy_data(m, i, o):
        my_embedding.copy_(o.flatten()) # <-- flatten

    h = layer.register_forward_hook(copy_data)
    model(img)                       
    h.remove() # Detach our copy function from the layer
    return my_embedding

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [13]:
curr_iter = 17600 # remember to reset to 0 when "restart and runall"
ckpt_filepath = os.path.join(opts.checkpoint_dir, 'ckpt_{:06d}.pth.tar'.format(curr_iter))
load_checkpoint(G_XtoY, G_YtoX, D_X, D_Y, g_optimizer, d_optimizer, curr_iter, ckpt_filepath)

iter_X = iter(dataloader_X)
iter_Y = iter(dataloader_Y)

test_iter_X = iter(test_dataloader_X)
test_iter_Y = iter(test_dataloader_Y)

# Get some fixed data from domains X and Y for sampling. These are images that are held
# constant throughout training, that allow us to inspect the model's performance.
fixed_X = test_iter_X.next()[0].to(device)
fixed_Y = test_iter_Y.next()[0].to(device)

=> loading checkpoint 'drive/MyDrive/workspace_4471/checkpoints_cyclegan_w_featloss/ckpt_017600.pth.tar'
=> loaded checkpoint 'drive/MyDrive/workspace_4471/checkpoints_cyclegan_w_featloss/ckpt_017600.pth.tar' (iter 17600)


In [None]:
iter_per_epoch = min(len(iter_X), len(iter_Y))
mse_loss = torch.nn.MSELoss()
L1_loss = torch.nn.L1Loss()

Loss = []

d_real_loss_list = [] 
d_fake_loss_list = [] 
g_loss_YXY_list = []
g_loss_XYX_list = []
feature_loss_list = []

for iteration in range(curr_iter+1, opts.train_iters+1):

    # Reset data_iter for each epoch
    if iteration % iter_per_epoch == 0:
        iter_X = iter(dataloader_X)
        iter_Y = iter(dataloader_Y)

    images_X = iter_X.next()[0].to(device)
    images_Y = iter_Y.next()[0].to(device)

    # ============================================
    #            TRAIN THE DISCRIMINATORS
    # ============================================
    
    # Train with real images

    # 1. Compute the discriminator losses on real images
    d_optimizer.zero_grad()

    D_X_real_loss = mse_loss(D_X(images_X), torch.ones(len(images_X)).to(device))
    D_Y_real_loss = mse_loss(D_Y(images_Y), torch.ones(len(images_Y)).to(device))

    d_real_loss = D_X_real_loss + D_Y_real_loss
    d_real_loss_list.append(d_real_loss)
    d_real_loss.backward()
    d_optimizer.step()

    # Train with fake images
    d_optimizer.zero_grad()

    # 2. Generate fake images that look like domain X based on real images in domain Y
    fake_X = G_YtoX(images_Y)

    # 3. Compute the loss for D_X
    D_X_fake_loss = mse_loss(D_X(fake_X), torch.zeros(len(fake_X)).to(device))
    #print(D_X(fake_X))

    # 4. Generate fake images that look like domain Y based on real images in domain X
    fake_Y = G_XtoY(images_X) 

    # 5. Compute the loss for D_Y
    D_Y_fake_loss = mse_loss(D_Y(fake_Y), torch.zeros(len(fake_Y)).to(device))
    #print(D_Y(fake_Y))

    d_fake_loss = D_X_fake_loss + D_Y_fake_loss
    d_fake_loss_list.append(d_fake_loss)
    d_fake_loss.backward()
    d_optimizer.step()

    # =========================================
    #            TRAIN THE GENERATORS
    # =========================================

    #########################################
    ##           Y--X-->Y CYCLE            ##
    #########################################

    g_optimizer.zero_grad()

    # 1. Generate fake images that look like domain X based on real images in domain Y
    fake_X = G_YtoX(images_Y)

    # 2. Compute the generator loss based on domain X
    g_loss_YXY = mse_loss(D_X(fake_X), torch.ones(len(fake_X)).to(device))

    # 3. Compute the cycle consistency loss (the reconstruction loss)
    cycle_consistency_loss = L1_loss(G_XtoY(fake_X), images_Y)
    g_loss_YXY += opts.lbd_cyclegan * cycle_consistency_loss

    # 4. Compute identity loss (Y domain images put into XtoY model should still loook the same)
    identity_loss = L1_loss(G_XtoY(images_Y), images_Y)
    g_loss_YXY += opts.lbd_identity * identity_loss

    g_loss_YXY.backward()
    g_loss_YXY_list.append(g_loss_YXY)
    g_optimizer.step()

    #########################################
    ##           X--Y-->X CYCLE            ##
    #########################################

    g_optimizer.zero_grad()

    # 1. Generate fake images that look like domain Y based on real images in domain X
    fake_Y = G_XtoY(images_X)

    # 2. Compute the generator loss based on domain Y
    g_loss_XYX = mse_loss(D_Y(fake_Y), torch.ones(len(fake_Y)).to(device))

    # 3. Compute the cycle consistency loss (the reconstruction loss)
    cycle_consistency_loss = L1_loss(G_YtoX(fake_Y), images_X)
    g_loss_XYX += opts.lbd_cyclegan * cycle_consistency_loss

    # 4. Compute identity loss (X domain images put into YtoX model should still loook the same)
    identity_loss = L1_loss(G_YtoX(images_X), images_X)
    g_loss_XYX += opts.lbd_identity * identity_loss

    g_loss_XYX.backward()
    g_loss_XYX_list.append(g_loss_XYX)
    g_optimizer.step()
    

    #########################################
    ##           feature loss              ##
    ##  L1_loss(feature(X), feature(T(X))) ##
    ##  L1_loss(feature(Y), feature(T(Y))) ##
    #########################################

    g_optimizer.zero_grad()

    # denormalize back to values in [0,1]
    imgX, imgY = images_X*0.5+0.5, images_Y*0.5+0.5
    timgX, timgY = G_XtoY(images_X)*0.5+0.5, G_YtoX(images_Y)*0.5+0.5

    # normalization done in get_vector
    imgX_feat, imgY_feat = get_vector(imgX), get_vector(imgY)
    timgX_feat, timgY_feat = get_vector(timgX), get_vector(timgY)

    feature_loss = opts.lbd_feature * ( L1_loss(imgX_feat, timgX_feat) + L1_loss(imgY_feat, timgY_feat) ) 
    feature_loss_list.append(feature_loss)

    feature_loss.backward()
    g_optimizer.step()

    if iteration % opts.sample_every == 0:

      plt.plot(d_real_loss_list)
      plt.show()
      plt.plot(d_fake_loss_list,'o')
      plt.show()
      plt.plot(g_loss_YXY_list, 'r-')
      plt.show()
      plt.plot(g_loss_XYX_list, 'r+')
      plt.show()
      plt.plot(feature_loss_list, 'go')
      plt.show()

    # Print the log info
    if iteration % opts.log_step == 0:
        print('Iteration [{:5d}/{:5d}] | d_real_loss: {:6.4f} | d_Y_real_loss: {:6.4f} | d_X_real_loss: {:6.4f} | d_Y_fake_loss: {:6.4f} | d_X_fake_loss: {:6.4f} |  '
              'd_fake_loss: {:6.4f} | g_loss_XYX: {:6.4f} | g_loss_YXY: {:6.4f}'.format(
                iteration, opts.train_iters, d_real_loss.item(), D_Y_real_loss.item(),
                D_X_real_loss.item(), D_Y_fake_loss.item(),
                D_X_fake_loss.item(), d_fake_loss.item(), g_loss_XYX.item(), g_loss_YXY.item()))

    # Save the generated samples
    if iteration % opts.sample_every == 0:
        save_samples(iteration, fixed_Y, fixed_X, G_YtoX, G_XtoY, opts)

    # Save the model parameters
    if iteration % opts.checkpoint_every == 0:
        checkpoint(iteration, G_XtoY, G_YtoX, D_X, D_Y, g_optimizer, d_optimizer, opts)
        Loss.append({'Iteration': iteration, 
                     'd_real_loss': d_real_loss.item(), 'D_Y_real_loss': D_Y_real_loss.item(), 'D_X_real_loss': D_X_real_loss.item(), 
                     'd_fake_loss': d_fake_loss.item(), 'D_Y_fake_loss': D_Y_fake_loss.item(), 'D_X_fake_loss': D_X_fake_loss.item(),
                     'g_loss_XYX': g_loss_XYX.item(), 'g_loss_YXY': g_loss_YXY.item()})
        losslog_file = os.path.join(opts.losslog_dir, 'losslog')
        with open(losslog_file + ".json", "w") as f:
            json.dump(Loss, f)


Iteration [17610/100000] | d_real_loss: 0.2923 | d_Y_real_loss: 0.0000 | d_X_real_loss: 0.2922 | d_Y_fake_loss: 0.0000 | d_X_fake_loss: 0.0047 |  d_fake_loss: 0.0048 | g_loss_XYX: 1.9795 | g_loss_YXY: 4.8347
Iteration [17620/100000] | d_real_loss: 0.1646 | d_Y_real_loss: 0.0000 | d_X_real_loss: 0.1646 | d_Y_fake_loss: 0.0001 | d_X_fake_loss: 0.0118 |  d_fake_loss: 0.0119 | g_loss_XYX: 2.1004 | g_loss_YXY: 10.3151
Iteration [17630/100000] | d_real_loss: 0.0601 | d_Y_real_loss: 0.0000 | d_X_real_loss: 0.0601 | d_Y_fake_loss: 0.0007 | d_X_fake_loss: 0.0730 |  d_fake_loss: 0.0737 | g_loss_XYX: 2.0317 | g_loss_YXY: 7.7487
Iteration [17640/100000] | d_real_loss: 0.1489 | d_Y_real_loss: 0.0000 | d_X_real_loss: 0.1489 | d_Y_fake_loss: 0.0000 | d_X_fake_loss: 0.0370 |  d_fake_loss: 0.0370 | g_loss_XYX: 2.1621 | g_loss_YXY: 11.0779
Iteration [17650/100000] | d_real_loss: 0.1565 | d_Y_real_loss: 0.0000 | d_X_real_loss: 0.1565 | d_Y_fake_loss: 0.0000 | d_X_fake_loss: 0.1063 |  d_fake_loss: 0.1063 