In [None]:
import os
import cv2
import torch
import numpy as np
from torchvision import transforms
from torch import nn
from torchvision.datasets import ImageFolder
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
from tqdm import tqdm
import cv2
import re
from mamba_ssm import Mamba
import torch.nn.functional as F


# skip_ids = [16700, 38159, 59359]
train_features_root = '/media/viplab/DATADRIVE1/driver_action_recognition/pose_resnet_features_multi/A1/train'
train_labels_root = '/home/viplab/Documents/driver_action_recognition/data_processing/array_generation/arrays'

validation_features_root = '/media/viplab/DATADRIVE1/driver_action_recognition/pose_resnet_features_multi/A1/valid'
validation_labels_root = '/home/viplab/Documents/driver_action_recognition/data_processing/array_generation/arrays'

num_epochs = 100
weights_save_path = "mamba_weights_pose_1e-4"


class MultiViewFeatureDataset(Dataset):
    def __init__(self, features_root, labels_root, views=("Dashboard", "Rear_view", "Right_side_window")):
        self.features_root = features_root
        self.labels_root = labels_root
        self.views = views
        # self.skip_ids = ['16700', '38159', '59359']
        self.sample_keys = []
        # print(features_root)
        for user_folder in os.listdir(features_root):
            # for i in self.skip_ids:
            #     if i in user_folder:
            #         continue
            # print('uuuu', user_folder)
            user_path = os.path.join(features_root, user_folder)
            # print(user_path)
            if not os.path.isdir(user_path):
                continue
            for file in os.listdir(user_path):
                # print('f', file)
                if file.startswith("Dash") and file.endswith(".npy"):
                    # print('1', os.path.splitext(file)[0])
                    key = os.path.join(user_folder, os.path.splitext(file)[0]) 
                    # print('k', key) # e.g. user_001_1/dash_1
                    self.sample_keys.append(key)

    def __len__(self):
        return len(self.sample_keys)

    def __getitem__(self, idx):
        key = self.sample_keys[idx]

        view_features = []
        for view in self.views:
            pp = view + '_' + "_".join(key.split('_')[:3]) + "_NoAudio_" + f"{key.split('_')[-1]}"
            path = os.path.join(self.features_root, os.path.dirname(key), f"{pp}.npy")
            features = np.load(path).astype(np.float32)  # [seq_len, feat_dim]
            view_features.append(features)
            # print('111111', features.shape, path)

        min_rows = min(view_features[0].shape[0], view_features[1].shape[0], view_features[2].shape[0])
        # Concatenate features across feature dimension
        features_cat = np.concatenate((view_features[0][:min_rows, -34:], view_features[1][:min_rows, -34:], view_features[2][:min_rows, -34:]), axis=1)  # [seq_len, total_feat_dim]
        label_path = os.path.join(self.labels_root, os.path.dirname(key)+ ".npy")
        labels = np.load(label_path).astype(np.int64)  # [seq_len]
        # print(features_cat.shape)
        return features_cat, labels

t_dataset = MultiViewFeatureDataset(
    features_root= train_features_root,
    labels_root= train_labels_root
)

v_dataset = MultiViewFeatureDataset(
    features_root= validation_features_root,
    labels_root= validation_labels_root
)
print(len(t_dataset))
print(len(v_dataset))

class ChunkedVideoDataset(Dataset):
    def __init__(self, base_dataset, chunk_size=100, stride=50):
        self.base_dataset = base_dataset
        self.chunk_size = chunk_size
        self.stride = stride
        self.index_map = []  # (video_idx, start_frame)

        for video_idx in range(len(base_dataset)):
            features, labels = base_dataset[video_idx]
            video_len = features.shape[0]

            for start in range(0, video_len - chunk_size + 1, stride):
                labels_chunk = labels[start:start + chunk_size]

                if labels_chunk.sum() == 0:
                    continue

                self.index_map.append((video_idx, start))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        video_idx, start = self.index_map[idx]
        features, labels = self.base_dataset[video_idx]
        features_chunk = features[start:start+self.chunk_size]
        labels_chunk = labels[start:start+self.chunk_size]
        return features_chunk, labels_chunk

train_dataset = ChunkedVideoDataset(t_dataset, chunk_size=300, stride=75)
valid_dataset = ChunkedVideoDataset(v_dataset, chunk_size=300, stride=300)

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=8,
    shuffle=False
)

for i in train_dataset:
    print('train', len(i))
    break

for i in train_dataset:
    print('valid', len(i))
    break

for features, labels in train_loader:
    print('train')
    print("Concatenated features:", features.shape)  # [1, seq_len, total_feat_dim]
    print("Labels:", labels.shape)
    break 
for features, labels in valid_loader:
    print('valid')
    print("Concatenated features:", features.shape)  # [1, seq_len, total_feat_dim]
    print("Labels:", labels.shape)
    break 

class MambaSequenceClassifier(nn.Module):
    def __init__(self, input_dim=6246, hidden_dim=2048, num_classes=16, seq_len=100):
        super().__init__()
        # self.input_proj = nn.Linear(input_dim, hidden_dim)

        self.mamba_block = Mamba(
            d_model=input_dim,
            d_state=16,
            d_conv=4,
            expand=2
        )
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, num_classes)

        self.dropout = nn.Dropout(p=0.2)
        self.relu = nn.ReLU()
    def forward(self, x):
        """
        x: [batch_size, seq_len, input_dim]
        returns: [batch_size, seq_len, num_classes]
        """ 
        # x = self.input_proj(x)  # -> [B, L, H]
        x = self.mamba_block(x)  # [B, L, H]
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        logits = self.output_layer(x)  # [B, L, C]
        return logits
    
# class BoundaryAwareLoss(nn.Module):
#     def __init__(self, classification_weight=1.0, localization_weight=1.0, ignore_index=None):
#         super().__init__()
#         self.classification_weight = classification_weight
#         self.localization_weight = localization_weight
#         self.ignore_index = ignore_index

#     def forward(self, frame_logits, labels, pred_boundaries, true_boundaries):
#         B, T, C = frame_logits.shape

#         # Frame-wise classification loss
#         loss_cls = F.cross_entropy(
#             frame_logits.view(-1, C),
#             labels.view(-1),
#             ignore_index=self.ignore_index if self.ignore_index is not None else -100
#         )

#         # Temporal boundary regression loss (L1)
#         loss_loc = F.l1_loss(pred_boundaries, true_boundaries)

#         return self.classification_weight * loss_cls + self.localization_weight * loss_loc



def train_mamba(
    model, train_loader, val_loader, optimizer, num_epochs, weights_save_path, device,
    num_classes=16, ignore_index=None
):
    model.to(device)
    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_correct = 0
        total_count = 0

        for features, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            features = features.to(device)  # [B, 50, 6144]
            labels = labels.to(device)      # [B, 50]

            optimizer.zero_grad()
            logits = model(features)        # [B, 50, num_classes]

            loss = F.cross_entropy(
                logits.view(-1, num_classes),
                labels.view(-1),
                # ignore_index=ignore_index if ignore_index is not None else -100,
            )
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # --- Accuracy ---
            preds = logits.argmax(dim=-1)  # [B, 50]
            # if ignore_index is not None:
            #     mask = labels != ignore_index
            #     total_correct += (preds[mask] == labels[mask]).sum().item()
            #     total_count += mask.sum().item()
            # else:
            total_correct += (preds == labels).sum().item()
            total_count += labels.numel()

        avg_train_loss = total_loss / len(train_loader)
        train_accuracy = total_correct / total_count if total_count > 0 else 0.0
        print(f"Epoch {epoch+1} -  Training Loss: {avg_train_loss:.4f} - Accuracy: {train_accuracy*100:.2f}%")
        # --- Validation ---
        model.eval()
        val_loss = 0
        val_correct = 0
        val_count = 0
        with torch.no_grad():
            for features, labels in tqdm(val_loader, desc="Validation"):
                features = features.to(device)
                labels = labels.to(device)

                logits = model(features)
                loss = F.cross_entropy(
                    logits.view(-1, num_classes),
                    labels.view(-1),
                    # ignore_index=ignore_index if ignore_index is not None else -100,
                )
                val_loss += loss.item()

                preds = logits.argmax(dim=-1)
                val_correct += (preds == labels).sum().item()
                val_count += labels.numel()

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = val_correct / val_count if val_count > 0 else 0.0
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            # Save the model checkpoint
            torch.save(model.state_dict(), f"{weights_save_path}/mamba_best.pth")
            print(f"Best model saved with accuracy: {val_accuracy*100:.2f}%")

        print(f"Epoch {epoch+1} - Validation Loss: {avg_val_loss:.4f} - Accuracy: {val_accuracy*100:.2f}%")
        # Save the model checkpoint
        if (epoch + 1) % 5 == 0:
            # Save the model checkpoint
            torch.save(model.state_dict(), f"{weights_save_path}/mamba_epoch_{epoch+1}.pth")
            print(f"Model saved at epoch {epoch+1}")


model = MambaSequenceClassifier(input_dim=102, hidden_dim=68, num_classes=16)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

train_mamba(model, train_loader, valid_loader, optimizer, num_epochs, weights_save_path, device="cuda", num_classes=16)

120
18
train 2
valid 2
train
Concatenated features: torch.Size([8, 300, 102])
Labels: torch.Size([8, 300])
valid
Concatenated features: torch.Size([8, 300, 102])
Labels: torch.Size([8, 300])


Epoch 1/100: 100%|██████████| 429/429 [02:44<00:00,  2.61it/s]


Epoch 1 -  Training Loss: 349663.7817 - Accuracy: 26.25%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.68it/s]


Best model saved with accuracy: 35.88%
Epoch 1 - Validation Loss: 22973.7009 - Accuracy: 35.88%


