In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
from PIL import Image
from torchvision import transforms
from fcn_dataset import CamVidDataset, rev_normalize
import torch
import tqdm

In [None]:
images_dir = "train/"
labels_dir = "train_labels/"
class_dict_path = "class_dict.csv"
resolution = (240, 240)
class_dict = pd.read_csv("CamVid/" + class_dict_path)
camvid_dataset = CamVidDataset(root='CamVid/', images_dir=images_dir, labels_dir=labels_dir, class_dict_path=class_dict_path, resolution=resolution)
train_loader = torch.utils.data.DataLoader(camvid_dataset, batch_size=4, shuffle=True, num_workers=4)

val_images_dir = "val/"
val_labels_dir = "val_labels/"
val_camvid_dataset = CamVidDataset(root='CamVid/', images_dir=val_images_dir, labels_dir=val_labels_dir, class_dict_path=class_dict_path, resolution=resolution)
val_loader = torch.utils.data.DataLoader(val_camvid_dataset, batch_size=4, shuffle=True, num_workers=4)

test_images_dir = "test/"
test_labels_dir = "test_labels/"
test_camvid_dataset = CamVidDataset(root='CamVid/', images_dir=test_images_dir, labels_dir=test_labels_dir, class_dict_path=class_dict_path, resolution=resolution)
test_loader = torch.utils.data.DataLoader(test_camvid_dataset, batch_size=4, shuffle=True, num_workers=4)


In [None]:
# Example of loading a single sample
image, label = camvid_dataset[0]
print(image.size())
# To visualize or further process, you might want to convert 'label' back to a color image or directly use it for training a segmentation model.
label_vis = label.numpy().astype(np.float32)
label_vis /= 31.
label_vis *= 255.
label_vis = label_vis.astype(np.uint8)
label_vis = Image.fromarray(label_vis)
image_vis = transforms.functional.to_pil_image(rev_normalize(image))

In [None]:
image_vis

In [None]:
label_vis

In [None]:
from fcn_model import FCN8s
model = FCN8s(num_classes=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
num_classes = 32

In [None]:
# Define the loss function and optimizer
def loss_fn(outputs, labels):
    """ 
    In the original paper, the authors mention a per-pixel multinomial logistic loss, which is equivalent to the standard cross-entropy loss.
    """ 
    return torch.nn.CrossEntropyLoss()(outputs, labels)

def calculate_metrics(pred, target, num_classes):
    """ 
    Calculate the pixel accuracy, mean IoU, and frequency weighted IoU.
    """
    pixel_acc = (pred == target).sum() / (target.shape[0] * target.shape[1])
    iou = []
    for i in range(num_classes):
        intersection = ((pred == i) & (target == i)).sum()
        union = ((pred == i) | (target == i)).sum()
        iou.append(intersection / union)
    mean_iou = np.mean(iou)
    freq_iou = np.sum([(target == i).sum() * iou[i] for i in range(num_classes)]) / (target.shape[0] * target.shape[1])
    return pixel_acc, mean_iou, freq_iou

def eval_model(model, dataloader, device, save_pred=False):
    print("Starting eval ....")
    model.eval()
    loss_list = []
    if save_pred:
        pred_list = []
    with torch.no_grad():
        for images, labels in tqdm.tqdm(dataloader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss_list.append(loss.item())
            _, predicted = torch.max(outputs, 1)
            if save_pred:
                pred_list.append(predicted.cpu().numpy())
           
        loss = sum(loss_list) / len(loss_list)
        pixel_acc, mean_iou, freq_iou = calculate_metrics(predicted.cpu().numpy(), labels.cpu().numpy(), num_classes)
        print('Pixel accuracy: {:.4f}, Mean IoU: {:.4f}, Frequency weighted IoU: {:.4f}, Loss: {:.4f}'.format(pixel_acc, mean_iou, freq_iou, loss))

    if save_pred:
        pred_list = np.concatenate(pred_list, axis=0)
        np.save('test_pred.npy', pred_list)
    model.train()

    
# Train the model
loss_list = []
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(tqdm.tqdm(dataloader_train)):
        images, labels = images.to("cpu"), labels.to("cpu")

        # Forward pass
        outputs = model(images)
        loss = loss_fn(outputs, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())

        if (i+1) % 10 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), sum(loss_list)/len(loss_list)))
            loss_list = []
    pixel_acc, mean_iou, freq_iou = calculate_metrics(torch.argmax(outputs, dim=1).cpu().numpy(), labels.cpu().numpy(), num_classes)
    print('Pixel accuracy: {:.4f}, Mean IoU: {:.4f}, Frequency weighted IoU: {:.4f}'.format(pixel_acc, mean_iou, freq_iou))
    # eval the model        
    eval_model(model, val_loader, "cpu")