In [2]:
!pip install ncps
!pip install torchsummary
!pip install torchinfo



In [3]:
import pandas as pd
import re
import cv2
from datetime import datetime
import matplotlib.pyplot as plt
import os
import torch
import numpy as np
import cv2
from collections import OrderedDict
from itertools import islice
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn as nn
from ncps.wirings import NCP
from ncps.torch import LTC
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from torchinfo import summary
import scipy

## Check Data

In [22]:
pd.set_option("display.max_rows", 200)
plt.ion() 

def get_full_image_filepaths(data_dir):
    filepaths = os.listdir(data_dir)
    sorted_filepaths = sorted(filepaths, key=lambda x: int(re.search(r'\d+', x).group()))
    full_filepaths = [os.path.join(data_dir, file) for file in sorted_filepaths]

    return full_filepaths

def get_steering_angles(path):
    file = open(path, 'r')
    lines = file.readlines()
    #change this to line[:-1] if using kaggle dataset else keep the same for sullychen dataset
    steering_angles = [(line[:-1]).split(' ')[1].split(',')[0] for line in lines]
    timestamps = [line.strip().split(' ', 1)[1].split(',')[1] for line in lines]

    return [steering_angles, timestamps]

def convert_to_df(full_filepaths, steering_angles, timestamps, norm=True):
    data = pd.DataFrame({'filepath':full_filepaths,'steering_angle':steering_angles, 'timestamps':timestamps})
    data['parsed_timestamp'] = data['timestamps'].apply(lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S:%f'))
    data = data.drop('timestamps', axis=1)
    data = data.reset_index(drop=True)
    data['steering_angle'] = data['steering_angle'].astype('float')
    #normalize steering angles to -1,1 based on actual steering range (-360, 360) of wheel
    if norm == True:
        data['steering_angle'] = data['steering_angle']  / 360
    return data

def display_images_with_angles(data_pd, steering_wheel_img):

    # steering wheel image
    img = cv2.imread(steering_wheel_img,0)
    rows,cols = img.shape
    i = 0
    xs = data_pd['filepath'].values
    ys = data_pd['steering_angle'].values

    #Press q to stop 
    while(cv2.waitKey(100) != ord('q')) and i < len(xs):
        try:
            # Driving Test Image displayed as Video
            full_image = cv2.imread(xs[i])

            degrees = ys[i] * 360
            print("Steering angle: " + str(degrees) + " (actual)")
            cv2.imshow("frame", full_image)

            # Angle at which the steering wheel image should be rotated
            M = cv2.getRotationMatrix2D((cols/2,rows/2),-degrees,1)
            dst = cv2.warpAffine(img,M,(cols,rows))

            cv2.imshow("steering wheel", dst)
        except:
            print('ERROR at', i)
        i += 1
    cv2.destroyAllWindows()

def disp_freq_steering_angles(data_pd):
    # Define bin edges from -1 to 1 with step 0.1
    bin_edges = list(range(-10, 11))  # Since steering angle is between -1 and 1, multiply by 10
    bin_edges = [x / 10 for x in bin_edges]  # Convert back to decimal values

    # Assign steering angles to bins
    data_pd['binned'] = pd.cut(data_pd['steering_angle'].astype('float'), bins=bin_edges, right=False)

    # Count occurrences in each bin
    bin_counts = data_pd['binned'].value_counts().sort_index()

    # Plot bar chart
    plt.figure(figsize=(12, 5))
    bin_counts.plot(kind='bar', color='skyblue', edgecolor='black')

    # Formatting
    plt.xlabel("Steering Angle Range")
    plt.ylabel("Frequency")
    plt.title("Frequency of Steering Angles in 0.1 Intervals")
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    plt.show()

def disp_start_and_end_in_filtered_data(data_pd_filtered):
    turn_starts = data_pd_filtered[data_pd_filtered['turn_shift'] == 1]
    turn_ends = data_pd_filtered[data_pd_filtered['turn_shift'] == -1]

    plt.figure(figsize=(12, 5))
    plt.plot(data_pd_filtered.index, data_pd_filtered['steering_angle'], label="Steering Angle", alpha=0.5)
    plt.scatter(turn_starts.index, turn_starts['steering_angle'], color='red', label="Turn Start", marker="o")
    plt.scatter(turn_ends.index, turn_ends['steering_angle'], color='blue', label="Turn End", marker="x")

    plt.xlabel("Frame Index")
    plt.ylabel("Steering Angle")
    plt.title("Turn Start and End Points in Steering Angle Data")
    plt.legend()
    plt.grid(True)
    plt.show()

def filter_df_on_turns(data_pd, turn_threshold = 0.06, buffer_before = 60, buffer_after = 60):
    # Parameters
    turn_threshold = turn_threshold  # Define turn threshold (absolute value)
    buffer_before = buffer_before    # Frames to include before a turn
    buffer_after = buffer_after     # Frames to include after a turn

    # Load your dataset (assuming it's a DataFrame named df)
    data_pd['index'] = data_pd.index  # Preserve original ordering if needed

    # Identify where turning happens
    data_pd['turning'] = (data_pd['steering_angle'].abs() > turn_threshold).astype(int)

    # Find where turns start and end
    data_pd['turn_shift'] = data_pd['turning'].diff()  # 1 indicates start, -1 indicates end

    # Get turn start and end indices
    turn_starts = data_pd[data_pd['turn_shift'] == 1].index
    turn_ends = data_pd[data_pd['turn_shift'] == -1].index

    # Ensure equal number of start and end points
    if len(turn_ends) > 0 and turn_starts[0] > turn_ends[0]:  
        turn_ends = turn_ends[1:]  # Drop the first turn_end if it comes before a start

    # Selected indices for keeping
    selected_indices = set()

    for start, end in zip(turn_starts, turn_ends):
        # Include a buffer of frames before and after
        start_idx = max(0, start - buffer_before)
        end_idx = min(len(data_pd) - 1, end + buffer_after)
        
        # Add indices to selection
        selected_indices.update(range(start_idx, end_idx + 1))

    # Create filtered dataset
    data_pd_filtered = data_pd.loc[sorted(selected_indices)].reset_index(drop=True)

    # Drop temporary columns
    data_pd_filtered = data_pd_filtered.drop(columns=['turning', 'turn_shift'])

    # Detect sequence breaks (where the original index is not continuous)
    data_pd_filtered["sequence_id"] = (data_pd_filtered["index"].diff() != 1).cumsum()

    return data_pd_filtered

def group_data_by_sequences(data_pd_filtered):
    sequence_lengths = data_pd_filtered.groupby("sequence_id").size()

    print(f"Minimum sequence length: {min(sequence_lengths)}")
    print(f"Maximum sequence length: {max(sequence_lengths)}")

    # plt.plot(sequence_lengths)

    # Keep only sequences with at least 10 frames (adjust as needed)
    valid_sequences = sequence_lengths[sequence_lengths >= 40].index
    data_pd_filtered = data_pd_filtered[data_pd_filtered["sequence_id"].isin(valid_sequences)]

    print(f"Total valid sequences: {len(valid_sequences)}")

    return data_pd_filtered

def get_preprocessed_data_pd(data_dir, steering_angles_txt_path, filter = True,
                             turn_threshold = 0.06, buffer_before = 60, buffer_after = 60,
                             norm=True, save_dir = 'data/csv_files'):
    img_paths = get_full_image_filepaths(data_dir)
    steering_angles, timestamps = get_steering_angles(steering_angles_txt_path)

    data_pd = convert_to_df(img_paths, steering_angles, timestamps, norm)
    if filter and norm:
        data_pd_filtered = filter_df_on_turns(data_pd, turn_threshold = turn_threshold, 
                                            buffer_before = buffer_before, buffer_after = buffer_after)
        data_pd_filtered = group_data_by_sequences(data_pd_filtered)

        # Save
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        data_pd_filtered.to_csv(os.path.join(save_dir,f"flt_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"), index=False)

        return data_pd_filtered
    
    elif filter and not norm:
        print('Error: If filtering, then steering values should be normalized (-1, 1), set norm to True')
        exit(1)

    else:
        sequence_id = 0
        sequence_ids = [sequence_id]

        # Iterate through rows to calculate time differences (if greater than 3 seconds 
        # or not) and assign sequence IDs
        for i in range(1, len(data_pd)):
            time_diff = (data_pd['parsed_timestamp'][i] - data_pd['parsed_timestamp'][i-1]).total_seconds()
            if time_diff > 3:
                sequence_id += 1
            sequence_ids.append(sequence_id)

        # Add sequence_id column to DataFrame
        data_pd['sequence_id'] = sequence_ids

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        data_pd.to_csv(os.path.join(save_dir,f"ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"), index=False)
        
        return data_pd

## Model

In [6]:
class WeightedMSE(nn.Module):
    def __init__(self, alpha=0.1):
        super(WeightedMSE, self).__init__()
        self.alpha = alpha
        
    def forward(self, predictions, targets):
        # squared error
        squared_error = (predictions - targets)**2
        
        # weighting factor: w(y) = exp(|y|*alpha)
        weights = torch.exp(torch.abs(targets) * self.alpha)
        weighted_loss = squared_error * weights

        return weighted_loss.mean()

class convolutional_head(nn.Module):
    def __init__(self, num_filters = 8, features_per_filter = 16):
        super(convolutional_head, self).__init__()

        self.num_filters = num_filters
        self.features_per_filter = features_per_filter

        self.conv1 = nn.Conv2d(3, 24, kernel_size=5, stride=2, padding=2)
        self.conv2 = nn.Conv2d(24, 36, kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(36, 48, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(48, 64, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(64, num_filters, kernel_size=3, stride=1, padding=1)

        self.relu = nn.ReLU()

        # FC to extract features per filter
        self.fc_layers = nn.ModuleList([
            nn.Linear(28 * 28, features_per_filter) for _ in range(num_filters)
        ])

        self.activations = []
        self.feature_layer = None

    def forward(self, x):
        """
        Forward pass of the convolutional head.

        :param x: Input tensor of shape [batch, channels, height, width]
        :return: Feature vector of shape [batch, num_filters * features_per_filter]
        """

        self.activations = []
        batch_size = x.shape[0]

        # apply conv layers
        # [batch, num_filters, height, width] -> [512, 8, 28, 28] for seq len of 64 and batch size of 8

        x = self.relu(self.conv1(x)); self.activations.append(x)
        x = self.relu(self.conv2(x)); self.activations.append(x)
        x = self.relu(self.conv3(x)); self.activations.append(x)
        x = self.relu(self.conv4(x)); self.activations.append(x)
        x = self.relu(self.conv5(x)); self.activations.append(x)

        # individual filter outputs
        filter_outputs = torch.split(x, 1, dim=1)  # splitting along channel dimension -> len of 8 for num_filters = 8
        #shape of a single filter_output -> [512, 1, 28, 28]

        feature_vectors = []
        for i in range(self.num_filters):
            filter_out = filter_outputs[i].view(batch_size, -1)  # flatten each filter output -> shape of: [512, 784]
            feature_vec = F.relu(self.fc_layers[i](filter_out))  # apply FC layer -> shape of: [512, 4]
            feature_vectors.append(feature_vec)

        # concat feature vectors
        feature_layer = torch.cat(feature_vectors, dim=1)  # [batch, num_filters * features_per_filter]
        # [512, 32]
        self.feature_layer = feature_layer

        return feature_layer
    
    def visual_backprop(self, idx=0):
        """
        VisualBackprop-like mask computation using torch (GPU compatible).
        Returns: [H, W] attention mask (still returned as a CPU numpy array).
        """
        # mean maps for each layer
        means = []

        for layer_act in self.activations:
            # [B, C, H, W] -> one sample
            a = layer_act[idx]  # [C, H, W]
            a = a.float()
            per_channel_max = torch.amax(torch.amax(a, dim=1), dim=1) + 1e-6  # [C]
            norm = a / per_channel_max[:, None, None]  # [C, H, W]
            mean_map = norm.mean(dim=0)  # [H, W]
            means.append(mean_map)

        # feature-level activation mask
        feat_layer = self.feature_layer[idx]  # [num_filters * features_per_filter]
        feat_layer = torch.abs(feat_layer).view(self.num_filters, self.features_per_filter)  # [F, P]
        feat_mask = feat_layer.mean(dim=1)  # [F]
        feat_mask = feat_mask / (feat_mask.max() + 1e-6)

        # applies a rough weighting on last activation map
        mask = means[-1] * feat_mask.mean()  # [H, W]

        # backward pass through mean activations (resize each to next layer size)
        for i in range(len(means) - 2, -1, -1):
            larger = means[i]  # [H, W]
            smaller = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=larger.shape, mode='bilinear', align_corners=False)
            smaller = smaller.squeeze()
            mask = larger * smaller

        # normalize and move to cpu
        mask = mask - mask.min()
        mask = mask / (mask.max() + 1e-6)
        return mask.detach().cpu().numpy()

    
class ConvNCPModel(nn.Module):
    def __init__(self, num_filters=8, features_per_filter=4, 
                 inter_neurons = 12, command_neurons = 6, motor_neurons = 1, 
                 sensory_fanout = 6, inter_fanout = 4, recurrent_command_synapses = 6,
                 motor_fanin = 6, seed = 20190120):
        super(ConvNCPModel, self).__init__()

        # Define NCP wiring based on CommandLayerWormnetArchitecture parameters (from NCP Paper)
        wiring = NCP(
            inter_neurons=inter_neurons,   # Number of interneurons
            command_neurons=command_neurons,  # Number of command neurons
            motor_neurons=motor_neurons,    # Output neurons (1 for steering)
            sensory_fanout=sensory_fanout,   # Number of interneurons each sensory neuron connects to
            inter_fanout=inter_fanout,     # Number of command neurons each interneuron connects to
            recurrent_command_synapses=recurrent_command_synapses,  # Recurrent connections in the command layer
            motor_fanin=motor_fanin,      # Number of command neurons each motor neuron connects to
            seed=seed       # Random seed for reproducibility
        )

        self.conv_head = convolutional_head(num_filters, features_per_filter)

        self.ltc = LTC(
            input_size=num_filters * features_per_filter,  # should match the conv head output
            units=wiring,
            return_sequences=True
        )

        # FC layer to map motor neuron output to a steering angle
        self.fc_out = nn.Linear(wiring.output_dim, 1)

    def forward(self, x):
        """
        Forward pass: Conv Head → LTC-NCP → Fully Connected.
        :param x: Input shape [batch, seq_len, channels, height, width]
        :return: Steering angles [batch, seq_len]
        """
        batch_size, seq_len, c, h, w = x.size()

        # flatten batch and sequence for cnn processing
        x = x.view(batch_size * seq_len, c, h, w)

        # extract features using conv head
        features = self.conv_head(x)  # [batch * seq_len, feature_dim]

        # back to [batch, seq_len, feature_dim] for LTC
        features = features.view(batch_size, seq_len, -1)

        # forward pass through LTC
        outputs, _ = self.ltc(features) # [batch, seq_len, 1]

        # map NCP output to steering angle
        predictions = self.fc_out(outputs) # [batch, seq_len, 1]
        return predictions.squeeze(-1)  # [batch, seq_len]

## Dataset

In [12]:
def df_split_train_val(df_filtered, train_csv_filename, val_csv_filename,
                       save_dir='data/csv_files',train_size = 0.8):
    train_dataset = df_filtered[:int(train_size * len(df_filtered))]
    val_dataset = df_filtered[int(train_size * len(df_filtered)):]
    print('Train dataset length:', len(train_dataset))
    print('Val dataset length:', len(val_dataset))
    print(save_dir)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    train_dataset.to_csv(os.path.join(save_dir,train_csv_filename), index=False)
    val_dataset.to_csv(os.path.join(save_dir,val_csv_filename), index=False)

    return os.path.join(save_dir,train_csv_filename), os.path.join(save_dir,val_csv_filename)

def calculate_mean_and_std(dataset_path):
    num_pixels = 0
    channel_sum = np.zeros(3)  # Assuming RGB images, change to (1,) for grayscale
    channel_sum_squared = np.zeros(3)  # Assuming RGB images, change to (1,) for grayscale

    for root, _, files in os.walk(dataset_path):
        for file in files:
            image_path = os.path.join(root, file)
            image = Image.open(image_path).convert('RGB')  # Convert to RGB if needed

            pixels = np.array(image) / 255.0  # Normalize pixel values between 0 and 1
            num_pixels += pixels.size // 3  # Assuming RGB images, change to 1 for grayscale

            channel_sum += np.sum(pixels, axis=(0, 1))
            channel_sum_squared += np.sum(pixels ** 2, axis=(0, 1))

    mean = channel_sum / num_pixels
    std = np.sqrt((channel_sum_squared / num_pixels) - mean ** 2)
    return mean, std

class CustomDataset(Dataset):
    def __init__(self, csv_file, seq_len, imgh=224, imgw=224, step_size=1, crop=True, transform=None):
        """
        Dataset that extracts continuous sequences from the given DataFrame.

        :param df: Filtered DataFrame with columns ['filepath', 'steering_angle', 'sequence_id']
        :param seq_len: Length of each sequence
        :param transform: Transformations for image preprocessing
        """
        self.df = pd.read_csv(csv_file)
        self.seq_len = seq_len
        self.transform = transform
        self.imgh = imgh
        self.imgw = imgw
        self.step_size = step_size
        self.crop = crop

        # Group by sequence_id and collect valid sequences
        self.sequences = OrderedDict({})
        num_sequences_total = 0
        # for each sequence id
        for seq_id in self.df["sequence_id"].unique():
            seq_data = self.df[self.df["sequence_id"] == seq_id]
            num_sequences = max((len(seq_data) - self.seq_len)//self.step_size + 1, 0)
            num_sequences_total += num_sequences
            # for each sequence of len=self.seq_len for that sequence_id
            for i in range(0,len(seq_data) - self.seq_len + 1, self.step_size):  # Only full sequences
                self.sequences[(seq_id,i)] = (seq_data.iloc[i : i + self.seq_len])

        self.index_map = {key: i for i, key in enumerate(self.sequences.keys())}

        print(f"Total sequences extracted: {len(self.sequences)} using step_size={self.step_size} and seq_len={self.seq_len}")

    def get_ith_element(self, od, i):
        return next(islice(od.items(), i, None))
    
    def _crop_lower_half(self,img, keep_ratio=0.6):
        #Crops the bottom `keep_ratio` portion of the image.
        h = img.shape[0]
        crop_start = int(h * (1 - keep_ratio))
        return img[crop_start:, :, :]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq_batch = self.get_ith_element(self.sequences, idx)
        sequence_id = seq_batch[0][0]
        seq_num = seq_batch[0][1]
        seq_df = seq_batch[1]
        # Extract filepaths and steering angles
        img_names = seq_df['filepath'].tolist()
        angles = torch.tensor(seq_df['steering_angle'].tolist(), dtype=torch.float32)

        # Read and process images in one go with OpenCV
        images = []
        for img_name in img_names:
            img = cv2.imread(img_name)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            if self.crop:
                img = self._crop_lower_half(img)
            img = cv2.resize(img, (self.imgh, self.imgw ))  # Resize directly with OpenCV
            images.append(img)

        # Convert to tensor and normalize in batch
        if self.transform:
            images = torch.stack([self.transform(img) for img in images])

        return sequence_id, seq_num, images, angles
    
def create_train_val_dataset(train_csv_file, 
                              val_csv_file,
                              seq_len = 32, 
                              imgw = 224,
                              imgh = 224,
                              step_size = 32,
                              crop = True,
                              mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]):
    
    transform = transforms.Compose([
        transforms.ToTensor(),  # Convert image to PyTorch tensor
        transforms.Normalize(mean=mean, std=std)
    ])

    train_dataset = CustomDataset(csv_file=train_csv_file, seq_len=seq_len, imgh = imgh, imgw=imgw,
                                  step_size=step_size,crop=crop, transform=transform)
    val_dataset = CustomDataset(csv_file=val_csv_file, seq_len=seq_len,imgh = imgh, imgw=imgw,
                                step_size=step_size,crop=crop, transform=transform)

    return train_dataset, val_dataset

def create_train_val_loader(train_dataset, val_dataset, train_sampler=None, val_sampler=None, batch_size=8,
                            num_workers=4, prefetch_factor=2, pin_memory=True, train_shuffle=False):

    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers, 
                              prefetch_factor=prefetch_factor,pin_memory=pin_memory, shuffle=train_shuffle)
    val_loader = DataLoader(val_dataset, sampler=val_sampler, batch_size=batch_size, num_workers=num_workers, 
                            prefetch_factor=prefetch_factor, pin_memory=pin_memory, shuffle=False)

    print('len of train loader:', len(train_loader))
    print('len of val loader', len(val_loader))

    for (_, _, inputs, labels) in train_loader:
        print("Batch input shape:", inputs.shape)
        print("Batch label shape:", labels.shape)
        break

    return train_loader, val_loader

def get_loaders_for_training(data_dir, steering_angles_path, step_size, seq_len, imgh, imgw, filter, turn_threshold, 
                                     buffer_before, buffer_after, crop=True, train_size=0.8, save_dir='data/csv_files', 
                                     norm=True, batch_size=16, num_workers=4, prefetch_factor=4, pin_memory=True, train_shuffle=False):
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    #preprocesses data
    data_preprocessed_pd = get_preprocessed_data_pd(data_dir, steering_angles_path, filter, turn_threshold, 
                                     buffer_before, buffer_after, norm, save_dir)

    if filter:
        train_csv_filename = f"train_flt_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"
        val_csv_filename = f"val_flt_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"
    else:
        train_csv_filename = f"train_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"
        val_csv_filename = f"val_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"

    #splits data into train and val
    train_dataset_path, val_dataset_path = df_split_train_val(data_preprocessed_pd,
                                                              train_csv_filename=train_csv_filename,
                                                              val_csv_filename=val_csv_filename,
                                                              save_dir=save_dir,
                                                              train_size=train_size)
    #gets custom train and test pytorch dataset
    train_dataset , val_dataset = create_train_val_dataset(train_csv_file = train_dataset_path,
                                                             val_csv_file = val_dataset_path,
                                                             seq_len=seq_len, imgh=imgh, imgw=imgw,
                                                             step_size=step_size, crop=crop)
    
    #gets dataloaders from dataset
    train_loader, val_loader = create_train_val_loader(train_dataset, val_dataset,
                                                       batch_size=batch_size, 
                                                       num_workers=num_workers, 
                                                       prefetch_factor=prefetch_factor,
                                                       pin_memory=pin_memory, 
                                                       train_shuffle=train_shuffle)

    return train_loader, val_loader

if __name__ == '__main__':

    #preprocessing csv file args
    data_dir = '/kaggle/input/sullychen/07012018/data'
    steering_angles_txt_path = '/kaggle/input/sullychen/07012018/data.txt'
    save_dir = '/kaggle/working/csv_files_exp'
    filter = False
    norm=False

    turn_threshold = 0.06 
    buffer_before = 32 
    buffer_after = 32
    train_size = 0.8

    #custom pytorch dataset args
    imgh=224
    imgw=224
    step_size = 16
    seq_len = 32
    crop=True

    #dataloader args
    batch_size = 16
    prefetch_factor = 2
    num_workers=4
    pin_memory=True
    train_shuffle=True

    get_loaders_for_training(
        #preprocessing args:
        data_dir, steering_angles_path=steering_angles_txt_path, save_dir=save_dir, filter=filter, norm=norm,
        turn_threshold=turn_threshold, buffer_before=buffer_before, buffer_after=buffer_after, train_size=train_size,
        #dataset args:
        imgh=imgh, imgw=imgw, step_size=step_size, seq_len=seq_len, crop=crop, 
        #dataloader args:
        batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=num_workers, pin_memory=pin_memory,
        train_shuffle=train_shuffle)

Train dataset length: 51060
Val dataset length: 12765
/kaggle/working/csv_files_exp
Total sequences extracted: 3190 using step_size=16 and seq_len=32
Total sequences extracted: 796 using step_size=16 and seq_len=32
len of train loader: 200
len of val loader 50
Batch input shape: torch.Size([16, 32, 3, 224, 224])
Batch label shape: torch.Size([16, 32])


## Training Script

In [43]:
import json

config = {
    "load_from_ckpt": False,
    "ckpt_save_dir": "/kaggle/working/checkpoints/conv_ncp_exp/exp5_0.08_32bb_ss16_fpf16",
    "ckpt_path": "",
    "train_dataset_path": "/kaggle/working/csv_files_exp/train_flt_ncp_tt_0.08_bb_32_ba_32.csv",
    "val_dataset_path": "/kaggle/working/csv_files_exp/val_flt_ncp_tt_0.08_bb_32_ba_32.csv",
    "save_every": 10,
    "epochs": 50,
    "data_dir": "/kaggle/input/sullychen/07012018/data",
    "steering_angles_txt_path": "/kaggle/input/sullychen/07012018/data.txt",
    "csv_save_dir": "/kaggle/working/csv_files_exp",
    "filter": True,
    "norm": True,
    "turn_threshold": 0.08,
    "buffer_before": 32,
    "buffer_after": 32,
    "train_size": 0.8,
    "imgw": 224,
    "imgh": 224,
    "step_size": 16,
    "seq_len": 32,
    "mean": [0.485, 0.456, 0.406],
    "std": [0.229, 0.224, 0.225],
    "crop": True,
    "batch_size": 32,
    "prefetch_factor": 2,
    "num_workers": 4,
    "pin_memory": True,
    "train_shuffle": True,
    "conv_head_lr": 2.5e-5,
    "feat_per_filt": 16,
    "alpha": 0.1,
    "ncp_lr": 1e-3,
    "optim_betas": [0.9, 0.999],
    "he_init": False
}

with open('/kaggle/working/config.json', 'w') as f:
    json.dump(config, f, indent=4)

In [44]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.mps as mps
import torch.backends.cudnn as cudnn
import os
import numpy as np
import cv2
import json
import matplotlib.pyplot as plt
from tqdm import tqdm

def plot_loss_accuracy(train_loss, val_loss, save_dir=None):
    epochs = range(1, len(train_loss) + 1)

    plt.figure(figsize=(12, 6))

    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, label='Train Loss', color='blue', linestyle='-')
    plt.plot(epochs, val_loss, label='Validation Loss', color='red', linestyle='--')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss vs Epochs')
    plt.tight_layout()
    if save_dir:
        plt.savefig(save_dir, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

def overlay_visual_backprop(input_tensor, mask, save_path=None, alpha=0.1):
    """
    Overlays the visual backprop mask on the original input image.

    Args:
        input_tensor: [3, H, W] torch.Tensor (before batch dimension), normalized
        mask: [H, W] numpy array, already normalized [0, 1]
        save_path: optional path to save overlay image
        alpha: blending factor (heatmap vs original)
    """
    #denormalize image (undo mean/std normalization)
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = input_tensor.clone().detach().cpu().numpy() 
    img = img * std[:, None, None] + mean[:, None, None]
    img = np.clip(img, 0, 1)
    img = np.transpose(img, (1, 2, 0))  # -> [H, W, 3]

    if mask.shape != img.shape[:2]:
        mask = cv2.resize(mask, (img.shape[1], img.shape[0]))

    #colormap to mask and overlay
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
    overlayed = (1 - alpha) * img + alpha * heatmap
    overlayed = np.clip(overlayed, 0, 1)

    plt.imshow(overlayed)
    plt.axis('off')
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

def train_validate(train_loader, val_loader, optimizer, model, criterion, train_params, current_epoch=0, epochs=10, 
                   save_dir = 'checkpoints/', training_losses = [], validation_losses = [], save_every=10):

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    backprop_save_dir = os.path.join(save_dir,'backprops')
    if not os.path.exists(backprop_save_dir):
        os.makedirs(backprop_save_dir)

    this_run_epoch = 0
    for epoch in range(current_epoch, epochs): 
        model.train()
        total_train_loss = 0.0
        for _, (_, _, batch_x, batch_y) in tqdm(enumerate(train_loader), 
                                          desc=f'Training {epoch+1}/{epochs}:', total=len(train_loader), ncols=100):
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            predictions = model(batch_x)
            loss = criterion(predictions, batch_y)

            total_train_loss += loss.item()

            loss.backward()
            optimizer.step()

        training_losses.append(total_train_loss/len(train_loader))

        print(f"Train Loss: {total_train_loss/len(train_loader)}")

        #validation loop
        model.eval()
        total_val_loss = 0.0
        for _, (_, _, batch_x, batch_y) in tqdm(enumerate(val_loader), 
                                          desc=f'Validation {epoch+1}/{epochs}:', total=len(val_loader), ncols=100):
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            with torch.no_grad():
                predictions = model(batch_x)
                loss = criterion(predictions, batch_y)

            total_val_loss += loss.item()

        validation_losses.append(total_val_loss/ len(val_loader))

        print(f"Validation Loss: {total_val_loss / len(val_loader)}")
        
        # visualbackprop dump
        with torch.no_grad():
            model.eval()
            batch = next(iter(train_loader))
            _, _, batch_x, _ = batch
            batch_x = batch_x.to(device)

            B, T, C, H, W = batch_x.shape
            x_flat = batch_x.view(B*T, C, H, W)

            _ = model.conv_head(x_flat)  # intermediate activations
            vis_mask = model.conv_head.visual_backprop(idx=0)
            input_image = x_flat[0]  # one image: shape [3, H, W]
            overlay_visual_backprop(input_image, vis_mask, save_path=f'{backprop_save_dir}/epoch_{epoch+1}.png', alpha=0.5)
            plt.close()

        this_run_epoch += 1
        if this_run_epoch % save_every  == 0:

            checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch + 1,
            'training_losses': training_losses,
            'validation_losses': validation_losses,
            'train_params': train_params,
            }

            model_path = os.path.join(save_dir, f'model_epoch{epoch+1}.pth')
            torch.save(checkpoint, model_path)
            print(f"Checkpoint saved to {save_dir}\n")

    return model_path

def init_weights_he(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)

def load_config(config_path='/kaggle/working/config.json'):
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config

if __name__ == '__main__':

    if torch.cuda.is_available():
        device = torch.device('cuda')
        cudnn.benchmark = True
        print(f'Using CUDA device: {torch.cuda.get_device_name(0)}')
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
        mps.benchmark = True
        print("Using MPS device")
    else:
        device = torch.device('cpu')
        print("Using CPU")

    config = load_config()

    train_loader, val_loader = get_loaders_for_training(
    # Preprocessing args:
    data_dir=config["data_dir"], steering_angles_path=config["steering_angles_txt_path"], save_dir=config["csv_save_dir"],
    filter=config["filter"], norm=config["norm"], turn_threshold=config["turn_threshold"], buffer_before=config["buffer_before"],
    buffer_after=config["buffer_after"], train_size=config["train_size"],

    # Dataset args:
    imgh=config["imgh"], imgw=config["imgw"], step_size=config["step_size"], seq_len=config["seq_len"], crop=config["crop"],

    # Dataloader args:
    batch_size=config["batch_size"], prefetch_factor=config["prefetch_factor"], num_workers=config["num_workers"], 
    pin_memory=config["pin_memory"], train_shuffle=config["train_shuffle"])

    
    # Assuming extracted features from conv head (8*4) are 32-dimensional
    model = ConvNCPModel(num_filters=8, features_per_filter=config['feat_per_filt'], inter_neurons = 12,
                        command_neurons = 6, motor_neurons = 1, sensory_fanout = 6, inter_fanout = 4, 
                        recurrent_command_synapses = 6, motor_fanin = 6, seed = 20190120) 
    
    if config['he_init']:
        model.apply(init_weights_he)
        
    model = model.to(device)

    # Define loss function and optimizer
    criterion = WeightedMSE(config['alpha'])
    # criterion = nn.MSELoss()
    optimizer = optim.Adam([
    # Convolutional head
    {'params': model.conv_head.parameters(), 'lr': config['conv_head_lr']},
    # NCP/LTC
    {'params': model.ltc.parameters(), 'lr': config['ncp_lr']},
    # Output layer
    {'params': model.fc_out.parameters(), 'lr': config['ncp_lr']}], 
        betas=config['optim_betas'])

    if config['load_from_ckpt']:
        model_ckpt = torch.load(config['ckpt_path'], map_location=device)
        model.load_state_dict(model_ckpt['model_state_dict'])
        optimizer.load_state_dict(model_ckpt['optimizer_state_dict'])
        current_epoch = model_ckpt['epoch']
        training_losses = model_ckpt['training_losses']
        validation_losses = model_ckpt['validation_losses']
        loaded_train_params = model_ckpt['train_params']
        print('checkpoint loaded successfully!')

    else:
            current_epoch = 0
            training_losses = []
            validation_losses = []
    
    if len(training_losses) > 0:
        print("last training and validation losses:", training_losses[-1], validation_losses[-1])
    else:
        print('Training losses:', training_losses)
        print('Validation losses:', validation_losses)
    print('Current Epoch Number:', current_epoch)

    final_model_path = train_validate(train_loader=train_loader,
          val_loader=val_loader,
          optimizer=optimizer,
          model=model,
          train_params=config,
          criterion=criterion,
          current_epoch=current_epoch,
          epochs=config['epochs'], 
          save_dir=config['ckpt_save_dir'],
          training_losses=training_losses,
          validation_losses=validation_losses,
          save_every=config['save_every'])
    
    final_checkpoint = torch.load(final_model_path)

    training_losses = final_checkpoint['training_losses']
    validation_losses = final_checkpoint['validation_losses']

    plot_loss_accuracy(training_losses, validation_losses, save_dir=config['ckpt_save_dir'])

Using CUDA device: Tesla T4
Minimum sequence length: 77
Maximum sequence length: 539
Total valid sequences: 55
Train dataset length: 8791
Val dataset length: 2198
/kaggle/working/csv_files_exp
Total sequences extracted: 486 using step_size=16 and seq_len=32
Total sequences extracted: 116 using step_size=16 and seq_len=32
len of train loader: 16
len of val loader 4
Batch input shape: torch.Size([32, 32, 3, 224, 224])
Batch label shape: torch.Size([32, 32])
Training losses: []
Validation losses: []
Current Epoch Number: 0


Training 1/50:: 100%|███████████████████████████████████████████████| 16/16 [00:27<00:00,  1.75s/it]

Train Loss: 0.4888116344809532



Validation 1/50:: 100%|███████████████████████████████████████████████| 4/4 [00:07<00:00,  1.89s/it]

Validation Loss: 0.3819001838564873



Training 2/50:: 100%|███████████████████████████████████████████████| 16/16 [00:26<00:00,  1.65s/it]


Train Loss: 0.40255807153880596


Validation 2/50:: 100%|███████████████████████████████████████████████| 4/4 [00:07<00:00,  1.89s/it]

Validation Loss: 0.30517419427633286



Training 3/50:: 100%|███████████████████████████████████████████████| 16/16 [00:28<00:00,  1.76s/it]

Train Loss: 0.32814805768430233



Validation 3/50:: 100%|███████████████████████████████████████████████| 4/4 [00:07<00:00,  1.89s/it]

Validation Loss: 0.23832038789987564



Training 4/50:: 100%|███████████████████████████████████████████████| 16/16 [00:28<00:00,  1.75s/it]

Train Loss: 0.25109065789729357



Validation 4/50:: 100%|███████████████████████████████████████████████| 4/4 [00:07<00:00,  1.96s/it]

Validation Loss: 0.17815269902348518



Training 5/50:: 100%|███████████████████████████████████████████████| 16/16 [00:29<00:00,  1.82s/it]

Train Loss: 0.1918056532740593



Validation 5/50:: 100%|███████████████████████████████████████████████| 4/4 [00:07<00:00,  1.89s/it]

Validation Loss: 0.13516053929924965



Training 6/50:: 100%|███████████████████████████████████████████████| 16/16 [00:27<00:00,  1.75s/it]

Train Loss: 0.15188538189977407



Validation 6/50:: 100%|███████████████████████████████████████████████| 4/4 [00:07<00:00,  1.96s/it]

Validation Loss: 0.10329722426831722



Training 7/50:: 100%|███████████████████████████████████████████████| 16/16 [00:28<00:00,  1.80s/it]

Train Loss: 0.11727162217721343



Validation 7/50:: 100%|███████████████████████████████████████████████| 4/4 [00:07<00:00,  1.83s/it]


Validation Loss: 0.07907368429005146


Training 8/50:: 100%|███████████████████████████████████████████████| 16/16 [00:28<00:00,  1.78s/it]

Train Loss: 0.09264656249433756



Validation 8/50:: 100%|███████████████████████████████████████████████| 4/4 [00:07<00:00,  1.84s/it]

Validation Loss: 0.061014700680971146



Training 9/50:: 100%|███████████████████████████████████████████████| 16/16 [00:28<00:00,  1.77s/it]

Train Loss: 0.07249165256507695



Validation 9/50:: 100%|███████████████████████████████████████████████| 4/4 [00:07<00:00,  1.75s/it]

Validation Loss: 0.04848985932767391



Training 10/50:: 100%|██████████████████████████████████████████████| 16/16 [00:28<00:00,  1.79s/it]

Train Loss: 0.059347369242459536



Validation 10/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.93s/it]

