In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [5]:
from datasets import MultiDataset, GTA5Dataset, VistasDataset, CityscapesDataset, ADE20KDataset
from dataloaders import collate_fn, MultiDatasetBatchSampler
from models import MultiHeadUnet
from epochs import MultiHeadTrainEpoch, MultiHeadValidEpoch
from utils import get_subset

In [6]:
GTA5_DATASET_IMAGES = "/gpfs/space/home/arammul/datasets/GTA-images/images"
GTA5_DATASET_LABELS = "/gpfs/space/home/arammul/datasets/GTA-images/labels"

VISTAS_BASE_PATH = "/gpfs/space/home/arammul/datasets/vistas"
VISTAS_NUM_CLASSES = 124

CITYSCAPES_IMAGES_BASE_PATH = "/gpfs/space/home/arammul/datasets/cityscapes/leftImg8bit"
CITYSCAPES_LABELS_BASE_PATH = "/gpfs/space/home/arammul/datasets/cityscapes/gtFine"

ADE20K_BASE_PATH = "/gpfs/space/home/arammul/datasets/ade20k/ADE20K_2021_17_01/images/ADE"

BATCH_SIZE = 16
DATALOADER_NUM_WORKERS=8

In [7]:
GTA5_LABEL_TO_CLASS = {
  0: 'unlabeled'           ,
  1: 'ego vehicle'         ,
  2: 'rectification border',
  3: 'out of roi'          ,
  4: 'static'              ,
  5: 'dynamic'             ,
  6: 'ground'              ,
  7: 'road'                ,
  8: 'sidewalk'            ,
  9: 'parking'             ,
 10: 'rail track'          ,
 11: 'building'            ,
 12: 'wall'                ,
 13: 'fence'               ,
 14: 'guard rail'          ,
 15: 'bridge'              ,
 16: 'tunnel'              ,
 17: 'pole'                ,
 18: 'polegroup'           ,
 19: 'traffic light'       ,
 20: 'traffic sign'        ,
 21: 'vegetation'          ,
 22: 'terrain'             ,
 23: 'sky'                 ,
 24: 'person'              ,
 25: 'rider'               ,
 26: 'car'                 ,
 27: 'truck'               ,
 28: 'bus'                 ,
 29: 'caravan'             , 
 30: 'trailer'             , 
 31: 'train'               , 
 32: 'motorcycle'          , 
 33: 'bicycle'             , 
 34: 'license plate'       
}

In [8]:
dataset_configs = [
    {
        "name": "GTA5",
        "images_path": GTA5_DATASET_IMAGES,
        "labels_path": GTA5_DATASET_LABELS,
        "num_classes": len(GTA5_LABEL_TO_CLASS),
    },
    {
        "name": "Vistas",
        "num_classes": VISTAS_NUM_CLASSES,
    },
    {
        "name": "Cityscapes",
        "num_classes": 19,
    },
    {
        "name": "ADE20K",
        "num_classes": 100,
    }
]

for config in [dataset_configs[0]]:
    image_files = sorted(os.listdir(config["images_path"]))
    image_files = [path for path in image_files if path.endswith('.png')]
    image_files = get_subset(image_files, 0.4)
    
    config["image_files"] = image_files

    X_train_val, X_test = train_test_split(image_files, test_size=0.1, random_state=42)
    X_train, X_val = train_test_split(X_train_val, test_size=0.1111, random_state=42)

    config["X_train"] = X_train
    config["X_val"] = X_val
    config["X_test"] = X_test

for config in [dataset_configs[3]]:
    image_files = ADE20KDataset.get_image_files(f"{ADE20K_BASE_PATH}/training/")
    image_files = get_subset(image_files, 0.4)
    
    X_train, X_val = train_test_split(image_files, test_size=0.1, random_state=42)
    config["X_train"] = X_train
    config["X_val"] = X_val
    image_files_test = ADE20KDataset.get_image_files(f"{ADE20K_BASE_PATH}/validation/")
    config["X_test"] = X_test

In [9]:
# dataset_1 = GTA5Dataset(
#     dataset_name="gta5_1",
#     image_dir=dataset_configs[0]["images_path"], 
#     label_dir=dataset_configs[0]["labels_path"],
#     image_files=dataset_configs[0]["X_train"],
#     num_classes=len(GTA5_LABEL_TO_CLASS),
#     resize_dims=(526, 957), 
#     crop_dims=(512, 512)
# )

# dataset_2 = GTA5Dataset(
#     dataset_name="gta5_2",
#     image_dir=dataset_configs[1]["images_path"], 
#     label_dir=dataset_configs[1]["labels_path"],
#     image_files=dataset_configs[1]["X_train"],
#     num_classes=len(GTA5_LABEL_TO_CLASS),
#     resize_dims=(526, 957), 
#     crop_dims=(512, 512)
# )

# datasets = [dataset_1, dataset_2]
# multi_dataset = MultiDataset(datasets)

In [10]:
# batch_sampler = MultiDatasetBatchSampler(datasets, batch_size=BATCH_SIZE)

In [11]:
# dataloader = DataLoader(
#     dataset=multi_dataset,
#     batch_sampler=batch_sampler,
#     collate_fn=collate_fn,
#     num_workers=1
# )

# Model

In [12]:
ENCODER = 'resnet101'
ENCODER_WEIGHTS = 'imagenet'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [13]:
model = MultiHeadUnet(
    dataset_configs=dataset_configs,
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
)

