In [2]:
import torch
import torch.nn as nn
import torch
import pandas as pd
import re
import cv2
from datetime import datetime
import matplotlib.pyplot as plt
import os
import numpy as np
import cv2
from tqdm import tqdm
from collections import OrderedDict
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image



In [3]:
class ConvLSTMCell(nn.Module):

    def __init__(self, input_dim, hidden_dim, kernel_size, bias=True):
        super().__init__()
        self.input_dim  = input_dim
    
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = (kernel_size[0] // 2, kernel_size[1] // 2)
        # output 4*hidden for the gates
        self.conv = nn.Conv2d(
            in_channels=input_dim + hidden_dim,
            out_channels=4 * hidden_dim,
            kernel_size=kernel_size,
            padding=self.padding,
            bias=bias
        )

    def forward(self, x, hc):
        h_cur, c_cur = hc
        # concatenate on channel axis
        combined = torch.cat([x, h_cur], dim=1)
        conv_out = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(conv_out, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)

        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        device = next(self.parameters()).device
        return (
            torch.zeros(batch_size, self.hidden_dim, height, width, device=device),
            torch.zeros(batch_size, self.hidden_dim, height, width, device=device),
        )


class STConvLSTM(nn.Module):
    def __init__(
        self,
        seq_len=3,
        height=66,
        width=200,
        input_channels=3,
        hidden_channels=8,
        fc_units=50,
        dropout=0.5
    ):
        super().__init__()
        self.seq_len = seq_len
        self.height = height
        self.width  = width

        # stack of 4 ConvLSTM layers
        self.cells = nn.ModuleList([
            ConvLSTMCell(input_channels, hidden_channels, (3, 3)),
            ConvLSTMCell(hidden_channels, hidden_channels, (3, 3)),
            ConvLSTMCell(hidden_channels, hidden_channels, (3, 3)),
            ConvLSTMCell(hidden_channels, hidden_channels, (3, 3)),
        ])
        # one BatchNorm3d per layer (treating time as depth)
        self.bns = nn.ModuleList([
            nn.BatchNorm3d(hidden_channels) for _ in range(4)
        ])

        self.conv3d = nn.Conv3d(
            in_channels=hidden_channels,
            out_channels=2,
            kernel_size=(3, 3, 3),
            padding=1
        )
        self.pool3d = nn.MaxPool3d((2, 2, 2))

        #flattened size: out_channels * (seq_len//2) * (height//2) * (width//2)
        flat_size = 2 * (seq_len // 2) * (height // 2) * (width // 2)
        self.fc1     = nn.Linear(flat_size, fc_units)
        self.leaky   = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(dropout)
        self.fc2     = nn.Linear(fc_units, 1)

    def _run_convlstm(self, cell: ConvLSTMCell, bn: nn.BatchNorm3d, x: torch.Tensor) -> torch.Tensor:

        #x: (batch, seq, channels, H, W)
        b, seq, c, h, w = x.size()
        h_t, c_t = cell.init_hidden(b, (h, w))
        outputs = []
        for t in range(seq):
            h_t, c_t = cell(x[:, t], (h_t, c_t))
            outputs.append(h_t)
        x_seq = torch.stack(outputs, dim=1)       #(b, seq, hidden, H, W)
        x_bn  = x_seq.permute(0, 2, 1, 3, 4)      #(b, hidden, seq, H, W)
        x_bn  = bn(x_bn)
        return x_bn.permute(0, 2, 1, 3, 4)        #(b, seq, hidden, H, W)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (batch, seq, H, W, C) 
        # x = x / 255.0
        # (batch, seq, C, H, W)

        for cell, bn in zip(self.cells, self.bns):
            x = self._run_convlstm(cell, bn, x)

        x = x.permute(0, 2, 1, 3, 4)
        x = self.conv3d(x)
        x = self.pool3d(x)
        
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.leaky(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


# example usage
if __name__ == "__main__":
    model = STConvLSTM(seq_len=6, height=224, width=224)
    sample = torch.randn(16, 6, 3, 224, 224)   # batch=4, seq=3, H=66, W=200, C=3
    out = model(sample)
    print(out.shape)  # → torch.Size([4, 1])

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

torch.Size([16, 1])


In [5]:
pd.set_option("display.max_rows", 200)
plt.ion() 

def get_full_image_filepaths(data_dir):
    filepaths = os.listdir(data_dir)
    sorted_filepaths = sorted(filepaths, key=lambda x: int(re.search(r'\d+', x).group()))
    full_filepaths = [os.path.join(data_dir, file) for file in sorted_filepaths]

    return full_filepaths

def get_steering_angles(path):
    file = open(path, 'r')
    lines = file.readlines()
    #change this to line[:-1] if using kaggle dataset else keep the same for sullychen dataset
    steering_angles = [(line[:-1]).split(' ')[1].split(',')[0] for line in lines]
    timestamps = [line.strip().split(' ', 1)[1].split(',')[1] for line in lines]

    return [steering_angles, timestamps]

def convert_to_df(full_filepaths, steering_angles, timestamps, norm=True):
    data = pd.DataFrame({'filepath':full_filepaths,'steering_angle':steering_angles, 'timestamps':timestamps})
    data['parsed_timestamp'] = data['timestamps'].apply(lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S:%f'))
    data = data.drop('timestamps', axis=1)
    data = data.reset_index(drop=True)
    data['steering_angle'] = data['steering_angle'].astype('float')
    #normalize steering angles to -1,1 based on actual steering range (-360, 360) of wheel
    if norm == True:
        data['steering_angle'] = data['steering_angle'] / 360
    return data

def display_images_with_angles(data_pd, steering_wheel_img):

    # steering wheel image
    img = cv2.imread(steering_wheel_img,0)
    rows,cols = img.shape
    i = 0
    xs = data_pd['filepath'].values
    ys = data_pd['steering_angle'].values

    #Press q to stop 
    while(cv2.waitKey(100) != ord('q')) and i < len(xs):
        try:
            # Driving Test Image displayed as Video
            full_image = cv2.imread(xs[i])

            degrees = ys[i] * 360
            print("Steering angle: " + str(degrees) + " (actual)")
            cv2.imshow("frame", full_image)

            # Angle at which the steering wheel image should be rotated
            M = cv2.getRotationMatrix2D((cols/2,rows/2),-degrees,1)
            dst = cv2.warpAffine(img,M,(cols,rows))

            cv2.imshow("steering wheel", dst)
        except:
            print('ERROR at', i)
        i += 1
    cv2.destroyAllWindows()

def disp_freq_steering_angles(data_pd):
    # Define bin edges from -1 to 1 with step 0.1
    bin_edges = list(range(-10, 11))  # Since steering angle is between -1 and 1, multiply by 10
    bin_edges = [x / 10 for x in bin_edges]  # Convert back to decimal values

    # Assign steering angles to bins
    data_pd['binned'] = pd.cut(data_pd['steering_angle'].astype('float'), bins=bin_edges, right=False)

    # Count occurrences in each bin
    bin_counts = data_pd['binned'].value_counts().sort_index()

    # Plot bar chart
    plt.figure(figsize=(12, 5))
    bin_counts.plot(kind='bar', color='skyblue', edgecolor='black')

    # Formatting
    plt.xlabel("Steering Angle Range")
    plt.ylabel("Frequency")
    plt.title("Frequency of Steering Angles in 0.1 Intervals")
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    plt.show()

def disp_start_and_end_in_filtered_data(data_pd_filtered):
    turn_starts = data_pd_filtered[data_pd_filtered['turn_shift'] == 1]
    turn_ends = data_pd_filtered[data_pd_filtered['turn_shift'] == -1]

    plt.figure(figsize=(12, 5))
    plt.plot(data_pd_filtered.index, data_pd_filtered['steering_angle'], label="Steering Angle", alpha=0.5)
    plt.scatter(turn_starts.index, turn_starts['steering_angle'], color='red', label="Turn Start", marker="o")
    plt.scatter(turn_ends.index, turn_ends['steering_angle'], color='blue', label="Turn End", marker="x")

    plt.xlabel("Frame Index")
    plt.ylabel("Steering Angle")
    plt.title("Turn Start and End Points in Steering Angle Data")
    plt.legend()
    plt.grid(True)
    plt.show()

def filter_df_on_turns(data_pd, turn_threshold = 0.06, buffer_before = 60, buffer_after = 60):
    # Parameters
    turn_threshold = turn_threshold  # Define turn threshold (absolute value)
    buffer_before = buffer_before    # Frames to include before a turn
    buffer_after = buffer_after     # Frames to include after a turn

    # Load your dataset (assuming it's a DataFrame named df)
    data_pd['index'] = data_pd.index  # Preserve original ordering if needed

    # Identify where turning happens
    data_pd['turning'] = (data_pd['steering_angle'].abs() > turn_threshold).astype(int)

    # Find where turns start and end
    data_pd['turn_shift'] = data_pd['turning'].diff()  # 1 indicates start, -1 indicates end

    # Get turn start and end indices
    turn_starts = data_pd[data_pd['turn_shift'] == 1].index
    turn_ends = data_pd[data_pd['turn_shift'] == -1].index

    # Ensure equal number of start and end points
    if len(turn_ends) > 0 and turn_starts[0] > turn_ends[0]:  
        turn_ends = turn_ends[1:]  # Drop the first turn_end if it comes before a start

    # Selected indices for keeping
    selected_indices = set()

    for start, end in zip(turn_starts, turn_ends):
        # Include a buffer of frames before and after
        start_idx = max(0, start - buffer_before)
        end_idx = min(len(data_pd) - 1, end + buffer_after)
        
        # Add indices to selection
        selected_indices.update(range(start_idx, end_idx + 1))

    # Create filtered dataset
    data_pd_filtered = data_pd.loc[sorted(selected_indices)].reset_index(drop=True)

    # Drop temporary columns
    data_pd_filtered = data_pd_filtered.drop(columns=['turning', 'turn_shift'])

    # Detect sequence breaks (where the original index is not continuous)
    data_pd_filtered["sequence_id"] = (data_pd_filtered["index"].diff() != 1).cumsum()

    return data_pd_filtered

def group_data_by_sequences(data_pd_filtered):
    sequence_lengths = data_pd_filtered.groupby("sequence_id").size()

    print(f"Minimum sequence length: {min(sequence_lengths)}")
    print(f"Maximum sequence length: {max(sequence_lengths)}")

    # plt.plot(sequence_lengths)

    # Keep only sequences with at least 10 frames (adjust as needed)
    valid_sequences = sequence_lengths[sequence_lengths >= 40].index
    data_pd_filtered = data_pd_filtered[data_pd_filtered["sequence_id"].isin(valid_sequences)]

    print(f"Total valid sequences: {len(valid_sequences)}")

    return data_pd_filtered

def get_preprocessed_data_pd(data_dir, steering_angles_txt_path, filter = True,
                             turn_threshold = 0.06, buffer_before = 60, buffer_after = 60,
                             norm=True, save_dir = 'data/csv_files'):
    img_paths = get_full_image_filepaths(data_dir)
    steering_angles, timestamps = get_steering_angles(steering_angles_txt_path)

    data_pd = convert_to_df(img_paths, steering_angles, timestamps, norm)
    if filter and norm:
        data_pd_filtered = filter_df_on_turns(data_pd, turn_threshold = turn_threshold, 
                                            buffer_before = buffer_before, buffer_after = buffer_after)
        data_pd_filtered = group_data_by_sequences(data_pd_filtered)

        # Save
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        data_pd_filtered.to_csv(os.path.join(save_dir,f"flt_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"), index=False)

        return data_pd_filtered
    
    elif filter and not norm:
        print('Error: If filtering, then steering values should be normalized (-1, 1), set norm to True')
        exit(1)

    else:
        sequence_id = 0
        sequence_ids = [sequence_id]

        # Iterate through rows to calculate time differences (if greater than 3 seconds 
        # or not) and assign sequence IDs
        for i in range(1, len(data_pd)):
            time_diff = (data_pd['parsed_timestamp'][i] - data_pd['parsed_timestamp'][i-1]).total_seconds()
            if time_diff > 3:
                sequence_id += 1
            sequence_ids.append(sequence_id)

        # Add sequence_id column to DataFrame
        data_pd['sequence_id'] = sequence_ids

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        data_pd.to_csv(os.path.join(save_dir,f"ncp_unfiltered.csv"), index=False)
        
        return data_pd

In [6]:
def df_split_train_val(df_filtered, train_csv_filename, val_csv_filename,
                       save_dir='data/csv_files',train_size = 0.8):
    train_dataset = df_filtered[:int(train_size * len(df_filtered))]
    val_dataset = df_filtered[int(train_size * len(df_filtered)):]
    print('Train dataset length:', len(train_dataset))
    print('Val dataset length:', len(val_dataset))
    print(save_dir)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    train_dataset.to_csv(os.path.join(save_dir,train_csv_filename), index=False)
    val_dataset.to_csv(os.path.join(save_dir,val_csv_filename), index=False)

    return os.path.join(save_dir,train_csv_filename), os.path.join(save_dir,val_csv_filename)

def calculate_mean_and_std(dataset_path):
    num_pixels = 0
    channel_sum = np.zeros(3)  # Assuming RGB images, change to (1,) for grayscale
    channel_sum_squared = np.zeros(3)  # Assuming RGB images, change to (1,) for grayscale

    for root, _, files in os.walk(dataset_path):
        for file in files:
            image_path = os.path.join(root, file)
            image = Image.open(image_path).convert('RGB')  # Convert to RGB if needed

            pixels = np.array(image) / 255.0  # Normalize pixel values between 0 and 1
            num_pixels += pixels.size // 3  # Assuming RGB images, change to 1 for grayscale

            channel_sum += np.sum(pixels, axis=(0, 1))
            channel_sum_squared += np.sum(pixels ** 2, axis=(0, 1))

    mean = channel_sum / num_pixels
    std = np.sqrt((channel_sum_squared / num_pixels) - mean ** 2)
    return mean, std


In [7]:
import math
class CustomDataset(Dataset):
    def __init__(self,
                 csv_file,
                 seq_len,
                 imgh=224,
                 imgw=224,
                 step_size=1,
                 crop=True,
                 transform=None):
        """
        Now builds only those windows of length seq_len
        for which there *is* a next frame to predict.
        """
        self.df        = pd.read_csv(csv_file)
        self.seq_len   = seq_len
        self.imgh      = imgh
        self.imgw      = imgw
        self.step_size = step_size
        self.crop      = crop
        self.transform = transform

        self.sequences = []  # will hold tuples (seq_df, next_angle)

        # group once by sequence_id
        for seq_id, seq_data in self.df.groupby("sequence_id"):
            # we need at least seq_len + 1 frames to form one training example
            N = len(seq_data)
            for start in range(0, N - seq_len, step_size):
                window = seq_data.iloc[start : start + seq_len]
                next_angle = seq_data.iloc[start + seq_len]["steering_angle"]
                self.sequences.append((window.reset_index(drop=True), next_angle))

        print(f"Total examples: {len(self.sequences)} "
              f"(each is {seq_len} frames → 1 target)")

    def _crop_lower_half(self, img, keep_ratio=0.6):
        h = img.shape[0]
        return img[int(h*(1-keep_ratio)) :, :, :]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq_df, next_angle = self.sequences[idx]

        # 1) load & preprocess the seq_len frames
        imgs = []
        for fp in seq_df["filepath"]:
            img = cv2.cvtColor(cv2.imread(fp), cv2.COLOR_BGR2RGB)
            if self.crop:
                img = self._crop_lower_half(img)
            img = cv2.resize(img, (self.imgh, self.imgw))
            if self.transform:
                img = self.transform(img)  # → C×H×W
            imgs.append(img)

        # shape (seq_len, C, H, W)
        x = torch.stack(imgs, dim=0)
        y = torch.tensor([next_angle], dtype=torch.float32)
        return x, y

In [21]:
def create_train_val_dataset(train_csv_file, 
                              val_csv_file,
                              seq_len = 32, 
                              imgw = 224,
                              imgh = 224,
                              step_size = 32,
                              crop = True,
                              mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]):
    
    transform = transforms.Compose([
        transforms.ToTensor(),  # Convert image to PyTorch tensor
        transforms.Normalize(mean=mean, std=std)
    ])

    train_dataset = CustomDataset(csv_file=train_csv_file, seq_len=seq_len, imgh = imgh, imgw=imgw,
                                  step_size=step_size,crop=crop, transform=transform)
    val_dataset = CustomDataset(csv_file=val_csv_file, seq_len=seq_len,imgh = imgh, imgw=imgw,
                                step_size=step_size,crop=crop, transform=transform)

    return train_dataset, val_dataset

def create_train_val_loader(train_dataset, val_dataset, train_sampler=None, val_sampler=None, batch_size=8,
                            num_workers=4, prefetch_factor=2, pin_memory=True, train_shuffle=False):

    train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size, num_workers=num_workers, 
                              prefetch_factor=prefetch_factor,pin_memory=pin_memory, shuffle=train_shuffle)
    val_loader = DataLoader(val_dataset, sampler=val_sampler, batch_size=batch_size, num_workers=num_workers, 
                            prefetch_factor=prefetch_factor, pin_memory=pin_memory, shuffle=False)

    print('len of train loader:', len(train_loader))
    print('len of val loader', len(val_loader))

    for (inputs, labels) in train_loader:
        print("Batch input shape:", inputs.shape)
        print("Batch label shape:", labels.shape)
        break

    return train_loader, val_loader

def get_loaders_for_training(data_dir, steering_angles_path, step_size, seq_len, imgh, imgw, filter, turn_threshold, 
                                     buffer_before, buffer_after, crop=True, train_size=0.8, save_dir='data/csv_files', 
                                     norm=True, batch_size=16, num_workers=4, prefetch_factor=4, pin_memory=True, train_shuffle=False):
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    #preprocesses data
    data_preprocessed_pd = get_preprocessed_data_pd(data_dir, steering_angles_path, filter, turn_threshold, 
                                     buffer_before, buffer_after, norm, save_dir)

    print('max:', max(data_preprocessed_pd['steering_angle']), 'rad:', max(data_preprocessed_pd['steering_angle'] * math.pi / 180.0))
    print('min:', min(data_preprocessed_pd['steering_angle']), 'rad:', min(data_preprocessed_pd['steering_angle'] * math.pi / 180.0))

    if filter:
        train_csv_filename = f"train_flt_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"
        val_csv_filename = f"val_flt_ncp_tt_{turn_threshold}_bb_{buffer_before}_ba_{buffer_after}.csv"
    else:
        train_csv_filename = f"train_ncp_unfiltered.csv"
        val_csv_filename = f"val_ncp_unfiltered.csv"

    #splits data into train and val
    train_dataset_path, val_dataset_path = df_split_train_val(data_preprocessed_pd,
                                                              train_csv_filename=train_csv_filename,
                                                              val_csv_filename=val_csv_filename,
                                                              save_dir=save_dir,
                                                              train_size=train_size)
    #gets custom train and test pytorch dataset
    train_dataset , val_dataset = create_train_val_dataset(train_csv_file = train_dataset_path,
                                                             val_csv_file = val_dataset_path,
                                                             seq_len=seq_len, imgh=imgh, imgw=imgw,
                                                             step_size=step_size, crop=crop)
    
    #gets dataloaders from dataset
    train_loader, val_loader = create_train_val_loader(train_dataset, val_dataset,
                                                       batch_size=batch_size, 
                                                       num_workers=num_workers, 
                                                       prefetch_factor=prefetch_factor,
                                                       pin_memory=pin_memory, 
                                                       train_shuffle=train_shuffle)

    return train_loader, val_loader

if __name__ == '__main__':

    #preprocessing csv file args
    data_dir = '/kaggle/input/sullychen/07012018/data'
    steering_angles_txt_path = '/kaggle/input/sullychen/07012018/data.txt'
    save_dir = './st_lstm_norm'
    filter = True
    norm=True

    turn_threshold = 0.08
    buffer_before = 32 
    buffer_after = 32
    train_size = 0.8

    #custom pytorch dataset args
    imgh=224
    imgw=224
    step_size = 1
    seq_len = 3
    crop=True

    #dataloader args
    batch_size = 32
    prefetch_factor = 2
    num_workers=4
    pin_memory=True
    train_shuffle=False

    get_loaders_for_training(
        #preprocessing args:
        data_dir, steering_angles_path=steering_angles_txt_path, save_dir=save_dir, filter=filter, norm=norm,
        turn_threshold=turn_threshold, buffer_before=buffer_before, buffer_after=buffer_after, train_size=train_size,
        #dataset args:
        imgh=imgh, imgw=imgw, step_size=step_size, seq_len=seq_len, crop=crop, 
        #dataloader args:
        batch_size=batch_size, prefetch_factor=prefetch_factor, num_workers=num_workers, pin_memory=pin_memory,
        train_shuffle=train_shuffle)

Minimum sequence length: 77
Maximum sequence length: 539
Total valid sequences: 55
max: 0.7016944472222223 rad: 0.012246878446989358
min: -0.9411666861111111 rad: -0.016426457482722874
Train dataset length: 8791
Val dataset length: 2198
./st_lstm_norm
Total examples: 8662 (each is 3 frames → 1 target)
Total examples: 2159 (each is 3 frames → 1 target)
len of train loader: 271
len of val loader 68
Batch input shape: torch.Size([32, 3, 3, 224, 224])
Batch label shape: torch.Size([32, 1])


In [48]:
def train_validate(train_loader, val_loader, optimizer, model, device, criterion, epochs=10, 
                   training_losses = None, val_losses = None, save_every=2, save_dir='/kaggle/working/checkpoints_norm_filt_sl_6_1e-5_mse'):

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    if training_losses is None: training_losses = []
    if val_losses is None: val_losses = []
    
    for epoch in range(epochs+1): 
        model.train()
        running_train_loss = 0.0
        correct_train = 0
        total_train   = 0

        for (batch_x, batch_y) in tqdm(train_loader, desc=f'Training {epoch+1}/{epochs}:', ncols=100):

            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            optimizer.zero_grad()
            predictions = model(batch_x)
            loss = criterion(predictions, batch_y)
            loss.backward()
            optimizer.step()

            running_train_loss += loss.item()
            _, pred_labels = predictions.max(dim=1)
            correct_train += (pred_labels == batch_y).sum().item()
            total_train += batch_y.size(0)

        avg_train_loss = running_train_loss / len(train_loader)

        training_losses.append(avg_train_loss)
        print(f"Train Loss: {avg_train_loss:.4f}")

        #validation loop
        model.eval()
        running_val_loss = 0.0
        correct_val = 0
        total_val   = 0
        for (batch_x, batch_y) in tqdm(val_loader, desc=f'Val {epoch+1}/{epochs}:', ncols=100):
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)

            with torch.no_grad():
                predictions = model(batch_x)
                loss = criterion(predictions, batch_y)

            running_val_loss += loss.item()
            _, pred_labels = predictions.max(dim=1)
            correct_val += (pred_labels == batch_y).sum().item()
            total_val += batch_y.size(0)

        avg_val_loss = running_val_loss / len(val_loader)

        val_losses.append(avg_val_loss)
        print(f"Val Loss: {avg_val_loss:.8f}")

        if epoch % save_every == 0:

            checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch + 1,
            'training_losses': training_losses,
            'val_losses': val_losses,
            }

            model_path = os.path.join(save_dir, f'model_epoch{epoch+1}.pth')
            torch.save(checkpoint, model_path)
            print(f"Checkpoint saved to {save_dir}\n")

    return training_losses, val_losses