Validation Loss: 0.04031631117686629





Checkpoint saved to /kaggle/working/checkpoints/conv_ncp_exp/exp5_0.08_32bb_ss16_fpf16



Training 11/50:: 100%|██████████████████████████████████████████████| 16/16 [00:27<00:00,  1.75s/it]

Train Loss: 0.051698652910999954



Validation 11/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.93s/it]

Validation Loss: 0.035405082860961556



Training 12/50:: 100%|██████████████████████████████████████████████| 16/16 [00:27<00:00,  1.70s/it]

Train Loss: 0.04265775578096509



Validation 12/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.76s/it]

Validation Loss: 0.03302645101211965



Training 13/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.66s/it]

Train Loss: 0.03870675002690405



Validation 13/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.78s/it]

Validation Loss: 0.03224399173632264



Training 14/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.66s/it]

Train Loss: 0.03718192514497787



Validation 14/50:: 100%|██████████████████████████████████████████████| 4/4 [00:06<00:00,  1.73s/it]

Validation Loss: 0.0322459468152374



Training 15/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.67s/it]

Train Loss: 0.03836087021045387



Validation 15/50:: 100%|██████████████████████████████████████████████| 4/4 [00:06<00:00,  1.73s/it]

Validation Loss: 0.03264010697603226



Training 16/50:: 100%|██████████████████████████████████████████████| 16/16 [00:25<00:00,  1.61s/it]