Epoch 2/100: 100%|██████████| 429/429 [02:42<00:00,  2.64it/s]


Epoch 2 -  Training Loss: 14005.9453 - Accuracy: 30.40%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


Best model saved with accuracy: 44.97%
Epoch 2 - Validation Loss: 7333.0403 - Accuracy: 44.97%


Epoch 3/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 3 -  Training Loss: 4660.6458 - Accuracy: 36.14%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.65it/s]


Best model saved with accuracy: 52.69%
Epoch 3 - Validation Loss: 3132.7622 - Accuracy: 52.69%


Epoch 4/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 4 -  Training Loss: 2132.2641 - Accuracy: 41.66%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Best model saved with accuracy: 56.20%
Epoch 4 - Validation Loss: 1732.9698 - Accuracy: 56.20%


Epoch 5/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 5 -  Training Loss: 1154.6605 - Accuracy: 44.38%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Best model saved with accuracy: 58.86%
Epoch 5 - Validation Loss: 1019.2828 - Accuracy: 58.86%
Model saved at epoch 5


Epoch 6/100: 100%|██████████| 429/429 [02:39<00:00,  2.68it/s]


Epoch 6 -  Training Loss: 727.0447 - Accuracy: 45.99%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.66it/s]


Best model saved with accuracy: 60.00%
Epoch 6 - Validation Loss: 749.4398 - Accuracy: 60.00%


Epoch 7/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 7 -  Training Loss: 514.2861 - Accuracy: 46.43%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.83it/s]


Best model saved with accuracy: 61.03%
Epoch 7 - Validation Loss: 573.2818 - Accuracy: 61.03%


Epoch 8/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 8 -  Training Loss: 391.1259 - Accuracy: 46.67%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.74it/s]


Best model saved with accuracy: 62.00%
Epoch 8 - Validation Loss: 448.1941 - Accuracy: 62.00%


Epoch 9/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 9 -  Training Loss: 305.5563 - Accuracy: 46.64%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.62it/s]


Best model saved with accuracy: 62.68%
Epoch 9 - Validation Loss: 358.3714 - Accuracy: 62.68%


Epoch 10/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 10 -  Training Loss: 243.5893 - Accuracy: 46.46%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Best model saved with accuracy: 63.02%
Epoch 10 - Validation Loss: 284.4771 - Accuracy: 63.02%
Model saved at epoch 10


Epoch 11/100: 100%|██████████| 429/429 [02:46<00:00,  2.58it/s]


Epoch 11 -  Training Loss: 201.9556 - Accuracy: 46.11%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.79it/s]


Best model saved with accuracy: 63.57%
Epoch 11 - Validation Loss: 234.7738 - Accuracy: 63.57%


Epoch 12/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 12 -  Training Loss: 166.4479 - Accuracy: 45.71%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.80it/s]


Best model saved with accuracy: 63.83%
Epoch 12 - Validation Loss: 192.5371 - Accuracy: 63.83%


Epoch 13/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 13 -  Training Loss: 139.8802 - Accuracy: 45.37%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Best model saved with accuracy: 64.00%
Epoch 13 - Validation Loss: 160.8265 - Accuracy: 64.00%


Epoch 14/100: 100%|██████████| 429/429 [02:40<00:00,  2.68it/s]


Epoch 14 -  Training Loss: 117.0165 - Accuracy: 44.88%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.74it/s]


Epoch 14 - Validation Loss: 142.2359 - Accuracy: 63.99%


Epoch 15/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 15 -  Training Loss: 97.0249 - Accuracy: 44.47%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.75it/s]


Best model saved with accuracy: 64.16%
Epoch 15 - Validation Loss: 125.7763 - Accuracy: 64.16%
Model saved at epoch 15


Epoch 16/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 16 -  Training Loss: 78.9953 - Accuracy: 43.57%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.53it/s]


Epoch 16 - Validation Loss: 113.0872 - Accuracy: 63.99%


Epoch 17/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 17 -  Training Loss: 66.4091 - Accuracy: 42.64%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Epoch 17 - Validation Loss: 93.8821 - Accuracy: 60.53%


Epoch 18/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 18 -  Training Loss: 39.5334 - Accuracy: 32.44%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.61it/s]


Epoch 18 - Validation Loss: 67.0071 - Accuracy: 26.07%


Epoch 19/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 19 -  Training Loss: 21.5636 - Accuracy: 13.12%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.82it/s]


Epoch 19 - Validation Loss: 54.7436 - Accuracy: 10.22%


Epoch 20/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 20 -  Training Loss: 14.1838 - Accuracy: 16.00%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


Best model saved with accuracy: 65.40%
Epoch 20 - Validation Loss: 39.1077 - Accuracy: 65.40%
Model saved at epoch 20


Epoch 21/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 21 -  Training Loss: 11.1477 - Accuracy: 64.27%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.75it/s]


Best model saved with accuracy: 66.15%
Epoch 21 - Validation Loss: 37.1305 - Accuracy: 66.15%


Epoch 22/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 22 -  Training Loss: 8.6803 - Accuracy: 65.05%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Best model saved with accuracy: 66.63%
Epoch 22 - Validation Loss: 31.3627 - Accuracy: 66.63%


Epoch 23/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 23 -  Training Loss: 6.8222 - Accuracy: 65.53%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.85it/s]


Best model saved with accuracy: 66.94%
Epoch 23 - Validation Loss: 29.3243 - Accuracy: 66.94%


Epoch 24/100: 100%|██████████| 429/429 [02:42<00:00,  2.64it/s]


Epoch 24 -  Training Loss: 5.7671 - Accuracy: 65.76%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Best model saved with accuracy: 67.35%
Epoch 24 - Validation Loss: 26.3449 - Accuracy: 67.35%


Epoch 25/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 25 -  Training Loss: 4.8398 - Accuracy: 66.00%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Best model saved with accuracy: 67.36%
Epoch 25 - Validation Loss: 23.7830 - Accuracy: 67.36%
Model saved at epoch 25


Epoch 26/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 26 -  Training Loss: 4.2395 - Accuracy: 66.11%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Epoch 26 - Validation Loss: 19.4554 - Accuracy: 67.32%


Epoch 27/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 27 -  Training Loss: 3.5441 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Best model saved with accuracy: 67.38%
Epoch 27 - Validation Loss: 16.0246 - Accuracy: 67.38%


Epoch 28/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 28 -  Training Loss: 3.2250 - Accuracy: 66.20%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.61it/s]


Best model saved with accuracy: 67.61%
Epoch 28 - Validation Loss: 16.2283 - Accuracy: 67.61%


Epoch 29/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 29 -  Training Loss: 2.9242 - Accuracy: 66.25%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Epoch 29 - Validation Loss: 12.0397 - Accuracy: 67.60%


Epoch 30/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 30 -  Training Loss: 2.7536 - Accuracy: 66.27%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.75it/s]


Epoch 30 - Validation Loss: 12.0780 - Accuracy: 67.57%
Model saved at epoch 30


Epoch 31/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 31 -  Training Loss: 2.4141 - Accuracy: 66.27%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Best model saved with accuracy: 67.70%
Epoch 31 - Validation Loss: 10.6168 - Accuracy: 67.70%


Epoch 32/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 32 -  Training Loss: 2.3220 - Accuracy: 66.28%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Epoch 32 - Validation Loss: 7.3355 - Accuracy: 67.70%


Epoch 33/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 33 -  Training Loss: 2.2655 - Accuracy: 66.28%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Epoch 33 - Validation Loss: 8.4969 - Accuracy: 67.64%


Epoch 34/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 34 -  Training Loss: 2.1931 - Accuracy: 66.28%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.80it/s]


Best model saved with accuracy: 67.71%
Epoch 34 - Validation Loss: 7.2458 - Accuracy: 67.71%


Epoch 35/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 35 -  Training Loss: 2.3149 - Accuracy: 66.29%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.63it/s]


Best model saved with accuracy: 67.79%
Epoch 35 - Validation Loss: 6.0527 - Accuracy: 67.79%
Model saved at epoch 35


Epoch 36/100: 100%|██████████| 429/429 [02:50<00:00,  2.51it/s]


Epoch 36 -  Training Loss: 2.0725 - Accuracy: 66.30%


Validation: 100%|██████████| 17/17 [00:07<00:00,  2.14it/s]


Best model saved with accuracy: 67.94%
Epoch 36 - Validation Loss: 5.3050 - Accuracy: 67.94%


Epoch 37/100: 100%|██████████| 429/429 [02:36<00:00,  2.74it/s]


Epoch 37 -  Training Loss: 1.9819 - Accuracy: 66.30%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.79it/s]


Best model saved with accuracy: 67.96%
Epoch 37 - Validation Loss: 5.1431 - Accuracy: 67.96%


Epoch 38/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 38 -  Training Loss: 1.9639 - Accuracy: 66.30%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.81it/s]


Best model saved with accuracy: 67.96%
Epoch 38 - Validation Loss: 5.7108 - Accuracy: 67.96%


Epoch 39/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 39 -  Training Loss: 1.9145 - Accuracy: 66.30%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Best model saved with accuracy: 67.96%
Epoch 39 - Validation Loss: 6.7831 - Accuracy: 67.96%


Epoch 40/100: 100%|██████████| 429/429 [02:37<00:00,  2.73it/s]


Epoch 40 -  Training Loss: 1.9134 - Accuracy: 66.30%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Epoch 40 - Validation Loss: 4.0809 - Accuracy: 67.96%
Model saved at epoch 40


Epoch 41/100: 100%|██████████| 429/429 [02:36<00:00,  2.74it/s]


Epoch 41 -  Training Loss: 5.0000 - Accuracy: 66.26%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.74it/s]


