In [16]:
import torch
# import torchvision
import cv2
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import os
import glob
import pandas as pd
import re
import torch.nn.functional as F
import pdb
import seaborn as sns
from plotly.subplots import make_subplots
import plotly.graph_objects as go
# import configparser
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import scipy.io as io

In [None]:
! pip uninstall imgaug==0.2.9 --yes
! pip install imgaug==0.4
! pip install segmentation_models_pytorch
! pip install inplace-abn

In [2]:
from google.colab import drive
drive.mount('/content/drive')
%matplotlib inline
os.chdir("/content/drive/My Drive/Colab Notebooks/clothes_segmentation/distributed_structure")  
# import import_ipynb
from Dataset.dataset import data_set, get_dataloaders  #TODO: mount google drive
from Models.models import Generator, Discriminator, weights_init
from Utils.losses import get_gen_loss
from config import *


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
label_list = io.loadmat('./..//clothes_data/label_list.mat')
num_of_classes = label_list['label_list'].shape[1]
softmax = torch.nn.Softmax(dim=1)

###### Get configuration params:

In [4]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(1)
torch.cuda.manual_seed_all(3)
np.random.seed(2)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

##### Get dataloaders

In [5]:
train_dl, val_dl, _ = get_dataloaders(config=config)

##### Training preperation:

In [9]:
# New parameters
adv_criterion = torch.nn.BCEWithLogitsLoss() 
recon_criterion = torch.nn.L1Loss() 
lambda_recon = 100

n_epochs = 150
input_dim = 3
real_dim = 3
batch_size = 4
lr = 0.0002
target_shape = 256
assert 'cuda' in device.type, "Cuda is not working"

##### Model:

In [10]:
generator = Generator(config=config).to(device)
discriminator = Discriminator(input_channels=6).to(device)

gen_opt = torch.optim.Adam(filter(lambda p: p.requires_grad, 
                                  generator.parameters()), lr=lr)
disc_opt = torch.optim.Adam(filter(lambda p: p.requires_grad, 
                                  discriminator.parameters()), lr=lr)

gen_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(gen_opt, 'min', patience=10)
disc_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(disc_opt, 'min', patience=10)

In [14]:
type(train_dl)

torch.utils.data.dataloader.DataLoader

device(type='cpu')

##### Training:

In [11]:
def save_gen_model_checkpoint(model, epoch, output_path, best=False):
    if not os.path.isdir(output_path):
      os.makedirs(os.path.join(output_path,'Disc'))
    if best:
        previous_best_pt = glob.glob(os.path.join(output_path, '*GEN*BEST.pt'))
        if len(previous_best_pt) > 0:
            os.remove(previous_best_pt[0])
        name = os.path.join(output_path, 'Model_GEN_stateDict__Epoch={}__BEST.pt'.format(epoch))
    else:
        name = os.path.join(output_path, 'Model_GEN_stateDict__Epoch={}.pt'.format(epoch))
    torch.save(model.state_dict(), os.path.join(output_path, name))

def save_disc_model_checkpoint(model, epoch, output_path, best=False):
    if not os.path.isdir(output_path):
      os.makedirs(os.path.join(output_path,'Disc'))
    if best:
        previous_best_pt = glob.glob(os.path.join(output_path, '*DISC*BEST.pt'))
        if len(previous_best_pt) > 0:
            os.remove(previous_best_pt[0])
        name = os.path.join(output_path, 'Model_DISC_stateDict__Epoch={}__BEST.pt'.format(epoch))
    else:
        name = os.path.join(output_path, 'Model_DISC_stateDict__Epoch={}.pt'.format(epoch))
    torch.save(model.state_dict(), os.path.join(output_path, name))  

In [17]:
mean_generator_loss = 0
mean_discriminator_loss = 0
cur_step = 0
best_gen_loss, best_disc_loss= np.inf, np.inf
losses_list = []

for epoch in range(n_epochs):
    mean_discriminator_loss, mean_generator_loss, val_disc_loss, val_gen_loss = [0] *4
    ### TRAIN ###
    generator.train()
    discriminator.train()
    for idx, train_data in enumerate(tqdm(train_dl)):
        origin_image, mask_gt =train_data['image'], train_data['gt_reg_map']
        origin_image = torch.nn.functional.interpolate(origin_image, 
                                                  size=target_shape)
        mask_gt = torch.nn.functional.interpolate(mask_gt, size=target_shape)
        cur_batch_size = len(origin_image)

        ### Update discriminator ###
        disc_opt.zero_grad() # Zero out the gradient before backpropagation
        with torch.no_grad():
            fake = generator(origin_image)
        disc_fake_hat = discriminator(fake.detach(), origin_image) # Detach generator
        disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
        disc_real_hat = discriminator(mask_gt, origin_image)
        disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        disc_loss.backward(retain_graph=True) # Update gradients
        disc_opt.step() # Update optimizer

        ### Update generator ###
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(generator, discriminator, mask_gt, origin_image, 
                                adv_criterion, recon_criterion, 
                                lambda_recon)
        gen_loss.backward() # Update gradients
        gen_opt.step() # Update optimizer

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() 
        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() 

    
    ### VAL ###
    val_batch_idx = 0
    generator.eval()
    discriminator.eval()
    with torch.no_grad():
      for val_idx, val_data in enumerate(tqdm(val_dl)):
          origin_val_image, val_mask_gt = val_data['image'], val_data['gt_reg_map']
          origin_val_image = torch.nn.functional.interpolate(origin_val_image, 
                                                      size=target_shape)
          val_mask_gt = torch.nn.functional.interpolate(val_mask_gt, 
                                                  size=target_shape)
          
          fake = generator(origin_val_image)
          disc_fake_hat = disc(fake.detach(), origin_val_image)
          disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
          disc_real_hat = disc(val_mask_gt, origin_val_image)
          disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
          disc_loss = (disc_fake_loss + disc_real_loss) / 2

          gen_loss = get_gen_loss(gen, disc, mask_gt, origin_image, 
                                  adv_criterion, recon_criterion, lambda_recon)  #TODO: add IOU loss
          
          val_disc_loss += disc_loss.item()
          val_gen_loss += gen_loss.item()

          val_batch_idx += 1
          
    ## EPOCH RESULTS ##
    losses = {'train_disc_loss': mean_discriminator_loss / idx,
              'train_gen_loss': mean_generator_loss / idx,
              'val_disc_loss': val_disc_loss / val_idx,
              'val_gen_loss': val_gen_loss / val_idx}
    losses_list.append(losses)
    
    if losses['val_disc_loss'] < best_disc_loss:
      best_disc_loss = losses['val_disc_loss']
      save_disc_model_checkpoint(model=generator, epoch=epoch, 
                                 output_path=config['Paths']['output_folder'],
                                 best=True)
    
    elif epoch %10 == 0:
      save_disc_model_checkpoint(model=generator, epoch=epoch, 
                                 output_path=config['Paths']['output_folder'],
                                 best=False)
      
    if gen_scheduler and disc_scheduler is not None:
      gen_scheduler.step(losses['val_gen_loss'])
      disc_scheduler.step(losses['val_disc_loss'])

    print("Epoch {}:\n \t train gen loss: {:.5f}, val gen loss: {:.5f},\n \t \
           train disc loss: {:.5f}, \
           val disc loss: {:.5f}".format(epoch, losses['train_gen_loss'],
                                         losses['val_gen_loss'], 
                                         losses['train_disc_loss'],
                                         losses['val_disc_loss']))
    


HBox(children=(FloatProgress(value=0.0, max=151.0), HTML(value='')))

NameError: ignored