In [4]:
def extract_boundaries_from_labels(labels, background_class=0):
    B, T = labels.shape
    boundaries = torch.zeros(B, 2)

    for b in range(B):
        action_mask = labels[b] != background_class
        indices = action_mask.nonzero(as_tuple=True)[0]
        # print(action_mask, indices)
        if len(indices) > 0:
            start = indices[0].item()
            end = indices[-1].item()
            # # Normalize
            boundaries[b, 0] = start #/ T
            boundaries[b, 1] = (end + 1) #/ T  # +1 to make end exclusive
        else:
            # If only background: set to [0, 0]
            boundaries[b, :] = 0.0

    return boundaries

In [3]:
import numpy as np
f = np.load('/media/viplab/DATADRIVE1/driver_action_recognition/pose_resnet_features_multi/A1/train/user_id_13522_5/Rear_view_user_id_13522_NoAudio_5.npy')

In [4]:
f.shape

(2545, 50)

In [1]:
import os
import cv2
import torch
import numpy as np
from torchvision import transforms
from torch import nn
from torchvision.datasets import ImageFolder
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
from tqdm import tqdm
import cv2
import re
from mamba_ssm import Mamba
import torch.nn.functional as F


# skip_ids = [16700, 38159, 59359]
train_features_root = '/media/viplab/Storage1/driver_action_recognition/raw_features/A1/train'
train_labels_root = '/home/viplab/Documents/driver_action_recognition/data_processing/array_generation/arrays'

validation_features_root = '/media/viplab/Storage1/driver_action_recognition/raw_features/A1/valid'
validation_labels_root = '/home/viplab/Documents/driver_action_recognition/data_processing/array_generation/arrays'

num_epochs = 100
weights_save_path = "mamba_weights_6144_boundary"


class MultiViewFeatureDataset(Dataset):
    def __init__(self, features_root, labels_root, views=("Dashboard", "Rear_view", "Right_side_window")):
        self.features_root = features_root
        self.labels_root = labels_root
        self.views = views
        # self.skip_ids = ['16700', '38159', '59359']
        self.sample_keys = []
        print(features_root)
        for user_folder in os.listdir(features_root):
            # for i in self.skip_ids:
            #     if i in user_folder:
            #         continue
            print('uuuu', user_folder)
            user_path = os.path.join(features_root, user_folder)
            # print(user_path)
            if not os.path.isdir(user_path):
                continue
            for file in os.listdir(user_path):
                # print('f', file)
                if file.startswith("Dash") and file.endswith(".npy"):
                    print('1', os.path.splitext(file)[0])
                    key = os.path.join(user_folder, os.path.splitext(file)[0]) 
                    print('k', key) # e.g. user_001_1/dash_1
                    self.sample_keys.append(key)

    def __len__(self):
        return len(self.sample_keys)

    def __getitem__(self, idx):
        key = self.sample_keys[idx]

        view_features = []
        for view in self.views:
            pp = view + '_' + "_".join(key.split('_')[:3]) + "_NoAudio_" + f"{key.split('_')[-1]}"
            path = os.path.join(self.features_root, os.path.dirname(key), f"{pp}.npy")
            features = np.load(path).astype(np.float32)  # [seq_len, feat_dim]
            view_features.append(features)
            # print('111111', features.shape, path)

        min_rows = min(view_features[0].shape[0], view_features[1].shape[0], view_features[2].shape[0])
        # Concatenate features across feature dimension
        features_cat = np.concatenate((view_features[0][:min_rows, :], view_features[1][:min_rows, :], view_features[2][:min_rows, :]), axis=1)  # [seq_len, total_feat_dim]
        label_path = os.path.join(self.labels_root, os.path.dirname(key)+ ".npy")
        labels = np.load(label_path).astype(np.int64)  # [seq_len]
        # print(features_cat.shape)
        return features_cat, labels

t_dataset = MultiViewFeatureDataset(
    features_root= train_features_root,
    labels_root= train_labels_root
)

v_dataset = MultiViewFeatureDataset(
    features_root= validation_features_root,
    labels_root= validation_labels_root
)
print(len(t_dataset))
print(len(v_dataset))

class ChunkedVideoDataset(Dataset):
    def __init__(self, base_dataset, chunk_size=100, stride=50):
        self.base_dataset = base_dataset
        self.chunk_size = chunk_size
        self.stride = stride
        self.index_map = []  # (video_idx, start_frame)

        for video_idx in range(len(base_dataset)):
            features, labels = base_dataset[video_idx]
            video_len = features.shape[0]

            for start in range(0, video_len - chunk_size + 1, stride):
                labels_chunk = labels[start:start + chunk_size]

                if labels_chunk.sum() == 0:
                    continue

                self.index_map.append((video_idx, start))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        video_idx, start = self.index_map[idx]
        features, labels = self.base_dataset[video_idx]
        features_chunk = features[start:start+self.chunk_size]
        labels_chunk = labels[start:start+self.chunk_size]
        return features_chunk, labels_chunk

train_dataset = ChunkedVideoDataset(t_dataset, chunk_size=300, stride=75)
valid_dataset = ChunkedVideoDataset(v_dataset, chunk_size=300, stride=300)

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=16,
    shuffle=True
)

for i in train_dataset:
    print('train', len(i))
    break

for i in train_dataset:
    print('valid', len(i))
    break

for features, labels in train_loader:
    print('train')
    print("Concatenated features:", features.shape)  # [1, seq_len, total_feat_dim]
    print("Labels:", labels.shape)
    break 
for features, labels in valid_loader:
    print('valid')
    print("Concatenated features:", features.shape)  # [1, seq_len, total_feat_dim]
    print("Labels:", labels.shape)
    break 


/media/viplab/Storage1/driver_action_recognition/raw_features/A1/train
uuuu user_id_70176_7
1 Dashboard_user_id_70176_NoAudio_7
k user_id_70176_7/Dashboard_user_id_70176_NoAudio_7
uuuu user_id_69039_5
1 Dashboard_user_id_69039_NoAudio_5
k user_id_69039_5/Dashboard_user_id_69039_NoAudio_5
uuuu user_id_16080_5
1 Dashboard_user_id_16080_NoAudio_5
k user_id_16080_5/Dashboard_user_id_16080_NoAudio_5
uuuu user_id_61962_5
1 Dashboard_user_id_61962_NoAudio_5
k user_id_61962_5/Dashboard_user_id_61962_NoAudio_5
uuuu user_id_16700_5
1 Dashboard_user_id_16700_NoAudio_5
k user_id_16700_5/Dashboard_user_id_16700_NoAudio_5
uuuu user_id_59014_7
1 Dashboard_user_id_59014_NoAudio_7
k user_id_59014_7/Dashboard_user_id_59014_NoAudio_7
uuuu user_id_53307_5
1 Dashboard_user_id_53307_NoAudio_5
k user_id_53307_5/Dashboard_user_id_53307_NoAudio_5
uuuu user_id_50921_7
1 Dashboard_user_id_50921_NoAudio_7
k user_id_50921_7/Dashboard_user_id_50921_NoAudio_7
uuuu user_id_52046_5
1 Dashboard_user_id_52046_NoAudio_5


In [2]:
device="cuda"
# for epoch in range(num_epochs):
#     for features, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
#         features = features.to(device)  # [B, 50, 6144]
#         labels = labels.to(device)      # [B, 50]
#         print(labels)
#         true_boundaries = extract_boundaries_from_labels(labels)
#         print(true_boundaries)
#         break
#     break

In [3]:
import numpy as np
import torch

In [8]:
labels = np.array([[1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,8,8,8,8,8,0,0,0,6,6,6,6,6,6,6,6,0,0,0,0,4,4,4,4,4,4,4,0,0,0,0,0,0,6,6,6,6,6,0],
[1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,8,8,8,8,8,0,0,0,6,6,6,6,6,6,6,6,0,0,0,0,4,4,4,4,4,4,4,0,0,0,0,0,0,6,6,6,6,6,0],
[1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,8,8,8,8,8,0,0,0,6,6,6,6,6,6,6,6,0,0,0,0,4,4,4,4,4,4,4,0,0,0,0,0,0,6,6,6,6,6,0]])
extract_boundaries_from_labels(labels, background_class=0)

TypeError: nonzero() takes no keyword arguments

In [4]:
class MambaSequenceClassifier(nn.Module):
    def __init__(self, input_dim=6144, hidden_dim=2048, num_classes=16, seq_len=100):
        super().__init__()
        # self.input_proj = nn.Linear(input_dim, hidden_dim)

        self.mamba_block = Mamba(
            d_model=input_dim,
            d_state=16,
            d_conv=4,
            expand=2
        )
        self.output_layer = nn.Linear(input_dim, num_classes)
    def forward(self, x):
        x = self.mamba_block(x) 
        logits = self.output_layer(x)  # [B, L, C]
        return logits

In [10]:
import torch
from collections import defaultdict

def compute_per_class_accuracy(logits, labels, num_classes, ignore_index=None):
    preds = logits.argmax(dim=-1)  # shape: [batch, seq_len]
    labels = labels.view(-1)
    preds = preds.view(-1)

    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    for cls in range(num_classes):
        if ignore_index is not None and cls == ignore_index:
            continue
        mask = labels == cls
        class_correct[cls] += (preds[mask] == labels[mask]).sum().item()
        class_total[cls] += mask.sum().item()

    class_accuracy = {}
    for cls in range(num_classes):
        if ignore_index is not None and cls == ignore_index:
            continue
        total = class_total[cls]
        correct = class_correct[cls]
        acc = correct / total if total > 0 else 0.0
        class_accuracy[cls] = acc

    return class_accuracy

In [8]:
model = MambaSequenceClassifier(input_dim=6144, hidden_dim=2048, num_classes=16)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-6)
model.load_state_dict(torch.load('mamba_weights_6144_batch16/mamba_best.pth'))  # replace with actual path)
model.eval()
num_classes = 16  # change to your actual number of classes
model.to(device)

MambaSequenceClassifier(
  (mamba_block): Mamba(
    (in_proj): Linear(in_features=6144, out_features=24576, bias=False)
    (conv1d): Conv1d(12288, 12288, kernel_size=(4,), stride=(1,), padding=(3,), groups=12288)
    (act): SiLU()
    (x_proj): Linear(in_features=12288, out_features=416, bias=False)
    (dt_proj): Linear(in_features=384, out_features=12288, bias=True)
    (out_proj): Linear(in_features=12288, out_features=6144, bias=False)
  )
  (output_layer): Linear(in_features=6144, out_features=16, bias=True)
)

