# Imports

In [None]:
import os
import pickle
import random

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from dataloaders import data_loaders
from DeepLabV3Plus_ResNet50 import Deeplabv3Plus
from google.colab import drive
from plot_utils import (
    plot_training_validation_losses,
    show_images_and_masks,
    visualize_predictions,
)
from predict import make_predictions
from torch.optim.lr_scheduler import CosineAnnealingLR
from train import train_and_validate

# Prepare the Sets

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
random_seed = 42
random.seed(random_seed)
torch.manual_seed(random_seed)

In [None]:
drive.mount('/content/drive')
!unzip -q /content/drive/MyDrive/carseg_arrays.zip

In [None]:
SOURCE_DIR = "/content/dataset"
pickle_file_path = "/content/deeplabv3plus_train_metrics.bin"

In [None]:
test_list = [f"photo_{str(i).zfill(4)}.npy" for i in range(1, 31)]

test_set = [os.path.join(SOURCE_DIR, filename) for filename in test_list]

replicate_list = [f"photo_{str(i).zfill(4)}.npy" for i in range(32, 169)]

replication_factor = 2

replicated_list = [
    filename for filename in replicate_list for _ in range(replication_factor)
]

train_set = [
    os.path.join(SOURCE_DIR, filename)
    for filename in os.listdir(SOURCE_DIR)
    if filename not in test_list
]

train_set += [
    os.path.join(SOURCE_DIR, filename) for filename in replicated_list
]

train_list = [
    filename
    for filename in os.listdir(SOURCE_DIR)
    if os.path.isfile(os.path.join(SOURCE_DIR, filename))
    and filename not in test_list
]

train_list += replicated_list

In [None]:
model_type = "deeplabv3+"
train_loader, val_loader, test_loader = data_loaders(
    model_type, train_set, test_set
)

In [None]:
data_iterator = iter(train_loader)
images, masks = next(data_iterator)
show_images_and_masks(images, masks, num_samples=6)

In [None]:
for i in range(masks.shape[0]):
    mask = masks[i]
    unique_labels = np.unique(mask)
    print(f"Unique labels in mask {i + 1}: {unique_labels}")

# Training

In [None]:
EPOCHS = 30
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
PATIENCE = 7
T_MAX = 30
ETA_MIN = 1e-6
INPUT_CHANNELS = 3
NUM_CLASSES = 9

In [None]:
model = Deeplabv3Plus(INPUT_CHANNELS, NUM_CLASSES)
model.to(DEVICE)

In [None]:
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

scheduler = CosineAnnealingLR(optimizer, T_max=T_MAX, eta_min=ETA_MIN)

In [None]:
history = train_and_validate(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    scheduler,
    NUM_CLASSES,
    EPOCHS,
    DEVICE,
    PATIENCE,
    use_scheduler=True,
)

In [None]:
train_losses = history["train_losses"]
val_losses = history["val_losses"]
iou_scores = history["iou_scores"]
dice_scores = history["dice_scores"]
class_iou = history["class_iou"]
class_dice = history["class_dice"]
pixel_accuracies = history["pixel_accuracies"]

In [None]:
history = {
    "train_losses": train_losses,
    "val_losses": val_losses,
    "iou_scores": iou_scores,
    "dice_scores": dice_scores,
    "class_iou": class_iou,
    "class_dice": class_dice,
    "pixel_accuracies": pixel_accuracies,
}

with open(pickle_file_path, "wb") as file:
    pickle.dump(history, file)

In [None]:
epochs = range(1, len(val_losses) + 1)
plot_training_validation_losses(epochs, train_losses, val_losses)

In [None]:
mean_dice_score = sum(dice_scores) / len(dice_scores)
print(f"Mean Dice Coefficient: {mean_dice_score}")

In [None]:
mean_iou_score = sum(iou_scores) / len(iou_scores)
print(f"Mean IoU Score: {mean_iou_score}")

In [None]:
mean_pixel_accuracy = sum(pixel_accuracies) / len(pixel_accuracies)
print(f"Mean Pixel Accuracy: {mean_pixel_accuracy}")

In [None]:
torch.save(model.state_dict(), "deeplabv3plus_best.pth")

# Inference

In [None]:
predictions = make_predictions(
    model, test_loader, DEVICE, NUM_CLASSES, num_samples=4
)

In [None]:
mean_iou = predictions["iou"]
mean_dice = predictions["dice"]
mean_pixel_accuracy = predictions["pixel_accuracy"]

print(f"Mean IoU: {mean_iou}")
print(f"Mean Dice: {mean_dice}")
print(f"Mean Pixel Accuracy: {mean_pixel_accuracy}")