Best model saved with accuracy: 67.97%
Epoch 41 - Validation Loss: 3.1088 - Accuracy: 67.97%


Epoch 42/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 42 -  Training Loss: 3.0218 - Accuracy: 66.20%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Epoch 42 - Validation Loss: 1.7237 - Accuracy: 67.96%


Epoch 43/100: 100%|██████████| 429/429 [02:36<00:00,  2.73it/s]


Epoch 43 -  Training Loss: 2.3116 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.80it/s]


Epoch 43 - Validation Loss: 1.7706 - Accuracy: 67.96%


Epoch 44/100: 100%|██████████| 429/429 [02:37<00:00,  2.73it/s]


Epoch 44 -  Training Loss: 1.7771 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Epoch 44 - Validation Loss: 1.7066 - Accuracy: 67.96%


Epoch 45/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 45 -  Training Loss: 1.8014 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Epoch 45 - Validation Loss: 1.6608 - Accuracy: 67.96%
Model saved at epoch 45


Epoch 46/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 46 -  Training Loss: 1.7021 - Accuracy: 66.20%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.79it/s]


Epoch 46 - Validation Loss: 1.6424 - Accuracy: 67.96%


Epoch 47/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 47 -  Training Loss: 1.6720 - Accuracy: 66.20%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.70it/s]


Epoch 47 - Validation Loss: 1.6252 - Accuracy: 67.96%


Epoch 48/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 48 -  Training Loss: 1.6569 - Accuracy: 66.20%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.75it/s]


Epoch 48 - Validation Loss: 1.6092 - Accuracy: 67.96%


Epoch 49/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 49 -  Training Loss: 1.6424 - Accuracy: 66.20%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.79it/s]


Epoch 49 - Validation Loss: 1.5943 - Accuracy: 67.96%


Epoch 50/100: 100%|██████████| 429/429 [02:39<00:00,  2.68it/s]


Epoch 50 -  Training Loss: 1.6285 - Accuracy: 66.20%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Epoch 50 - Validation Loss: 1.5807 - Accuracy: 67.96%
Model saved at epoch 50


Epoch 51/100: 100%|██████████| 429/429 [02:35<00:00,  2.76it/s]


Epoch 51 -  Training Loss: 2.2899 - Accuracy: 66.22%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.85it/s]


Epoch 51 - Validation Loss: 1.5681 - Accuracy: 67.96%


Epoch 52/100: 100%|██████████| 429/429 [02:35<00:00,  2.76it/s]


Epoch 52 -  Training Loss: 1.9552 - Accuracy: 66.20%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.84it/s]


Epoch 52 - Validation Loss: 1.5566 - Accuracy: 67.96%


Epoch 53/100: 100%|██████████| 429/429 [02:36<00:00,  2.75it/s]


Epoch 53 -  Training Loss: 2.3249 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Epoch 53 - Validation Loss: 1.5461 - Accuracy: 67.96%


Epoch 54/100: 100%|██████████| 429/429 [02:34<00:00,  2.77it/s]


Epoch 54 -  Training Loss: 1.6070 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Epoch 54 - Validation Loss: 1.5365 - Accuracy: 67.96%


Epoch 55/100: 100%|██████████| 429/429 [02:36<00:00,  2.74it/s]


Epoch 55 -  Training Loss: 1.5781 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.66it/s]


Epoch 55 - Validation Loss: 1.5279 - Accuracy: 67.96%
Model saved at epoch 55


Epoch 56/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 56 -  Training Loss: 1.5692 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.77it/s]


Epoch 56 - Validation Loss: 1.5201 - Accuracy: 67.96%


Epoch 57/100: 100%|██████████| 429/429 [02:35<00:00,  2.75it/s]


Epoch 57 -  Training Loss: 1.5635 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.80it/s]


Epoch 57 - Validation Loss: 1.5131 - Accuracy: 67.96%


Epoch 58/100: 100%|██████████| 429/429 [02:35<00:00,  2.76it/s]


Epoch 58 -  Training Loss: 1.5572 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.85it/s]


Epoch 58 - Validation Loss: 1.5069 - Accuracy: 67.96%


Epoch 59/100: 100%|██████████| 429/429 [02:36<00:00,  2.74it/s]


Epoch 59 -  Training Loss: 1.5523 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.81it/s]


Epoch 59 - Validation Loss: 1.5014 - Accuracy: 67.96%


Epoch 60/100: 100%|██████████| 429/429 [02:35<00:00,  2.75it/s]


Epoch 60 -  Training Loss: 1.5474 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Epoch 60 - Validation Loss: 1.4966 - Accuracy: 67.96%
Model saved at epoch 60


Epoch 61/100: 100%|██████████| 429/429 [02:37<00:00,  2.73it/s]


Epoch 61 -  Training Loss: 1.5437 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.81it/s]


Epoch 61 - Validation Loss: 1.4924 - Accuracy: 67.96%


Epoch 62/100: 100%|██████████| 429/429 [02:36<00:00,  2.74it/s]


Epoch 62 -  Training Loss: 1.5404 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Epoch 62 - Validation Loss: 1.4887 - Accuracy: 67.96%


Epoch 63/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 63 -  Training Loss: 1.5376 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.70it/s]


Epoch 63 - Validation Loss: 1.4856 - Accuracy: 67.96%


Epoch 64/100: 100%|██████████| 429/429 [02:36<00:00,  2.74it/s]


Epoch 64 -  Training Loss: 1.5350 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.80it/s]


Epoch 64 - Validation Loss: 1.4829 - Accuracy: 67.96%


Epoch 65/100: 100%|██████████| 429/429 [02:35<00:00,  2.76it/s]


Epoch 65 -  Training Loss: 1.5328 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Epoch 65 - Validation Loss: 1.4806 - Accuracy: 67.96%
Model saved at epoch 65


Epoch 66/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 66 -  Training Loss: 1.5314 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.81it/s]


Epoch 66 - Validation Loss: 1.4787 - Accuracy: 67.96%


Epoch 67/100: 100%|██████████| 429/429 [02:44<00:00,  2.61it/s]


Epoch 67 -  Training Loss: 1.5301 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.77it/s]


Epoch 67 - Validation Loss: 1.4771 - Accuracy: 67.96%


Epoch 68/100: 100%|██████████| 429/429 [02:37<00:00,  2.73it/s]


Epoch 68 -  Training Loss: 1.5295 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.84it/s]


Epoch 68 - Validation Loss: 1.4757 - Accuracy: 67.96%


Epoch 69/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 69 -  Training Loss: 1.5282 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.73it/s]


Epoch 69 - Validation Loss: 1.4747 - Accuracy: 67.96%


Epoch 70/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 70 -  Training Loss: 1.5277 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.66it/s]


Epoch 70 - Validation Loss: 1.4738 - Accuracy: 67.96%
Model saved at epoch 70


Epoch 71/100: 100%|██████████| 429/429 [02:46<00:00,  2.58it/s]


Epoch 71 -  Training Loss: 1.5295 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.86it/s]


Epoch 71 - Validation Loss: 1.4731 - Accuracy: 67.96%


Epoch 72/100: 100%|██████████| 429/429 [02:35<00:00,  2.76it/s]


Epoch 72 -  Training Loss: 1.5289 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Epoch 72 - Validation Loss: 1.4725 - Accuracy: 67.96%


Epoch 73/100: 100%|██████████| 429/429 [02:36<00:00,  2.74it/s]


Epoch 73 -  Training Loss: 1.5260 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Epoch 73 - Validation Loss: 1.4720 - Accuracy: 67.96%


Epoch 74/100: 100%|██████████| 429/429 [02:43<00:00,  2.63it/s]


Epoch 74 -  Training Loss: 1.5830 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Epoch 74 - Validation Loss: 1.4716 - Accuracy: 67.96%


Epoch 75/100: 100%|██████████| 429/429 [02:40<00:00,  2.67it/s]


Epoch 75 -  Training Loss: 1.5434 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.88it/s]


Epoch 75 - Validation Loss: 1.4713 - Accuracy: 67.96%
Model saved at epoch 75


Epoch 76/100: 100%|██████████| 429/429 [02:36<00:00,  2.74it/s]


Epoch 76 -  Training Loss: 1.6873 - Accuracy: 66.19%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.81it/s]


Epoch 76 - Validation Loss: 1.4711 - Accuracy: 67.96%


Epoch 77/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 77 -  Training Loss: 80.0433 - Accuracy: 66.10%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


Best model saved with accuracy: 67.98%
Epoch 77 - Validation Loss: 7.8416 - Accuracy: 67.98%


Epoch 78/100: 100%|██████████| 429/429 [02:42<00:00,  2.64it/s]


Epoch 78 -  Training Loss: 2.0752 - Accuracy: 66.24%


Validation: 100%|██████████| 17/17 [00:07<00:00,  2.32it/s]


Epoch 78 - Validation Loss: 4.4445 - Accuracy: 67.96%


Epoch 79/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 79 -  Training Loss: 1.6420 - Accuracy: 66.22%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.74it/s]


Epoch 79 - Validation Loss: 2.5828 - Accuracy: 67.96%


Epoch 80/100: 100%|██████████| 429/429 [02:37<00:00,  2.73it/s]


Epoch 80 -  Training Loss: 1.5307 - Accuracy: 66.22%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Epoch 80 - Validation Loss: 2.4165 - Accuracy: 67.96%
Model saved at epoch 80


Epoch 81/100: 100%|██████████| 429/429 [02:47<00:00,  2.56it/s]


Epoch 81 -  Training Loss: 1.5360 - Accuracy: 66.22%


Validation: 100%|██████████| 17/17 [00:07<00:00,  2.32it/s]


Epoch 81 - Validation Loss: 2.3826 - Accuracy: 67.96%