In [None]:
all_logits = []
all_labels = []
with torch.no_grad():
    for features, labels in valid_loader:
        features, labels = features.to(device), labels.to(device)
        logits = model(features).cuda()
        all_logits.append(logits.cpu())
        all_labels.append(labels.cpu())

logits = torch.cat(all_logits, dim=0)
labels = torch.cat(all_labels, dim=0)

class_acc = compute_per_class_accuracy(logits, labels, num_classes=num_classes, ignore_index=0)

print("Per-class accuracy:")
for cls, acc in class_acc.items():
    print(f"Class {cls}: {acc:.2%}")

Per-class accuracy:
Class 1: 76.25%
Class 2: 90.91%
Class 3: 73.07%
Class 4: 7.10%
Class 5: 39.16%
Class 6: 58.41%
Class 7: 60.63%
Class 8: 51.10%
Class 9: 29.41%
Class 10: 44.58%
Class 11: 46.93%
Class 12: 30.10%
Class 13: 55.49%
Class 14: 87.25%
Class 15: 23.56%


In [24]:
from sklearn.metrics import precision_recall_fscore_support

def compute_metrics(logits, labels, num_classes, ignore_index=None):
    preds = logits.argmax(dim=-1).view(-1).cpu().numpy()
    labels = labels.view(-1).cpu().numpy()

    if ignore_index is not None:
        mask = labels != ignore_index
        preds = preds[mask]
        labels = labels[mask]

    precision, recall, f1, support = precision_recall_fscore_support(
        labels,
        preds,
        labels=list(range(num_classes)),
        zero_division=0  # Avoid divide-by-zero errors
    )

    metrics = {}
    for cls in range(num_classes):
        if ignore_index is not None and cls == ignore_index:
            continue
        metrics[cls] = {
            "precision": precision[cls],
            "recall": recall[cls],
            "f1": f1[cls],
            "support": support[cls],
        }

    return metrics


In [25]:
model.eval()
num_classes = 16  # change as needed
all_logits = []
all_labels = []

with torch.no_grad():
    for features, labels in valid_loader:
        features, labels = features.to(device), labels.to(device)
        logits = model(features)
        all_logits.append(logits.cpu())
        all_labels.append(labels.cpu())

logits = torch.cat(all_logits, dim=0)
labels = torch.cat(all_labels, dim=0)

metrics = compute_metrics(logits, labels, num_classes=num_classes, ignore_index=0)

print("Per-class metrics (ignoring class 0):")
for cls, values in metrics.items():
    print(f"Class {cls}: "
          f"Precision: {values['precision']:.2%}, "
          f"Recall: {values['recall']:.2%}, "
          f"F1: {values['f1']:.2%}, "
          f"Support: {values['support']}")


Per-class metrics (ignoring class 0):
Class 1: Precision: 99.59%, Recall: 76.25%, F1: 86.37%, Support: 320
Class 2: Precision: 100.00%, Recall: 90.91%, F1: 95.24%, Support: 1155
Class 3: Precision: 95.28%, Recall: 73.07%, F1: 82.71%, Support: 995
Class 4: Precision: 96.88%, Recall: 7.10%, F1: 13.23%, Support: 1310
Class 5: Precision: 97.56%, Recall: 39.16%, F1: 55.89%, Support: 715
Class 6: Precision: 99.64%, Recall: 58.41%, F1: 73.65%, Support: 945
Class 7: Precision: 58.95%, Recall: 60.63%, F1: 59.78%, Support: 315
Class 8: Precision: 89.21%, Recall: 51.10%, F1: 64.98%, Support: 955
Class 9: Precision: 70.22%, Recall: 29.41%, F1: 41.46%, Support: 425
Class 10: Precision: 60.06%, Recall: 44.58%, F1: 51.18%, Support: 415
Class 11: Precision: 79.94%, Recall: 46.93%, F1: 59.14%, Support: 1155
Class 12: Precision: 54.48%, Recall: 30.10%, F1: 38.78%, Support: 1010
Class 13: Precision: 83.83%, Recall: 55.49%, F1: 66.78%, Support: 355
Class 14: Precision: 96.86%, Recall: 87.25%, F1: 91.81%, 

In [21]:
all_logits = []
all_labels = []
with torch.no_grad():
    for features, labels in valid_loader:
        features, labels = features.to(device), labels.to(device)
        logits = model(features).cuda()
        all_logits.append(logits.cpu())
        all_labels.append(labels.cpu())

logits = torch.cat(all_logits, dim=0)
labels = torch.cat(all_labels, dim=0)


In [22]:
len(logits)

133

In [None]:
from scipy.ndimage import median_filter

frame_rate = 30
decision_threshold = 0.4
smoothing_kernel_size = 7
positive_ratio_threshold = 0.3

def rts(labels, frame_rate):
    num_seconds = len(labels) // frame_rate
    second_level = []
    for i in range(num_seconds):
        segment = labels[i * frame_rate: (i + 1) * frame_rate]
        if len(segment) == 0:
            continue
        # Mode class in the second
        vals, counts = np.unique(segment, return_counts=True)
        second_level.append(vals[np.argmax(counts)])
    return np.array(second_level)

model.eval()
num_classes = 16  # change as needed
all_preds = []
all_labels = []
from sklearn.metrics import classification_report
with torch.no_grad():
    for features, labels in valid_loader:
        features, labels = features.to(device), labels.to(device)
        logits = model(features)
        probs = torch.softmax(logits, dim = -1)
        
        labels = labels.squeeze().cpu().numpy()
        probs = probs.cpu().numpy()

        # mask = labels != 0  # Ignore instances of the specified class
        # all_preds.append(probs[mask])
        # all_labels.append(labels[mask])
        all_preds.append(probs)
        all_labels.append(labels)
all_preds = np.concatenate(all_preds, axis = 0)
all_labels = np.concatenate(all_labels, axis = 0)

all_preds = all_preds.reshape(-1, all_preds.shape[-1])
all_labels = all_labels.reshape(-1)

smoothed_preds = np.stack([median_filter(all_preds[:, c], size = 20) for c in range(all_preds.shape[1])], axis = 1)
frame_preds = np.argmax(smoothed_preds, axis = 1)

In [5]:
from scipy.ndimage import median_filter

In [None]:

frame_rate = 30
decision_threshold = 0.4
smoothing_kernel_size = 7
positive_ratio_threshold = 0.3

In [6]:
def rts(labels, frame_rate):
    num_seconds = len(labels) // frame_rate
    second_level = []
    for i in range(num_seconds):
        segment = labels[i * frame_rate: (i + 1) * frame_rate]
        if len(segment) == 0:
            continue
        # Mode class in the second
        vals, counts = np.unique(segment, return_counts=True)
        second_level.append(vals[np.argmax(counts)])
    return np.array(second_level)


In [18]:
model.eval()
num_classes = 16  # change as needed
all_preds = []
all_labels = []
from sklearn.metrics import classification_report
with torch.no_grad():
    for features, labels in valid_loader:
        features, labels = features.to(device), labels.to(device)
        logits = model(features)
        probs = torch.softmax(logits, dim = -1)
        
        labels = labels.squeeze().cpu().numpy()
        probs = probs.cpu().numpy()

        # mask = labels != 0  # Ignore instances of the specified class
        # all_preds.append(probs[mask])
        # all_labels.append(labels[mask])
        all_preds.append(probs)
        all_labels.append(labels)
all_preds = np.concatenate(all_preds, axis = 0)
all_labels = np.concatenate(all_labels, axis = 0)

all_preds = all_preds.reshape(-1, all_preds.shape[-1])
all_labels = all_labels.reshape(-1)

smoothed_preds = np.stack([median_filter(all_preds[:, c], size = 20) for c in range(all_preds.shape[1])], axis = 1)
frame_preds = np.argmax(smoothed_preds, axis = 1)

In [19]:
second_preds = rts(frame_preds, 5)
second_labels = rts(all_labels, 5)

In [22]:
second_preds.shape

(7980,)

In [23]:
labels_to_include = [i for i in range(num_classes) if i != 0]

print(classification_report(second_labels, second_preds, labels=labels_to_include))

              precision    recall  f1-score   support

           1       0.81      0.69      0.75        64
           2       0.87      0.92      0.89       231
           3       0.80      0.73      0.77       199
           4       0.21      0.05      0.08       262
           5       0.75      0.38      0.50       143
           6       0.68      0.59      0.63       189
           7       0.49      0.68      0.57        63
           8       0.77      0.52      0.62       191
           9       0.59      0.27      0.37        85
          10       0.61      0.47      0.53        83
          11       0.66      0.46      0.54       231
          12       0.44      0.30      0.36       202
          13       0.68      0.51      0.58        71
          14       0.89      0.88      0.88       262
          15       0.23      0.23      0.23       281

   micro avg       0.64      0.50      0.56      2557
   macro avg       0.63      0.51      0.55      2557
weighted avg       0.61   

In [16]:
second_labels.shape

(2557,)

In [13]:
print(classification_report(second_labels, second_preds))

              precision    recall  f1-score   support

           0       0.82      0.91      0.86      5423
           1       0.81      0.69      0.75        64
           2       0.87      0.92      0.89       231
           3       0.80      0.73      0.77       199
           4       0.21      0.05      0.08       262
           5       0.75      0.38      0.50       143
           6       0.68      0.59      0.63       189
           7       0.49      0.68      0.57        63
           8       0.77      0.52      0.62       191
           9       0.59      0.27      0.37        85
          10       0.61      0.47      0.53        83
          11       0.66      0.46      0.54       231
          12       0.44      0.30      0.36       202
          13       0.68      0.51      0.58        71
          14       0.89      0.88      0.88       262
          15       0.23      0.23      0.23       281

    accuracy                           0.78      7980
   macro avg       0.64   