Train Loss: 0.037699098931625485



Validation 16/50:: 100%|██████████████████████████████████████████████| 4/4 [00:06<00:00,  1.71s/it]

Validation Loss: 0.03304972080513835



Training 17/50:: 100%|██████████████████████████████████████████████| 16/16 [00:27<00:00,  1.69s/it]

Train Loss: 0.03433856804622337



Validation 17/50:: 100%|██████████████████████████████████████████████| 4/4 [00:06<00:00,  1.72s/it]

Validation Loss: 0.03339187055826187



Training 18/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.67s/it]

Train Loss: 0.03442566329613328



Validation 18/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.82s/it]

Validation Loss: 0.03357815649360418



Training 19/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.68s/it]

Train Loss: 0.034266244794707745



Validation 19/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.81s/it]

Validation Loss: 0.033637679647654295



Training 20/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.66s/it]

Train Loss: 0.03838872816413641



Validation 20/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.88s/it]

Validation Loss: 0.033744022250175476





Checkpoint saved to /kaggle/working/checkpoints/conv_ncp_exp/exp5_0.08_32bb_ss16_fpf16



Training 21/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.63s/it]

Train Loss: 0.03465953574050218



Validation 21/50:: 100%|██████████████████████████████████████████████| 4/4 [00:06<00:00,  1.71s/it]