In [14]:
train_dataset_1 = GTA5Dataset(
    dataset_name="GTA5",
    image_dir=dataset_configs[0]["images_path"], 
    label_dir=dataset_configs[0]["labels_path"],
    image_files=dataset_configs[0]["X_train"],
    num_classes=dataset_configs[0]["num_classes"],
    preprocessing_fn=preprocessing_fn,
    resize_dims=(526, 957), 
    crop_dims=(512, 512)
)

train_dataset_2 = VistasDataset(
    dataset_name="Vistas",
    image_dir=VISTAS_BASE_PATH + "/training/images/", 
    label_dir=VISTAS_BASE_PATH + "/training/v2.0/labels/",
    image_files=get_subset(sorted(os.listdir(VISTAS_BASE_PATH + "/training/images/")), 0.4),
    num_classes=VISTAS_NUM_CLASSES,
    preprocessing_fn=preprocessing_fn,
    downscale_to_height = 512,
    crop_dims=(512, 512)
)

train_dataset_3 = CityscapesDataset(
    dataset_name = "Cityscapes",
    image_dir = f"{CITYSCAPES_IMAGES_BASE_PATH}/train/", 
    label_dir = f"{CITYSCAPES_LABELS_BASE_PATH}/train/", 
    image_files = CityscapesDataset.get_image_files(f"{CITYSCAPES_IMAGES_BASE_PATH}/train/"), 
    preprocessing_fn=preprocessing_fn,
    downscale_to_height=512,
    crop_dims=(512, 512)
)

train_dataset_4 = ADE20KDataset(
    dataset_name = "ADE20K",
    image_dir = f"{ADE20K_BASE_PATH}/training/", 
    label_dir = f"{ADE20K_BASE_PATH}/training/", 
    image_files = dataset_configs[0]["X_train"], 
    preprocessing_fn=preprocessing_fn,
    downscale_to_height=512,
    crop_dims=(512, 512)
)

train_datasets = [train_dataset_1, train_dataset_2, train_dataset_3, train_dataset_4]
train_multi_dataset = MultiDataset(train_datasets)
train_batch_sampler = MultiDatasetBatchSampler(train_datasets, batch_size=BATCH_SIZE)

valid_dataset_1 = GTA5Dataset(
    dataset_name="GTA5",
    image_dir=dataset_configs[0]["images_path"], 
    label_dir=dataset_configs[0]["labels_path"],
    image_files=dataset_configs[0]["X_val"],
    num_classes=dataset_configs[0]["num_classes"],
    preprocessing_fn=preprocessing_fn,
    resize_dims=(526, 957), 
    crop_dims=(512, 512)
)

valid_dataset_2 = VistasDataset(
    dataset_name="Vistas",
    image_dir=VISTAS_BASE_PATH + "/validation/images/", 
    label_dir=VISTAS_BASE_PATH + "/validation/v2.0/labels/",
    image_files=get_subset(sorted(os.listdir(VISTAS_BASE_PATH + "/validation/images/")), 0.4),
    num_classes=VISTAS_NUM_CLASSES,
    downscale_to_height = 512,
    preprocessing_fn=preprocessing_fn,
    crop_dims=(512, 512)
)

valid_dataset_3 = CityscapesDataset(
    dataset_name = "Cityscapes",
    image_dir = f"{CITYSCAPES_IMAGES_BASE_PATH}/val/", 
    label_dir = f"{CITYSCAPES_LABELS_BASE_PATH}/val/", 
    image_files = CityscapesDataset.get_image_files(f"{CITYSCAPES_IMAGES_BASE_PATH}/val/"), 
    preprocessing_fn=preprocessing_fn,
    downscale_to_height=512,
    crop_dims=(512, 512)
)

valid_dataset_4 = ADE20KDataset(
    dataset_name = "ADE20K",
    image_dir = f"{ADE20K_BASE_PATH}/validation/", 
    label_dir = f"{ADE20K_BASE_PATH}/validation/", 
    image_files = dataset_configs[0]["X_val"], 
    preprocessing_fn=preprocessing_fn,
    downscale_to_height=512,
    crop_dims=(512, 512)
)

valid_datasets = [valid_dataset_1, valid_dataset_2, valid_dataset_3, valid_dataset_4]
valid_multi_dataset = MultiDataset(valid_datasets)
valid_batch_sampler = MultiDatasetBatchSampler(valid_datasets, batch_size=BATCH_SIZE)

In [15]:
sum([len(d) for d in train_datasets])

26151

In [16]:
[len(d) for d in valid_datasets]

[999, 800, 500, 999]

In [17]:
train_loader = DataLoader(
    dataset=train_multi_dataset,
    batch_sampler=train_batch_sampler,
    collate_fn=collate_fn,
    num_workers=DATALOADER_NUM_WORKERS
)

valid_loader = DataLoader(
    dataset=valid_multi_dataset,
    batch_sampler=valid_batch_sampler,
    collate_fn=collate_fn,
    num_workers=DATALOADER_NUM_WORKERS
)

In [18]:
loss = utils.losses.DiceLoss(activation='softmax2d')
metrics = [
    utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.001),
])

In [19]:
train_epoch = MultiHeadTrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = MultiHeadValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [20]:
NUM_EPOCHS = 5
max_score = 0
for i in range(0, NUM_EPOCHS):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model-4-datasets.pth')
        print('Model saved!')


Epoch: 0
train:   2%| | 35/1634 [08:29<6:27:54, 14.56s/it, dice_loss - 0.2135, iou_score 


KeyboardInterrupt: 