In [None]:
print(classification_report(second_labels, second_preds))

              precision    recall  f1-score   support

           0       0.91      0.82      0.86      5992
           1       0.69      0.81      0.75        54
           2       0.92      0.87      0.89       243
           3       0.73      0.80      0.77       182
           4       0.05      0.21      0.08        68
           5       0.38      0.75      0.50        72
           6       0.59      0.68      0.63       164
           7       0.68      0.49      0.57        88
           8       0.52      0.77      0.62       128
           9       0.27      0.59      0.37        39
          10       0.47      0.61      0.53        64
          11       0.46      0.66      0.54       161
          12       0.30      0.44      0.36       135
          13       0.51      0.68      0.58        53
          14       0.88      0.89      0.88       259
          15       0.23      0.23      0.23       278

    accuracy                           0.78      7980
   macro avg       0.54   

In [None]:
print(classification_report(second_labels, second_preds))

              precision    recall  f1-score   support

           0       0.90      0.82      0.86      5957
           1       0.75      0.84      0.79        57
           2       0.91      0.89      0.90       238
           3       0.74      0.80      0.77       183
           4       0.06      0.22      0.09        69
           5       0.39      0.75      0.51        75
           6       0.59      0.70      0.64       159
           7       0.65      0.46      0.54        89
           8       0.52      0.75      0.62       133
           9       0.26      0.48      0.34        46
          10       0.45      0.52      0.48        71
          11       0.45      0.65      0.53       159
          12       0.30      0.41      0.34       146
          13       0.55      0.74      0.63        53
          14       0.88      0.89      0.88       258
          15       0.24      0.24      0.24       287

    accuracy                           0.78      7980
   macro avg       0.54   

In [18]:
len(second_preds)

7980

In [20]:
print(classification_report(second_preds, second_labels))

              precision    recall  f1-score   support

           0       0.89      0.82      0.86      5860
           1       0.77      0.80      0.78        61
           2       0.91      0.88      0.90       239
           3       0.73      0.80      0.76       183
           4       0.06      0.22      0.10        72
           5       0.41      0.72      0.52        81
           6       0.60      0.69      0.64       163
           7       0.65      0.38      0.48       107
           8       0.51      0.68      0.58       143
           9       0.26      0.37      0.30        60
          10       0.47      0.54      0.50        72
          11       0.47      0.65      0.55       168
          12       0.30      0.40      0.34       152
          13       0.55      0.63      0.59        62
          14       0.88      0.87      0.87       264
          15       0.23      0.23      0.23       293

    accuracy                           0.77      7980
   macro avg       0.54   

In [None]:
Class 1: Precision: 99.59%, Recall: 76.25%, F1: 86.37%, Support: 320
Class 2: Precision: 100.00%, Recall: 90.91%, F1: 95.24%, Support: 1155
Class 3: Precision: 95.28%, Recall: 73.07%, F1: 82.71%, Support: 995
Class 4: Precision: 96.88%, Recall: 7.10%, F1: 13.23%, Support: 1310
Class 5: Precision: 97.56%, Recall: 39.16%, F1: 55.89%, Support: 715
Class 6: Precision: 99.64%, Recall: 58.41%, F1: 73.65%, Support: 945
Class 7: Precision: 58.95%, Recall: 60.63%, F1: 59.78%, Support: 315
Class 8: Precision: 89.21%, Recall: 51.10%, F1: 64.98%, Support: 955
Class 9: Precision: 70.22%, Recall: 29.41%, F1: 41.46%, Support: 425
Class 10: Precision: 60.06%, Recall: 44.58%, F1: 51.18%, Support: 415
Class 11: Precision: 79.94%, Recall: 46.93%, F1: 59.14%, Support: 1155
Class 12: Precision: 54.48%, Recall: 30.10%, F1: 38.78%, Support: 1010
Class 13: Precision: 83.83%, Recall: 55.49%, F1: 66.78%, Support: 355
Class 14: Precision: 96.86%, Recall: 87.25%, F1: 91.81%, Support: 1310
Class 15: Precision: 48.53%, Recall: 23.56%, F1: 31.72%, Support: 1405

In [2]:
import os
import cv2
import torch
import numpy as np
from torchvision import transforms
from torch import nn
from torchvision.datasets import ImageFolder
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
from tqdm import tqdm
import cv2
import re
from mamba_ssm import Mamba
import torch.nn.functional as F


# skip_ids = [16700, 38159, 59359]
train_features_root = '/media/viplab/DATADRIVE1/driver_action_recognition/pose_resnet_features_multi/A1/train'
train_labels_root = '/home/viplab/Documents/driver_action_recognition/data_processing/array_generation/arrays'

validation_features_root = '/media/viplab/DATADRIVE1/driver_action_recognition/pose_resnet_features_multi/A1/valid'
validation_labels_root = '/home/viplab/Documents/driver_action_recognition/data_processing/array_generation/arrays'

num_epochs = 100
weights_save_path = "mamba_weights_pose_1e-5"


class MultiViewFeatureDataset(Dataset):
    def __init__(self, features_root, labels_root, views=("Dashboard", "Rear_view", "Right_side_window")):
        self.features_root = features_root
        self.labels_root = labels_root
        self.views = views
        # self.skip_ids = ['16700', '38159', '59359']
        self.sample_keys = []
        # print(features_root)
        for user_folder in os.listdir(features_root):
            # for i in self.skip_ids:
            #     if i in user_folder:
            #         continue
            # print('uuuu', user_folder)
            user_path = os.path.join(features_root, user_folder)
            # print(user_path)
            if not os.path.isdir(user_path):
                continue
            for file in os.listdir(user_path):
                # print('f', file)
                if file.startswith("Dash") and file.endswith(".npy"):
                    # print('1', os.path.splitext(file)[0])
                    key = os.path.join(user_folder, os.path.splitext(file)[0]) 
                    # print('k', key) # e.g. user_001_1/dash_1
                    self.sample_keys.append(key)

    def __len__(self):
        return len(self.sample_keys)

    def __getitem__(self, idx):
        key = self.sample_keys[idx]

        view_features = []
        for view in self.views:
            pp = view + '_' + "_".join(key.split('_')[:3]) + "_NoAudio_" + f"{key.split('_')[-1]}"
            path = os.path.join(self.features_root, os.path.dirname(key), f"{pp}.npy")
            features = np.load(path).astype(np.float32)  # [seq_len, feat_dim]
            view_features.append(features)
            # print('111111', features.shape, path)

        min_rows = min(view_features[0].shape[0], view_features[1].shape[0], view_features[2].shape[0])
        # Concatenate features across feature dimension
        features_cat = np.concatenate((view_features[0][:min_rows, -34:], view_features[1][:min_rows, -34:], view_features[2][:min_rows, -34:]), axis=1)  # [seq_len, total_feat_dim]
        label_path = os.path.join(self.labels_root, os.path.dirname(key)+ ".npy")
        labels = np.load(label_path).astype(np.int64)  # [seq_len]
        # print(features_cat.shape)
        return features_cat, labels

t_dataset = MultiViewFeatureDataset(
    features_root= train_features_root,
    labels_root= train_labels_root
)

v_dataset = MultiViewFeatureDataset(
    features_root= validation_features_root,
    labels_root= validation_labels_root
)
print(len(t_dataset))
print(len(v_dataset))

class ChunkedVideoDataset(Dataset):
    def __init__(self, base_dataset, chunk_size=100, stride=50):
        self.base_dataset = base_dataset
        self.chunk_size = chunk_size
        self.stride = stride
        self.index_map = []  # (video_idx, start_frame)

        for video_idx in range(len(base_dataset)):
            features, labels = base_dataset[video_idx]
            video_len = features.shape[0]

            for start in range(0, video_len - chunk_size + 1, stride):
                labels_chunk = labels[start:start + chunk_size]

                if labels_chunk.sum() == 0:
                    continue

                self.index_map.append((video_idx, start))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        video_idx, start = self.index_map[idx]
        features, labels = self.base_dataset[video_idx]
        features_chunk = features[start:start+self.chunk_size]
        labels_chunk = labels[start:start+self.chunk_size]
        return features_chunk, labels_chunk

train_dataset = ChunkedVideoDataset(t_dataset, chunk_size=300, stride=75)
valid_dataset = ChunkedVideoDataset(v_dataset, chunk_size=300, stride=300)

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=8,
    shuffle=False
)

for i in train_dataset:
    print('train', len(i))
    break

for i in train_dataset:
    print('valid', len(i))
    break

for features, labels in train_loader:
    print('train')
    print("Concatenated features:", features.shape)  # [1, seq_len, total_feat_dim]
    print("Labels:", labels.shape)
    break 
for features, labels in valid_loader:
    print('valid')
    print("Concatenated features:", features.shape)  # [1, seq_len, total_feat_dim]
    print("Labels:", labels.shape)
    break 

class MambaSequenceClassifier(nn.Module):
    def __init__(self, input_dim=6246, hidden_dim=2048, num_classes=16, seq_len=100):
        super().__init__()
        # self.input_proj = nn.Linear(input_dim, hidden_dim)

        self.mamba_block = Mamba(
            d_model=input_dim,
            d_state=16,
            d_conv=4,
            expand=2
        )
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, num_classes)

        self.dropout = nn.Dropout(p=0.2)
        self.relu = nn.ReLU()
    def forward(self, x):
        """
        x: [batch_size, seq_len, input_dim]
        returns: [batch_size, seq_len, num_classes]
        """ 
        # x = self.input_proj(x)  # -> [B, L, H]
        x = self.mamba_block(x)  # [B, L, H]
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        logits = self.output_layer(x)  # [B, L, C]
        return logits
    
# class BoundaryAwareLoss(nn.Module):
#     def __init__(self, classification_weight=1.0, localization_weight=1.0, ignore_index=None):
#         super().__init__()
#         self.classification_weight = classification_weight
#         self.localization_weight = localization_weight
#         self.ignore_index = ignore_index

#     def forward(self, frame_logits, labels, pred_boundaries, true_boundaries):
#         B, T, C = frame_logits.shape

#         # Frame-wise classification loss
#         loss_cls = F.cross_entropy(
#             frame_logits.view(-1, C),
#             labels.view(-1),
#             ignore_index=self.ignore_index if self.ignore_index is not None else -100
#         )

#         # Temporal boundary regression loss (L1)
#         loss_loc = F.l1_loss(pred_boundaries, true_boundaries)

#         return self.classification_weight * loss_cls + self.localization_weight * loss_loc



