# Preparations

## CityScapes download

In [None]:
from utils import pretty_extract

# !pip install -q gdown

# file_id = "1MI8QsvjW0R6WDJiL49L7sDGpPWYAQB6O"
# !gdown https://drive.google.com/uc?id={file_id}

# pretty_extract("Cityscapes.zip", ".")


## GTA5 download

In [None]:
from utils import pretty_extract

# !pip install -q gdown

# file_id = "1PWavqXDxuifsyYvs2PFua9sdMl0JG8AE"
# !gdown https://drive.google.com/uc?id={file_id}

# pretty_extract("Gta5_extended.zip", "./Gta5_extended")


## DeepLab weights download

In [None]:
# !pip install -q gdown

# file_id = "1KgYgBTmvq7UcBwKui2b4TomnbTmzJMBf"
# !gdown https://drive.google.com/uc?id={file_id}

## Dataset Utilization

In [None]:
from torch.utils.data import DataLoader
import torchvision.transforms as TF

import matplotlib.pyplot as plt

from utils import tensorToImageCompatible, decode_segmap
from datasets import cityscapes, gta5 

def test_dataset():
    B = 3
    H = 512
    W = 1024
    dataset = "GTA5"

    transform = TF.Compose([
        TF.ToTensor(),
        TF.Resize((H,W)),
        TF.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
    ])
    target_transform = TF.Compose([
        TF.ToTensor(),
        TF.Resize((H, W), interpolation=TF.InterpolationMode.NEAREST),
    ])

    if dataset == "Cityscapes":
        data = cityscapes.CityScapes("./Cityscapes/Cityspaces", split="train", transform=transform, target_transform=target_transform)
    elif dataset == "GTA5":
        data, _ = gta5.GTA5_dataset_splitter("./Gta5_extended", train_split_percent=0.6, split_seed=42, augment=False, transform=transform, target_transform=target_transform)
    else:
        raise Exception("Wrong dataset name")

    dataloader = DataLoader(data, batch_size=B, shuffle=False)
    i = 0
    img_tensor, color_tensor, label = next(iter(dataloader))

    img = tensorToImageCompatible(img_tensor[i])

    color = tensorToImageCompatible(color_tensor[i])
    decoded_from_labelId = decode_segmap(label[i, 0])

    fig, ax = plt.subplots(2,2, figsize=(10,10), layout="tight")

    ax[0,0].set_title("Image")
    ax[0,0].imshow(img)
    ax[0,0].axis('off')

    ax[0,1].set_title("Colored by label")
    ax[0,1].imshow(color)
    ax[0,1].axis('off')

    ax[1,0].set_title("Reconstructed from class ID")
    ax[1,0].imshow(decoded_from_labelId)
    ax[1,0].axis('off')

    ax[1,1].set_title("Raw Classes")
    ax[1,1].imshow(label[i, 0])
    ax[1,1].axis('off')

    fig.show()

# test_dataset()

Cityscapes image:

![title](cityscapes_example.png)

Gta image:

![title](gta_example.png)


# Main

## Logging

In [None]:
ENABLE_PRINT = False
ENABLE_WANDB_LOG = True
log_per_epoch = 20
n_classes = 19

train_step = 0
val_step = 0

## Device

In [None]:
import torch
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

## Machine Learning

In [None]:
# TODO: something not working with validate/batch_miou when epoch starts, always a peak
# Possibly since less images -> less classes seen -> lower denominator when calculating mIou
# TODO: something wrong with validate, batch_loss going up and batch_miou going down while epoch-level metrics are fine
# TODO: num of log from validate and train is different?
# Found validate/step went back to 0 -> typo: validate_step instead of val_step

