In [1]:
import sys
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
# need ms coco api to get dataset
from pycocotools import mask as mask
import numpy as np
import matplotlib.pyplot as plt
sys.path.insert(0, "../utils")
import utils
import train_eval

In [2]:
utils.set_random_seed()

In [3]:
import coco_dataset
import importlib
importlib.reload(coco_dataset)

# path to MS COCO dataset
train_data_dir = '/home/nfs/inf6/data/datasets/coco/train2017'
val_data_dir = '/home/nfs/inf6/data/datasets/coco/val2017'


# initialize COCO API for segmentation
train_ann_file = '/home/nfs/inf6/data/datasets/coco/annotations/instances_train2017.json'
val_ann_file = '/home/nfs/inf6/data/datasets/coco/annotations/instances_val2017.json'


# make dataset, default transforms is resizing and to tensor
train_set = coco_dataset.Coco_Dataset(train_data_dir, train_ann_file, mode="segmentation")
val_set = coco_dataset.Coco_Dataset(val_data_dir, val_ann_file, mode="segmentation")

batch_size = 16
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False)

loading annotations into memory...
Done (t=10.64s)
creating index...
index created!
loading annotations into memory...
Done (t=0.34s)
creating index...
index created!


In [4]:
img, seg_mask = train_set[0]
seg_mask.shape

torch.Size([1, 256, 256])

In [5]:
import Temporal_UNET_Template
import importlib
importlib.reload(Temporal_UNET_Template)

config = Temporal_UNET_Template.Temporal_UNetConfig()

temp_unet = Temporal_UNET_Template.Temporal_UNet(config)

temp_unet_optim = torch.optim.Adam(temp_unet.parameters(), lr=3e-4)

criterion = nn.CrossEntropyLoss()

<class 'temporal_modules.Conv2dGRUCell'>


In [6]:
importlib.reload(train_eval)

epochs=10
temp_unet_trainer = train_eval.Trainer(
    temp_unet, temp_unet_optim, criterion, train_loader, valid_loader, "coco", epochs, sequence=False, all_labels=91, start_epoch=10)

/home/user/schwemme/CudaLab_Project/src/tboard_logs/Temporal_UNet


In [None]:
temp_unet_trainer.train_model()

In [None]:
city_home = "/home/nfs/inf6/data/datasets/cityscapes"

cityscapes_train = torchvision.datasets.Cityscapes(root=city_home, split="train", mode="coarse", target_type="semantic")
cityscapes_test = torchvision.datasets.Cityscapes(root=city_home, split="train_extra", mode="coarse", target_type="semantic")
cityscapes_valid = torchvision.datasets.Cityscapes(root=city_home, split="val", mode="coarse", target_type="semantic")


In [None]:
img, label_dict = train_set[6]
masks = label_dict["masks"]
labels = label_dict["labels"]
seg_mask = label_dict["segmentation_mask"]

_, ax = plt.subplots(1, 2)

ax[0].imshow(img.permute(1, 2, 0), cmap="gray")
ax[1].imshow(seg_mask.squeeze(), cmap="gray")

_, x = plt.subplots(1, len(masks), figsize=(20,20))
for i in range(len(masks)):
    mask = masks[i]
    idx = mask == 0
    mask[mask == 0] = 100
    x[i].imshow(mask, cmap="gray")