def train_mamba(
    model, train_loader, val_loader, optimizer, num_epochs, weights_save_path, device,
    num_classes=16, ignore_index=None
):
    model.to(device)
    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_correct = 0
        total_count = 0

        for features, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            features = features.to(device)  # [B, 50, 6144]
            labels = labels.to(device)      # [B, 50]

            optimizer.zero_grad()
            logits = model(features)        # [B, 50, num_classes]

            loss = F.cross_entropy(
                logits.view(-1, num_classes),
                labels.view(-1),
                # ignore_index=ignore_index if ignore_index is not None else -100,
            )
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # --- Accuracy ---
            preds = logits.argmax(dim=-1)  # [B, 50]
            # if ignore_index is not None:
            #     mask = labels != ignore_index
            #     total_correct += (preds[mask] == labels[mask]).sum().item()
            #     total_count += mask.sum().item()
            # else:
            total_correct += (preds == labels).sum().item()
            total_count += labels.numel()

        avg_train_loss = total_loss / len(train_loader)
        train_accuracy = total_correct / total_count if total_count > 0 else 0.0
        print(f"Epoch {epoch+1} -  Training Loss: {avg_train_loss:.4f} - Accuracy: {train_accuracy*100:.2f}%")
        # --- Validation ---
        model.eval()
        val_loss = 0
        val_correct = 0
        val_count = 0
        with torch.no_grad():
            for features, labels in tqdm(val_loader, desc="Validation"):
                features = features.to(device)
                labels = labels.to(device)

                logits = model(features)
                loss = F.cross_entropy(
                    logits.view(-1, num_classes),
                    labels.view(-1),
                    # ignore_index=ignore_index if ignore_index is not None else -100,
                )
                val_loss += loss.item()

                preds = logits.argmax(dim=-1)
                val_correct += (preds == labels).sum().item()
                val_count += labels.numel()

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = val_correct / val_count if val_count > 0 else 0.0
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            # Save the model checkpoint
            torch.save(model.state_dict(), f"{weights_save_path}/mamba_best.pth")
            print(f"Best model saved with accuracy: {val_accuracy*100:.2f}%")

        print(f"Epoch {epoch+1} - Validation Loss: {avg_val_loss:.4f} - Accuracy: {val_accuracy*100:.2f}%")
        # Save the model checkpoint
        if (epoch + 1) % 5 == 0:
            # Save the model checkpoint
            torch.save(model.state_dict(), f"{weights_save_path}/mamba_epoch_{epoch+1}.pth")
            print(f"Model saved at epoch {epoch+1}")


model = MambaSequenceClassifier(input_dim=102, hidden_dim= 68, num_classes=16)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

train_mamba(model, train_loader, valid_loader, optimizer, num_epochs, weights_save_path, device="cuda", num_classes=16)

120
18
train 2
valid 2
train
Concatenated features: torch.Size([8, 300, 102])
Labels: torch.Size([8, 300])
valid
Concatenated features: torch.Size([8, 300, 102])
Labels: torch.Size([8, 300])


Epoch 1/100: 100%|██████████| 429/429 [02:43<00:00,  2.62it/s]


Epoch 1 -  Training Loss: 556503.6836 - Accuracy: 7.73%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.63it/s]


Best model saved with accuracy: 19.58%
Epoch 1 - Validation Loss: 197672.4660 - Accuracy: 19.58%


Epoch 2/100: 100%|██████████| 429/429 [02:45<00:00,  2.59it/s]


Epoch 2 -  Training Loss: 158333.3511 - Accuracy: 16.84%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.62it/s]


Best model saved with accuracy: 27.01%
Epoch 2 - Validation Loss: 92096.5593 - Accuracy: 27.01%


Epoch 3/100: 100%|██████████| 429/429 [02:44<00:00,  2.61it/s]


Epoch 3 -  Training Loss: 80720.7606 - Accuracy: 23.13%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.59it/s]


Best model saved with accuracy: 31.27%
Epoch 3 - Validation Loss: 53399.5597 - Accuracy: 31.27%


Epoch 4/100: 100%|██████████| 429/429 [02:43<00:00,  2.62it/s]


Epoch 4 -  Training Loss: 48893.6573 - Accuracy: 27.27%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Best model saved with accuracy: 34.01%
Epoch 4 - Validation Loss: 33713.1264 - Accuracy: 34.01%


Epoch 5/100: 100%|██████████| 429/429 [02:42<00:00,  2.63it/s]


Epoch 5 -  Training Loss: 32223.4761 - Accuracy: 29.42%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Best model saved with accuracy: 35.35%
Epoch 5 - Validation Loss: 22970.0344 - Accuracy: 35.35%
Model saved at epoch 5


Epoch 6/100: 100%|██████████| 429/429 [02:43<00:00,  2.63it/s]


Epoch 6 -  Training Loss: 22441.0810 - Accuracy: 30.71%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.59it/s]


Epoch 6 - Validation Loss: 16311.8316 - Accuracy: 35.25%


Epoch 7/100: 100%|██████████| 429/429 [02:42<00:00,  2.64it/s]


Epoch 7 -  Training Loss: 16055.5009 - Accuracy: 31.28%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.80it/s]


Best model saved with accuracy: 35.51%
Epoch 7 - Validation Loss: 11823.9807 - Accuracy: 35.51%


Epoch 8/100: 100%|██████████| 429/429 [02:42<00:00,  2.65it/s]


Epoch 8 -  Training Loss: 11660.4478 - Accuracy: 31.71%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.68it/s]


Epoch 8 - Validation Loss: 8855.6256 - Accuracy: 35.01%


Epoch 9/100: 100%|██████████| 429/429 [02:42<00:00,  2.65it/s]


Epoch 9 -  Training Loss: 8645.4002 - Accuracy: 31.84%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.64it/s]


Epoch 9 - Validation Loss: 6694.6798 - Accuracy: 34.71%


Epoch 10/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 10 -  Training Loss: 6450.4355 - Accuracy: 31.59%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.66it/s]


Epoch 10 - Validation Loss: 5107.1976 - Accuracy: 34.55%
Model saved at epoch 10


Epoch 11/100: 100%|██████████| 429/429 [02:48<00:00,  2.55it/s]


Epoch 11 -  Training Loss: 4838.3306 - Accuracy: 31.46%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Epoch 11 - Validation Loss: 3889.1791 - Accuracy: 34.27%


Epoch 12/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 12 -  Training Loss: 3639.0284 - Accuracy: 31.34%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.65it/s]


Epoch 12 - Validation Loss: 2953.1020 - Accuracy: 34.02%


Epoch 13/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 13 -  Training Loss: 2733.6442 - Accuracy: 31.69%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.48it/s]


Epoch 13 - Validation Loss: 2270.6568 - Accuracy: 34.70%


Epoch 14/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 14 -  Training Loss: 2065.0401 - Accuracy: 31.95%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


Best model saved with accuracy: 36.58%
Epoch 14 - Validation Loss: 1749.3224 - Accuracy: 36.58%


Epoch 15/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 15 -  Training Loss: 1565.9544 - Accuracy: 32.80%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Best model saved with accuracy: 40.33%
Epoch 15 - Validation Loss: 1344.8858 - Accuracy: 40.33%
Model saved at epoch 15


Epoch 16/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 16 -  Training Loss: 1207.5348 - Accuracy: 34.10%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.77it/s]


Best model saved with accuracy: 44.55%
Epoch 16 - Validation Loss: 1046.1406 - Accuracy: 44.55%


Epoch 17/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 17 -  Training Loss: 925.5480 - Accuracy: 35.30%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Best model saved with accuracy: 46.81%
Epoch 17 - Validation Loss: 823.9498 - Accuracy: 46.81%


Epoch 18/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 18 -  Training Loss: 726.6182 - Accuracy: 36.63%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Best model saved with accuracy: 49.60%
Epoch 18 - Validation Loss: 658.2621 - Accuracy: 49.60%


Epoch 19/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 19 -  Training Loss: 569.4438 - Accuracy: 38.08%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Best model saved with accuracy: 51.43%
Epoch 19 - Validation Loss: 523.3839 - Accuracy: 51.43%


Epoch 20/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 20 -  Training Loss: 445.7380 - Accuracy: 39.59%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.74it/s]


Best model saved with accuracy: 52.88%
Epoch 20 - Validation Loss: 421.2691 - Accuracy: 52.88%
Model saved at epoch 20


Epoch 21/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 21 -  Training Loss: 351.3379 - Accuracy: 41.02%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Best model saved with accuracy: 54.21%
Epoch 21 - Validation Loss: 339.7896 - Accuracy: 54.21%


Epoch 22/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 22 -  Training Loss: 276.3401 - Accuracy: 41.89%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.75it/s]


Best model saved with accuracy: 54.91%
Epoch 22 - Validation Loss: 278.0471 - Accuracy: 54.91%


Epoch 23/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 23 -  Training Loss: 221.6904 - Accuracy: 42.86%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.56it/s]


Best model saved with accuracy: 55.60%
Epoch 23 - Validation Loss: 230.7187 - Accuracy: 55.60%


Epoch 24/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 24 -  Training Loss: 177.3774 - Accuracy: 43.18%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.66it/s]


Best model saved with accuracy: 56.34%
Epoch 24 - Validation Loss: 191.0626 - Accuracy: 56.34%


Epoch 25/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 25 -  Training Loss: 141.1742 - Accuracy: 43.33%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.66it/s]


Best model saved with accuracy: 56.88%
Epoch 25 - Validation Loss: 161.2149 - Accuracy: 56.88%
Model saved at epoch 25


Epoch 26/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 26 -  Training Loss: 112.5787 - Accuracy: 43.38%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.74it/s]


Best model saved with accuracy: 58.23%
Epoch 26 - Validation Loss: 135.8949 - Accuracy: 58.23%


Epoch 27/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 27 -  Training Loss: 94.4103 - Accuracy: 43.49%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Epoch 27 - Validation Loss: 115.8763 - Accuracy: 58.21%


Epoch 28/100: 100%|██████████| 429/429 [02:42<00:00,  2.64it/s]


Epoch 28 -  Training Loss: 74.7427 - Accuracy: 43.01%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


Epoch 28 - Validation Loss: 98.1673 - Accuracy: 58.02%


