In [6]:
import os
import cv2
import json
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torch.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet18, resnet34

In [None]:
class ImitationDataset(Dataset):
    def __init__(self, rgb_dir, seg_dir, log_path, transforms=None, img_size=(128, 128)):
        self.rgb_dir = rgb_dir
        self.seg_dir = seg_dir
        self.log_path = log_path
        self.transform = transforms
        self.img_size = img_size
        
        with open(log_path, 'r') as f:
            self.log_data = json.load(f)
            
    def __len__(self):
        return len(self.log_data)
    
    def __getitem__(self, index):
        # print(f"Loading sample {index}")
        rgb_image_path = os.path.join(self.rgb_dir, f"{index:05d}.png")
        seg_image_path = os.path.join(self.seg_dir, f"{index:05d}.png")
        
        rgb_image = Image.open(rgb_image_path).resize(self.img_size)
        seg_image = Image.open(seg_image_path).resize(self.img_size)
        
        if self.transform:
            rgb_tensor = self.transform(rgb_image)
        else:
            rgb_tensor = transforms.ToTensor()(rgb_image)
            
        seg_image = np.array(seg_image)
        
        lane_mask = np.all(seg_image == [0, 255, 0], axis=2).astype(np.uint8)
        obs_mask = np.all(seg_image == [255, 0, 0], axis=2).astype(np.uint8)
        
        seg_tensor = torch.tensor(np.stack([lane_mask, obs_mask], axis=0), dtype=torch.float32)
        
        if seg_tensor.shape[1:] != rgb_tensor.shape[1:]:
            seg_tensor = F.interpolate(
                seg_tensor.unsqueeze(0),
                size=rgb_tensor.shape[1:],
                mode='nearest'
            ).squeeze(0)
        
        input_tensor = torch.cat([rgb_tensor, seg_tensor], dim=0)
                        
        control = self.log_data[index]
        control_tensor = torch.tensor([
            float(f"{control['steer']:.3f}"),
            float(f"{control['throttle']:.3f}"),
            control['brake']
        ], dtype=torch.float32)

        return input_tensor, control_tensor
    


