In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torchvision import datasets, models, transforms
import torchvision

import os
import sys
sys.path.insert(0, "../utils")
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from tqdm import tqdm

import src.datasets.cityscapes_loader as cityscapes_loader
import utils.train_eval as train_eval
import importlib
import visualizations as vis

%load_ext autoreload
%autoreload 2



In [11]:
importlib.reload(cityscapes_loader)

is_sequence = True

dataset_root_dir = "/home/nfs/inf6/data/datasets/cityscapes/"

train_ds = cityscapes_loader.cityscapesLoader(root=dataset_root_dir, split='train', img_size=(512, 1024), is_transform=True, is_sequence=is_sequence)
#val_ds = cityscapes_loader.cityscapesLoader(root=dataset_root_dir, split='val', img_size=(1024, 2048), is_transform=True, is_sequence=is_sequence)
val_ds = cityscapes_loader.cityscapesLoader(root=dataset_root_dir, split='val', img_size=(512, 1024), is_transform=True, is_sequence=is_sequence)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=2, shuffle=True, drop_last=True)
valid_loader = torch.utils.data.DataLoader(val_ds, batch_size=1, shuffle=False, drop_last=True)

Found 2975 train images
Found 500 val images


In [3]:
from src.architectures.architecture_configs import *
import src.architectures.Temporal_UNET_Template as Temporal_UNET_Template
import src.architectures.UNet_Template as UNet_Template
import utils.utils
import utils

"""
# for BaselineVanillaSmallShallow
encoder_blocks = SmallShallow_NetworkSize.encoder_blocks
decoder_blocks = SmallShallow_NetworkSize.decoder_blocks

config = Temporal_VanillaUNetConfig(
    encoder_blocks=encoder_blocks,
    decoder_blocks=decoder_blocks,
    temporal_cell= Conv2dRNNCell
    )

temp_unet = Temporal_UNET_Template.Temporal_UNet(config)
"""
"""
# for BaselineVanillaSmallDeep
encoder_blocks = SmallDeep_NetworkSize.encoder_blocks
decoder_blocks = SmallDeep_NetworkSize.decoder_blocks

config = Temporal_VanillaUNetConfig(
    encoder_blocks=encoder_blocks,
    decoder_blocks=decoder_blocks,
    temporal_cell= Conv2dRNNCell
    )
"""

# for BaselineVanillaOriginalSizes
encoder_blocks = Original_Dimensions.encoder_blocks
decoder_blocks = Original_Dimensions.decoder_blocks


config = Temporal_VanillaUNetConfig(
        encoder_blocks=encoder_blocks,
        decoder_blocks=decoder_blocks,
        temporal_cell= Conv2dRNNCell
        )

unet = UNet_Template.UNet(config)