Epoch 29/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 29 -  Training Loss: 62.9902 - Accuracy: 42.35%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Epoch 29 - Validation Loss: 84.2257 - Accuracy: 58.12%


Epoch 30/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 30 -  Training Loss: 49.0652 - Accuracy: 42.16%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.66it/s]


Epoch 30 - Validation Loss: 73.0804 - Accuracy: 57.58%
Model saved at epoch 30


Epoch 31/100: 100%|██████████| 429/429 [02:41<00:00,  2.65it/s]


Epoch 31 -  Training Loss: 41.5165 - Accuracy: 41.92%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.61it/s]


Epoch 31 - Validation Loss: 64.3538 - Accuracy: 57.09%


Epoch 32/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 32 -  Training Loss: 33.2044 - Accuracy: 42.61%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


Epoch 32 - Validation Loss: 56.0605 - Accuracy: 56.47%


Epoch 33/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 33 -  Training Loss: 27.6888 - Accuracy: 44.50%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Epoch 33 - Validation Loss: 50.2124 - Accuracy: 56.07%


Epoch 34/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 34 -  Training Loss: 22.9633 - Accuracy: 46.76%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.68it/s]


Best model saved with accuracy: 59.02%
Epoch 34 - Validation Loss: 44.1731 - Accuracy: 59.02%


Epoch 35/100: 100%|██████████| 429/429 [02:41<00:00,  2.66it/s]


Epoch 35 -  Training Loss: 19.2221 - Accuracy: 49.40%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.66it/s]


Best model saved with accuracy: 61.19%
Epoch 35 - Validation Loss: 40.0314 - Accuracy: 61.19%
Model saved at epoch 35


Epoch 36/100: 100%|██████████| 429/429 [02:52<00:00,  2.49it/s]


Epoch 36 -  Training Loss: 16.3651 - Accuracy: 51.54%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.82it/s]


Best model saved with accuracy: 62.13%
Epoch 36 - Validation Loss: 35.9925 - Accuracy: 62.13%


Epoch 37/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 37 -  Training Loss: 14.0214 - Accuracy: 53.63%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Best model saved with accuracy: 62.85%
Epoch 37 - Validation Loss: 33.2537 - Accuracy: 62.85%


Epoch 38/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 38 -  Training Loss: 11.8705 - Accuracy: 55.76%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.77it/s]


Best model saved with accuracy: 63.23%
Epoch 38 - Validation Loss: 29.5935 - Accuracy: 63.23%


Epoch 39/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 39 -  Training Loss: 9.9327 - Accuracy: 57.58%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.86it/s]


Best model saved with accuracy: 63.73%
Epoch 39 - Validation Loss: 26.8527 - Accuracy: 63.73%


Epoch 40/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 40 -  Training Loss: 8.7390 - Accuracy: 59.11%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.74it/s]


Best model saved with accuracy: 64.36%
Epoch 40 - Validation Loss: 24.7523 - Accuracy: 64.36%
Model saved at epoch 40


Epoch 41/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 41 -  Training Loss: 7.4696 - Accuracy: 60.38%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.75it/s]


Best model saved with accuracy: 64.68%
Epoch 41 - Validation Loss: 23.1088 - Accuracy: 64.68%


Epoch 42/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 42 -  Training Loss: 6.6998 - Accuracy: 61.46%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.77it/s]


Best model saved with accuracy: 64.85%
Epoch 42 - Validation Loss: 21.9198 - Accuracy: 64.85%


Epoch 43/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 43 -  Training Loss: 6.0591 - Accuracy: 62.36%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Best model saved with accuracy: 65.10%
Epoch 43 - Validation Loss: 21.0459 - Accuracy: 65.10%


Epoch 44/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 44 -  Training Loss: 5.2891 - Accuracy: 63.15%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Best model saved with accuracy: 65.40%
Epoch 44 - Validation Loss: 18.9421 - Accuracy: 65.40%


Epoch 45/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 45 -  Training Loss: 4.5627 - Accuracy: 63.76%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.74it/s]


Best model saved with accuracy: 65.46%
Epoch 45 - Validation Loss: 18.5538 - Accuracy: 65.46%
Model saved at epoch 45


Epoch 46/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 46 -  Training Loss: 4.3031 - Accuracy: 64.02%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


Best model saved with accuracy: 65.62%
Epoch 46 - Validation Loss: 17.7276 - Accuracy: 65.62%


Epoch 47/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 47 -  Training Loss: 3.9693 - Accuracy: 64.48%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


Best model saved with accuracy: 65.85%
Epoch 47 - Validation Loss: 16.8299 - Accuracy: 65.85%


Epoch 48/100: 100%|██████████| 429/429 [02:41<00:00,  2.65it/s]


Epoch 48 -  Training Loss: 3.6353 - Accuracy: 64.72%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Best model saved with accuracy: 66.26%
Epoch 48 - Validation Loss: 15.4007 - Accuracy: 66.26%


Epoch 49/100: 100%|██████████| 429/429 [02:41<00:00,  2.65it/s]


Epoch 49 -  Training Loss: 3.3099 - Accuracy: 65.02%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.56it/s]


Best model saved with accuracy: 66.66%
Epoch 49 - Validation Loss: 14.5974 - Accuracy: 66.66%


Epoch 50/100: 100%|██████████| 429/429 [02:37<00:00,  2.73it/s]


Epoch 50 -  Training Loss: 3.0558 - Accuracy: 65.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.77it/s]


Epoch 50 - Validation Loss: 14.3408 - Accuracy: 66.64%
Model saved at epoch 50


Epoch 51/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 51 -  Training Loss: 2.9588 - Accuracy: 65.32%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


Best model saved with accuracy: 66.87%
Epoch 51 - Validation Loss: 13.1136 - Accuracy: 66.87%


Epoch 52/100: 100%|██████████| 429/429 [02:38<00:00,  2.72it/s]


Epoch 52 -  Training Loss: 2.8956 - Accuracy: 65.44%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Best model saved with accuracy: 66.97%
Epoch 52 - Validation Loss: 13.1510 - Accuracy: 66.97%


Epoch 53/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 53 -  Training Loss: 2.6943 - Accuracy: 65.52%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Best model saved with accuracy: 67.20%
Epoch 53 - Validation Loss: 12.5435 - Accuracy: 67.20%


Epoch 54/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 54 -  Training Loss: 2.5629 - Accuracy: 65.61%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Best model saved with accuracy: 67.24%
Epoch 54 - Validation Loss: 12.0117 - Accuracy: 67.24%


Epoch 55/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 55 -  Training Loss: 2.4558 - Accuracy: 65.71%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Best model saved with accuracy: 67.32%
Epoch 55 - Validation Loss: 11.3744 - Accuracy: 67.32%
Model saved at epoch 55


Epoch 56/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 56 -  Training Loss: 2.2489 - Accuracy: 65.83%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Best model saved with accuracy: 67.50%
Epoch 56 - Validation Loss: 10.3918 - Accuracy: 67.50%


Epoch 57/100: 100%|██████████| 429/429 [02:36<00:00,  2.75it/s]


Epoch 57 -  Training Loss: 2.2775 - Accuracy: 66.00%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Best model saved with accuracy: 67.66%
Epoch 57 - Validation Loss: 10.2133 - Accuracy: 67.66%


Epoch 58/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 58 -  Training Loss: 2.1342 - Accuracy: 66.23%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Best model saved with accuracy: 67.73%
Epoch 58 - Validation Loss: 10.2918 - Accuracy: 67.73%


Epoch 59/100: 100%|██████████| 429/429 [02:37<00:00,  2.73it/s]


Epoch 59 -  Training Loss: 2.1214 - Accuracy: 66.29%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Best model saved with accuracy: 67.89%
Epoch 59 - Validation Loss: 9.8805 - Accuracy: 67.89%


Epoch 60/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 60 -  Training Loss: 2.0289 - Accuracy: 66.32%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Best model saved with accuracy: 67.90%
Epoch 60 - Validation Loss: 9.9233 - Accuracy: 67.90%
Model saved at epoch 60


Epoch 61/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 61 -  Training Loss: 1.9670 - Accuracy: 66.32%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.82it/s]


Best model saved with accuracy: 67.94%
Epoch 61 - Validation Loss: 9.3090 - Accuracy: 67.94%


Epoch 62/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 62 -  Training Loss: 1.9402 - Accuracy: 66.34%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.65it/s]


Best model saved with accuracy: 67.95%
Epoch 62 - Validation Loss: 8.6018 - Accuracy: 67.95%


Epoch 63/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 63 -  Training Loss: 1.8833 - Accuracy: 66.37%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Best model saved with accuracy: 67.95%
Epoch 63 - Validation Loss: 9.0889 - Accuracy: 67.95%


Epoch 64/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 64 -  Training Loss: 1.8237 - Accuracy: 66.40%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.77it/s]


Best model saved with accuracy: 67.98%
Epoch 64 - Validation Loss: 8.6892 - Accuracy: 67.98%


Epoch 65/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 65 -  Training Loss: 1.8809 - Accuracy: 66.43%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Epoch 65 - Validation Loss: 8.6043 - Accuracy: 67.98%
Model saved at epoch 65


Epoch 66/100: 100%|██████████| 429/429 [02:45<00:00,  2.59it/s]


Epoch 66 -  Training Loss: 1.8208 - Accuracy: 66.45%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.48it/s]


Best model saved with accuracy: 68.00%
Epoch 66 - Validation Loss: 8.2806 - Accuracy: 68.00%


Epoch 67/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 67 -  Training Loss: 1.7446 - Accuracy: 66.47%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Best model saved with accuracy: 68.01%
Epoch 67 - Validation Loss: 7.8826 - Accuracy: 68.01%


Epoch 68/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 68 -  Training Loss: 1.7115 - Accuracy: 66.49%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Best model saved with accuracy: 68.03%
Epoch 68 - Validation Loss: 7.8021 - Accuracy: 68.03%


Epoch 69/100: 100%|██████████| 429/429 [02:39<00:00,  2.68it/s]


Epoch 69 -  Training Loss: 1.7720 - Accuracy: 66.51%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.65it/s]


Best model saved with accuracy: 68.04%
Epoch 69 - Validation Loss: 7.8205 - Accuracy: 68.04%


Epoch 70/100: 100%|██████████| 429/429 [02:48<00:00,  2.55it/s]