Validation Loss: 0.033483573934063315



Training 22/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.65s/it]

Train Loss: 0.034837951650843024



Validation 22/50:: 100%|██████████████████████████████████████████████| 4/4 [00:06<00:00,  1.70s/it]

Validation Loss: 0.03353830939158797



Training 23/50:: 100%|██████████████████████████████████████████████| 16/16 [00:27<00:00,  1.71s/it]

Train Loss: 0.0353165838168934



Validation 23/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.80s/it]

Validation Loss: 0.03343341359868646



Training 24/50:: 100%|██████████████████████████████████████████████| 16/16 [00:27<00:00,  1.74s/it]

Train Loss: 0.03464849619194865



Validation 24/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.83s/it]

Validation Loss: 0.03339860402047634



Training 25/50:: 100%|██████████████████████████████████████████████| 16/16 [00:27<00:00,  1.70s/it]

Train Loss: 0.033631901344051585



Validation 25/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.84s/it]

Validation Loss: 0.033495324198156595



Training 26/50:: 100%|██████████████████████████████████████████████| 16/16 [00:27<00:00,  1.71s/it]

Train Loss: 0.03555608540773392



Validation 26/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.76s/it]

Validation Loss: 0.03351765451952815



Training 27/50:: 100%|██████████████████████████████████████████████| 16/16 [00:28<00:00,  1.78s/it]

Train Loss: 0.03438604937400669



Validation 27/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.81s/it]

Validation Loss: 0.033464611042290926



Training 28/50:: 100%|██████████████████████████████████████████████| 16/16 [00:27<00:00,  1.71s/it]

Train Loss: 0.03421174769755453



Validation 28/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.79s/it]

Validation Loss: 0.033304489916190505



Training 29/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.67s/it]

Train Loss: 0.03400142549071461



Validation 29/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.81s/it]

Validation Loss: 0.03304178989492357



Training 30/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.67s/it]

Train Loss: 0.03470766940154135



Validation 30/50:: 100%|██████████████████████████████████████████████| 4/4 [00:07<00:00,  1.79s/it]

Validation Loss: 0.03288056259043515





Checkpoint saved to /kaggle/working/checkpoints/conv_ncp_exp/exp5_0.08_32bb_ss16_fpf16



Training 31/50:: 100%|██████████████████████████████████████████████| 16/16 [00:26<00:00,  1.69s/it]

Train Loss: 0.03587806364521384



Validation 31/50::  25%|███████████▌                                  | 1/4 [00:06<00:19,  6.53s/it]


KeyboardInterrupt: 