Epoch 82/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 82 -  Training Loss: 1.5584 - Accuracy: 66.22%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Epoch 82 - Validation Loss: 2.4221 - Accuracy: 67.96%


Epoch 83/100: 100%|██████████| 429/429 [02:37<00:00,  2.73it/s]


Epoch 83 -  Training Loss: 1.5347 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.71it/s]


Epoch 83 - Validation Loss: 2.5040 - Accuracy: 67.96%


Epoch 84/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 84 -  Training Loss: 1.5937 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.75it/s]


Epoch 84 - Validation Loss: 3.0844 - Accuracy: 67.96%


Epoch 85/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 85 -  Training Loss: 1.5255 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.70it/s]


Epoch 85 - Validation Loss: 3.3589 - Accuracy: 67.96%
Model saved at epoch 85


Epoch 86/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 86 -  Training Loss: 1.5658 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.77it/s]


Epoch 86 - Validation Loss: 2.8725 - Accuracy: 67.96%


Epoch 87/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 87 -  Training Loss: 1.5301 - Accuracy: 66.22%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Epoch 87 - Validation Loss: 3.0537 - Accuracy: 67.96%


Epoch 88/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 88 -  Training Loss: 1.5265 - Accuracy: 66.22%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Epoch 88 - Validation Loss: 3.0922 - Accuracy: 67.96%


Epoch 89/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 89 -  Training Loss: 1.5830 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Epoch 89 - Validation Loss: 2.8712 - Accuracy: 67.96%


Epoch 90/100: 100%|██████████| 429/429 [02:36<00:00,  2.74it/s]


Epoch 90 -  Training Loss: 1.5510 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:05<00:00,  2.85it/s]


Epoch 90 - Validation Loss: 3.9883 - Accuracy: 67.96%
Model saved at epoch 90


Epoch 91/100: 100%|██████████| 429/429 [02:39<00:00,  2.69it/s]


Epoch 91 -  Training Loss: 1.5608 - Accuracy: 66.20%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.80it/s]


Epoch 91 - Validation Loss: 4.0288 - Accuracy: 67.96%


Epoch 92/100: 100%|██████████| 429/429 [02:37<00:00,  2.73it/s]


Epoch 92 -  Training Loss: 1.5261 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.72it/s]


Epoch 92 - Validation Loss: 3.2797 - Accuracy: 67.96%


Epoch 93/100: 100%|██████████| 429/429 [02:37<00:00,  2.73it/s]


Epoch 93 -  Training Loss: 1.5233 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.80it/s]


Epoch 93 - Validation Loss: 3.1233 - Accuracy: 67.96%


Epoch 94/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 94 -  Training Loss: 1.5236 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.67it/s]


Epoch 94 - Validation Loss: 3.1557 - Accuracy: 67.96%


Epoch 95/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 95 -  Training Loss: 1.5230 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.78it/s]


Epoch 95 - Validation Loss: 3.2478 - Accuracy: 67.96%
Model saved at epoch 95


Epoch 96/100: 100%|██████████| 429/429 [02:39<00:00,  2.70it/s]


Epoch 96 -  Training Loss: 1.5237 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.69it/s]


Epoch 96 - Validation Loss: 3.4188 - Accuracy: 67.96%


Epoch 97/100: 100%|██████████| 429/429 [02:36<00:00,  2.74it/s]


Epoch 97 -  Training Loss: 1.5234 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.83it/s]


Epoch 97 - Validation Loss: 3.4109 - Accuracy: 67.96%


Epoch 98/100: 100%|██████████| 429/429 [02:37<00:00,  2.72it/s]


Epoch 98 -  Training Loss: 1.5236 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.76it/s]


Epoch 98 - Validation Loss: 3.3980 - Accuracy: 67.96%


Epoch 99/100: 100%|██████████| 429/429 [02:38<00:00,  2.70it/s]


Epoch 99 -  Training Loss: 1.5239 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.79it/s]


Epoch 99 - Validation Loss: 3.3928 - Accuracy: 67.96%


Epoch 100/100: 100%|██████████| 429/429 [02:38<00:00,  2.71it/s]


Epoch 100 -  Training Loss: 1.5913 - Accuracy: 66.21%


Validation: 100%|██████████| 17/17 [00:06<00:00,  2.82it/s]

Epoch 100 - Validation Loss: 1.4697 - Accuracy: 67.96%
Model saved at epoch 100





In [1]:
import os
import cv2
import torch
import numpy as np
from torchvision import transforms
from torch import nn
from torchvision.datasets import ImageFolder
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from PIL import Image
from tqdm import tqdm
import cv2
import re
from mamba_ssm import Mamba
import torch.nn.functional as F


# skip_ids = [16700, 38159, 59359]
train_features_root = '/media/viplab/DATADRIVE1/driver_action_recognition/pose_resnet_features_multi/A1/train'
train_labels_root = '/home/viplab/Documents/driver_action_recognition/data_processing/array_generation/arrays'

validation_features_root = '/media/viplab/DATADRIVE1/driver_action_recognition/pose_resnet_features_multi/A1/valid'
validation_labels_root = '/home/viplab/Documents/driver_action_recognition/data_processing/array_generation/arrays'

num_epochs = 100
weights_save_path = "mamba_weights_pose_1e-4"


class MultiViewFeatureDataset(Dataset):
    def __init__(self, features_root, labels_root, views=("Dashboard", "Rear_view", "Right_side_window")):
        self.features_root = features_root
        self.labels_root = labels_root
        self.views = views
        # self.skip_ids = ['16700', '38159', '59359']
        self.sample_keys = []
        # print(features_root)
        for user_folder in os.listdir(features_root):
            # for i in self.skip_ids:
            #     if i in user_folder:
            #         continue
            # print('uuuu', user_folder)
            user_path = os.path.join(features_root, user_folder)
            # print(user_path)
            if not os.path.isdir(user_path):
                continue
            for file in os.listdir(user_path):
                # print('f', file)
                if file.startswith("Dash") and file.endswith(".npy"):
                    # print('1', os.path.splitext(file)[0])
                    key = os.path.join(user_folder, os.path.splitext(file)[0]) 
                    # print('k', key) # e.g. user_001_1/dash_1
                    self.sample_keys.append(key)

    def __len__(self):
        return len(self.sample_keys)

    def __getitem__(self, idx):
        key = self.sample_keys[idx]

        view_features = []
        for view in self.views:
            pp = view + '_' + "_".join(key.split('_')[:3]) + "_NoAudio_" + f"{key.split('_')[-1]}"
            path = os.path.join(self.features_root, os.path.dirname(key), f"{pp}.npy")
            features = np.load(path).astype(np.float32)  # [seq_len, feat_dim]
            view_features.append(features)
            # print('111111', features.shape, path)

        min_rows = min(view_features[0].shape[0], view_features[1].shape[0], view_features[2].shape[0])
        # Concatenate features across feature dimension
        features_cat = view_features[2][:min_rows, -34: -12] # [seq_len, total_feat_dim]
        label_path = os.path.join(self.labels_root, os.path.dirname(key)+ ".npy")
        labels = np.load(label_path).astype(np.int64)  # [seq_len]
        # print(features_cat.shape)
        return features_cat, labels

t_dataset = MultiViewFeatureDataset(
    features_root= train_features_root,
    labels_root= train_labels_root
)

v_dataset = MultiViewFeatureDataset(
    features_root= validation_features_root,
    labels_root= validation_labels_root
)
print(len(t_dataset))
print(len(v_dataset))

class ChunkedVideoDataset(Dataset):
    def __init__(self, base_dataset, chunk_size=100, stride=50):
        self.base_dataset = base_dataset
        self.chunk_size = chunk_size
        self.stride = stride
        self.index_map = []  # (video_idx, start_frame)

        for video_idx in range(len(base_dataset)):
            features, labels = base_dataset[video_idx]
            video_len = features.shape[0]

            for start in range(0, video_len - chunk_size + 1, stride):
                labels_chunk = labels[start:start + chunk_size]

                if  np.sum(labels_chunk == 0) >= len(labels_chunk) // 2:
                    continue

                self.index_map.append((video_idx, start))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        video_idx, start = self.index_map[idx]
        features, labels = self.base_dataset[video_idx]
        features_chunk = features[start:start+self.chunk_size]
        labels_chunk = labels[start:start+self.chunk_size]
        return features_chunk, labels_chunk

train_dataset = ChunkedVideoDataset(t_dataset, chunk_size=150, stride=75)
valid_dataset = ChunkedVideoDataset(v_dataset, chunk_size=150, stride=150)

train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=8,
    shuffle=False
)

for i in train_dataset:
    print('train', len(i))
    break

for i in train_dataset:
    print('valid', len(i))
    break

for features, labels in train_loader:
    print('train')
    print("Concatenated features:", features.shape)  # [1, seq_len, total_feat_dim]
    print("Labels:", labels.shape)
    break 
for features, labels in valid_loader:
    print('valid')
    print("Concatenated features:", features.shape)  # [1, seq_len, total_feat_dim]
    print("Labels:", labels.shape)
    break 

class MambaSequenceClassifier(nn.Module):
    def __init__(self, input_dim=6246, hidden_dim=2048, num_classes=16, seq_len=100):
        super().__init__()
        # self.input_proj = nn.Linear(input_dim, hidden_dim)

        self.mamba_block = Mamba(
            d_model=input_dim,
            d_state=16,
            d_conv=4,
            expand=2
        )
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, num_classes)

        self.dropout = nn.Dropout(p=0.2)
        self.relu = nn.ReLU()
    def forward(self, x):
        """
        x: [batch_size, seq_len, input_dim]
        returns: [batch_size, seq_len, num_classes]
        """ 
        # x = self.input_proj(x)  # -> [B, L, H]
        x = self.mamba_block(x)  # [B, L, H]
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        logits = self.output_layer(x)  # [B, L, C]
        return logits
    
