In [1]:
import time
import os.path
import argparse
import numpy as np
import matplotlib.pyplot as plt
from UNet_model import *
from dataset_handler import *

In [2]:
# hyper-parameters (learning rate and how many epochs we will train for)
#lr = 0.0002
lr = 0.00001
epochs = 12

In [3]:
# cityscapes dataset loading
img_data = CityscapesDataset("/mnt/data/course/psarin/inm705/", split='train', mode='fine')
img_batch = torch.utils.data.DataLoader(img_data, batch_size=64, shuffle=True, num_workers=4)
print(img_data)
#/mnt/data/course/psarin/inm705/leftImg8bit
#/mnt/data/course/psarin/inm705/gtFine_trainvaltest/gtFine

Dataset CityscapesDataset
    Number of images: 2975
    Split: train
    Mode: gtFine
    Root Location: /mnt/data/course/psarin/inm705/





In [4]:
# loss function
# we want to do some binary classification on it (cityscapes classes)
recon_loss_func = nn.CrossEntropyLoss()
num_classes = img_data.num_classes  # background, road, sky, car

In [5]:
# initiate generator and optimizer
print("creating unet model...")
generator = nn.DataParallel(UnetGenerator(3, img_data.num_classes, 64), device_ids=[i for i in range(1)]).cuda()
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, weight_decay=0.001)

creating unet model...


In [6]:
# load pretrained model if it is there
file_model = './checkpoints_epochs/train_UNet_17epoch.pkl'
if os.path.isfile(file_model):
    generator = torch.load(file_model)
    print("    - model restored from file....")
    print("    - filename = %s" % file_model)

    - model restored from file....
    - filename = ./checkpoints_epochs/train_UNet_17epoch.pkl


In [7]:
# or log file that has the output of our loss
file_loss = open('./unet_loss', 'w')

In [8]:
# make the result directory
if not os.path.exists('./result/'):
    os.makedirs('./result/')

In [9]:
# finally!!! the training loop!!!
for epoch in range(epochs):
    for idx_batch, (imagergb, labelmask, labelrgb) in enumerate(img_batch):

        # zero the grad of the network before feed-forward
        gen_optimizer.zero_grad()

        # send to the GPU and do a forward pass
        x = Variable(imagergb).cuda(0)
        y_ = Variable(labelmask).cuda(0)
        y = generator.forward(x)

        # we "squeeze" the groundtruth if we are using cross-entropy loss
        # this is because it expects to have a [N, W, H] image where the values
        # in the 2D image correspond to the class that that pixel should be 0 < pix[u,v] < classes
        
        y_ = torch.squeeze(y_)

        # finally calculate the loss and back propagate
        loss = recon_loss_func(y, y_)
        file_loss.write(str(loss.item())+"\n")
        loss.backward()
        gen_optimizer.step()

        # every 400 images, save the current images
        # also checkpoint the model to disk
        if idx_batch % 400 == 0:

            # nice debug print of this epoch and its loss
            print("epoch = "+str(epoch)+" | loss = "+str(loss.item()))

            # save the original image and label batches to file
            #print(x.cpu().data)
            #print(labelrgb.float())
            v_utils.save_image(x.cpu().data, "./result/original_image_{}_{}.png".format(epoch, idx_batch))
            v_utils.save_image(labelrgb.float(), "./result/label_image_{}_{}.png".format(epoch, idx_batch))

            # max over the classes should be the prediction
            # our prediction is [N, classes, W, H]
            # so we max over the second dimension and take the max response
            # if we are doing rgb reconstruction, then just directly save it to file
            
            y_threshed = torch.zeros((y.size()[0], 3, y.size()[2], y.size()[3]))
            for idx in range(0, y.size()[0]):
                maxindex = torch.argmax(y[idx], dim=0).cpu().int()
                y_threshed[idx] = img_data.class_to_rgb(maxindex)
            v_utils.save_image(y_threshed, "./result/gen_image_{}_{}.png".format(epoch, idx_batch))
            

            # finally checkpoint this file to disk
            torch.save(generator, "./checkpoints_epochs/train_UNet.pkl")



epoch = 0 | loss = 1.1652849912643433




epoch = 1 | loss = 1.1659525632858276




epoch = 2 | loss = 1.1529920101165771




epoch = 3 | loss = 1.154135823249817




epoch = 4 | loss = 1.1487218141555786




epoch = 5 | loss = 1.1526705026626587




epoch = 6 | loss = 1.1709208488464355




epoch = 7 | loss = 1.1548783779144287




epoch = 8 | loss = 1.1639699935913086




epoch = 9 | loss = 1.1569418907165527




epoch = 10 | loss = 1.1505744457244873




epoch = 11 | loss = 1.155949354171753