unet_optim = torch.optim.Adam(unet.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

utils.utils.load_model(unet, unet_optim, "models/UNet_Original/UNet_Original/Layers3_InitDim64cityscapes_epoch_79.pth", "cuda")

# epochs=100
# temp_unet_trainer = utils.train_eval.Trainer(
#             temp_unet, temp_unet_optim, criterion,
#             train_loader, valid_loader, "cityscapes", epochs,
#             sequence=True, all_labels=20, start_epoch=82)

# load_model = True
# if load_model:
#     temp_unet_trainer.load_model("cityscapes")

(UNet(
   (encoder): UNetEncoder(
     (in_block): ConvBlock(
       (0): ConvWithNorm(
         (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
         (batchnorm2d): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (activation): ReLU(inplace=True)
       )
       (1): ConvWithNorm(
         (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
         (batchnorm2d): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         (activation): ReLU(inplace=True)
       )
     )
     (downsample_blocks): ModuleList(
       (0): DownsampleBlock(
         (downsampling): Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
         (conv_block): ConvBlock(
           (0): ConvWithNorm(
             (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
             (batchnorm2d): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 

In [12]:
import shutil
import os

@torch.no_grad()
def vis_seq(model, loader):
    visualizer = vis.CityscapesVisualizer()
    model.eval()
    for i, (imgs, targets) in enumerate(loader):
        imgs = imgs.cuda()
        preds = model(imgs)
        print(preds.shape)
        for i in range(preds.shape[0]):
            decoded_seq = get_decoded_img_seq(preds[i])

            for j in range(len(decoded_seq)):
                plt.imshow(decoded_seq[j])
                plt.show()

            break
        break
    return 

@torch.no_grad()
def save_vis_seq(model, loader, model_name="default"):
    if not os.path.exists("imgs"):
        os.makedirs("imgs")
    if not os.path.exists(f"imgs/{model_name}"):
        os.makedirs(f"imgs/{model_name}")

    visualizer = vis.CityscapesVisualizer()
    model.eval()
    for k, (imgs, targets) in enumerate(loader):
        if not os.path.exists(f"imgs/{model_name}/{k}"):
            os.makedirs(f"imgs/{model_name}/{k}")
        if not os.path.exists(f"imgs/{model_name}/{k}/original"):
            os.makedirs(f"imgs/{model_name}/{k}/original")
        if not os.path.exists(f"imgs/{model_name}/{k}/predicted"):
            os.makedirs(f"imgs/{model_name}/{k}/predicted")
        imgs = imgs.cuda()
        for i in range(imgs.shape[1]):
            preds = model(imgs[:, i, :, :, :])
            decoded_seq = get_decoded_img_seq(preds[0])
            torchvision.utils.save_image(torch.from_numpy(decoded_seq.transpose(2,0,1)), os.path.join(os.getcwd(), "imgs", f"{model_name}", f"{k}", "predicted", f"imgs_{i}.png"))
            mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
            unnorm_imgs = imgs.cpu() * std + mean

            torchvision.utils.save_image(unnorm_imgs[:, i], os.path.join(os.getcwd(), "imgs", f"{model_name}", f"{k}", "original", f"imgs_{i}.png"))
        print(k)

        if k > 20:
            break

    return 

def get_decoded_img_seq(preds):
    result = []
    visualizer = vis.CityscapesVisualizer()
    predicted_class = torch.argmax(preds, dim=1)[0]
    decoded_pred = visualizer.decode_segmap(predicted_class.cpu().numpy())
    result.append(decoded_pred)

    #     decoded_pred = visualizer.decode_segmap(predicted_class[j].cpu().numpy())
    #     result.append(decoded_pred)
        #torchvision.utils.save_image(torch.from_numpy(decoded_pred.transpose(2,0,1)), os.path.join(os.getcwd(), "imgs", "training", f"imgs_{j}.png"))
    return decoded_pred

save_vis_seq(unet, valid_loader, model_name="BaselineVanillaOriginalSizes_low_res")



0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21


In [8]:
import imageio
import PIL
import os

def create_gifs(model_name="default", mode="side-by-side", transparency=0.5, fps=8):

    allowed_modes = ["side-by-side", "overlay"]
    if mode not in allowed_modes:
        raise ValueError(f"mode must be one of {alloud_modes}")
    imgs_root=f"imgs/{model_name}"
    dirlist = [ item for item in os.listdir(imgs_root) if os.path.isdir(os.path.join(imgs_root, item)) ]
    # remove "gifs"-folder from dirlist
    dirlist = [item for item in dirlist if item != "gifs"]

    if not os.path.exists(f"imgs/{model_name}/gifs"):
        os.makedirs(f"imgs/{model_name}/gifs")



    if mode == "side-by-side":
        for i in range(len(dirlist)):
            images = []
            for j in range(12):
                original = PIL.Image.open(f"imgs/{model_name}/{dirlist[i]}/original/imgs_{j}.png")
                prediction = PIL.Image.open(f"imgs/{model_name}/{i}/predicted/imgs_{j}.png")

                (width1, height1) = original.size
                (width2, height2) = prediction.size

                result_width = width1 + width2
                result_height = max(height1, height2)

                result = PIL.Image.new('RGB', (result_width, result_height))
                result.paste(im=original, box=(0, 0))
                result.paste(im=prediction, box=(width1, 0))
                images.append(result)
            imageio.mimsave(f"imgs/{model_name}/gifs/{i}.gif", images, fps=fps)

    elif mode == "overlay":
        for i in dirlist:
            images = []
            for j in range(12):
                background = PIL.Image.open(f"imgs/{model_name}/{i}/original/imgs_{j}.png")
                foreground  = PIL.Image.open(f"imgs/{model_name}/{i}/predicted/imgs_{j}.png")
                foreground.putalpha(int(255*(1-transparency))) 
                background.paste(foreground, (0, 0), mask=foreground)
                images.append(background)
            imageio.mimsave(f"imgs/{model_name}/gifs/{i}.gif", images, fps=fps)

create_gifs("BaselineVanillaOriginalSizes_low_res", mode="overlay", transparency=0.45, fps=8)