In [1]:
import nbimporter
import os
import torch
import torchvision.transforms.v2 as transforms
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from UNetModel import UNet
from utils import (
    load_checkpoint,
    save_checkpoint,
    check_accuracy,
    save_predictions,
)
from dataset import KITTIDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

#%run -i utils.ipynb

#from dataset import get_loaders

In [2]:
# Hyperparameters
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

LEARNING_RATE = 1e-4
BATCH_SIZE = 8
NUM_EPOCHS = 3
NUM_WORKERS = 4
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 600
PIN_MEMORY = True
LOAD_MODEL = False
IMG_DIR = "data_road/training/image_2"
MASK_DIR = "data_road/training/gt_image_2"
best_acc = -1

In [3]:
def get_loaders(image_dir, mask_dir, batch, train_transform, val_transform, val_split=0.2, num_workers=0, pin_memory=True):
    all_images = [img for img in os.listdir(image_dir) if img.endswith(".png")]

    train_size = int(0.8 * len(all_images))
    val_size = len(all_images) - train_size
    train_images, val_images = torch.utils.data.random_split(all_images, [train_size, val_size])

    train_dataset = KITTIDataset(image_dir=image_dir, mask_dir=mask_dir, image_files=train_images, transform=train_transform)
    train_loader = DataLoader(train_dataset, batch_size=batch, num_workers=num_workers, pin_memory=pin_memory, shuffle=True)

    val_dataset = KITTIDataset(image_dir=image_dir, mask_dir=mask_dir, image_files=val_images, transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=batch, num_workers=num_workers, pin_memory=pin_memory, shuffle=False)

    return train_loader, val_loader

In [4]:
def train(loader, model, optimizer, loss_fn, device):
    model.train()
    loop = tqdm(loader)
    
    for batch_index, (data, targets) in enumerate(loop):
        data = data.to(device=device)
        targets = targets.float().to(device=device)  
        
        optimizer.zero_grad()
        predictions = model(data)
        loss = loss_fn(predictions, targets)  
        
        loss.backward()
        optimizer.step()
        
        loop.set_postfix(loss=loss.item())

tensor_transform = transforms.Compose([
            transforms.ToImage(),
            transforms.ToDtype(torch.float32, scale=True)
        ])
        
train_transforms = transforms.Compose(
    [
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        #transforms.RandomHorizontalFlip(p=0.5),
        #transforms.RandomVerticalFlip(p=0.1),
        tensor_transform,
        #transforms.Normalize(
        #    mean=[0.3509, 0.3773, 0.3662],
        #    std=[0.2796, 0.2989, 0.3114],
        #)
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        tensor_transform,
        #transforms.Normalize(
        #    mean=[0.3509, 0.3773, 0.3662],
        #    std=[0.2796, 0.2989, 0.3114],
        #)
    ]
)

model = UNet(in_channels=3, out_channels=1).to(DEVICE)

#cross entropy loss for multiple segmentation channels 
loss_fn = nn.BCEWithLogitsLoss()

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loader, val_loader = get_loaders(
    image_dir=IMG_DIR, mask_dir=MASK_DIR, batch=BATCH_SIZE, train_transform=train_transforms, val_transform=val_transforms, val_split=0.2,
    num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY
)

if LOAD_MODEL:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
    check_accuracy(val_loader, model, device=DEVICE)


scaler = torch.GradScaler(DEVICE)

global best_acc
best_dice = 0

for epoch in range(NUM_EPOCHS):
    train(train_loader, model, optimizer, loss_fn, DEVICE)

    accuracy = check_accuracy(val_loader, model, device=DEVICE)
    
    
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Accuracy: {accuracy:.4f}")
    save_predictions(val_loader, model, folder="results", device=DEVICE)


  0%|          | 0/29 [00:00<?, ?it/s]

Unique colors before in mask: [[  0   0   0]
 [255   0   0]
 [255   0 255]]