Epoch 70 -  Training Loss: 1.7051 - Accuracy: 66.52%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.59it/s]


Epoch 70 - Validation Loss: 7.6453 - Accuracy: 68.00%
Model saved at epoch 70


Epoch 71/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 71 -  Training Loss: 1.6735 - Accuracy: 66.54%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.81it/s]


Epoch 71 - Validation Loss: 7.6523 - Accuracy: 68.02%


Epoch 72/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 72 -  Training Loss: 1.6457 - Accuracy: 66.56%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.74it/s]


Best model saved with accuracy: 68.05%
Epoch 72 - Validation Loss: 7.1547 - Accuracy: 68.05%


Epoch 73/100: 100%|██████████| 429/429 [02:44<00:00,  2.60it/s]


Epoch 73 -  Training Loss: 1.6688 - Accuracy: 66.56%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Epoch 73 - Validation Loss: 7.1729 - Accuracy: 68.03%


Epoch 74/100: 100%|██████████| 429/429 [02:43<00:00,  2.62it/s]


Epoch 74 -  Training Loss: 1.6468 - Accuracy: 66.58%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.82it/s]


Epoch 74 - Validation Loss: 7.0800 - Accuracy: 68.04%


Epoch 75/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 75 -  Training Loss: 1.6736 - Accuracy: 66.58%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.60it/s]


Epoch 75 - Validation Loss: 6.6154 - Accuracy: 68.04%
Model saved at epoch 75


Epoch 76/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 76 -  Training Loss: 1.6353 - Accuracy: 66.60%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.75it/s]


Epoch 76 - Validation Loss: 6.3801 - Accuracy: 68.02%


Epoch 77/100: 100%|██████████| 429/429 [02:42<00:00,  2.64it/s]


Epoch 77 -  Training Loss: 1.6450 - Accuracy: 66.61%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.65it/s]


Epoch 77 - Validation Loss: 6.2189 - Accuracy: 68.01%


Epoch 78/100: 100%|██████████| 429/429 [02:43<00:00,  2.63it/s]


Epoch 78 -  Training Loss: 1.6658 - Accuracy: 66.62%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Best model saved with accuracy: 68.05%
Epoch 78 - Validation Loss: 6.2260 - Accuracy: 68.05%


Epoch 79/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 79 -  Training Loss: 1.5858 - Accuracy: 66.64%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Epoch 79 - Validation Loss: 6.1619 - Accuracy: 68.03%


Epoch 80/100: 100%|██████████| 429/429 [02:42<00:00,  2.64it/s]


Epoch 80 -  Training Loss: 1.5689 - Accuracy: 66.65%


Validation: 100%|██████████| 17/17 [00:10<00:00,  1.56it/s]


Epoch 80 - Validation Loss: 6.1149 - Accuracy: 68.01%
Model saved at epoch 80


Epoch 81/100: 100%|██████████| 429/429 [02:43<00:00,  2.62it/s]


Epoch 81 -  Training Loss: 1.5488 - Accuracy: 66.66%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.70it/s]


Best model saved with accuracy: 68.06%
Epoch 81 - Validation Loss: 6.0360 - Accuracy: 68.06%


Epoch 82/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 82 -  Training Loss: 1.5936 - Accuracy: 66.68%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Best model saved with accuracy: 68.07%
Epoch 82 - Validation Loss: 5.8676 - Accuracy: 68.07%


Epoch 83/100: 100%|██████████| 429/429 [02:39<00:00,  2.68it/s]


Epoch 83 -  Training Loss: 1.6157 - Accuracy: 66.69%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Best model saved with accuracy: 68.09%
Epoch 83 - Validation Loss: 5.5343 - Accuracy: 68.09%


Epoch 84/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 84 -  Training Loss: 1.5666 - Accuracy: 66.69%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Best model saved with accuracy: 68.12%
Epoch 84 - Validation Loss: 5.5383 - Accuracy: 68.12%


Epoch 85/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 85 -  Training Loss: 1.5378 - Accuracy: 66.70%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.85it/s]


Epoch 85 - Validation Loss: 5.6371 - Accuracy: 68.10%
Model saved at epoch 85


Epoch 86/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 86 -  Training Loss: 1.5343 - Accuracy: 66.72%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.77it/s]


Epoch 86 - Validation Loss: 5.5882 - Accuracy: 68.10%


Epoch 87/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 87 -  Training Loss: 1.5182 - Accuracy: 66.72%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.65it/s]


Epoch 87 - Validation Loss: 5.3152 - Accuracy: 68.11%


Epoch 88/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 88 -  Training Loss: 1.5076 - Accuracy: 66.73%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.68it/s]


Epoch 88 - Validation Loss: 5.5529 - Accuracy: 68.11%


Epoch 89/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 89 -  Training Loss: 1.5487 - Accuracy: 66.74%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Best model saved with accuracy: 68.13%
Epoch 89 - Validation Loss: 5.3394 - Accuracy: 68.13%


Epoch 90/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 90 -  Training Loss: 1.5142 - Accuracy: 66.74%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.70it/s]


Best model saved with accuracy: 68.15%
Epoch 90 - Validation Loss: 5.3546 - Accuracy: 68.15%
Model saved at epoch 90


Epoch 91/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 91 -  Training Loss: 1.5223 - Accuracy: 66.75%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.70it/s]


Epoch 91 - Validation Loss: 5.4022 - Accuracy: 68.13%


Epoch 92/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 92 -  Training Loss: 1.5105 - Accuracy: 66.76%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.79it/s]


Best model saved with accuracy: 68.15%
Epoch 92 - Validation Loss: 5.2328 - Accuracy: 68.15%


Epoch 93/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 93 -  Training Loss: 1.4888 - Accuracy: 66.76%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.66it/s]


Epoch 93 - Validation Loss: 5.6857 - Accuracy: 68.15%


Epoch 94/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 94 -  Training Loss: 1.4761 - Accuracy: 66.78%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.75it/s]


Epoch 94 - Validation Loss: 5.2512 - Accuracy: 68.14%


Epoch 95/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 95 -  Training Loss: 1.4900 - Accuracy: 66.78%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Epoch 95 - Validation Loss: 4.9887 - Accuracy: 68.14%
Model saved at epoch 95


Epoch 96/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 96 -  Training Loss: 1.4932 - Accuracy: 66.78%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.77it/s]


Epoch 96 - Validation Loss: 5.2542 - Accuracy: 68.15%


Epoch 97/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 97 -  Training Loss: 1.5194 - Accuracy: 66.79%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.84it/s]


Epoch 97 - Validation Loss: 5.4447 - Accuracy: 68.15%


Epoch 98/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 98 -  Training Loss: 1.4802 - Accuracy: 66.79%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.79it/s]


Epoch 98 - Validation Loss: 5.0617 - Accuracy: 68.14%


Epoch 99/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 99 -  Training Loss: 1.4751 - Accuracy: 66.80%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.99it/s]


Best model saved with accuracy: 68.18%
Epoch 99 - Validation Loss: 5.4762 - Accuracy: 68.18%


Epoch 100/100: 100%|██████████| 429/429 [02:08<00:00,  3.33it/s]


Epoch 100 -  Training Loss: 1.4821 - Accuracy: 66.80%


Validation: 100%|██████████| 17/17 [00:04<00:00,  3.81it/s]

Epoch 100 - Validation Loss: 5.6688 - Accuracy: 68.16%
Model saved at epoch 100





In [None]:
import os
import cv2
import torch
import numpy as np
from torchvision import transforms
from torch import nn
from torchvision.datasets import ImageFolder
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
from tqdm import tqdm
import cv2
import re
from mamba_ssm import Mamba
import torch.nn.functional as F


# skip_ids = [16700, 38159, 59359]
train_features_root = '/media/viplab/DATADRIVE1/driver_action_recognition/pose_resnet_features_multi/A1/train'
train_labels_root = '/home/viplab/Documents/driver_action_recognition/data_processing/array_generation/arrays'

validation_features_root = '/media/viplab/DATADRIVE1/driver_action_recognition/pose_resnet_features_multi/A1/valid'
validation_labels_root = '/home/viplab/Documents/driver_action_recognition/data_processing/array_generation/arrays'

num_epochs = 100
weights_save_path = "mamba_weights_pose"


class MultiViewFeatureDataset(Dataset):
    def __init__(self, features_root, labels_root, views=("Dashboard", "Rear_view", "Right_side_window")):
        self.features_root = features_root
        self.labels_root = labels_root
        self.views = views
        # self.skip_ids = ['16700', '38159', '59359']
        self.sample_keys = []
        # print(features_root)
        for user_folder in os.listdir(features_root):
            # for i in self.skip_ids:
            #     if i in user_folder:
            #         continue
            # print('uuuu', user_folder)
            user_path = os.path.join(features_root, user_folder)
            # print(user_path)
            if not os.path.isdir(user_path):
                continue
            for file in os.listdir(user_path):
                # print('f', file)
                if file.startswith("Dash") and file.endswith(".npy"):
                    # print('1', os.path.splitext(file)[0])
                    key = os.path.join(user_folder, os.path.splitext(file)[0]) 
                    # print('k', key) # e.g. user_001_1/dash_1
                    self.sample_keys.append(key)

    def __len__(self):
        return len(self.sample_keys)

    def __getitem__(self, idx):
        key = self.sample_keys[idx]

        view_features = []
        for view in self.views:
            pp = view + '_' + "_".join(key.split('_')[:3]) + "_NoAudio_" + f"{key.split('_')[-1]}"
            path = os.path.join(self.features_root, os.path.dirname(key), f"{pp}.npy")
            features = np.load(path).astype(np.float32)  # [seq_len, feat_dim]
            view_features.append(features)
            # print('111111', features.shape, path)

        min_rows = min(view_features[0].shape[0], view_features[1].shape[0], view_features[2].shape[0])
        # Concatenate features across feature dimension
        features_cat = view_features[2][:min_rows, -34: -12] # [seq_len, total_feat_dim]
        label_path = os.path.join(self.labels_root, os.path.dirname(key)+ ".npy")
        labels = np.load(label_path).astype(np.int64)  # [seq_len]
        # print(features_cat.shape)
        return features_cat, labels

t_dataset = MultiViewFeatureDataset(
    features_root= train_features_root,
    labels_root= train_labels_root
)

