In [None]:
#| eval: false
import torch
import numpy as np
import os, time
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import torch.nn as nn
from torch.optim import Adam
from MVLidarImplementation import data
from MVLidarImplementation import model
from pathlib import Path
import matplotlib.pyplot as plt

In [None]:
#| eval: false
INIT_LR = 0.0001
NUM_EPOCHS = 40
BATCH_SIZE = 4
N_CLASSES = 7
MODEL_PATH = "mvlidar.pth"
PLOT_PATH = "plot.png"
TEST_PATHS = "test_paths.txt"

In [None]:
#| eval: false
# determine the device to be used for training and evaluation
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {DEVICE} device")
# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cuda" else False

Using cpu device


In [None]:
#| eval: false
# insert the path to your dataset
train_path = Path('../train')
test_path = Path("../test")

data_paths = [(train_path, "../train-merged"), (test_path, "../test-merged")]

for data_path, merged_path in data_paths:
    merged_dir_path = Path(merged_path)
    os.makedirs(merged_dir_path, exist_ok=True)
    data.merge_images(Path(data_path), merged_dir_path)

In [None]:
#| eval: false
masks_paths = [("../train_segmentation_mask", train_path), ("../test_segmentation_mask", test_path)]

for masks_path, data_path in masks_paths:
    os.makedirs(masks_path, exist_ok=True)
    data.remap_segmentation_masks(data_path, Path(masks_path))

In [None]:
#| eval: false
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = data.SemanticDataset(image_folder_path=imgs_train_path,
                        mask_folder_path=masks_train_path,
                        transform=transform)

test_dataset = data.SemanticDataset(image_folder_path=imgs_test_path,
                        mask_folder_path=masks_test_path,
                        transform=transform)

In [None]:
#| eval: false
trainLoader = DataLoader(train_dataset, shuffle=True,
	batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY,
	num_workers=os.cpu_count())

testLoader = DataLoader(test_dataset, shuffle=False,
	batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY,
	num_workers=os.cpu_count())

In [None]:
#| eval: false
mvlidar = model.MVLidar(N_CLASSES).to(DEVICE)

lossFunc = CrossEntropyLoss(reduction='none')
opt = Adam(mvlidar.parameters(), lr=INIT_LR)

trainSteps = len(train_dataset) // BATCH_SIZE
testSteps = len(test_dataset) // BATCH_SIZE

H = {"train_loss": [], "test_loss": []}

In [None]:
def apply_loss_binary_mask(pred, y):
  bin_mask_train = (y !=0).int()
  loss = lossFunc(pred, y)
  loss = loss * bin_mask_train
  loss = loss.mean()
  return loss

In [None]:
print("[INFO] training the network...")
startTime = time.time()

for e in tqdm(range(NUM_EPOCHS)):

	mvlidar.train()

	totalTrainLoss = 0
	totalTestLoss = 0

	for (i, (x, y)) in enumerate(trainLoader):

		(x, y) = (x.to(DEVICE), y.to(DEVICE))

		pred = mvlidar(x)
		loss = apply_loss_binary_mask(pred, y)

		opt.zero_grad()
		loss.backward()
		opt.step()

		totalTrainLoss += loss

	with torch.no_grad():
		mvlidar.eval()

		for (x, y) in testLoader:
			(x, y) = (x.to(DEVICE), y.to(DEVICE))

			pred = mvlidar(x)
			loss = apply_loss_binary_mask(pred, y)
			totalTestLoss += loss

	avgTrainLoss = totalTrainLoss / trainSteps
	avgTestLoss = totalTestLoss / testSteps

	H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
	H["test_loss"].append(avgTestLoss.cpu().detach().numpy())
	print("[INFO] EPOCH: {}/{}".format(e + 1, NUM_EPOCHS))
	print("Train loss: {:.6f}, Test loss: {:.4f}".format(
		avgTrainLoss, avgTestLoss))

endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
	endTime - startTime))

In [None]:
# plot the training loss
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(PLOT_PATH)

# serialize the model to disk
torch.save(mvlidar, MODEL_PATH)