def pipeline():
    from torch.utils.data import DataLoader
    import torchvision.transforms as TF
    import torch.nn as nn
    import torch.optim as optim
    import wandb
    import os

    from train import train, validate
    from utils import poly_lr_scheduler, num_flops, latency, log_confusion_matrix
    from datasets import cityscapes, gta5 
    from models.bisenet.build_bisenet import BiSeNet
    from models.deeplabv2.deeplabv2 import get_deeplab_v2

    global device
    global n_classes
    global ENABLE_PRINT
    global ENABLE_WANDB_LOG
    global train_step
    global val_step
    global log_per_epoch

    ENABLE_PRINT = False
    ENABLE_WANDB_LOG = True
    train_step = 0
    val_step = 0
    log_per_epoch = 20

    models_root_dir = "./models"
    !rm -rf {models_root_dir}
    !mkdir {models_root_dir}

    B = 3
    H = 512
    W = 1024
    n_classes = 19

    backbone = "BiSeNet"
    context_path = "resnet101"

    start_epoch = 0
    end_epoch = 2
    max_epoch = 50

    assert start_epoch < end_epoch <= max_epoch, "Check your start/end/max epoch settings."

    init_lr=0.001
    lr_decay_iter = 1
    momentum=0.9
    weight_decay=5e-4
    dataset = "Cityscapes"

    transform = TF.Compose([
        TF.ToTensor(),
        TF.Resize((H,W)),
        TF.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225])
    ])
    target_transform = TF.Compose([
        TF.ToTensor(),
        TF.Resize((H, W), interpolation=TF.InterpolationMode.NEAREST),
    ])

    # Dataset objects
    if dataset == "Cityscapes":
        data_train = cityscapes.CityScapes("./Cityscapes/Cityspaces", split="train", transform=transform, target_transform=target_transform)
        data_val = cityscapes.CityScapes("./Cityscapes/Cityspaces", split="val", transform=transform, target_transform=target_transform)
    elif dataset == "GTA5":
        data_train, data_val = gta5.GTA5_dataset_splitter("./Gta5_extended", train_split_percent=0.8, split_seed=42, transform=transform, target_transform=target_transform)
    else:
        raise Exception("Wrong dataset name")
    train_loader = DataLoader(data_train, batch_size=B, shuffle=True)
    val_loader = DataLoader(data_val, batch_size=B, shuffle=True)

    # Architecture
    if backbone == "BiSeNet":
        model = BiSeNet(n_classes, context_path).to(device)
        architecture = backbone+"-"+context_path
    elif backbone == "DeepLab":
        model = get_deeplab_v2(num_classes=n_classes, pretrain=True).to(device)
        architecture = backbone
    else:
        raise Exception("Wrong model name")

    # The other 2
    criterion = nn.CrossEntropyLoss(ignore_index=255)
    optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=momentum, weight_decay=weight_decay)

    # TODO: wandb can't let us reuse the same run_id, need way to manage it
    # Wandb setup and metrics
    run_name = f"step_2"
    run_id = f"{run_name}_{architecture}_{dataset}"
    run = wandb.init(
        entity="Machine_learning_and_Deep_learning_labs",
        project="Semantic Segmentation",
        name=run_name,
        resume="allow", # <----------------  IMPORTANT CONFIG KEY
        config={
            "initial_learning_rate": init_lr,
            "lr_decay_iter": lr_decay_iter,
            "momentum": momentum,
            "weight_decay": weight_decay,
            "architecture": architecture,
            "dataset": dataset,
            "start_epoch": start_epoch,
            "end_epoch": end_epoch,
            "max_epoch": max_epoch,
            "batch": B,
            "lr_scheduler": "poly"
        },
    )

    wandb.define_metric("epoch/step")
    wandb.define_metric("epoch/*", step_metric="epoch/step")

    wandb.define_metric("train/step")
    wandb.define_metric("train/*", step_metric="train/step")

    wandb.define_metric("validate/step")
    wandb.define_metric("validate/*", step_metric="validate/step")

    # Loading form a starting point
    if start_epoch > 0:
        artifact = run.use_artifact(f'Machine_learning_and_Deep_learning_labs/Semantic Segmentation/{run_id}:epoch_{start_epoch}', type='model')
        artifact_dir = artifact.download()

        artifact_path = os.path.join(artifact_dir, run_id+f"_epoch_{start_epoch}.pth")

        checkpoint = torch.load(artifact_path, map_location=device)

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        train_step = checkpoint["train_step"]+1
        val_step = checkpoint["validate_step"]+1

    # Main Loop
    for epoch in range(start_epoch+1, end_epoch+1):
        print("-----------------------------")
        print(f"Epoch {epoch}")

        lr = poly_lr_scheduler(optimizer, init_lr, epoch-1, max_iter=max_epoch)

        print(f"[Poly LR] 100xLR: {100.*lr:.6f}")

        run.log({
            "epoch/step": epoch,
            "epoch/100xlearning_rate": 100.*lr,
        })

        train_loss, train_mIou, train_hist = train(model, train_loader, criterion, optimizer)

        print(f'[Train Loss] : {train_loss:.6f} [mIoU]: {100.*train_mIou:.2f}%')

        # log_confusion_matrix("Confusion Matrix - Train", train_hist, "epoch/train_confusion_matrix", "epoch/step", epoch)
        run.log({
                "epoch/step": epoch,
                "epoch/train_loss": train_loss,
                "epoch/train_mIou": 100*train_mIou
            },
            commit=True,
        )

        val_loss, val_mIou, val_hist = validate(model, val_loader, criterion)

        print(f'[Validation Loss] : {val_loss:.6f} [mIoU]: {100.*val_mIou:.2f}%')

        # log_confusion_matrix("Confusion Matrix - Validate", val_hist, "epoch/validate_confusion_matrix", "epoch/step", epoch)
        run.log({
                "epoch/step": epoch,
                "epoch/val_loss": val_loss,
                "epoch/val_mIou": 100*val_mIou
            },
            commit=True
        )


        if epoch % 2 == 0 or epoch == end_epoch:
            checkpoint = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": train_step,
                "validate_step": val_step,
            }

            file_name = f"{run_id}_epoch_{epoch}.pth"

            # TODO: add some tables to artifact to enable comparisons

            # Saving the progress
            file_path = os.path.join(models_root_dir, file_name)
            torch.save(checkpoint, file_path)

            print(f"Model saved to {file_path}")

            artifact = wandb.Artifact(name=run_id, type="model")
            artifact.add_file(file_path)

            run.log_artifact(artifact, aliases=["latest", f"epoch_{epoch}"])

        if (epoch % 10) == 0:
            log_confusion_matrix("Confusion Matrix - Train", train_hist, "epoch/train_confusion_matrix", "epoch/step", epoch)
            log_confusion_matrix("Confusion Matrix - Validate", val_hist, "epoch/validate_confusion_matrix", "epoch/step", epoch)

    # TODO: need to check if works
    run.config["end_epoch"] = min(end_epoch, run.config["end_epoch"])

    run.log({
        "model/flops": num_flops(model, 512, 1024),
        "model/latency": latency(model, 512, 1024)
    })

    run.finish()

# wandb.finish()
pipeline()