v_dataset = MultiViewFeatureDataset(
    features_root= validation_features_root,
    labels_root= validation_labels_root
)
print(len(t_dataset))
print(len(v_dataset))

class ChunkedVideoDataset(Dataset):
    def __init__(self, base_dataset, chunk_size=100, stride=50):
        self.base_dataset = base_dataset
        self.chunk_size = chunk_size
        self.stride = stride
        self.index_map = []  # (video_idx, start_frame)

        for video_idx in range(len(base_dataset)):
            features, labels = base_dataset[video_idx]
            video_len = features.shape[0]

            for start in range(0, video_len - chunk_size + 1, stride):
                labels_chunk = labels[start:start + chunk_size]

                if labels_chunk.sum() == 0:
                    continue

                self.index_map.append((video_idx, start))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        video_idx, start = self.index_map[idx]
        features, labels = self.base_dataset[video_idx]
        features_chunk = features[start:start+self.chunk_size]
        labels_chunk = labels[start:start+self.chunk_size]
        return features_chunk, labels_chunk

train_dataset = ChunkedVideoDataset(t_dataset, chunk_size=150, stride=75)
valid_dataset = ChunkedVideoDataset(v_dataset, chunk_size=150, stride=150)

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=8,
    shuffle=False
)

for i in train_dataset:
    print('train', len(i))
    break

for i in train_dataset:
    print('valid', len(i))
    break

for features, labels in train_loader:
    print('train')
    print("Concatenated features:", features.shape)  # [1, seq_len, total_feat_dim]
    print("Labels:", labels.shape)
    break 
for features, labels in valid_loader:
    print('valid')
    print("Concatenated features:", features.shape)  # [1, seq_len, total_feat_dim]
    print("Labels:", labels.shape)
    break 

class MambaSequenceClassifier(nn.Module):
    def __init__(self, input_dim=6246, hidden_dim=2048, num_classes=16, seq_len=100):
        super().__init__()
        # self.input_proj = nn.Linear(input_dim, hidden_dim)

        self.mamba_block = Mamba(
            d_model=input_dim,
            d_state=16,
            d_conv=4,
            expand=2
        )
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, num_classes)

        self.dropout = nn.Dropout(p=0.2)
        self.relu = nn.ReLU()
    def forward(self, x):
        """
        x: [batch_size, seq_len, input_dim]
        returns: [batch_size, seq_len, num_classes]
        """ 
        # x = self.input_proj(x)  # -> [B, L, H]
        x = self.mamba_block(x)  # [B, L, H]
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        logits = self.output_layer(x)  # [B, L, C]
        return logits
    
# class BoundaryAwareLoss(nn.Module):
#     def __init__(self, classification_weight=1.0, localization_weight=1.0, ignore_index=None):
#         super().__init__()
#         self.classification_weight = classification_weight
#         self.localization_weight = localization_weight
#         self.ignore_index = ignore_index

#     def forward(self, frame_logits, labels, pred_boundaries, true_boundaries):
#         B, T, C = frame_logits.shape

#         # Frame-wise classification loss
#         loss_cls = F.cross_entropy(
#             frame_logits.view(-1, C),
#             labels.view(-1),
#             ignore_index=self.ignore_index if self.ignore_index is not None else -100
#         )

#         # Temporal boundary regression loss (L1)
#         loss_loc = F.l1_loss(pred_boundaries, true_boundaries)

#         return self.classification_weight * loss_cls + self.localization_weight * loss_loc

def train_mamba(
    model, train_loader, val_loader, optimizer, num_epochs, weights_save_path, device,
    num_classes=16, ignore_index=None
):
    model.to(device)
    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_correct = 0
        total_count = 0

        for features, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            features = features.to(device)  # [B, 50, 6144]
            labels = labels.to(device)      # [B, 50]

            optimizer.zero_grad()
            logits = model(features)        # [B, 50, num_classes]

            loss = F.cross_entropy(
                logits.view(-1, num_classes),
                labels.view(-1),
                # ignore_index=ignore_index if ignore_index is not None else -100,
            )
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # --- Accuracy ---
            preds = logits.argmax(dim=-1)  # [B, 50]
            # if ignore_index is not None:
            #     mask = labels != ignore_index
            #     total_correct += (preds[mask] == labels[mask]).sum().item()
            #     total_count += mask.sum().item()
            # else:
            total_correct += (preds == labels).sum().item()
            total_count += labels.numel()

        avg_train_loss = total_loss / len(train_loader)
        train_accuracy = total_correct / total_count if total_count > 0 else 0.0
        print(f"Epoch {epoch+1} -  Training Loss: {avg_train_loss:.4f} - Accuracy: {train_accuracy*100:.2f}%")
        # --- Validation ---
        model.eval()
        val_loss = 0
        val_correct = 0
        val_count = 0
        with torch.no_grad():
            for features, labels in tqdm(val_loader, desc="Validation"):
                features = features.to(device)
                labels = labels.to(device)

                logits = model(features)
                loss = F.cross_entropy(
                    logits.view(-1, num_classes),
                    labels.view(-1),
                    # ignore_index=ignore_index if ignore_index is not None else -100,
                )
                val_loss += loss.item()

                preds = logits.argmax(dim=-1)
                val_correct += (preds == labels).sum().item()
                val_count += labels.numel()

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = val_correct / val_count if val_count > 0 else 0.0
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            # Save the model checkpoint
            torch.save(model.state_dict(), f"{weights_save_path}/mamba_best.pth")
            print(f"Best model saved with accuracy: {val_accuracy*100:.2f}%")

        print(f"Epoch {epoch+1} - Validation Loss: {avg_val_loss:.4f} - Accuracy: {val_accuracy*100:.2f}%")
        # Save the model checkpoint
        if (epoch + 1) % 5 == 0:
            # Save the model checkpoint
            torch.save(model.state_dict(), f"{weights_save_path}/mamba_epoch_{epoch+1}.pth")
            print(f"Model saved at epoch {epoch+1}")


model = MambaSequenceClassifier(input_dim=22, hidden_dim=14, num_classes=16)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

train_mamba(model, train_loader, valid_loader, optimizer, num_epochs, weights_save_path, device="cuda", num_classes=16)

120
18
train 2
valid 2
train
Concatenated features: torch.Size([8, 150, 22])
Labels: torch.Size([8, 150])
valid
Concatenated features: torch.Size([8, 150, 22])
Labels: torch.Size([8, 150])


Epoch 1/100: 100%|██████████| 432/432 [01:38<00:00,  4.37it/s]


Epoch 1 -  Training Loss: 474348.3134 - Accuracy: 12.18%


Validation: 100%|██████████| 33/33 [00:06<00:00,  4.90it/s]


Best model saved with accuracy: 26.61%
Epoch 1 - Validation Loss: 210369.6477 - Accuracy: 26.61%


Epoch 2/100: 100%|██████████| 432/432 [01:31<00:00,  4.73it/s]


Epoch 2 -  Training Loss: 187716.5730 - Accuracy: 20.67%


Validation: 100%|██████████| 33/33 [00:06<00:00,  5.02it/s]


Best model saved with accuracy: 28.24%
Epoch 2 - Validation Loss: 118510.5402 - Accuracy: 28.24%


Epoch 3/100: 100%|██████████| 432/432 [01:29<00:00,  4.80it/s]


Epoch 3 -  Training Loss: 118326.7703 - Accuracy: 21.66%


Validation: 100%|██████████| 33/33 [00:06<00:00,  4.89it/s]


Best model saved with accuracy: 28.54%
Epoch 3 - Validation Loss: 77649.9833 - Accuracy: 28.54%


Epoch 4/100: 100%|██████████| 432/432 [01:38<00:00,  4.38it/s]


Epoch 4 -  Training Loss: 80673.8606 - Accuracy: 22.29%


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.92it/s]


Best model saved with accuracy: 28.80%
Epoch 4 - Validation Loss: 52688.5330 - Accuracy: 28.80%


Epoch 5/100: 100%|██████████| 432/432 [04:36<00:00,  1.56it/s]


Epoch 5 -  Training Loss: 57056.0536 - Accuracy: 21.97%


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.81it/s]


Epoch 5 - Validation Loss: 37523.3264 - Accuracy: 27.12%
Model saved at epoch 5


Epoch 6/100: 100%|██████████| 432/432 [03:10<00:00,  2.27it/s]


Epoch 6 -  Training Loss: 42085.6241 - Accuracy: 21.26%


Validation: 100%|██████████| 33/33 [00:07<00:00,  4.23it/s]


Epoch 6 - Validation Loss: 27605.7857 - Accuracy: 25.57%


Epoch 7/100: 100%|██████████| 432/432 [03:52<00:00,  1.85it/s]


Epoch 7 -  Training Loss: 32118.4018 - Accuracy: 19.58%


Validation: 100%|██████████| 33/33 [00:25<00:00,  1.29it/s]


Epoch 7 - Validation Loss: 20467.8663 - Accuracy: 24.03%


Epoch 8/100: 100%|██████████| 432/432 [03:55<00:00,  1.84it/s]


Epoch 8 -  Training Loss: 25064.3649 - Accuracy: 17.80%


Validation: 100%|██████████| 33/33 [00:21<00:00,  1.51it/s]


Epoch 8 - Validation Loss: 15879.8592 - Accuracy: 21.24%


Epoch 9/100: 100%|██████████| 432/432 [12:06<00:00,  1.68s/it]


Epoch 9 -  Training Loss: 19950.9981 - Accuracy: 15.72%


Validation: 100%|██████████| 33/33 [00:20<00:00,  1.62it/s]


Epoch 9 - Validation Loss: 12316.4119 - Accuracy: 18.57%


Epoch 10/100: 100%|██████████| 432/432 [01:53<00:00,  3.82it/s]


Epoch 10 -  Training Loss: 16198.0521 - Accuracy: 15.16%


Validation: 100%|██████████| 33/33 [00:07<00:00,  4.33it/s]


Epoch 10 - Validation Loss: 9833.9239 - Accuracy: 16.24%
Model saved at epoch 10


Epoch 11/100: 100%|██████████| 432/432 [02:04<00:00,  3.47it/s]


Epoch 11 -  Training Loss: 13346.7548 - Accuracy: 15.75%


Validation: 100%|██████████| 33/33 [00:07<00:00,  4.21it/s]