In [49]:
class WeightedMSE(nn.Module):
    def __init__(self, alpha=0.1):
        super(WeightedMSE, self).__init__()
        self.alpha = alpha
        
    def forward(self, predictions, targets):
        # squared error
        squared_error = (predictions - targets)**2
        
        # weighting factor: w(y) = exp(|y|*alpha)
        weights = torch.exp(torch.abs(targets) * self.alpha)
        weighted_loss = squared_error * weights

        return weighted_loss.mean()

In [50]:
train_dataset_path = '/kaggle/working/st_lstm_norm/train_flt_ncp_tt_0.08_bb_32_ba_32.csv'
val_dataset_path = '/kaggle/working/st_lstm_norm/val_flt_ncp_tt_0.08_bb_32_ba_32.csv'
seq_len = 6
imgh = 224
imgw = 224
step_size = 1
crop = True

batch_size = 16
num_workers = 4
prefetch_factor = 2
pin_memory = True
train_shuffle = True

In [51]:
train_dataset, val_dataset = create_train_val_dataset(train_csv_file = train_dataset_path,
                                                         val_csv_file = val_dataset_path,
                                                         seq_len=seq_len, imgh=imgh, imgw=imgw,
                                                         step_size=step_size, crop=crop)

#gets dataloaders from dataset
train_loader, val_loader = create_train_val_loader(train_dataset, val_dataset,
                                                   batch_size=batch_size, 
                                                   num_workers=num_workers, 
                                                   prefetch_factor=prefetch_factor,
                                                   pin_memory=pin_memory, 
                                                   train_shuffle=train_shuffle)

