In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.tensorboard import SummaryWriter

from model import unet, Discriminator

import os
import random

import matplotlib.pyplot as plt

In [2]:
!pip install pytorch_model_summary

Collecting pytorch_model_summary
  Downloading https://files.pythonhosted.org/packages/fe/45/01d67be55fe3683a9221ac956ba46d1ca32da7bf96029b8d1c7667b8a55c/pytorch_model_summary-0.1.2-py3-none-any.whl
Installing collected packages: pytorch-model-summary
Successfully installed pytorch-model-summary-0.1.2


In [None]:
Config={}
Config['num_epochs']=50
Config['batch_szie']=60
Config['learning_rate']=0.001
Config['disc_loss_coeff']=1.0
Config['gen_model_path']='/home/ubuntu/generator/'
Config['disc_model_path']='/home/ubuntu/discriminator/'

In [None]:
def train():

  #train_dataloader
  
  generator=unet(4,4)
  discriminator=Discriminator(4)

  gen_loss=nn.MSELoss()
  
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  generator.to(device)
  discriminator.to(device)

  opt_generator = torch.optim.Adadelta(generator.parameters(), lr=Config['learning_rate'])
  opt_generator2 = torch.optim.Adadelta(generator.parameters(), lr=Config['learning_rate'])
  opt_discriminator = torch.optim.Adadelta(discriminator.parameters(), lr=Config['learning_rate'])

  writer = SummaryWriter()

  for i in range(Config['num_epochs']):

    mean_gen_loss=0.0
    mean_disc_loss=0.0
    mean_gen_total_loss=0.0

    for inputs,target in tqdm(train_dataloader):

      #Inputs consists of (image,trimap)
      #Output consists of the 'target image'

      inputs = inputs.to(device)
      target = target.to(device)

      #For generator
      opt_generator.zero_grad()
      output_gen=torch.sigmoid(generator(inputs))
      gen_loss_batch=gen_loss(output_gen,target)
      mean_gen_loss+=gen_loss_batch.data


      #For discriminator
      opt_discriminator.zero_grad()
      output_disc_real=discriminator(target)
      output_disc_fake=discriminator(output_gen)
      disc_loss=torch.mean(torch.log(output_disc_real) + torch.log(1-output_disc_fake))

      mean_disc_loss+=disc_loss.data

      total_loss=gen_loss_batch + Config['disc_loss_coeff'] * disc_loss
      mean_gen_total_loss+=total_loss

      #Optimizer for total loss and optimizer for Gen loss
      opt_generator2.zero_grad()

      #backpropagation
      gen_loss_batch.backward()
      opt_generator.step()

      disc_loss.backward()
      opt_discriminator.step()

      total_loss.backward()
      opt_generator2.step()

    #loss per epoch
    writer.add_scalar('discriminator_loss', mean_disc_loss/len(train_dataloader), global_step=i)
    writer.add_scalar('generator_loss', mean_gen_loss/len(train_dataloader), global_step=i)
    writer.add_scalar('total_generator_loss', mean_gen_total_loss/len(train_dataloader), global_step=i)
 
  #Models are saved for the final epoch
  torch.save(generator.state_dict(), Config['gen_model_path'], 'generator_final.pth')
  torch.save(discriminator.state_dict(), Config['disc_model_path'], 'discriminator_final.pth')

if __name__=='__main__':
    train()