Epoch 11 - Validation Loss: 8058.5620 - Accuracy: 17.20%


Epoch 12/100: 100%|██████████| 432/432 [02:23<00:00,  3.01it/s]


Epoch 12 -  Training Loss: 11122.2370 - Accuracy: 17.26%


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.99it/s]


Epoch 12 - Validation Loss: 6747.3726 - Accuracy: 21.76%


Epoch 13/100: 100%|██████████| 432/432 [02:15<00:00,  3.19it/s]


Epoch 13 -  Training Loss: 9344.0997 - Accuracy: 19.89%


Validation: 100%|██████████| 33/33 [00:07<00:00,  4.15it/s]


Epoch 13 - Validation Loss: 5662.1745 - Accuracy: 24.64%


Epoch 14/100: 100%|██████████| 432/432 [01:57<00:00,  3.69it/s]


Epoch 14 -  Training Loss: 7889.4506 - Accuracy: 21.95%


Validation: 100%|██████████| 33/33 [00:07<00:00,  4.27it/s]


Epoch 14 - Validation Loss: 4693.4367 - Accuracy: 24.97%


Epoch 15/100: 100%|██████████| 432/432 [01:49<00:00,  3.93it/s]


Epoch 15 -  Training Loss: 6672.8013 - Accuracy: 23.28%


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.68it/s]


Epoch 15 - Validation Loss: 3972.8136 - Accuracy: 23.99%
Model saved at epoch 15


Epoch 16/100: 100%|██████████| 432/432 [02:06<00:00,  3.40it/s]


Epoch 16 -  Training Loss: 5704.6391 - Accuracy: 22.71%


Validation: 100%|██████████| 33/33 [00:08<00:00,  3.84it/s]


Epoch 16 - Validation Loss: 3314.8255 - Accuracy: 21.88%


Epoch 17/100: 100%|██████████| 432/432 [01:52<00:00,  3.86it/s]


Epoch 17 -  Training Loss: 4909.4611 - Accuracy: 21.99%


Validation: 100%|██████████| 33/33 [00:08<00:00,  4.11it/s]


Epoch 17 - Validation Loss: 2812.9289 - Accuracy: 18.77%


Epoch 18/100: 100%|██████████| 432/432 [01:58<00:00,  3.63it/s]


Epoch 18 -  Training Loss: 4100.4174 - Accuracy: 19.76%


Validation: 100%|██████████| 33/33 [00:08<00:00,  4.05it/s]


Epoch 18 - Validation Loss: 2350.4913 - Accuracy: 15.40%


Epoch 19/100: 100%|██████████| 432/432 [01:51<00:00,  3.87it/s]


Epoch 19 -  Training Loss: 3453.3409 - Accuracy: 16.81%


Validation: 100%|██████████| 33/33 [00:10<00:00,  3.01it/s]


Epoch 19 - Validation Loss: 1968.4330 - Accuracy: 13.33%


Epoch 20/100: 100%|██████████| 432/432 [04:06<00:00,  1.75it/s]


Epoch 20 -  Training Loss: 2922.2331 - Accuracy: 14.49%


Validation: 100%|██████████| 33/33 [00:10<00:00,  3.11it/s]


Epoch 20 - Validation Loss: 1634.0681 - Accuracy: 11.82%
Model saved at epoch 20


Epoch 21/100: 100%|██████████| 432/432 [14:03<00:00,  1.95s/it]


Epoch 21 -  Training Loss: 2420.6201 - Accuracy: 12.25%


Validation: 100%|██████████| 33/33 [00:09<00:00,  3.63it/s]


Epoch 21 - Validation Loss: 1355.5737 - Accuracy: 11.08%


Epoch 22/100: 100%|██████████| 432/432 [08:32<00:00,  1.19s/it]


Epoch 22 -  Training Loss: 2074.5120 - Accuracy: 10.90%


Validation: 100%|██████████| 33/33 [00:21<00:00,  1.52it/s]


Epoch 22 - Validation Loss: 1113.6183 - Accuracy: 10.19%


Epoch 23/100: 100%|██████████| 432/432 [12:50<00:00,  1.78s/it]


Epoch 23 -  Training Loss: 1768.5274 - Accuracy: 10.08%


Validation: 100%|██████████| 33/33 [00:10<00:00,  3.23it/s]


Epoch 23 - Validation Loss: 927.2852 - Accuracy: 9.47%


Epoch 24/100: 100%|██████████| 432/432 [10:53<00:00,  1.51s/it]


Epoch 24 -  Training Loss: 1458.0638 - Accuracy: 9.67%


Validation: 100%|██████████| 33/33 [00:44<00:00,  1.35s/it]


Epoch 24 - Validation Loss: 752.4512 - Accuracy: 9.54%


Epoch 25/100: 100%|██████████| 432/432 [26:41<00:00,  3.71s/it] 


Epoch 25 -  Training Loss: 1254.0751 - Accuracy: 9.52%


Validation: 100%|██████████| 33/33 [00:10<00:00,  3.14it/s]


Epoch 25 - Validation Loss: 631.6634 - Accuracy: 9.39%
Model saved at epoch 25


Epoch 26/100: 100%|██████████| 432/432 [03:49<00:00,  1.88it/s]


Epoch 26 -  Training Loss: 1062.2324 - Accuracy: 9.33%


Validation: 100%|██████████| 33/33 [00:18<00:00,  1.75it/s]


Epoch 26 - Validation Loss: 527.9455 - Accuracy: 9.46%


Epoch 27/100: 100%|██████████| 432/432 [17:51<00:00,  2.48s/it]


Epoch 27 -  Training Loss: 901.6127 - Accuracy: 9.34%


Validation: 100%|██████████| 33/33 [00:19<00:00,  1.72it/s]


Epoch 27 - Validation Loss: 438.3549 - Accuracy: 9.67%


Epoch 28/100:  63%|██████▎   | 273/432 [13:57<10:53,  4.11s/it]

In [2]:
import torch
import torch.nn as nn
from mamba_ssm import Mamba
# input_dim = 9408
# model = Mamba(
#             d_model=input_dim,
#             d_state=16,
#             d_conv=4,
#             expand=2
#         )

In [15]:
import numpy as np
data = np.load('/media/viplab/DATADRIVE1/driver_action_recognition/split_feat_vit/train/Side/6/Right_side_window_user_id_13522_NoAudio_7_2_88.npy')

In [4]:
data.shape

(15, 196, 768)

In [16]:
data = data.reshape(15, -1)

In [17]:
import torch
data = torch.from_numpy(data)

In [18]:
data

tensor([[ 0.4556,  0.2571,  0.4065,  ..., -0.1371, -0.0616, -0.0349],
        [ 0.4541,  0.2583,  0.4045,  ..., -0.1359, -0.0609, -0.0352],
        [ 0.4553,  0.2600,  0.4060,  ..., -0.1376, -0.0624, -0.0359],
        ...,
        [ 0.4519,  0.2571,  0.4019,  ..., -0.1361, -0.0605, -0.0349],
        [ 0.4519,  0.2551,  0.4019,  ..., -0.1361, -0.0608, -0.0352],
        [ 0.4524,  0.2534,  0.4014,  ..., -0.1346, -0.0588, -0.0330]],
       dtype=torch.float16)

In [19]:
device = "cuda"
data = data.to(device)

In [12]:
data = data.type(torch.cuda.HalfTensor)

In [9]:
data

tensor([[ 0.4556,  0.2571,  0.4065,  ..., -0.1371, -0.0616, -0.0349],
        [ 0.4541,  0.2583,  0.4045,  ..., -0.1359, -0.0609, -0.0352],
        [ 0.4553,  0.2600,  0.4060,  ..., -0.1376, -0.0624, -0.0359],
        ...,
        [ 0.4519,  0.2571,  0.4019,  ..., -0.1361, -0.0605, -0.0349],
        [ 0.4519,  0.2551,  0.4019,  ..., -0.1361, -0.0608, -0.0352],
        [ 0.4524,  0.2534,  0.4014,  ..., -0.1346, -0.0588, -0.0330]],
       device='cuda:0', dtype=torch.float16)

In [9]:
from torch import nn

In [16]:
conv

Conv1d(15, 15, kernel_size=(16,), stride=(16,))

In [20]:
data = data.unsqueeze(0)

In [12]:
conv = nn.Conv1d(15, 15, kernel_size=16, stride=16).to(device).half()

outputs = conv(data.unsqueeze(0))
# data = conv(data)

In [13]:
outputs.shape

torch.Size([1, 15, 9408])

In [36]:
outputs = outputs.unsqueeze(0)

In [37]:
outputs.shape

torch.Size([1, 15, 9408])

In [38]:
y = model(outputs)

In [26]:
device = "cuda"
model.to(device).half()

Mamba(
  (in_proj): Linear(in_features=9408, out_features=37632, bias=False)
  (conv1d): Conv1d(18816, 18816, kernel_size=(4,), stride=(1,), padding=(3,), groups=18816)
  (act): SiLU()
  (x_proj): Linear(in_features=18816, out_features=620, bias=False)
  (dt_proj): Linear(in_features=588, out_features=18816, bias=True)
  (out_proj): Linear(in_features=18816, out_features=9408, bias=False)
)

In [40]:
y.shape

torch.Size([1, 15, 9408])

In [21]:
class MambaSequenceClassifier(nn.Module):
    def __init__(self, input_dim=9408, hidden_dim=2048, num_classes=16):#  input_dim=451584, hidden_dim=301056, num_classes=16
        super().__init__()
        self.conv = nn.Conv1d(15, 15, kernel_size=16, stride=16)
        self.mamba_block = Mamba(
            d_model=input_dim,
            d_state=16,
            d_conv=4,
            expand=2
        )
        self.fc1 = nn.Linear(input_dim,num_classes) #hidden_dim)
        self.relu = nn.ReLU()
    def forward(self, x):
        """
        x: [batch_size, seq_len, input_dim]
        returns: [batch_size, seq_len, num_classes]
        """ 
        x = self.conv(x)
        x = self.mamba_block(x) 
        logits = self.relu(self.fc1(x))  
        return logits

In [None]:
model = MambaSequenceClassifier().to(device).half()

In [23]:
y = model(data)

In [25]:
y.shape

torch.Size([1, 15, 16])