for x, y in train_loader:
    print(x.shape)
    break

Total examples: 8533 (each is 6 frames → 1 target)
Total examples: 2120 (each is 6 frames → 1 target)
len of train loader: 534
len of val loader 133
Batch input shape: torch.Size([16, 6, 3, 224, 224])
Batch label shape: torch.Size([16, 1])
torch.Size([16, 6, 3, 224, 224])


In [52]:
model = STConvLSTM(seq_len=6, height=224, width=224).to(device=torch.device('cuda'))
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

train_validate(train_loader, val_loader, optimizer, model, torch.device('cuda'), criterion, 10)

Training 1/10:: 100%|█████████████████████████████████████████████| 534/534 [03:34<00:00,  2.49it/s]


Train Loss: 0.0230


Val 1/10:: 100%|██████████████████████████████████████████████████| 133/133 [00:30<00:00,  4.38it/s]


Val Loss: 0.02427235
Checkpoint saved to /kaggle/working/checkpoints_norm_filt_sl_6_1e-5_mse



Training 2/10:: 100%|█████████████████████████████████████████████| 534/534 [03:34<00:00,  2.49it/s]


Train Loss: 0.0128


Val 2/10:: 100%|██████████████████████████████████████████████████| 133/133 [00:29<00:00,  4.46it/s]


