In [1]:
import torch
import torch.nn as nn
import torch
import pandas as pd
import re
import cv2
from datetime import datetime
import matplotlib.pyplot as plt
import os
import numpy as np
import cv2
from tqdm import tqdm
from collections import OrderedDict
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn.functional as F
from torchinfo import summary

In [2]:
pd.set_option("display.max_rows", 200)
plt.ion() 

def get_full_image_filepaths(data_dir):
    filepaths = os.listdir(data_dir)
    sorted_filepaths = sorted(filepaths, key=lambda x: int(re.search(r'\d+', x).group()))
    full_filepaths = [os.path.join(data_dir, file) for file in sorted_filepaths]

    return full_filepaths

def get_steering_angles(path):
    file = open(path, 'r')
    lines = file.readlines()
    #change this to line[:-1] if using kaggle dataset else keep the same for sullychen dataset
    steering_angles = [(line[:-1]).split(' ')[1].split(',')[0] for line in lines]
    timestamps = [line.strip().split(' ', 1)[1].split(',')[1] for line in lines]

    return [steering_angles, timestamps]

def convert_to_df(full_filepaths, steering_angles, timestamps, norm=True):
    data = pd.DataFrame({'filepath':full_filepaths,'steering_angle':steering_angles, 'timestamps':timestamps})
    data['parsed_timestamp'] = data['timestamps'].apply(lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S:%f'))
    data = data.drop('timestamps', axis=1)
    data = data.reset_index(drop=True)
    data['steering_angle'] = data['steering_angle'].astype('float')
    #normalize steering angles to -1,1 based on actual steering range (-360, 360) of wheel
    if norm == True:
        data['steering_angle'] = data['steering_angle'] / 360
    return data

def display_images_with_angles(data_pd, steering_wheel_img):

    # steering wheel image
    img = cv2.imread(steering_wheel_img,0)
    rows,cols = img.shape
    i = 0
    xs = data_pd['filepath'].values
    ys = data_pd['steering_angle'].values

    #Press q to stop 
    while(cv2.waitKey(100) != ord('q')) and i < len(xs):
        try:
            # Driving Test Image displayed as Video
            full_image = cv2.imread(xs[i])

            degrees = ys[i] * 360
            print("Steering angle: " + str(degrees) + " (actual)")
            cv2.imshow("frame", full_image)

            # Angle at which the steering wheel image should be rotated
            M = cv2.getRotationMatrix2D((cols/2,rows/2),-degrees,1)
            dst = cv2.warpAffine(img,M,(cols,rows))

            cv2.imshow("steering wheel", dst)
        except:
            print('ERROR at', i)
        i += 1
    cv2.destroyAllWindows()

def disp_freq_steering_angles(data_pd):
    # Define bin edges from -1 to 1 with step 0.1
    bin_edges = list(range(-10, 11))  # Since steering angle is between -1 and 1, multiply by 10
    bin_edges = [x / 10 for x in bin_edges]  # Convert back to decimal values

    # Assign steering angles to bins
    data_pd['binned'] = pd.cut(data_pd['steering_angle'].astype('float'), bins=bin_edges, right=False)

    # Count occurrences in each bin
    bin_counts = data_pd['binned'].value_counts().sort_index()

    # Plot bar chart
    plt.figure(figsize=(12, 5))
    bin_counts.plot(kind='bar', color='skyblue', edgecolor='black')

    # Formatting
    plt.xlabel("Steering Angle Range")
    plt.ylabel("Frequency")
    plt.title("Frequency of Steering Angles in 0.1 Intervals")
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    plt.show()

def disp_start_and_end_in_filtered_data(data_pd_filtered):
    turn_starts = data_pd_filtered[data_pd_filtered['turn_shift'] == 1]
    turn_ends = data_pd_filtered[data_pd_filtered['turn_shift'] == -1]

    plt.figure(figsize=(12, 5))
    plt.plot(data_pd_filtered.index, data_pd_filtered['steering_angle'], label="Steering Angle", alpha=0.5)
    plt.scatter(turn_starts.index, turn_starts['steering_angle'], color='red', label="Turn Start", marker="o")
    plt.scatter(turn_ends.index, turn_ends['steering_angle'], color='blue', label="Turn End", marker="x")

    plt.xlabel("Frame Index")
    plt.ylabel("Steering Angle")
    plt.title("Turn Start and End Points in Steering Angle Data")
    plt.legend()
    plt.grid(True)
    plt.show()

def filter_df_on_turns(data_pd, turn_threshold = 0.06, buffer_before = 60, buffer_after = 60):
    # Parameters
    turn_threshold = turn_threshold  # Define turn threshold (absolute value)
    buffer_before = buffer_before    # Frames to include before a turn
    buffer_after = buffer_after     # Frames to include after a turn

    # Load your dataset (assuming it's a DataFrame named df)
    data_pd['index'] = data_pd.index  # Preserve original ordering if needed

    # Identify where turning happens
    data_pd['turning'] = (data_pd['steering_angle'].abs() > turn_threshold).astype(int)

    # Find where turns start and end
    data_pd['turn_shift'] = data_pd['turning'].diff()  # 1 indicates start, -1 indicates end

    # Get turn start and end indices
    turn_starts = data_pd[data_pd['turn_shift'] == 1].index
    turn_ends = data_pd[data_pd['turn_shift'] == -1].index

    # Ensure equal number of start and end points
    if len(turn_ends) > 0 and turn_starts[0] > turn_ends[0]:  
        turn_ends = turn_ends[1:]  # Drop the first turn_end if it comes before a start

    # Selected indices for keeping
    selected_indices = set()

    for start, end in zip(turn_starts, turn_ends):
        # Include a buffer of frames before and after
        start_idx = max(0, start - buffer_before)
        end_idx = min(len(data_pd) - 1, end + buffer_after)
        
        # Add indices to selection
        selected_indices.update(range(start_idx, end_idx + 1))

    # Create filtered dataset
    data_pd_filtered = data_pd.loc[sorted(selected_indices)].reset_index(drop=True)

    # Drop temporary columns
    data_pd_filtered = data_pd_filtered.drop(columns=['turning', 'turn_shift'])

    # Detect sequence breaks (where the original index is not continuous)
    data_pd_filtered["sequence_id"] = (data_pd_filtered["index"].diff() != 1).cumsum()

    return data_pd_filtered

def group_data_by_sequences(data_pd_filtered):
    sequence_lengths = data_pd_filtered.groupby("sequence_id").size()

    print(f"Minimum sequence length: {min(sequence_lengths)}")
    print(f"Maximum sequence length: {max(sequence_lengths)}")

    # plt.plot(sequence_lengths)

    # Keep only sequences with at least 10 frames (adjust as needed)
    valid_sequences = sequence_lengths[sequence_lengths >= 40].index
    data_pd_filtered = data_pd_filtered[data_pd_filtered["sequence_id"].isin(valid_sequences)]

    print(f"Total valid sequences: {len(valid_sequences)}")

    return data_pd_filtered

def get_preprocessed_data_pd(data_dir, steering_angles_txt_path, filter = True,
                             turn_threshold = 0.06, buffer_before = 60, buffer_after = 60,
                             norm=True, save_dir = 'data/csv_files'):
    img_paths = get_full_image_filepaths(data_dir)
    steering_angles, timestamps = get_steering_angles(steering_angles_txt_path)

    data_pd = convert_to_df(img_paths, steering_angles, timestamps, norm)
    if filter and norm:
        data_pd_filtered = filter_df_on_turns(data_pd, turn_threshold = turn_threshold, 
                                            buffer_before = buffer_before, buffer_after = buffer_after)
        data_pd_filtered = group_data_by_sequences(data_pd_filtered)

        # Save
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        data_pd_filtered.to_csv(os.path.join(save_dir,f"flt_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"), index=False)

        return data_pd_filtered
    
    elif filter and not norm:
        print('Error: If filtering, then steering values should be normalized (-1, 1), set norm to True')
        exit(1)

    else:
        sequence_id = 0
        sequence_ids = [sequence_id]

        # Iterate through rows to calculate time differences (if greater than 3 seconds 
        # or not) and assign sequence IDs
        for i in range(1, len(data_pd)):
            time_diff = (data_pd['parsed_timestamp'][i] - data_pd['parsed_timestamp'][i-1]).total_seconds()
            if time_diff > 3:
                sequence_id += 1
            sequence_ids.append(sequence_id)

        # Add sequence_id column to DataFrame
        data_pd['sequence_id'] = sequence_ids

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        data_pd.to_csv(os.path.join(save_dir,f"ncp_unfiltered.csv"), index=False)
        
        return data_pd


In [6]:
def df_split_train_val(df_filtered, train_csv_filename, val_csv_filename,
                       save_dir='data/csv_files',train_size = 0.8):
    train_dataset = df_filtered[:int(train_size * len(df_filtered))]
    val_dataset = df_filtered[int(train_size * len(df_filtered)):]
    print('Train dataset length:', len(train_dataset))
    print('Val dataset length:', len(val_dataset))
    print(save_dir)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    train_dataset.to_csv(os.path.join(save_dir,train_csv_filename), index=False)
    val_dataset.to_csv(os.path.join(save_dir,val_csv_filename), index=False)

    return os.path.join(save_dir,train_csv_filename), os.path.join(save_dir,val_csv_filename)

def calculate_mean_and_std(dataset_path):
    num_pixels = 0
    channel_sum = np.zeros(3)  # Assuming RGB images, change to (1,) for grayscale
    channel_sum_squared = np.zeros(3)  # Assuming RGB images, change to (1,) for grayscale

    for root, _, files in os.walk(dataset_path):
        for file in files:
            image_path = os.path.join(root, file)
            image = Image.open(image_path).convert('RGB')  # Convert to RGB if needed

            pixels = np.array(image) / 255.0  # Normalize pixel values between 0 and 1
            num_pixels += pixels.size // 3  # Assuming RGB images, change to 1 for grayscale

            channel_sum += np.sum(pixels, axis=(0, 1))
            channel_sum_squared += np.sum(pixels ** 2, axis=(0, 1))

    mean = channel_sum / num_pixels
    std = np.sqrt((channel_sum_squared / num_pixels) - mean ** 2)
    return mean, std

class CustomDataset(Dataset):
    def __init__(self,
                 csv_file,
                 seq_len,
                 imgh=224,
                 imgw=224,
                 step_size=1,
                 crop=True,
                 transform=None):

        self.df        = pd.read_csv(csv_file)
        self.seq_len   = seq_len
        self.imgh      = imgh
        self.imgw      = imgw
        self.step_size = step_size
        self.crop      = crop
        self.transform = transform

        self.sequences = []  # will hold tuples (seq_df, next_angle)

        # group once by sequence_id
        for seq_id, seq_data in self.df.groupby("sequence_id"):
            # we need at least seq_len + 1 frames to form one training example
            N = len(seq_data)
            for start in range(0, N - seq_len, step_size):
                window = seq_data.iloc[start : start + seq_len]
                next_angle = seq_data.iloc[start + seq_len]["steering_angle"]
                self.sequences.append((window.reset_index(drop=True), next_angle))

        print(f"Total examples: {len(self.sequences)} "
              f"(each is {seq_len} frames → 1 target)")

    def _crop_lower_half(self, img, keep_ratio=0.6):
        h = img.shape[0]
        return img[int(h*(1-keep_ratio)) :, :, :]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq_df, next_angle = self.sequences[idx]

        imgs = []
        for fp in seq_df["filepath"]:
            img = np.expand_dims(cv2.imread(fp, cv2.IMREAD_GRAYSCALE),2)
            if self.crop:
                img = self._crop_lower_half(img)
            img = cv2.resize(img, (self.imgh, self.imgw))
            if self.transform:
                img = self.transform(img)  # → C×H×W
            imgs.append(img)

        # shape (seq_len, C, H, W)
        x = torch.stack(imgs, dim=0).permute(1, 0, 2, 3)
        y = torch.tensor([next_angle], dtype=torch.float32)
        return x, y
    

def create_train_val_dataset(train_csv_file, 
                              val_csv_file,
                              seq_len = 32, 
                              imgw = 224,
                              imgh = 224,
                              step_size = 32,
                              crop = True,
                              mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]):
    
    transform = transforms.Compose([
        transforms.ToTensor(),  # Convert image to PyTorch tensor
        transforms.Normalize(mean=mean, std=std)
    ])

    train_dataset = CustomDataset(csv_file=train_csv_file, seq_len=seq_len, imgh = imgh, imgw=imgw,
                                  step_size=step_size,crop=crop, transform=transform)
    val_dataset = CustomDataset(csv_file=val_csv_file, seq_len=seq_len,imgh = imgh, imgw=imgw,
                                step_size=step_size,crop=crop, transform=transform)

    return train_dataset, val_dataset