# class BoundaryAwareLoss(nn.Module):
#     def __init__(self, classification_weight=1.0, localization_weight=1.0, ignore_index=None):
#         super().__init__()
#         self.classification_weight = classification_weight
#         self.localization_weight = localization_weight
#         self.ignore_index = ignore_index

#     def forward(self, frame_logits, labels, pred_boundaries, true_boundaries):
#         B, T, C = frame_logits.shape

#         # Frame-wise classification loss
#         loss_cls = F.cross_entropy(
#             frame_logits.view(-1, C),
#             labels.view(-1),
#             ignore_index=self.ignore_index if self.ignore_index is not None else -100
#         )

#         # Temporal boundary regression loss (L1)
#         loss_loc = F.l1_loss(pred_boundaries, true_boundaries)

#         return self.classification_weight * loss_cls + self.localization_weight * loss_loc



def train_mamba(
    model, train_loader, val_loader, optimizer, num_epochs, weights_save_path, device,
    num_classes=16, ignore_index=None
):
    model.to(device)
    best_val_accuracy = 0.0

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_correct = 0
        total_count = 0

        for features, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            features = features.to(device)  # [B, 50, 6144]
            labels = labels.to(device)      # [B, 50]

            optimizer.zero_grad()
            logits = model(features)        # [B, 50, num_classes]

            loss = F.cross_entropy(
                logits.view(-1, num_classes),
                labels.view(-1),
                # ignore_index=ignore_index if ignore_index is not None else -100,
            )
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # --- Accuracy ---
            preds = logits.argmax(dim=-1)  # [B, 50]
            # if ignore_index is not None:
            #     mask = labels != ignore_index
            #     total_correct += (preds[mask] == labels[mask]).sum().item()
            #     total_count += mask.sum().item()
            # else:
            total_correct += (preds == labels).sum().item()
            total_count += labels.numel()

        avg_train_loss = total_loss / len(train_loader)
        train_accuracy = total_correct / total_count if total_count > 0 else 0.0
        print(f"Epoch {epoch+1} -  Training Loss: {avg_train_loss:.4f} - Accuracy: {train_accuracy*100:.2f}%")
        # --- Validation ---
        model.eval()
        val_loss = 0
        val_correct = 0
        val_count = 0
        with torch.no_grad():
            for features, labels in tqdm(val_loader, desc="Validation"):
                features = features.to(device)
                labels = labels.to(device)

                logits = model(features)
                loss = F.cross_entropy(
                    logits.view(-1, num_classes),
                    labels.view(-1),
                    # ignore_index=ignore_index if ignore_index is not None else -100,
                )
                val_loss += loss.item()

                preds = logits.argmax(dim=-1)
                val_correct += (preds == labels).sum().item()
                val_count += labels.numel()

        avg_val_loss = val_loss / len(val_loader)
        val_accuracy = val_correct / val_count if val_count > 0 else 0.0
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            # Save the model checkpoint
            torch.save(model.state_dict(), f"{weights_save_path}/mamba_best.pth")
            print(f"Best model saved with accuracy: {val_accuracy*100:.2f}%")

        print(f"Epoch {epoch+1} - Validation Loss: {avg_val_loss:.4f} - Accuracy: {val_accuracy*100:.2f}%")
        # Save the model checkpoint
        if (epoch + 1) % 5 == 0:
            # Save the model checkpoint
            torch.save(model.state_dict(), f"{weights_save_path}/mamba_epoch_{epoch+1}.pth")
            print(f"Model saved at epoch {epoch+1}")


model = MambaSequenceClassifier(input_dim=22, hidden_dim=14, num_classes=16)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

train_mamba(model, train_loader, valid_loader, optimizer, num_epochs, weights_save_path, device="cuda", num_classes=16)

120
18
train 2
valid 2
train
Concatenated features: torch.Size([8, 150, 22])
Labels: torch.Size([8, 150])
valid
Concatenated features: torch.Size([8, 150, 22])
Labels: torch.Size([8, 150])


Epoch 1/100: 100%|██████████| 91/91 [02:57<00:00,  1.95s/it]


Epoch 1 -  Training Loss: 15232762.1758 - Accuracy: 1.66%


Validation: 100%|██████████| 7/7 [00:12<00:00,  1.84s/it]


Best model saved with accuracy: 0.55%
Epoch 1 - Validation Loss: 5832299.7679 - Accuracy: 0.55%


Epoch 2/100: 100%|██████████| 91/91 [00:24<00:00,  3.74it/s]


Epoch 2 -  Training Loss: 8782966.0659 - Accuracy: 1.38%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.89it/s]


Epoch 2 - Validation Loss: 3606837.4464 - Accuracy: 0.38%


Epoch 3/100: 100%|██████████| 91/91 [00:23<00:00,  3.82it/s]


Epoch 3 -  Training Loss: 5135322.6538 - Accuracy: 1.57%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.63it/s]


Epoch 3 - Validation Loss: 2428443.2143 - Accuracy: 0.29%


Epoch 4/100: 100%|██████████| 91/91 [00:25<00:00,  3.58it/s]


Epoch 4 -  Training Loss: 3267249.8448 - Accuracy: 1.77%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.66it/s]


Epoch 4 - Validation Loss: 1831838.9554 - Accuracy: 0.28%


Epoch 5/100: 100%|██████████| 91/91 [00:24<00:00,  3.67it/s]


Epoch 5 -  Training Loss: 2377216.9451 - Accuracy: 2.04%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.23it/s]


Epoch 5 - Validation Loss: 1502263.6786 - Accuracy: 0.33%
Model saved at epoch 5


Epoch 6/100: 100%|██████████| 91/91 [00:25<00:00,  3.52it/s]


Epoch 6 -  Training Loss: 1891373.2109 - Accuracy: 2.15%


Validation: 100%|██████████| 7/7 [00:18<00:00,  2.65s/it]


Epoch 6 - Validation Loss: 1268490.0312 - Accuracy: 0.32%


Epoch 7/100: 100%|██████████| 91/91 [01:21<00:00,  1.12it/s]


Epoch 7 -  Training Loss: 1570958.2301 - Accuracy: 2.24%


Validation: 100%|██████████| 7/7 [00:07<00:00,  1.14s/it]


Epoch 7 - Validation Loss: 1095245.4554 - Accuracy: 0.25%


Epoch 8/100: 100%|██████████| 91/91 [00:22<00:00,  4.03it/s]


Epoch 8 -  Training Loss: 1340610.5659 - Accuracy: 2.33%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.29it/s]


Epoch 8 - Validation Loss: 953237.9688 - Accuracy: 0.27%


Epoch 9/100: 100%|██████████| 91/91 [00:22<00:00,  4.05it/s]


Epoch 9 -  Training Loss: 1163612.2431 - Accuracy: 2.40%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.66it/s]


Epoch 9 - Validation Loss: 833259.5089 - Accuracy: 0.29%


Epoch 10/100: 100%|██████████| 91/91 [00:22<00:00,  4.05it/s]


Epoch 10 -  Training Loss: 1018832.1102 - Accuracy: 2.67%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.52it/s]


Epoch 10 - Validation Loss: 730512.5893 - Accuracy: 0.42%
Model saved at epoch 10


Epoch 11/100: 100%|██████████| 91/91 [00:22<00:00,  4.04it/s]


Epoch 11 -  Training Loss: 899947.7964 - Accuracy: 2.84%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.67it/s]


Epoch 11 - Validation Loss: 644110.8438 - Accuracy: 0.53%


Epoch 12/100: 100%|██████████| 91/91 [00:23<00:00,  3.90it/s]


Epoch 12 -  Training Loss: 799657.4176 - Accuracy: 3.08%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.64it/s]


Best model saved with accuracy: 0.65%
Epoch 12 - Validation Loss: 567316.4598 - Accuracy: 0.65%


Epoch 13/100: 100%|██████████| 91/91 [02:13<00:00,  1.47s/it]


Epoch 13 -  Training Loss: 712129.7799 - Accuracy: 3.33%


Validation: 100%|██████████| 7/7 [00:12<00:00,  1.78s/it]


Best model saved with accuracy: 0.78%
Epoch 13 - Validation Loss: 500623.9219 - Accuracy: 0.78%


Epoch 14/100: 100%|██████████| 91/91 [00:50<00:00,  1.80it/s]


Epoch 14 -  Training Loss: 631053.3015 - Accuracy: 3.55%


Validation: 100%|██████████| 7/7 [00:07<00:00,  1.07s/it]


Best model saved with accuracy: 1.03%
Epoch 14 - Validation Loss: 442891.6205 - Accuracy: 1.03%


Epoch 15/100: 100%|██████████| 91/91 [00:54<00:00,  1.68it/s]


Epoch 15 -  Training Loss: 566754.2766 - Accuracy: 3.91%


Validation: 100%|██████████| 7/7 [00:10<00:00,  1.48s/it]


Best model saved with accuracy: 1.22%
Epoch 15 - Validation Loss: 392934.6451 - Accuracy: 1.22%
Model saved at epoch 15


Epoch 16/100: 100%|██████████| 91/91 [00:21<00:00,  4.18it/s]


Epoch 16 -  Training Loss: 507672.6899 - Accuracy: 4.39%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.16it/s]


Best model saved with accuracy: 1.60%
Epoch 16 - Validation Loss: 347850.3058 - Accuracy: 1.60%


Epoch 17/100: 100%|██████████| 91/91 [00:21<00:00,  4.31it/s]


Epoch 17 -  Training Loss: 453822.7038 - Accuracy: 4.71%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.48it/s]


Best model saved with accuracy: 2.15%
Epoch 17 - Validation Loss: 308656.6384 - Accuracy: 2.15%