Val Loss: 0.02229142


Training 3/10:: 100%|█████████████████████████████████████████████| 534/534 [03:33<00:00,  2.50it/s]


Train Loss: 0.0101


Val 3/10:: 100%|██████████████████████████████████████████████████| 133/133 [00:29<00:00,  4.46it/s]


Val Loss: 0.02250050
Checkpoint saved to /kaggle/working/checkpoints_norm_filt_sl_6_1e-5_mse



Training 4/10:: 100%|█████████████████████████████████████████████| 534/534 [03:33<00:00,  2.50it/s]


Train Loss: 0.0079


Val 4/10:: 100%|██████████████████████████████████████████████████| 133/133 [00:29<00:00,  4.47it/s]


Val Loss: 0.02278229


Training 5/10:: 100%|█████████████████████████████████████████████| 534/534 [03:33<00:00,  2.50it/s]


Train Loss: 0.0069


Val 5/10:: 100%|██████████████████████████████████████████████████| 133/133 [00:29<00:00,  4.48it/s]


Val Loss: 0.02209171
Checkpoint saved to /kaggle/working/checkpoints_norm_filt_sl_6_1e-5_mse



Training 6/10:: 100%|█████████████████████████████████████████████| 534/534 [03:33<00:00,  2.50it/s]


Train Loss: 0.0062


