In [1]:
import sys
import os

# Add parent folder (ml/) to Python path
sys.path.append(os.path.abspath(".."))

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.models import ResNet18_Weights
from torch.utils.data import DataLoader, random_split
from training.dataset import RoadSightDataset
from training.transforms import train_transforms, val_transforms


In [2]:
image_dir = "../data/raw/RDD2020/Japan/images"
annotation_dir = "../data/raw/RDD2020/Japan/annotations/xmls"

dataset = RoadSightDataset(image_dir, annotation_dir, transform=train_transforms)

# Split dataset: 80% train, 20% validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

weights = ResNet18_Weights.DEFAULT
model = models.resnet18(weights=weights)
model.fc = nn.Linear(model.fc.in_features, 2)

# Load previous weights (baseline model)
model.load_state_dict(torch.load("../models/roadsight_v1.pt", map_location=device))
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)  # smaller LR for fine-tuning


In [4]:
num_epochs = 5  # You can increase later

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_train_loss = running_loss / len(train_loader)
    
    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = correct / total
    
    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {avg_train_loss:.4f} "
          f"Val Loss: {avg_val_loss:.4f} "
          f"Val Accuracy: {val_accuracy:.4f}")


Epoch [1/5] Train Loss: 0.2387 Val Loss: 0.2355 Val Accuracy: 0.9234
Epoch [2/5] Train Loss: 0.2071 Val Loss: 0.2252 Val Accuracy: 0.9244
Epoch [3/5] Train Loss: 0.1995 Val Loss: 0.2459 Val Accuracy: 0.9244
Epoch [4/5] Train Loss: 0.1927 Val Loss: 0.2248 Val Accuracy: 0.9220
Epoch [5/5] Train Loss: 0.1774 Val Loss: 0.2105 Val Accuracy: 0.9225


In [5]:
torch.save(model.state_dict(), "../models/roadsight_v2.pt")