Epoch 18/100: 100%|██████████| 91/91 [00:26<00:00,  3.48it/s]


Epoch 18 -  Training Loss: 407307.3896 - Accuracy: 5.00%


Validation: 100%|██████████| 7/7 [00:24<00:00,  3.44s/it]


Best model saved with accuracy: 2.58%
Epoch 18 - Validation Loss: 274411.7165 - Accuracy: 2.58%


Epoch 19/100: 100%|██████████| 91/91 [03:21<00:00,  2.22s/it]


Epoch 19 -  Training Loss: 369053.8032 - Accuracy: 5.22%


Validation: 100%|██████████| 7/7 [00:18<00:00,  2.58s/it]


Best model saved with accuracy: 3.26%
Epoch 19 - Validation Loss: 244556.9308 - Accuracy: 3.26%


Epoch 20/100: 100%|██████████| 91/91 [04:59<00:00,  3.29s/it]


Epoch 20 -  Training Loss: 333615.3425 - Accuracy: 5.49%


Validation: 100%|██████████| 7/7 [00:20<00:00,  2.86s/it]


Best model saved with accuracy: 3.90%
Epoch 20 - Validation Loss: 220006.4554 - Accuracy: 3.90%
Model saved at epoch 20


Epoch 21/100: 100%|██████████| 91/91 [02:08<00:00,  1.42s/it]


Epoch 21 -  Training Loss: 302272.2259 - Accuracy: 5.72%


Validation: 100%|██████████| 7/7 [00:13<00:00,  1.94s/it]


Best model saved with accuracy: 4.42%
Epoch 21 - Validation Loss: 198299.4358 - Accuracy: 4.42%


Epoch 22/100: 100%|██████████| 91/91 [00:25<00:00,  3.59it/s]


Epoch 22 -  Training Loss: 276396.5978 - Accuracy: 6.01%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.93it/s]


Best model saved with accuracy: 4.92%
Epoch 22 - Validation Loss: 179461.9626 - Accuracy: 4.92%


Epoch 23/100: 100%|██████████| 91/91 [00:34<00:00,  2.63it/s]


Epoch 23 -  Training Loss: 251617.9451 - Accuracy: 6.20%


Validation: 100%|██████████| 7/7 [00:04<00:00,  1.72it/s]


Best model saved with accuracy: 5.25%
Epoch 23 - Validation Loss: 162910.1529 - Accuracy: 5.25%


Epoch 24/100: 100%|██████████| 91/91 [01:36<00:00,  1.06s/it]


Epoch 24 -  Training Loss: 229947.2176 - Accuracy: 6.44%


Validation: 100%|██████████| 7/7 [00:02<00:00,  3.08it/s]


Best model saved with accuracy: 5.50%
Epoch 24 - Validation Loss: 148347.9459 - Accuracy: 5.50%


Epoch 25/100: 100%|██████████| 91/91 [00:22<00:00,  3.98it/s]


Epoch 25 -  Training Loss: 211046.9323 - Accuracy: 6.68%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.12it/s]


Best model saved with accuracy: 5.68%
Epoch 25 - Validation Loss: 135607.1730 - Accuracy: 5.68%
Model saved at epoch 25


Epoch 26/100: 100%|██████████| 91/91 [00:21<00:00,  4.24it/s]


Epoch 26 -  Training Loss: 194012.3377 - Accuracy: 7.13%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.19it/s]


Best model saved with accuracy: 6.08%
Epoch 26 - Validation Loss: 124105.3923 - Accuracy: 6.08%


Epoch 27/100: 100%|██████████| 91/91 [00:21<00:00,  4.26it/s]


Epoch 27 -  Training Loss: 177998.1832 - Accuracy: 7.38%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.72it/s]


Best model saved with accuracy: 6.53%
Epoch 27 - Validation Loss: 113765.9202 - Accuracy: 6.53%


Epoch 28/100: 100%|██████████| 91/91 [00:21<00:00,  4.17it/s]


Epoch 28 -  Training Loss: 164094.8170 - Accuracy: 7.75%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.71it/s]


Best model saved with accuracy: 7.83%
Epoch 28 - Validation Loss: 104623.1763 - Accuracy: 7.83%


Epoch 29/100: 100%|██████████| 91/91 [00:22<00:00,  4.07it/s]


Epoch 29 -  Training Loss: 150949.7791 - Accuracy: 8.16%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.95it/s]


Best model saved with accuracy: 8.96%
Epoch 29 - Validation Loss: 96152.7679 - Accuracy: 8.96%


Epoch 30/100: 100%|██████████| 91/91 [00:24<00:00,  3.74it/s]


Epoch 30 -  Training Loss: 139035.6242 - Accuracy: 8.67%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.74it/s]


Best model saved with accuracy: 9.45%
Epoch 30 - Validation Loss: 88570.2227 - Accuracy: 9.45%
Model saved at epoch 30


Epoch 31/100: 100%|██████████| 91/91 [00:23<00:00,  3.93it/s]


Epoch 31 -  Training Loss: 128288.9431 - Accuracy: 8.97%


Validation: 100%|██████████| 7/7 [00:10<00:00,  1.55s/it]


Best model saved with accuracy: 10.00%
Epoch 31 - Validation Loss: 81623.8471 - Accuracy: 10.00%


Epoch 32/100: 100%|██████████| 91/91 [00:28<00:00,  3.18it/s]


Epoch 32 -  Training Loss: 118746.8200 - Accuracy: 9.43%


Validation: 100%|██████████| 7/7 [00:08<00:00,  1.21s/it]


Best model saved with accuracy: 10.51%
Epoch 32 - Validation Loss: 75025.5960 - Accuracy: 10.51%


Epoch 33/100: 100%|██████████| 91/91 [00:22<00:00,  4.11it/s]


Epoch 33 -  Training Loss: 109517.4936 - Accuracy: 9.87%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.72it/s]


Best model saved with accuracy: 11.13%
Epoch 33 - Validation Loss: 69214.8962 - Accuracy: 11.13%


Epoch 34/100: 100%|██████████| 91/91 [00:24<00:00,  3.68it/s]


Epoch 34 -  Training Loss: 101056.1255 - Accuracy: 10.43%


Validation: 100%|██████████| 7/7 [00:02<00:00,  2.91it/s]


Best model saved with accuracy: 11.62%
Epoch 34 - Validation Loss: 63855.0647 - Accuracy: 11.62%


Epoch 35/100: 100%|██████████| 91/91 [00:22<00:00,  4.12it/s]


Epoch 35 -  Training Loss: 93970.3716 - Accuracy: 10.89%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.09it/s]


Best model saved with accuracy: 12.45%
Epoch 35 - Validation Loss: 59213.6356 - Accuracy: 12.45%
Model saved at epoch 35


Epoch 36/100: 100%|██████████| 91/91 [00:22<00:00,  3.96it/s]


Epoch 36 -  Training Loss: 87455.6725 - Accuracy: 11.15%


Validation: 100%|██████████| 7/7 [00:03<00:00,  2.03it/s]


Best model saved with accuracy: 13.36%
Epoch 36 - Validation Loss: 55078.5938 - Accuracy: 13.36%


Epoch 37/100: 100%|██████████| 91/91 [00:51<00:00,  1.78it/s]


Epoch 37 -  Training Loss: 81633.4351 - Accuracy: 11.53%


Validation: 100%|██████████| 7/7 [00:05<00:00,  1.27it/s]


Best model saved with accuracy: 14.36%
Epoch 37 - Validation Loss: 51277.2958 - Accuracy: 14.36%


Epoch 38/100: 100%|██████████| 91/91 [00:23<00:00,  3.82it/s]


Epoch 38 -  Training Loss: 76281.4908 - Accuracy: 11.92%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.12it/s]


Best model saved with accuracy: 15.27%
Epoch 38 - Validation Loss: 47840.7807 - Accuracy: 15.27%


Epoch 39/100: 100%|██████████| 91/91 [00:28<00:00,  3.15it/s]


Epoch 39 -  Training Loss: 71864.8627 - Accuracy: 12.34%


Validation: 100%|██████████| 7/7 [00:03<00:00,  1.87it/s]


Best model saved with accuracy: 16.33%
Epoch 39 - Validation Loss: 44847.4442 - Accuracy: 16.33%


Epoch 40/100: 100%|██████████| 91/91 [00:26<00:00,  3.42it/s]


Epoch 40 -  Training Loss: 67838.5169 - Accuracy: 12.61%


Validation: 100%|██████████| 7/7 [00:12<00:00,  1.77s/it]


Best model saved with accuracy: 17.84%
Epoch 40 - Validation Loss: 42038.2983 - Accuracy: 17.84%
Model saved at epoch 40


Epoch 41/100: 100%|██████████| 91/91 [00:25<00:00,  3.54it/s]


Epoch 41 -  Training Loss: 63503.3583 - Accuracy: 12.83%


Validation: 100%|██████████| 7/7 [00:03<00:00,  2.03it/s]


Best model saved with accuracy: 18.91%
Epoch 41 - Validation Loss: 39393.8703 - Accuracy: 18.91%


Epoch 42/100: 100%|██████████| 91/91 [00:21<00:00,  4.17it/s]


Epoch 42 -  Training Loss: 59878.2555 - Accuracy: 13.26%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.14it/s]


Best model saved with accuracy: 19.44%
Epoch 42 - Validation Loss: 36848.2533 - Accuracy: 19.44%


Epoch 43/100: 100%|██████████| 91/91 [00:21<00:00,  4.17it/s]


Epoch 43 -  Training Loss: 56150.3250 - Accuracy: 13.52%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.74it/s]


Best model saved with accuracy: 20.21%
Epoch 43 - Validation Loss: 34652.5586 - Accuracy: 20.21%