def create_train_val_loader(train_dataset, val_dataset, train_sampler=None, val_sampler=None, batch_size=8,
                            num_workers=4, prefetch_factor=2, pin_memory=True, train_shuffle=False):

    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers, 
                              prefetch_factor=prefetch_factor,pin_memory=pin_memory, shuffle=train_shuffle)
    val_loader = DataLoader(val_dataset, sampler=val_sampler, batch_size=batch_size, num_workers=num_workers, 
                            prefetch_factor=prefetch_factor, pin_memory=pin_memory, shuffle=False)

    print('len of train loader:', len(train_loader))
    print('len of val loader', len(val_loader))

    for (inputs, labels) in train_loader:
        print("Batch input shape:", inputs.shape)
        print("Batch label shape:", labels.shape)
        break

    return train_loader, val_loader

def get_loaders_for_training(data_dir, steering_angles_path, step_size, seq_len, imgh, imgw, filter, turn_threshold, 
                                     buffer_before, buffer_after, crop=True, train_size=0.8, save_dir='data/csv_files', 
                                     std=[0.5], mean=[0.5],
                                     norm=True, batch_size=16, num_workers=4, prefetch_factor=4, pin_memory=True, train_shuffle=False):
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    #preprocesses data
    data_preprocessed_pd = get_preprocessed_data_pd(data_dir, steering_angles_path, filter, turn_threshold, 
                                     buffer_before, buffer_after, norm, save_dir)

    if filter:
        train_csv_filename = f"train_flt_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"
        val_csv_filename = f"val_flt_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"
    else:
        train_csv_filename = f"train_ncp_unfiltered.csv"
        val_csv_filename = f"val_ncp_unfiltered.csv"

    #splits data into train and val
    train_dataset_path, val_dataset_path = df_split_train_val(data_preprocessed_pd,
                                                              train_csv_filename=train_csv_filename,
                                                              val_csv_filename=val_csv_filename,
                                                              save_dir=save_dir,
                                                              train_size=train_size)
    #gets custom train and test pytorch dataset
    train_dataset , val_dataset = create_train_val_dataset(train_csv_file = train_dataset_path,
                                                             val_csv_file = val_dataset_path,
                                                             seq_len=seq_len, imgh=imgh, imgw=imgw,
                                                             step_size=step_size, crop=crop,
                                                             std=std,
                                                             mean=mean)
    
    #gets dataloaders from dataset
    train_loader, val_loader = create_train_val_loader(train_dataset, val_dataset,
                                                       batch_size=batch_size, 
                                                       num_workers=num_workers, 
                                                       prefetch_factor=prefetch_factor,
                                                       pin_memory=pin_memory, 
                                                       train_shuffle=train_shuffle)

    return train_loader, val_loader

