This notebook was used in order to try and train the OASIS model on panoptic images.

In [None]:
#need to run at the start of every instance and restart kernerl
! pip install torchinfo
! pip install cityscapesscripts
! pip install --upgrade scipy
! pip install ipywidgets --user
! pip install nbconvert

In [None]:
! git clone "https://github.com/ItaiBear/OASIS"
! git clone "https://github.com/google-research/deeplab2.git"

fatal: destination path 'OASIS' already exists and is not an empty directory.
fatal: destination path 'deeplab2' already exists and is not an empty directory.


In [None]:
!pwd

/home/ubuntu/bearGAN


In [None]:
import numpy as np
import random
import torch
from torchvision import transforms as TR
from torchinfo import summary
import matplotlib.pyplot as plt
from matplotlib import gridspec
from PIL import Image
from types import SimpleNamespace
import os
import sys
import scipy
from scipy import linalg

sys.path.append("/home/ubuntu/bearGAN")
sys.path.append("/home/ubuntu/bearGAN/OASIS")
sys.path.append("/home/ubuntu/bearGAN/OASIS/dataloaders")
sys.path.append("/home/ubuntu/bearGAN/OASIS/models")
sys.path.append("/home/ubuntu/bearGAN/OASIS/utils")
sys.path.append("/home/ubuntu/bearGAN/OASIS/models/sync_batchnorm")



#sys.path.append(os.path.abspath("/content/BearGAN"))
#sys.path.append(os.path.abspath("/content/deeplab2"))
print(sys.path)
import OASIS
#import BearGAN


import models.models as models
import models.losses as losses
import dataloaders.dataloaders as dataloaders
import utils.utils as utils
from utils.fid_scores import fid_pytorch
import config
from constants import root_project_directory

from my_dataloaders import get_dataloaders
from my_utils import configure_arguments

['/home/ubuntu/bearGAN', '/usr/lib/python38.zip', '/usr/lib/python3.8', '/usr/lib/python3.8/lib-dynload', '', '/home/ubuntu/.local/lib/python3.8/site-packages', '/usr/local/lib/python3.8/dist-packages', '/usr/lib/python3/dist-packages', '/usr/lib/python3/dist-packages/IPython/extensions', '/home/ubuntu/.ipython', '/home/ubuntu/bearGAN', '/home/ubuntu/bearGAN/OASIS', '/home/ubuntu/bearGAN/OASIS/dataloaders', '/home/ubuntu/bearGAN/OASIS/models', '/home/ubuntu/bearGAN/OASIS/utils', '/home/ubuntu/bearGAN/OASIS/models/sync_batchnorm']


In [None]:
print(torch.cuda.device_count())
import gc
print(gc.collect())
print(torch.cuda.empty_cache())

1
0
None


In [None]:
def preprocess_input(opt, data):
    data['label'] = data['label'].long()
    if opt.gpu_ids != "-1":
        data['label'] = data['label'].cuda()
        data['image'] = data['image'].cuda()
    label_map = data['label']
    bs, _, h, w = label_map.size()
    nc = opt.semantic_nc
    if opt.gpu_ids != "-1":
        input_label = torch.cuda.FloatTensor(bs, nc, h, w).zero_()
    else:
        input_label = torch.FloatTensor(bs, nc, h, w).zero_()
    
    if (opt.segmentation == "panoptic"):
      semantic_map = torch.div(label_map, 1000, rounding_mode="floor")
      semantic_map = semantic_map.masked_fill_(semantic_map==255, 19)
      instance_map = torch.fmod(label_map, 1000.0) + 1.0
      input_semantics = input_label.scatter_(1, semantic_map, instance_map)
    else:
      input_semantics = input_label.scatter_(1, label_map, 1.0)
    
    return data['image'], input_semantics

In [None]:
#--- read options ---#
opt = configure_arguments(train=True)
print("configured arguments")
#--- create utils ---#
timer = utils.timer(opt)
visualizer_losses = utils.losses_saver(opt)
losses_computer = losses.losses_computer(opt)
dataloader, dataloader_val = get_dataloaders(opt)
im_saver = utils.image_saver(opt)
fid_computer = fid_pytorch(opt, dataloader_val)   #problem with tpus

#--- create models ---#
print("creating models")
model = models.OASIS_model(opt)
model = models.put_on_multi_gpus(model, opt)

#--- create optimizers ---#
optimizerG = torch.optim.Adam(model.module.netG.parameters(), lr=opt.lr_g, betas=(opt.beta1, opt.beta2))
optimizerD = torch.optim.Adam(model.module.netD.parameters(), lr=opt.lr_d, betas=(opt.beta1, opt.beta2))