Val 6/10:: 100%|██████████████████████████████████████████████████| 133/133 [00:30<00:00,  4.43it/s]


Val Loss: 0.02050576


Training 7/10:: 100%|█████████████████████████████████████████████| 534/534 [03:33<00:00,  2.50it/s]


Train Loss: 0.0056


Val 7/10:: 100%|██████████████████████████████████████████████████| 133/133 [00:29<00:00,  4.47it/s]


Val Loss: 0.02059266
Checkpoint saved to /kaggle/working/checkpoints_norm_filt_sl_6_1e-5_mse



Training 8/10:: 100%|█████████████████████████████████████████████| 534/534 [03:33<00:00,  2.50it/s]


Train Loss: 0.0053


Val 8/10:: 100%|██████████████████████████████████████████████████| 133/133 [00:30<00:00,  4.42it/s]


Val Loss: 0.02127501


Training 9/10:: 100%|█████████████████████████████████████████████| 534/534 [03:33<00:00,  2.50it/s]


Train Loss: 0.0048


Val 9/10:: 100%|██████████████████████████████████████████████████| 133/133 [00:30<00:00,  4.43it/s]


Val Loss: 0.02066096
Checkpoint saved to /kaggle/working/checkpoints_norm_filt_sl_6_1e-5_mse