if __name__ == '__main__':

    #preprocessing csv file args
    data_dir = '/kaggle/input/sullychen-edgemaps/edge_maps/edge_maps'
    steering_angles_txt_path = '/kaggle/input/sullychen/07012018/data.txt'
    save_dir = '/kaggle/working/convnet-edgemap'
    filter = True
    norm=True

    turn_threshold = 0.08
    buffer_before = 32
    buffer_after = 32
    train_size = 0.8

    #custom pytorch dataset args
    imgh=224
    imgw=224
    step_size = 1
    seq_len = 8
    crop=False

    #dataloader args
    batch_size = 16
    prefetch_factor = 2
    num_workers=4
    pin_memory=True
    train_shuffle=False

    get_loaders_for_training(
        #preprocessing args:
        data_dir, steering_angles_path=steering_angles_txt_path, save_dir=save_dir, filter=filter, norm=norm,
        turn_threshold=turn_threshold, buffer_before=buffer_before, buffer_after=buffer_after, train_size=train_size,
        #dataset args:
        imgh=imgh, imgw=imgw, step_size=step_size, seq_len=seq_len, crop=crop, 
        #dataloader args:
        batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=num_workers, pin_memory=pin_memory,
        train_shuffle=train_shuffle)

Minimum sequence length: 77
Maximum sequence length: 539
Total valid sequences: 55
Train dataset length: 8791
Val dataset length: 2198
/kaggle/working/convnet-edgemap
Total examples: 8447 (each is 8 frames → 1 target)
Total examples: 2094 (each is 8 frames → 1 target)
len of train loader: 528
len of val loader 131
Batch input shape: torch.Size([16, 1, 8, 224, 224])
Batch label shape: torch.Size([16, 1])