Unique colors before in mask: [[  0   0   0]
 [255   0   0]
 [255   0 255]]
Unique colors before in mask: [[  0   0   0]
 [255   0   0]
 [255   0 255]]
Unique colors before in mask: [[255   0   0]
 [255   0 255]]
Unique colors after in mask: [[-1.262295  -1.262295  -1.262295 ]
 [-1.2550071 -1.2550071 -1.2550071]
 [-1.2550071 -1.2550071 -1.2409816]
 ...
 [ 2.3215308  2.3215308  2.1812747]
 [ 2.3215308  2.3215308  2.3075054]
 [ 2.3215308  2.3215308  2.3215308]]
Unique colors after in mask: [[-1.262295  -1.262295  -1.262295 ]
 [-1.2550071 -1.2550071 -1.2550071]
 [-1.2550071 -1.2550071 -1.2409816]
 ...
 [ 2.3215308  2.3215308  2.2934797]
 [ 2.3215308  2.3215308  2.3075054]
 [ 2.3215308  2.3215308  2.3215308]]
Unique colors after in mask: [[-1.262295  -1.262295  -1.262295 ]
 [-1.1759795 -1.1759795 -1.1759795]
 [-1.1759795 -1.1759795 -1.1633861]
 ...
 [ 2.0353246  2.0353246  2.0227313]
 [ 2.0353246  2.0353246  2.0353

  0%|          | 0/29 [00:03<?, ?it/s]


Unique colors before in mask: [[255   0   0]
 [255   0 255]]
Unique colors before in mask: [[  0   0   0]
 [255   0   0]
 [255   0 255]]
Unique colors before in mask: [[255   0   0]
 [255   0 255]]
Unique colors before in mask: [[  0   0   0]
 [255   0   0]
 [255   0 255]]
Unique colors after in mask: [[-1.262295  -1.262295  -1.262295 ]
 [-1.1759795 -1.1759795 -1.1759795]
 [-1.1759795 -1.1759795 -1.1633861]
 ...
 [ 2.0353246  2.0353246  2.0101378]
 [ 2.0353246  2.0353246  2.0353246]
 [ 2.3215308  2.3215308  2.3215308]]
Unique colors after in mask: [[-1.262295  -1.262295  -1.262295 ]
 [-1.2550071 -1.2550071 -1.2550071]
 [-1.2550071 -1.2550071 -1.2409816]
 ...
 [ 2.3215308  2.3215308  2.2934797]
 [ 2.3215308  2.3215308  2.3075054]
 [ 2.3215308  2.3215308  2.3215308]]
Unique colors after in mask: [[-1.262295  -1.262295  -1.262295 ]
 [-1.2550071 -1.2550071 -1.2550071]
 [-1.2550071 -1.2550071 -1.2269559]
 ...
 [ 2.3215308  2.3215308  2.2654285]
 [ 2.3215308  2.3215308  2.3075054]
 [ 2.32153

KeyboardInterrupt: 

246   1.7330841  -0.6722455 ]
 [ 2.0353246   1.796051    0.14632222]
 [ 2.0353246   1.8464245  -0.10554482]
 [ 2.0353246   1.8590176   0.41078252]
 [ 2.0353246   1.8590176   0.42337587]
 [ 2.0353246   1.8842043  -0.05517143]
 [ 2.0353246   1.9219846   0.5870894 ]
 [ 2.0353246   1.9345777   0.15891558]
 [ 2.0353246   1.9471712   0.49893597]
 [ 2.0353246   1.9597644   0.08335549]
 [ 2.0353246   1.9723579   0.66264945]
 [ 2.0353246   1.984951    0.05816879]
 [ 2.0353246   1.984951    1.0026698 ]
 [ 2.0353246   1.9975446   1.11601   ]
 [ 2.0353246   1.9975446   1.2041634 ]
 [ 2.0353246   2.0101378   1.2041634 ]
 [ 2.0353246   2.0101378   1.2167568 ]
 [ 2.0353246   2.0227313   1.0782299 ]
 [ 2.0353246   2.0353246   0.83895636]
 [ 2.0353246   2.0353246   1.3426905 ]
 [ 2.0353246   2.0353246   1.544184  ]
 [ 2.0353246   2.0353246   1.6197442 ]
 [ 2.0353246   2.0353246   1.6323373 ]
 [ 2.0353246   2.0353246   1.657524  ]
 [ 2.0353246   2.0353246   1.7330841 ]
 [ 2.0353246   2.0353246   1.74567

Unique colors before in mask: [[  0   0   0]
 [255   0   0]
 [255   0 255]]
Unique colors before in mask: [[255   0   0]
 [255   0 255]]
Unique colors before in mask: [[255   0   0]
 [255   0 255]]
Unique colors before in mask: [[  0   0   0]
 [255   0   0]
 [255   0 255]]
Unique colors after in mask: [[-1.262295  -1.262295  -1.262295 ]
 [-1.2550071 -1.2550071 -1.2550071]
 [-1.2550071 -1.2550071 -1.1989046]
 ...
 [ 2.3215308  2.2514029  2.209326 ]
 [ 2.3215308  2.2654285  2.1953003]
 [ 2.3215308  2.3215308  2.3215308]]
Unique colors after in mask: [[-1.262295   -1.262295   -1.262295  ]
 [-1.1759795  -1.1759795  -1.1759795 ]
 [-1.1759795  -1.1759795  -1.1633861 ]
 [-1.1759795  -1.1759795  -1.1507927 ]
 [-1.1759795  -1.1759795  -1.1381994 ]
 [-1.1759795  -1.1759795  -1.1256061 ]
 [-1.1759795  -1.1759795  -1.1004194 ]
 [-1.1759795  -1.1759795  -1.087826  ]
 [-1.1759795  -1.1759795  -1.0752326 ]
 [-1.1759795  -1.1759795  -1.050046  ]
 [-1.1759795  -1.1759795  -1.0122659 ]
 [-1.1759795  -1.

In [None]:
def val_model(model, loader, device, num_classes=3):
    model.eval()
    num_correct = 0
    num_pixels = 0
    dice_score = 0

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            preds = torch.argmax(model(x), dim=1)
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            
            for cls in range(num_classes):
                pred_cls = (preds == cls).float()
                true_cls = (y == cls).float()
                dice_score += (2 * (pred_cls * true_cls).sum()) / (pred_cls.sum() + true_cls.sum() + 1e-8)

    accuracy = num_correct / num_pixels
    dice_score /= (len(loader) * num_classes)
    
    print(f"Val Accuracy: {accuracy:.4f}")
    print(f"Val Dice Score: {dice_score:.4f}")
    
    model.train()
    return accuracy, dice_score
    

checkpoint_path = "my_checkpoint.pth.tar"

load_checkpoint(torch.load(checkpoint_path), model)
val_acc, val_dice = val_model(model, val_loader, device=DEVICE, num_classes=3)

In [None]:
from PIL import Image

def calculate_mean_std(image_dir):
    transform = transforms.Compose([
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.ToTensor()
    ])

    all_images = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(".png")]
    
    mean = torch.zeros(3)
    std = torch.zeros(3)
    
    for img_path in tqdm(all_images):
        img = Image.open(img_path).convert("RGB")
        img_tensor = transform(img)
        mean += img_tensor.mean([1, 2])
        std += img_tensor.std([1, 2])
    
    mean /= len(all_images)
    std /= len(all_images)
    
    return mean, std

# Calculate mean and std
mean, std = calculate_mean_std(IMG_DIR)
print(f"Mean: {mean}")
print(f"Std: {std}")

In [None]:
def number_of_parameters(x):
    params = []
    parameters = list(x.parameters())
    for parameter in parameters:
        total = parameter.flatten().shape[0]
        params.append(total)
        
    return sum(params)

print(number_of_parameters(model))