In [None]:
import os
import csv
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
from mamba_ssm import Mamba2

class Mamba2Classifier(nn.Module):
    def __init__(self, input_size, num_classes, d_state=64, d_conv=4, expand=2):
        super(Mamba2Classifier, self).__init__()
        self.mamba = Mamba2(
            d_model=input_size,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )
        self.fc = nn.Linear(input_size, num_classes)

    def forward(self, x):
        out = self.mamba(x)
        out = out.mean(dim=1)
        out = self.fc(out)
        return out

class EnsembleMultiViewMamba2Model(nn.Module):
    def __init__(self, input_size, num_classes, d_state=64, d_conv=4, expand=2):
        super(EnsembleMultiViewMamba2Model, self).__init__()
        self.view_models = nn.ModuleList([
            Mamba2Classifier(input_size, num_classes, d_state, d_conv, expand)
            for _ in range(3)
        ])
        self.ensemble_weights = nn.Parameter(torch.ones(3) / 3)

    def forward(self, x):
        view_outputs = []
        for i, view_x in enumerate(x):
            output = self.view_models[i](view_x)
            view_outputs.append(output)

        ensemble_weights = F.softmax(self.ensemble_weights, dim=0)
        final_output = sum(w * out for w, out in zip(ensemble_weights, view_outputs))
        return final_output

class MultiViewSequenceDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.views = ['frontal_view', 'left_side_mirror_view', 'right_side_mirror_view']
        self.classes = sorted(os.listdir(os.path.join(root_dir, self.views[0])))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.data = []

        for cls in self.classes:
            for file in os.listdir(os.path.join(root_dir, self.views[0], cls)):
                if file.endswith('.npy'):
                    self.data.append((file, self.class_to_idx[cls]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        file_name, label = self.data[idx]
        sequences = []
        for view in self.views:
            file_path = os.path.join(self.root_dir, view, self.classes[label], file_name)
            sequence = np.load(file_path, allow_pickle=True)
            sequence = (sequence - np.mean(sequence)) / np.std(sequence)
            sequences.append(torch.tensor(sequence, dtype=torch.float32).squeeze())
        return sequences, torch.tensor(label, dtype=torch.long)

class MultiViewTestDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.views = ['frontal_view', 'left_view', 'right_view']
        self.file_names = [f for f in os.listdir(os.path.join(root_dir, self.views[0])) if f.endswith('.npy')]

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        sequences = []
        for view in self.views:
            file_path = os.path.join(self.root_dir, view, file_name)
            sequence = np.load(file_path, allow_pickle=True)
            sequence = (sequence - np.mean(sequence)) / np.std(sequence)
            sequences.append(torch.tensor(sequence, dtype=torch.float32).squeeze())
        return sequences, file_name

def collate_fn(batch):
    sequences, labels = zip(*batch)
    padded_sequences = [nn.utils.rnn.pad_sequence([seq[i] for seq in sequences], batch_first=True) for i in range(3)]
    return padded_sequences, torch.stack(labels) if isinstance(labels[0], torch.Tensor) else labels

def calculate_metrics(loader, model, device):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for x, y in loader:
            x = [d.to(device) for d in x]
            y = y.to(device)
            scores = model(x)
            _, predictions = scores.max(1)
            y_true.extend(y.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())
    return (accuracy_score(y_true, y_pred),
            precision_score(y_true, y_pred, average='weighted'),
            recall_score(y_true, y_pred, average='weighted'),
            f1_score(y_true, y_pred, average='weighted'))

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device):
    model.to(device)
    best_accuracy = 0
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for data, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            data = [d.to(device) for d in data]
            targets = targets.to(device)
            scores = model(data)
            loss = criterion(scores, targets)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        avg_train_loss = train_loss / len(train_loader)
        train_metrics = calculate_metrics(train_loader, model, device)
        val_metrics = calculate_metrics(val_loader, model, device)

        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"Train Loss: {avg_train_loss:.4f}")
        print(f"Train: Accuracy: {train_metrics[0]:.4f}, Precision: {train_metrics[1]:.4f}, Recall: {train_metrics[2]:.4f}, F1: {train_metrics[3]:.4f}")
        print(f"Val: Accuracy: {val_metrics[0]:.4f}, Precision: {val_metrics[1]:.4f}, Recall: {val_metrics[2]:.4f}, F1: {val_metrics[3]:.4f}")

        scheduler.step()

        if val_metrics[0] > best_accuracy:
            best_accuracy = val_metrics[0]
            torch.save(model.state_dict(), 'best_multi_view_mamba2_model.pth')
            print(f"New best model saved with accuracy: {best_accuracy:.4f}")

def evaluate_model(model, data_loader, device):
    model.eval()
    all_predictions = []
    all_file_names = []

    with torch.no_grad():
        for data, file_names in tqdm(data_loader, desc="Evaluating"):
            data = [d.to(device) for d in data]
            outputs = model(data)
            _, predictions = torch.max(outputs, 1)
            all_predictions.extend(predictions.cpu().numpy())
            all_file_names.extend(file_names)

    return all_predictions, all_file_names

def save_results_to_csv(file_names, predictions, class_names, output_file):
    with open(output_file, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        header = ['frontal_view_video_name'] + class_names
        writer.writerow(header)

        for file_name, pred in zip(file_names, predictions):
            file_name_without_ext = os.path.splitext(file_name)[0]
            row = [file_name_without_ext] + [1 if i == pred else 0 for i in range(len(class_names))]
            writer.writerow(row)

def main():
    input_size = 512
    num_classes = 6
    d_state = 32
    d_conv = 4
    expand = 6
    batch_size = 16
    learning_rate = 0.001
    num_epochs = 20
    train_dataset = MultiViewSequenceDataset('/workspace/data/VGG16_Training_Features')
    val_dataset = MultiViewSequenceDataset('/workspace/data/VGG16_val_features')
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    model = EnsembleMultiViewMamba2Model(input_size, num_classes, d_state, d_conv, expand)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.8)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device)
    model.load_state_dict(torch.load('best_multi_view_mamba2_model.pth'))
    test_dataset = MultiViewTestDataset('/workspace/data/VGG16_test_features')
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    predictions, file_names = evaluate_model(model, test_loader, device)
    class_names = ['Left Lane Change', 'Left Turn', 'Right Lane Change', 'Right Turn', 'Slow-Stop', 'Straight']
    save_results_to_csv(file_names, predictions, class_names, '/workspace/multiview_test_results.csv')
    print("Testing completed. Results saved to mamba2_multiview_test_results.csv")

    from collections import Counter
    print(f"Total predictions: {len(predictions)}")
    print(f"Unique classes predicted: {set(predictions)}")
    print(f"Class distribution: {dict(Counter(predictions))}")

if __name__ == "__main__":
    main()