In [10]:
class Conv2Plus1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding):
        super().__init__()
        kT, kH, kW = kernel_size
        pT, pH, pW = padding

        #spatial conv: only HxW
        self.spatial = nn.Conv3d(in_channels, out_channels, kernel_size=(1, kH, kW), padding=(0, pH, pW), bias=False)
        self.bn_spatial = nn.BatchNorm3d(out_channels)

        #temporal conv: only T
        self.temporal = nn.Conv3d(out_channels, out_channels, kernel_size=(kT, 1, 1), padding=(pT, 0, 0), bias=False)
        self.bn_temporal = nn.BatchNorm3d(out_channels)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.spatial(x)
        x = self.bn_spatial(x)
        x = self.relu(x)
        x = self.temporal(x)
        x = self.bn_temporal(x)
        return self.relu(x)
    

class ResidualMain(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        p = tuple(k//2 for k in kernel_size)
        self.conv1 = Conv2Plus1D(in_channels, out_channels, kernel_size, p)
        self.conv2 = Conv2Plus1D(out_channels, out_channels, kernel_size, p)

    def forward(self, x):
        return self.conv2(self.conv1(x))
    
class Project(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm3d(out_channels)
        )

    def forward(self, x):
        return self.proj(x)
    
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()
        self.main = ResidualMain(in_channels, out_channels, kernel_size)
        self.need_proj = (in_channels != out_channels)
        if self.need_proj:
            self.proj = Project(in_channels, out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        res = self.proj(x) if self.need_proj else x
        out = self.main(x)
        return self.relu(out + res)
    
class ResizeVideo(nn.Module):
    def __init__(self, out_size):
        super().__init__()
        self.out_size = out_size

    def forward(self, x):
        # x: [B, C, T, H, W]
        B, C, T, H, W = x.shape
        # collapse batch such that we treat frames as images
        frames = x.permute(0,2,1,3,4).reshape(B*T, C, H, W)
        frames = F.interpolate(frames, size=self.out_size,
                               mode='bilinear', align_corners=False)
        # reshape back
        h, w = self.out_size
        return frames.reshape(B, T, C, h, w).permute(0,2,1,3,4)
    
class TemporalResNet(nn.Module):
    def __init__(self, 
                 in_channels=1, 
                 seq_len=16,
                 height=224,
                 width=224):
        super().__init__()

        self.stem = nn.Sequential(
            Conv2Plus1D(in_channels, 16, kernel_size=(3,7,7), padding=(1,3,3)),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True),
            ResizeVideo((height//2, width//2))
        )
        # 4 residual stages, downsampling spatially between them
        self.stage1 = nn.Sequential(
            ResidualBlock(16,  16, (3,3,3)),
            ResizeVideo((height//4, width//4))
        )
        self.stage2 = nn.Sequential(
            ResidualBlock(16,  32, (3,3,3)),
            ResizeVideo((height//8, width//8))
        )
        self.stage3 = nn.Sequential(
            ResidualBlock(32,  64, (3,3,3)),
            ResizeVideo((height//16, width//16))
        )
        self.stage4 = nn.Sequential(
            ResidualBlock(64,  128, (3,3,3)),
            ResizeVideo((height//32, width//32))
        )
        self.stage5 = ResidualBlock(128, 256, (3,3,3))

        # global spatio‑temporal pooling + linear regressor
        self.pool = nn.AdaptiveAvgPool3d((1, 1, 1))  # [B, 128, 1,1,1]
        self.fc1   = nn.Linear(256, 128)
        self.fc2   = nn.Linear(128, 1)

    def forward(self, x):
        # x: [B, C, T, H, W]
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.pool(x).view(x.size(0), -1) # [B, 128]
        return self.fc2((self.fc1(x))) # [B,1]

In [8]:
def process_images(input_dir, output_dir):

    os.makedirs(output_dir, exist_ok=True)
    exts = ('.jpg', '.jpeg')

    for fname in tqdm(os.listdir(input_dir)):
        if not fname.lower().endswith(exts):
            continue

        img_path = os.path.join(input_dir, fname)
        img = cv2.imread(img_path)
        if img is None:
            print(f"Skipping {fname} (couldn’t read)")
            continue

        height, width = img.shape[:2]

        # grayscale + blur
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        blur = cv2.GaussianBlur(gray, (3, 3), 0)

        # canny edges
        edges = cv2.Canny(blur, 50, 70)

        # build and apply ROI mask
        mask = np.zeros_like(edges)
        roi_corners = np.array([[
            (0, height),
            (width, height),
            (width, int(height * 0.52)),
            (0, int(height * 0.52))]], dtype=np.int32)
        cv2.fillPoly(mask, roi_corners, 255)
        roi_edges = cv2.bitwise_and(edges, mask)

        y0 = int(height*0.52)
        cropped_edges = roi_edges[y0:,:]

        # save the masked edges with the same filename
        out_path = os.path.join(output_dir, fname)
        cv2.imwrite(out_path, cropped_edges)
        # print(f"Saved ROI edges to {out_path}")

if __name__ == "__main__":
    input_folder  = "/kaggle/input/sullychen/07012018/data"
    output_folder = "/kaggle/working/edge_maps"
    process_images(input_folder, output_folder)

100%|██████████| 63825/63825 [17:43<00:00, 60.04it/s]


In [8]:
def plot_loss_accuracy(train_loss, val_loss, save_dir=None):
    epochs = range(1, len(train_loss) + 1)

    plt.figure(figsize=(12, 6))

    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, label='Train Loss', color='blue', linestyle='-')
    plt.plot(epochs, val_loss, label='Validation Loss', color='red', linestyle='--')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss vs Epochs')
    plt.tight_layout()
    if save_dir:
        plt.savefig(save_dir, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

def train_validate(train_loader, val_loader, optimizer, model, device, criterion, epochs=10, 
                   training_losses = None, val_losses = None, save_every=2, save_dir='/kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges'):

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    if training_losses is None: training_losses = []
    if val_losses is None: val_losses = []
    
    for epoch in range(epochs): 
        model.train()
        running_train_loss = 0.0
        correct_train = 0
        total_train   = 0

        for (batch_x, batch_y) in tqdm(train_loader, desc=f'Training {epoch+1}/{epochs}:', ncols=100):

            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            predictions = model(batch_x)
            loss = criterion(predictions, batch_y)
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item()
            _, pred_labels = predictions.max(dim=1)
            correct_train += (pred_labels == batch_y).sum().item()
            total_train += batch_y.size(0)

        avg_train_loss = running_train_loss / len(train_loader)

        training_losses.append(avg_train_loss)
        print(f"Train Loss: {avg_train_loss:.4f}")

        #validation loop
        model.eval()
        running_val_loss = 0.0
        correct_val = 0
        total_val   = 0
        for (batch_x, batch_y) in tqdm(val_loader, desc=f'Val {epoch+1}/{epochs}:', ncols=100):
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            with torch.no_grad():
                predictions = model(batch_x)
                loss = criterion(predictions, batch_y)

            running_val_loss += loss.item()
            _, pred_labels = predictions.max(dim=1)
            correct_val += (pred_labels == batch_y).sum().item()
            total_val += batch_y.size(0)

        avg_val_loss = running_val_loss / len(val_loader)

        val_losses.append(avg_val_loss)
        print(f"Val Loss: {avg_val_loss:.4f}")

        if (epoch+1) % save_every == 0:

            checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch + 1,
            'training_losses': training_losses,
            'val_losses': val_losses,
            }

            model_path = os.path.join(save_dir, f'model_epoch{epoch+1}.pth')
            torch.save(checkpoint, model_path)
            print(f"Checkpoint saved to {save_dir}\n")

    return training_losses, val_losses

In [None]:
train_dataset_path = '/kaggle/working/convnet-edgemap/train_flt_ncp_tt_0.08_bb_32_ba_32.csv'
val_dataset_path = '/kaggle/working/convnet-edgemap/val_flt_ncp_tt_0.08_bb_32_ba_32.csv'
seq_len = 8
imgh = 123
imgw = 455
step_size = 1
crop = False

batch_size = 32
num_workers = 4
prefetch_factor = 2
pin_memory = True
train_shuffle = False



train_dataset , val_dataset = create_train_val_dataset(train_csv_file = train_dataset_path,
                                                     val_csv_file = val_dataset_path,
                                                     seq_len=seq_len, imgh=imgh, imgw=imgw,
                                                     step_size=step_size, crop=crop,
                                                     mean=[0.5], std=[0.5])

#gets dataloaders from dataset
train_loader, val_loader = create_train_val_loader(train_dataset, val_dataset,
                                                batch_size=batch_size, 
                                                num_workers=num_workers, 
                                                prefetch_factor=prefetch_factor,
                                                pin_memory=pin_memory, 
                                                train_shuffle=train_shuffle)

model = TemporalResNet(in_channels=1, height=123, width=455).to(device=torch.device('cuda'))
print(summary(model, input_size=(32, 1, 8, 455, 123)))
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

train_losses, val_losses = train_validate(train_loader, val_loader, optimizer, model, 
                                          torch.device('cuda'), criterion, 20,
                                          save_dir='/kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges_5stages')

Total examples: 8447 (each is 8 frames → 1 target)
Total examples: 2094 (each is 8 frames → 1 target)
len of train loader: 264
len of val loader 66
Batch input shape: torch.Size([32, 1, 8, 455, 123])
Batch label shape: torch.Size([32, 1])
Layer (type:depth-idx)                        Output Shape              Param #
TemporalResNet                                [32, 1]                   --
├─Sequential: 1-1                             [32, 16, 8, 61, 227]      --
│    └─Conv2Plus1D: 2-1                       [32, 16, 8, 455, 123]     --
│    │    └─Conv3d: 3-1                       [32, 16, 8, 455, 123]     784
│    │    └─BatchNorm3d: 3-2                  [32, 16, 8, 455, 123]     32
│    │    └─ReLU: 3-3                         [32, 16, 8, 455, 123]     --
│    │    └─Conv3d: 3-4                       [32, 16, 8, 455, 123]     768
│    │    └─BatchNorm3d: 3-5                  [32, 16, 8, 455, 123]     32
│    │    └─ReLU: 3-6                         [32, 16, 8, 455, 123]     --
│   

Training 1/20:: 100%|█████████████████████████████████████████████| 264/264 [03:20<00:00,  1.31it/s]


Train Loss: 0.0483


Val 1/20:: 100%|████████████████████████████████████████████████████| 66/66 [00:14<00:00,  4.42it/s]


Val Loss: 0.0338


Training 2/20:: 100%|█████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0377


Val 2/20:: 100%|████████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.92it/s]


Val Loss: 0.0244
Checkpoint saved to /kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges_5stages



Training 3/20:: 100%|█████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0274


Val 3/20:: 100%|████████████████████████████████████████████████████| 66/66 [00:13<00:00,  5.00it/s]


Val Loss: 0.0273


Training 4/20:: 100%|█████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0238


Val 4/20:: 100%|████████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.86it/s]


Val Loss: 0.0300
Checkpoint saved to /kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges_5stages



Training 5/20:: 100%|█████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0217


Val 5/20:: 100%|████████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.97it/s]


Val Loss: 0.0279


Training 6/20:: 100%|█████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0187


Val 6/20:: 100%|████████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.97it/s]


Val Loss: 0.0281
Checkpoint saved to /kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges_5stages



Training 7/20:: 100%|█████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0162


Val 7/20:: 100%|████████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.77it/s]


Val Loss: 0.0265


Training 8/20:: 100%|█████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0129


Val 8/20:: 100%|████████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.91it/s]


Val Loss: 0.0268
Checkpoint saved to /kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges_5stages



Training 9/20:: 100%|█████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0099


Val 9/20:: 100%|████████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.82it/s]


Val Loss: 0.0246


Training 10/20:: 100%|████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0075


Val 10/20:: 100%|███████████████████████████████████████████████████| 66/66 [00:14<00:00,  4.56it/s]


Val Loss: 0.0305
Checkpoint saved to /kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges_5stages



Training 11/20:: 100%|████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0072


Val 11/20:: 100%|███████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.90it/s]


Val Loss: 0.0291


Training 12/20:: 100%|████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0081


Val 12/20:: 100%|███████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.94it/s]


Val Loss: 0.0295
Checkpoint saved to /kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges_5stages



Training 13/20:: 100%|████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0069


Val 13/20:: 100%|███████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.91it/s]


Val Loss: 0.0235


Training 14/20:: 100%|████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0065


Val 14/20:: 100%|███████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.96it/s]