Training 10/10:: 100%|████████████████████████████████████████████| 534/534 [03:33<00:00,  2.50it/s]


Train Loss: 0.0050


Val 10/10:: 100%|█████████████████████████████████████████████████| 133/133 [00:30<00:00,  4.41it/s]


Val Loss: 0.02045179


Training 11/10:: 100%|████████████████████████████████████████████| 534/534 [03:33<00:00,  2.50it/s]


Train Loss: 0.0044


Val 11/10:: 100%|█████████████████████████████████████████████████| 133/133 [00:30<00:00,  4.42it/s]

Val Loss: 0.02083567
Checkpoint saved to /kaggle/working/checkpoints_norm_filt_sl_6_1e-5_mse






([0.02301717814919626,
  0.012803915207019626,
  0.010104719254673085,
  0.007880288538564876,
  0.006871680675025625,
  0.006234068525972521,
  0.00563676723174198,
  0.0052824584860529065,
  0.004815897225356078,
  0.0050171273902201614,
  0.004384999162903795],
 [0.02427234873897746,
  0.022291423887700626,
  0.022500502761178546,
  0.022782294821699633,
  0.0220917121557995,
  0.02050576242399437,
  0.020592663940646218,
  0.021275006522182935,
  0.020660957566828335,
  0.02045179403194881,
  0.020835673522958672])