Epoch 44/100: 100%|██████████| 91/91 [00:22<00:00,  4.06it/s]


Epoch 44 -  Training Loss: 52801.9799 - Accuracy: 13.87%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.24it/s]


Best model saved with accuracy: 20.99%
Epoch 44 - Validation Loss: 32464.1571 - Accuracy: 20.99%


Epoch 45/100: 100%|██████████| 91/91 [00:21<00:00,  4.17it/s]


Epoch 45 -  Training Loss: 49713.3931 - Accuracy: 14.11%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.49it/s]


Best model saved with accuracy: 21.58%
Epoch 45 - Validation Loss: 30448.0304 - Accuracy: 21.58%
Model saved at epoch 45


Epoch 46/100: 100%|██████████| 91/91 [00:27<00:00,  3.31it/s]


Epoch 46 -  Training Loss: 46450.8865 - Accuracy: 14.42%


Validation: 100%|██████████| 7/7 [00:08<00:00,  1.21s/it]


Best model saved with accuracy: 22.24%
Epoch 46 - Validation Loss: 28488.9760 - Accuracy: 22.24%


Epoch 47/100: 100%|██████████| 91/91 [00:22<00:00,  4.05it/s]


Epoch 47 -  Training Loss: 43765.7187 - Accuracy: 14.54%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.52it/s]


Best model saved with accuracy: 23.35%
Epoch 47 - Validation Loss: 26627.4392 - Accuracy: 23.35%


Epoch 48/100: 100%|██████████| 91/91 [00:22<00:00,  4.12it/s]


Epoch 48 -  Training Loss: 40406.0064 - Accuracy: 14.94%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.83it/s]


Best model saved with accuracy: 23.62%
Epoch 48 - Validation Loss: 24969.7016 - Accuracy: 23.62%


Epoch 49/100: 100%|██████████| 91/91 [00:23<00:00,  3.95it/s]


Epoch 49 -  Training Loss: 37997.7833 - Accuracy: 15.37%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.15it/s]


Epoch 49 - Validation Loss: 23398.5230 - Accuracy: 23.35%


Epoch 50/100: 100%|██████████| 91/91 [00:22<00:00,  4.11it/s]


Epoch 50 -  Training Loss: 35258.2553 - Accuracy: 15.64%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.21it/s]


Epoch 50 - Validation Loss: 22013.6455 - Accuracy: 23.54%
Model saved at epoch 50


Epoch 51/100: 100%|██████████| 91/91 [00:24<00:00,  3.72it/s]


Epoch 51 -  Training Loss: 33229.9652 - Accuracy: 16.00%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.13it/s]


Epoch 51 - Validation Loss: 20746.7817 - Accuracy: 23.58%


Epoch 52/100: 100%|██████████| 91/91 [00:25<00:00,  3.57it/s]


Epoch 52 -  Training Loss: 30818.4450 - Accuracy: 16.32%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.38it/s]


Best model saved with accuracy: 23.78%
Epoch 52 - Validation Loss: 19513.3111 - Accuracy: 23.78%


Epoch 53/100: 100%|██████████| 91/91 [00:22<00:00,  3.98it/s]


Epoch 53 -  Training Loss: 28939.1825 - Accuracy: 16.42%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.93it/s]


Epoch 53 - Validation Loss: 18376.6131 - Accuracy: 23.48%


Epoch 54/100: 100%|██████████| 91/91 [00:23<00:00,  3.90it/s]


Epoch 54 -  Training Loss: 26963.4049 - Accuracy: 16.80%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.59it/s]


Epoch 54 - Validation Loss: 17342.9203 - Accuracy: 23.35%


Epoch 55/100: 100%|██████████| 91/91 [00:27<00:00,  3.36it/s]


Epoch 55 -  Training Loss: 25546.6240 - Accuracy: 17.01%


Validation: 100%|██████████| 7/7 [00:02<00:00,  3.38it/s]


Epoch 55 - Validation Loss: 16365.3936 - Accuracy: 23.42%
Model saved at epoch 55


Epoch 56/100: 100%|██████████| 91/91 [00:28<00:00,  3.15it/s]


Epoch 56 -  Training Loss: 24115.3181 - Accuracy: 17.02%


Validation: 100%|██████████| 7/7 [00:09<00:00,  1.40s/it]


Epoch 56 - Validation Loss: 15489.9965 - Accuracy: 23.02%


Epoch 57/100: 100%|██████████| 91/91 [00:23<00:00,  3.92it/s]


Epoch 57 -  Training Loss: 22801.5002 - Accuracy: 17.02%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.78it/s]


Epoch 57 - Validation Loss: 14736.7744 - Accuracy: 23.09%


Epoch 58/100: 100%|██████████| 91/91 [00:22<00:00,  4.06it/s]


Epoch 58 -  Training Loss: 21434.3896 - Accuracy: 16.81%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.49it/s]


Epoch 58 - Validation Loss: 14042.9955 - Accuracy: 23.65%


Epoch 59/100: 100%|██████████| 91/91 [00:23<00:00,  3.84it/s]


Epoch 59 -  Training Loss: 20254.2935 - Accuracy: 17.05%


Validation: 100%|██████████| 7/7 [00:02<00:00,  2.73it/s]


Best model saved with accuracy: 23.81%
Epoch 59 - Validation Loss: 13446.8060 - Accuracy: 23.81%


Epoch 60/100: 100%|██████████| 91/91 [00:24<00:00,  3.65it/s]


Epoch 60 -  Training Loss: 19303.1019 - Accuracy: 17.01%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.21it/s]


Epoch 60 - Validation Loss: 12907.2492 - Accuracy: 23.71%
Model saved at epoch 60


Epoch 61/100: 100%|██████████| 91/91 [00:22<00:00,  4.11it/s]


Epoch 61 -  Training Loss: 18691.0921 - Accuracy: 16.79%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.18it/s]


Epoch 61 - Validation Loss: 12404.9662 - Accuracy: 23.37%


Epoch 62/100: 100%|██████████| 91/91 [00:22<00:00,  3.97it/s]


Epoch 62 -  Training Loss: 17625.1891 - Accuracy: 16.89%


Validation: 100%|██████████| 7/7 [00:05<00:00,  1.22it/s]


Epoch 62 - Validation Loss: 11890.0253 - Accuracy: 22.46%


Epoch 63/100: 100%|██████████| 91/91 [00:24<00:00,  3.73it/s]


Epoch 63 -  Training Loss: 16886.9633 - Accuracy: 16.73%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.99it/s]


Epoch 63 - Validation Loss: 11418.3509 - Accuracy: 22.10%


Epoch 64/100: 100%|██████████| 91/91 [00:23<00:00,  3.81it/s]


Epoch 64 -  Training Loss: 16220.2910 - Accuracy: 16.58%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.21it/s]


Epoch 64 - Validation Loss: 10970.4707 - Accuracy: 21.81%


Epoch 65/100: 100%|██████████| 91/91 [00:24<00:00,  3.67it/s]


Epoch 65 -  Training Loss: 15562.3009 - Accuracy: 16.30%


Validation: 100%|██████████| 7/7 [00:05<00:00,  1.39it/s]


Epoch 65 - Validation Loss: 10534.5840 - Accuracy: 21.24%
Model saved at epoch 65


Epoch 66/100: 100%|██████████| 91/91 [00:22<00:00,  4.05it/s]


Epoch 66 -  Training Loss: 14926.3596 - Accuracy: 16.19%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.84it/s]


Epoch 66 - Validation Loss: 10130.9510 - Accuracy: 20.13%


Epoch 67/100: 100%|██████████| 91/91 [00:23<00:00,  3.91it/s]


Epoch 67 -  Training Loss: 14387.3044 - Accuracy: 16.21%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.17it/s]


Epoch 67 - Validation Loss: 9725.1774 - Accuracy: 19.50%


Epoch 68/100: 100%|██████████| 91/91 [00:23<00:00,  3.80it/s]


Epoch 68 -  Training Loss: 13764.2097 - Accuracy: 15.77%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.19it/s]


Epoch 68 - Validation Loss: 9350.1759 - Accuracy: 19.22%


Epoch 69/100: 100%|██████████| 91/91 [00:23<00:00,  3.87it/s]


Epoch 69 -  Training Loss: 13319.1977 - Accuracy: 15.63%


Validation: 100%|██████████| 7/7 [00:02<00:00,  3.47it/s]


Epoch 69 - Validation Loss: 8983.4150 - Accuracy: 18.92%


Epoch 70/100: 100%|██████████| 91/91 [00:23<00:00,  3.87it/s]


Epoch 70 -  Training Loss: 12691.0433 - Accuracy: 15.67%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.47it/s]


Epoch 70 - Validation Loss: 8635.2387 - Accuracy: 18.62%
Model saved at epoch 70


Epoch 71/100: 100%|██████████| 91/91 [00:26<00:00,  3.44it/s]


Epoch 71 -  Training Loss: 12332.8027 - Accuracy: 15.35%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.99it/s]


Epoch 71 - Validation Loss: 8295.0095 - Accuracy: 18.33%


Epoch 72/100: 100%|██████████| 91/91 [00:24<00:00,  3.70it/s]


Epoch 72 -  Training Loss: 11784.2643 - Accuracy: 15.31%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.14it/s]


Epoch 72 - Validation Loss: 7947.0585 - Accuracy: 18.06%


Epoch 73/100: 100%|██████████| 91/91 [00:23<00:00,  3.87it/s]


Epoch 73 -  Training Loss: 11410.1989 - Accuracy: 15.09%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.93it/s]


Epoch 73 - Validation Loss: 7624.3534 - Accuracy: 18.42%


Epoch 74/100: 100%|██████████| 91/91 [00:26<00:00,  3.45it/s]