Val Loss: 0.0248
Checkpoint saved to /kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges_5stages



Training 15/20:: 100%|████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0074


Val 15/20:: 100%|███████████████████████████████████████████████████| 66/66 [00:13<00:00,  5.00it/s]


Val Loss: 0.0256


Training 16/20:: 100%|████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0086


Val 16/20:: 100%|███████████████████████████████████████████████████| 66/66 [00:15<00:00,  4.34it/s]


Val Loss: 0.0269
Checkpoint saved to /kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges_5stages



Training 17/20:: 100%|████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0088


Val 17/20:: 100%|███████████████████████████████████████████████████| 66/66 [00:13<00:00,  5.00it/s]


Val Loss: 0.0260


Training 18/20:: 100%|████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0070


Val 18/20:: 100%|███████████████████████████████████████████████████| 66/66 [00:13<00:00,  4.91it/s]


Val Loss: 0.0272
Checkpoint saved to /kaggle/working/convnet_sl_8_ss_4_filt_mse_1e-3_edges_5stages



Training 19/20:: 100%|████████████████████████████████████████████| 264/264 [03:20<00:00,  1.32it/s]


Train Loss: 0.0063


Val 19/20::  55%|███████████████████████████▊                       | 36/66 [00:07<00:05,  5.31it/s]