In [8]:
class ImitationResNet(nn.Module):
    def __init__(self, pretrained=True, backbone='resnet34'):
        super(ImitationResNet, self).__init__()

        if backbone == 'resnet34':
            base_model = resnet34(pretrained=pretrained)
        else:
            base_model = resnet18(pretrained=pretrained)
        
        self.conv1 = nn.Conv2d(5, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = base_model.bn1
        self.relu = base_model.relu
        self.maxpool = base_model.maxpool
        self.layer1 = base_model.layer1
        self.layer2 = base_model.layer2
        self.layer3 = base_model.layer3
        self.layer4 = base_model.layer4
        self.avgpool = base_model.avgpool
        
        self.lstm = nn.LSTM(input_size=512, hidden_size=256, num_layers=1, batch_first=True)
        self.fc = nn.Linear(256, 512)
        self.steer_head = nn.Linear(512, 1)
        self.throttle_head = nn.Linear(512, 1)
        self.brake_head = nn.Linear(512, 1)

        self._init_weights_from_rgb(base_model)

    def _init_weights_from_rgb(self, base_model):
        old_weights = base_model.conv1.weight.data
        new_weights = torch.zeros((64, 5, 7, 7))  # (out_channels, in_channels, H, W)
        new_weights[:, :3, :, :] = old_weights
        self.conv1.weight.data = new_weights

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        
        x = x.unsqueeze(1)
        x, _ = self.lstm(x)  
        x = x.squeeze(1)  
        
        x = torch.relu(self.fc(x))
        steer = torch.tanh(self.steer_head(x))  
        throttle = torch.sigmoid(self.throttle_head(x))  
        brake = self.brake_head(x)  
        
        return torch.cat([steer, throttle, brake], dim=1)

In [11]:
rgb_dir = 'Dataset/rgb_image'
seg_dir = 'Dataset/seg_image'
log_path = 'logs/logs.json'
checkpoint_path = 'checkpoints/last_epoch.pth'

os.makedirs('checkpoints', exist_ok=True)

dataset = ImitationDataset(rgb_dir, seg_dir, log_path, None, img_size=(128, 128))
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_set, batch_size=2048, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=2048, shuffle=False, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# model = ImitationCNN().to(device)
model = ImitationResNet(pretrained=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scaler = torch.amp.GradScaler('cuda')

start_epoch = 0
num_epochs = 100
best_rmse = float('inf')

if os.path.exists(checkpoint_path):
    print("Resuming training from last checkpoint")
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    train_losses = checkpoint['train_losses']
    val_steer = checkpoint['val_steer']
    val_throttle = checkpoint['val_throttle']
    val_brake = checkpoint['val_brake']
else:
    print("Starting training, no model found")
    train_losses = []
    val_steer, val_throttle, val_brake = [], [], []

for epoch in range(start_epoch, num_epochs):
    model.train()
    running_loss = 0.0
    
    for i, (inputs, targets) in enumerate(train_loader):
        # print(f"Processing batch {i+1}/{len(train_loader)}")
        inputs = inputs.to(device)
        targets = targets.to(device)


        optimizer.zero_grad()

        with autocast(device_type='cuda'):
            outputs= model(inputs)
            loss_steer = F.mse_loss(outputs[:, 0], targets[:, 0])
            loss_throttle = F.smooth_l1_loss(outputs[:, 1], targets[:, 1])
            loss_brake = F.binary_cross_entropy_with_logits(outputs[:, 2], targets[:, 2])
            loss = loss_steer + 2.0 * loss_throttle + 2.0 * loss_brake
        
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
    
    model.eval()
    total_se, total_th, total_br = 0, 0, 0
    n=0
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            with autocast(device_type='cuda'):
                outputs = model(inputs)
                total_se += F.mse_loss(outputs[:, 0], targets[:, 0], reduction='sum').item()
                total_th += F.smooth_l1_loss(outputs[:, 1], targets[:, 1], reduction='sum').item()
                total_br += F.binary_cross_entropy_with_logits(outputs[:, 2], targets[:, 2], reduction='sum').item()
            n += inputs.size(0)
            
    # val_rmse = ((total_se +total_th + total_br)/(3*n)) **0.5
    val_st = (total_se/n)**0.5     
    val_th = (total_th/n)**0.5     
    val_br = total_br/n
    val_rmse =  val_st + val_th + val_br
    
    if  val_rmse < best_rmse:
        best_rmse = val_rmse
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_losses': train_losses,
            'val_steer': val_steer,
            'val_throttle': val_throttle,
            'val_brake': val_brake
        }, 'models/bc_model.pth')

    
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_steer': val_steer,
        'val_throttle': val_throttle,
        'val_brake': val_brake
    }, 'checkpoints/last_epoch.pth') 

    
    
    train_losses.append(running_loss/len(train_loader))
    val_steer.append(val_st)
    val_throttle.append(val_th)
    val_brake.append(val_br)
    
    
    print(f"[VAL] Steer MSE: {val_st:.4f}, Throttle RMSE: {val_th:.4f}, Brake RMSE: {val_br:.4f}")
    
    with open('logs/val_metrics.csv', 'a') as f:
        f.write(f"{epoch+1},{val_st:.4f},{val_th:.4f},{val_br:.4f}\n")


cuda
Resuming training from last checkpoint
Epoch [18/100], Loss: 0.0358
[VAL] Steer MSE: 0.0317, Throttle RMSE: 0.0489, Brake RMSE: 0.0164
Epoch [19/100], Loss: 0.0272
[VAL] Steer MSE: 0.0312, Throttle RMSE: 0.0486, Brake RMSE: 0.0143
Epoch [20/100], Loss: 0.0218
[VAL] Steer MSE: 0.0322, Throttle RMSE: 0.0488, Brake RMSE: 0.0140
Epoch [21/100], Loss: 0.0180
[VAL] Steer MSE: 0.0321, Throttle RMSE: 0.0485, Brake RMSE: 0.0139
Epoch [22/100], Loss: 0.0146
[VAL] Steer MSE: 0.0329, Throttle RMSE: 0.0489, Brake RMSE: 0.0144
Epoch [23/100], Loss: 0.0141
[VAL] Steer MSE: 0.0317, Throttle RMSE: 0.0490, Brake RMSE: 0.0152
Epoch [24/100], Loss: 0.0138
[VAL] Steer MSE: 0.0334, Throttle RMSE: 0.0486, Brake RMSE: 0.0152
Epoch [25/100], Loss: 0.0134
[VAL] Steer MSE: 0.0309, Throttle RMSE: 0.0485, Brake RMSE: 0.0154
Epoch [26/100], Loss: 0.0131
[VAL] Steer MSE: 0.0319, Throttle RMSE: 0.0485, Brake RMSE: 0.0161
Epoch [27/100], Loss: 0.0130
[VAL] Steer MSE: 0.0307, Throttle RMSE: 0.0485, Brake RMSE: 0.0