Epoch 74 -  Training Loss: 10922.1388 - Accuracy: 15.01%


Validation: 100%|██████████| 7/7 [00:01<00:00,  3.97it/s]


Epoch 74 - Validation Loss: 7280.8838 - Accuracy: 18.24%


Epoch 75/100: 100%|██████████| 91/91 [02:23<00:00,  1.57s/it]


Epoch 75 -  Training Loss: 10474.2928 - Accuracy: 15.11%


Validation: 100%|██████████| 7/7 [00:17<00:00,  2.47s/it]


Epoch 75 - Validation Loss: 6962.8897 - Accuracy: 17.37%
Model saved at epoch 75


Epoch 76/100: 100%|██████████| 91/91 [02:33<00:00,  1.69s/it]


Epoch 76 -  Training Loss: 10147.1622 - Accuracy: 14.86%


Validation: 100%|██████████| 7/7 [00:26<00:00,  3.78s/it]


Epoch 76 - Validation Loss: 6673.8905 - Accuracy: 16.99%


Epoch 77/100: 100%|██████████| 91/91 [05:08<00:00,  3.40s/it]


Epoch 77 -  Training Loss: 9684.4089 - Accuracy: 14.51%


Validation: 100%|██████████| 7/7 [00:20<00:00,  2.99s/it]


Epoch 77 - Validation Loss: 6391.0182 - Accuracy: 16.69%


Epoch 78/100: 100%|██████████| 91/91 [03:27<00:00,  2.28s/it]


Epoch 78 -  Training Loss: 9409.8214 - Accuracy: 14.47%


Validation: 100%|██████████| 7/7 [00:21<00:00,  3.10s/it]


Epoch 78 - Validation Loss: 6126.7661 - Accuracy: 16.40%


Epoch 79/100: 100%|██████████| 91/91 [00:57<00:00,  1.59it/s]


Epoch 79 -  Training Loss: 9082.0018 - Accuracy: 14.19%


Validation: 100%|██████████| 7/7 [00:19<00:00,  2.72s/it]


Epoch 79 - Validation Loss: 5867.5908 - Accuracy: 16.59%


Epoch 80/100: 100%|██████████| 91/91 [00:31<00:00,  2.92it/s]


Epoch 80 -  Training Loss: 8772.3573 - Accuracy: 14.29%


Validation: 100%|██████████| 7/7 [00:04<00:00,  1.74it/s]


Epoch 80 - Validation Loss: 5634.1934 - Accuracy: 16.61%
Model saved at epoch 80


Epoch 81/100: 100%|██████████| 91/91 [01:41<00:00,  1.12s/it]


Epoch 81 -  Training Loss: 8400.4420 - Accuracy: 14.12%


Validation: 100%|██████████| 7/7 [00:21<00:00,  3.07s/it]


Epoch 81 - Validation Loss: 5378.2861 - Accuracy: 16.47%


Epoch 82/100: 100%|██████████| 91/91 [03:01<00:00,  2.00s/it]


Epoch 82 -  Training Loss: 8091.0233 - Accuracy: 14.01%


Validation: 100%|██████████| 7/7 [00:11<00:00,  1.65s/it]


Epoch 82 - Validation Loss: 5146.8485 - Accuracy: 16.38%


Epoch 83/100: 100%|██████████| 91/91 [02:10<00:00,  1.43s/it]


Epoch 83 -  Training Loss: 7862.1177 - Accuracy: 14.08%


Validation: 100%|██████████| 7/7 [00:10<00:00,  1.49s/it]


Epoch 83 - Validation Loss: 4911.4368 - Accuracy: 16.47%


Epoch 84/100: 100%|██████████| 91/91 [00:25<00:00,  3.52it/s]


Epoch 84 -  Training Loss: 7631.9520 - Accuracy: 13.84%


Validation: 100%|██████████| 7/7 [00:02<00:00,  3.06it/s]


Epoch 84 - Validation Loss: 4705.4619 - Accuracy: 16.41%


Epoch 85/100: 100%|██████████| 91/91 [00:26<00:00,  3.49it/s]


Epoch 85 -  Training Loss: 7315.3037 - Accuracy: 13.59%


Validation: 100%|██████████| 7/7 [00:02<00:00,  3.36it/s]


Epoch 85 - Validation Loss: 4488.6364 - Accuracy: 16.27%
Model saved at epoch 85


Epoch 86/100: 100%|██████████| 91/91 [01:40<00:00,  1.11s/it]


Epoch 86 -  Training Loss: 6950.3338 - Accuracy: 13.68%


Validation: 100%|██████████| 7/7 [00:08<00:00,  1.23s/it]


Epoch 86 - Validation Loss: 4287.0873 - Accuracy: 16.16%


Epoch 87/100: 100%|██████████| 91/91 [00:33<00:00,  2.71it/s]


Epoch 87 -  Training Loss: 6709.0540 - Accuracy: 13.52%


Validation: 100%|██████████| 7/7 [00:02<00:00,  3.38it/s]


Epoch 87 - Validation Loss: 4084.4079 - Accuracy: 15.71%


Epoch 88/100: 100%|██████████| 91/91 [01:20<00:00,  1.12it/s]


Epoch 88 -  Training Loss: 6492.0300 - Accuracy: 13.45%


Validation: 100%|██████████| 7/7 [00:17<00:00,  2.49s/it]


Epoch 88 - Validation Loss: 3897.9497 - Accuracy: 15.45%


Epoch 89/100: 100%|██████████| 91/91 [05:21<00:00,  3.53s/it]


Epoch 89 -  Training Loss: 6296.8828 - Accuracy: 13.36%


Validation: 100%|██████████| 7/7 [00:22<00:00,  3.16s/it]


Epoch 89 - Validation Loss: 3718.2849 - Accuracy: 15.15%


Epoch 90/100: 100%|██████████| 91/91 [02:47<00:00,  1.84s/it]


Epoch 90 -  Training Loss: 6047.8163 - Accuracy: 13.29%


Validation: 100%|██████████| 7/7 [00:21<00:00,  3.13s/it]


Epoch 90 - Validation Loss: 3557.7406 - Accuracy: 14.92%
Model saved at epoch 90


Epoch 91/100: 100%|██████████| 91/91 [00:44<00:00,  2.05it/s]


Epoch 91 -  Training Loss: 5784.2334 - Accuracy: 13.20%


Validation: 100%|██████████| 7/7 [00:02<00:00,  3.32it/s]


Epoch 91 - Validation Loss: 3394.2185 - Accuracy: 14.48%


Epoch 92/100: 100%|██████████| 91/91 [01:49<00:00,  1.21s/it]


Epoch 92 -  Training Loss: 5597.9015 - Accuracy: 13.01%


Validation: 100%|██████████| 7/7 [00:08<00:00,  1.24s/it]


Epoch 92 - Validation Loss: 3239.9464 - Accuracy: 14.24%


Epoch 93/100: 100%|██████████| 91/91 [00:23<00:00,  3.89it/s]


Epoch 93 -  Training Loss: 5417.0885 - Accuracy: 12.96%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.07it/s]


Epoch 93 - Validation Loss: 3090.4091 - Accuracy: 13.95%


Epoch 94/100: 100%|██████████| 91/91 [00:24<00:00,  3.75it/s]


Epoch 94 -  Training Loss: 5173.6765 - Accuracy: 12.90%


Validation: 100%|██████████| 7/7 [00:01<00:00,  4.19it/s]


Epoch 94 - Validation Loss: 2941.9427 - Accuracy: 13.60%


Epoch 95/100: 100%|██████████| 91/91 [04:37<00:00,  3.05s/it]


Epoch 95 -  Training Loss: 5077.6497 - Accuracy: 12.71%


Validation: 100%|██████████| 7/7 [00:23<00:00,  3.33s/it]


Epoch 95 - Validation Loss: 2792.6519 - Accuracy: 13.62%
Model saved at epoch 95


Epoch 96/100: 100%|██████████| 91/91 [09:51<00:00,  6.51s/it]


Epoch 96 -  Training Loss: 4799.4850 - Accuracy: 12.64%


Validation: 100%|██████████| 7/7 [00:29<00:00,  4.17s/it]


Epoch 96 - Validation Loss: 2650.3584 - Accuracy: 13.70%


Epoch 97/100: 100%|██████████| 91/91 [08:47<00:00,  5.80s/it]


Epoch 97 -  Training Loss: 4656.4807 - Accuracy: 12.52%


Validation: 100%|██████████| 7/7 [00:25<00:00,  3.67s/it]


Epoch 97 - Validation Loss: 2510.4855 - Accuracy: 13.43%


Epoch 98/100: 100%|██████████| 91/91 [04:51<00:00,  3.20s/it]


Epoch 98 -  Training Loss: 4453.3357 - Accuracy: 12.35%


Validation: 100%|██████████| 7/7 [00:26<00:00,  3.82s/it]


Epoch 98 - Validation Loss: 2359.2720 - Accuracy: 12.35%


Epoch 99/100: 100%|██████████| 91/91 [02:17<00:00,  1.51s/it]


Epoch 99 -  Training Loss: 4234.3090 - Accuracy: 12.39%


Validation: 100%|██████████| 7/7 [00:12<00:00,  1.76s/it]


Epoch 99 - Validation Loss: 2227.8290 - Accuracy: 11.90%


Epoch 100/100: 100%|██████████| 91/91 [02:18<00:00,  1.52s/it]


Epoch 100 -  Training Loss: 4094.6617 - Accuracy: 12.27%


Validation: 100%|██████████| 7/7 [00:20<00:00,  2.94s/it]

Epoch 100 - Validation Loss: 2097.0962 - Accuracy: 11.88%
Model saved at epoch 100



