## Load Dataset

In [None]:
from src.data.dataset_foodseg import FoodSegDataset
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

# Parameters
image_size = 256
num_classes = 104  # 103 classes + background
batch_size = 4

transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

# Row train + val dataset
full_train_dataset = FoodSegDataset(split="train", transform=None)
test_dataset = FoodSegDataset(split="val", transform=None)

# Split train_dataset -> train(80%) + val(20%)
val_ratio = 0.2
train_size = int(len(full_train_dataset) * (1 - val_ratio))
val_size = len(full_train_dataset) - train_size

train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])
print(f"Train: {len(train_dataset)}  Val: {len(val_dataset)}  Test: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


ModuleNotFoundError: No module named 'datasets'

## Model & Optimizer

In [None]:
import torch
import torch.nn as nn
from model.UNet import UNet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=3, n_classes=num_classes).to(device)

loss_fn = nn.CrossEntropyLoss()  # segmentation loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## Training

In [None]:
import os
save_path = "../output/"
os.makedirs(save_path, exist_ok=True)

epochs = 10
valid_interval = 2
train_losses = []
valid_losses = []
best_val_loss = float('inf')

for epoch in range(epochs):
    model.train()
    running_loss = 0

    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(imgs) # (B, num_classes, H, W)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    print(f"Epoch {epoch+1} loss: {epoch_loss:.4f}")
    
    if ((epoch + 1) % valid_interval == 0):
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for img, mask in val_loader:
                img, mask = img.to(device), mask.to(device)
                output = model(img)
                loss = loss_fn(output, mask)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        valid_losses.append(val_loss)
        print(f"[Epoch {epoch+1}] Validation Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path+"unet.pth")
            print(f"New best model saved (val loss = {best_val_loss:.4f})")


## Testing and visualization

In [None]:
import random
import matplotlib.pyplot as plt
import numpy as np

# Simple color map (random)
np.random.seed(42)
COLORMAP = np.random.randint(0, 255, size=(num_classes, 3), dtype=np.uint8)

def visualize(img, mask, pred):
    img = img.permute(1,2,0).cpu().numpy()
    mask_color = COLORMAP[mask.cpu().numpy()]
    pred_color = COLORMAP[pred.cpu().numpy()]

    fig, ax = plt.subplots(1,3, figsize=(12,4))
    ax[0].imshow(img)
    ax[0].set_title("Image")
    ax[1].imshow(mask_color)
    ax[1].set_title("Ground Truth")
    ax[2].imshow(pred_color)
    ax[2].set_title("Prediction")
    plt.show()
    
model.eval()
test_loss = 0.0
# Randomly selected five images from the test set
test_indices = random.sample(range(len(test_dataset)), 5)

with torch.no_grad():
    for img, mask in test_loader:
        img, mask = img.to(device), mask.to(device)
        output = model(img)
        loss = loss_fn(output, mask)
        test_loss += loss.item()

    # Sampling Visualization
    for idx in test_indices:
        sample = test_dataset[idx] 
        img = sample["image"].to(device).unsqueeze(0)
        mask = torch.tensor(sample["mask"]).to(device).unsqueeze(0)
        pred = torch.argmax(model(img), dim=1)
        visualize(img[0], mask[0], pred[0])
    
test_loss /= len(test_loss)
print(f'Testing Loss: {test_loss:.8f}, ')


NameError: name 'num_classes' is not defined

## Plot Loss Curve

In [None]:
# Generate the epoch list corresponding to the x axis
total_epochs = epochs
epoch_list = list(range(1, total_epochs + 1))
valid_epochs = list(range(valid_interval, total_epochs + 1, valid_interval))
#test_epochs = list(range(test_interval, len(train_losses) + 1, test_interval))

plt.figure(figsize=(10, 6))
plt.plot(epochs, train_losses, label='Training Loss', marker='o')
plt.plot(valid_epochs, valid_losses, label='Validation Loss', marker='o')
#plt.plot(test_epochs, test_losses, label='Test Loss', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Total Loss')
plt.title('Training And Validation Loss Over Epochs')
plt.legend()
plt.grid(True)

plt.savefig(save_path+'loss_plot.png', dpi=300, bbox_inches='tight')
plt.show()