#--- the training loop ---#
already_started = False
start_epoch, start_iter = utils.get_start_iters(opt.loaded_latest_iter, len(dataloader))
for epoch in range(start_epoch, opt.num_epochs):
    for i, data_i in enumerate(dataloader):
        if not already_started and i < start_iter:
            continue
        already_started = True
        cur_iter = epoch*len(dataloader) + i
        image, label = preprocess_input(opt, data_i)

        #--- generator update ---#
        model.module.netG.zero_grad()
        loss_G, losses_G_list = model.forward(image, label, "losses_G", losses_computer)
        loss_G, losses_G_list = loss_G.mean(), [loss.mean() if loss is not None else None for loss in losses_G_list]
        loss_G.backward()
        optimizerG.step()

        #--- discriminator update ---#
        model.module.netD.zero_grad()
        loss_D, losses_D_list = model.forward(image, label, "losses_D", losses_computer)
        loss_D, losses_D_list = loss_D.mean(), [loss.mean() if loss is not None else None for loss in losses_D_list]
        loss_D.backward()
        optimizerD.step()

        #--- stats update ---#
        if not opt.no_EMA:
            utils.update_EMA(model, cur_iter, dataloader, opt)
        if cur_iter % opt.freq_print == 0:
            im_saver.visualize_batch(model, image, label, cur_iter)
            timer(epoch, cur_iter)
        if cur_iter % opt.freq_save_ckpt == 0:
            utils.save_networks(opt, cur_iter, model)
        if cur_iter % opt.freq_save_latest == 0:
            utils.save_networks(opt, cur_iter, model, latest=True)
        if cur_iter % opt.freq_fid == 0 and cur_iter > 0:
            is_best = fid_computer.update(model, cur_iter)
            if is_best:
                utils.save_networks(opt, cur_iter, model, best=True)
        visualizer_losses(cur_iter, losses_G_list+losses_D_list)

#--- after training ---#
#utils.update_EMA(model, cur_iter, dataloader, opt, force_run_stats=True)
utils.save_networks(opt, cur_iter, model)
utils.save_networks(opt, cur_iter, model, latest=True)
is_best = fid_computer.update(model, cur_iter)
if is_best:
    utils.save_networks(opt, cur_iter, model, best=True)

print("The training has successfully finished")

----------------- Options ---------------
                EMA_decay: 0.999                         	[default: 0.9999]
             add_vgg_loss: False                         
               batch_size: 14                            	[default: 20]
                    beta1: 0.0                           
                    beta2: 0.999                         
               channels_D: 64                            
               channels_G: 64                            
          checkpoints_dir: /home/ubuntu/bearGAN/pretrained_checkpoints
           continue_train: False                         	[default: True]
                 dataroot: /home/ubuntu/bearGAN/dataset  
             dataset_mode: cityscapes                    
                 freq_fid: 100000                        	[default: 5000]
               freq_print: 100                           	[default: 1000]
           freq_save_ckpt: 500                           	[default: 20000]
         freq_save_latest: 500      



--- Now computing Inception activations for real set ---




--- Finished FID stats for real set ---
creating models
Created OASIS_Generator with 68923331 parameters
Created OASIS_Discriminator with 22250389 parameters




[epoch 0/200 - iter 0], time:0.000
[epoch 0/200 - iter 100], time:6.748
[epoch 0/200 - iter 200], time:6.768
[epoch 1/200 - iter 300], time:6.722
[epoch 1/200 - iter 400], time:6.697
[epoch 2/200 - iter 500], time:6.732
[epoch 2/200 - iter 600], time:6.714
[epoch 3/200 - iter 700], time:6.703
[epoch 3/200 - iter 800], time:6.699
[epoch 4/200 - iter 900], time:6.723
[epoch 4/200 - iter 1000], time:6.726
[epoch 5/200 - iter 1100], time:6.728
[epoch 5/200 - iter 1200], time:6.710
[epoch 6/200 - iter 1300], time:6.708
[epoch 6/200 - iter 1400], time:6.725
[epoch 7/200 - iter 1500], time:6.739
[epoch 7/200 - iter 1600], time:6.808
[epoch 8/200 - iter 1700], time:6.699
[epoch 8/200 - iter 1800], time:6.708
[epoch 8/200 - iter 1900], time:6.710
[epoch 9/200 - iter 2000], time:6.758
[epoch 9/200 - iter 2100], time:6.711
[epoch 10/200 - iter 2200], time:6.726
[epoch 10/200 - iter 2300], time:6.682
[epoch 11/200 - iter 2400], time:6.696
[epoch 11/200 - iter 2500], time:6.741
[epoch 12/200 - iter