In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt


with open('../features/Features_thina', 'rb') as f:
    df_features,start_move,stop_turn = pickle.load(f)

print(f"Numver of features: {len(df_features['id'].unique())}")
df_features = df_features[(df_features['normalized_start_movement'].notna()) & (df_features['normalized_start_movement'] > 0.0005)]
print(f"Number of features after filtering: {len(df_features['id'].unique())}")
#df_features = df_features[df_features['normalized_stop_turn'].notnull()]
print(f"Number of start movements before filtering: {len(start_move)}")
start_move = {k: v for k, v in start_move.items() if v > 0.0005 and not pd.isna(v)}
print(f"Number of start movements after filtering: {len(start_move)}")
dict_list = [('start_move', start_move)]

# Initialize an empty DataFrame for the result
df_annotated = pd.DataFrame()
# Iterate over the list of tuples to create and merge each DataFrame
for feature, data in dict_list:
    # Convert dictionary to DataFrame and reset index
    df_temp = pd.DataFrame.from_dict(data, orient='index', columns=[feature]).reset_index()
    df_temp.columns = ['id', feature]  # Rename columns
    # Merge the temporary DataFrame with the main annotated DataFrame
    if df_annotated.empty:
        df_annotated = df_temp
    else:
        df_annotated = df_annotated.merge(df_temp, on='id', how='outer')
features = ['rotation_x', 'rotation_y','rotation_z','normalized_time','velocity','angle','distance_to_end']
max_time = df_features.groupby('id')['adjusted_time'].max()
min_time = df_features.groupby('id')['adjusted_time'].min()

X = df_features.drop(['type_trajectory'], axis=1)
grouped = X.groupby('id')
trajectories = [group[features].values for _, group in grouped]
ids = list(grouped.groups.keys())
trajectory_type = df_features.groupby('id')['type_trajectory'].first().values
unique_trajectory_types = np.unique(trajectory_type)
trajectory_type_to_index = {t: i for i, t in enumerate(unique_trajectory_types)}
trajectory_type_indices = np.array([trajectory_type_to_index[t] for t in trajectory_type])
num_trajectory_types = len(unique_trajectory_types)

# Scale the data - Column-wise scaling across all trajectories
def scale_trajectories_columnwise(trajectories):
    if not trajectories:
        return trajectories
    all_trajectories = np.array(trajectories, dtype=object)
    num_features = trajectories[0].shape[1] if len(trajectories[0]) > 0 else 0
    if num_features == 0:
        return trajectories
    all_values_per_feature = [[] for _ in range(num_features)]
    for traj in trajectories:
        if len(traj) > 0:
            for feature_idx in range(num_features):
                all_values_per_feature[feature_idx].extend(traj[:, feature_idx])
    feature_means = []
    feature_stds = []
    for feature_idx in range(num_features):
        values = np.array(all_values_per_feature[feature_idx])
        mean_val = np.mean(values)
        std_val = np.std(values)
        feature_means.append(mean_val)
        feature_stds.append(std_val if std_val > 0 else 1.0)
    feature_means = np.array(feature_means)
    feature_stds = np.array(feature_stds)
    scaled_trajectories = []
    for traj in trajectories:
        if len(traj) == 0:
            scaled_trajectories.append(traj)
        else:
            scaled_traj = (traj - feature_means) / feature_stds
            scaled_trajectories.append(scaled_traj)
    return scaled_trajectories, feature_means, feature_stds

def scale_test_trajectories(test_trajectories, means, stds):
    scaled_trajectories = []
    for traj in test_trajectories:
        if len(traj) == 0:
            scaled_trajectories.append(traj)
        else:
            scaled_traj = (traj - means) / stds
            scaled_trajectories.append(scaled_traj)
    return scaled_trajectories

def pad_sequences(sequences, maxlen, dtype='float32', padding='post', value=0):
    padded_sequences = np.full((len(sequences), maxlen, len(sequences[0][0])), value, dtype=dtype)
    for i, seq in enumerate(sequences):
        if len(seq) > maxlen:
            padded_sequences[i] = seq[:maxlen]
        else:
            padded_sequences[i, :len(seq)] = seq
    return padded_sequences

class TrajectoryDataset(Dataset):
    def __init__(self, sequences, types, labels, device):
        self.sequences = torch.tensor(sequences, dtype=torch.float32).to(device)
        self.types = torch.tensor(types, dtype=torch.int64).to(device)
        self.labels = torch.tensor(labels, dtype=torch.float32).to(device)
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        return self.sequences[idx], self.types[idx], self.labels[idx]

class TrajectoryModel(nn.Module):
    def __init__(self, num_trajectory_types, number_of_features, sequence_length):
        super(TrajectoryModel, self).__init__()
        self.lstm = nn.LSTM(number_of_features, 50, bidirectional=True, batch_first=True)
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.embedding = nn.Embedding(num_trajectory_types, 2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(100 + 2, 100)
        self.fc_mean = nn.Linear(100, 1)
        self.fc_logvar = nn.Linear(100, 1)
    def forward(self, x, type_input):
        lstm_out, _ = self.lstm(x)
        lstm_out = lstm_out.permute(0, 2, 1)
        avg_pool_out = self.avg_pool(lstm_out).squeeze(-1)
        embedded_type = self.embedding(type_input)
        embedded_type = self.flatten(embedded_type)
        combined = torch.cat((avg_pool_out, embedded_type), dim=1)
        x = torch.relu(self.fc1(combined))
        mu = self.fc_mean(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

def nll_criterion_gaussian(mu, logvar, target, reduction='mean'):
    loss = torch.exp(-logvar) * torch.pow(target - mu, 2) + logvar
    return loss.mean() if reduction == 'mean' else loss.sum()

def calculate_mae(outputs, labels, ids, max_time, min_time):
    total_mae = 0.0
    num_samples = 0
    for pred, lbl in zip(outputs, labels):
        _id = ids[num_samples]
        true_value_ms = lbl.item() * (max_time[_id] - min_time[_id]) + min_time[_id]
        pred_value_ms = pred.item() * (max_time[_id] - min_time[_id]) + min_time[_id]
        total_mae += abs(pred_value_ms - true_value_ms)
        num_samples += 1
    return total_mae, num_samples

# Repeat train/test split, training, and evaluation 10 times
num_runs = 1
mae_results = []

for run in range(num_runs):
    print(f"\nRun {run+1}/{num_runs}")
    # Split the data
    X_train, X_test, id_train, id_test, type_train_lst, type_test_lst = train_test_split(
        trajectories, ids, trajectory_type_indices, test_size=0.2, random_state=42 + run)
    type_train = np.array(type_train_lst)
    type_test = np.array(type_test_lst)

    # Scale training and test data
    X_train_scaled, train_means, train_stds = scale_trajectories_columnwise(X_train)
    X_test_scaled = scale_test_trajectories(X_test, train_means, train_stds)

    max_sequence_length = max([len(sequence) for sequence in X_train])
    X_train_padded = pad_sequences(X_train_scaled, max_sequence_length, padding='post')
    X_test_padded = pad_sequences(X_test_scaled, max_sequence_length, padding='post')
    sequence_length = X_train_padded.shape[1]
    number_of_features = X_train_padded.shape[2]

    # Use CPU only for training and evaluation to avoid MPS memory issues
    device = torch.device("cpu")

    for name, label in dict_list:
        y_train = np.array([label[_id] for _id in id_train])
        y_test = np.array([label[_id] for _id in id_test])

        train_dataset = TrajectoryDataset(X_train_padded, type_train, y_train, device)
        test_dataset = TrajectoryDataset(X_test_padded, type_test, y_test, device)

        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

        model = TrajectoryModel(num_trajectory_types, number_of_features, sequence_length).to(device)
        criterion = nll_criterion_gaussian
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        epochs = 100
        for epoch in range(epochs):
            model.train()
            for sequences, types, labels in train_loader:
                sequences, types, labels = sequences.to(device), types.to(device), labels.to(device)
                optimizer.zero_grad()
                mu, logvar = model(sequences, types)
                loss = nll_criterion_gaussian(mu.squeeze(), logvar.squeeze(), labels)
                loss.backward()
                optimizer.step()

        # Evaluate on test set
        model.eval()
        with torch.no_grad():
            all_preds = []
            all_labels = []
            for sequences, types, labels in test_loader:
                sequences, types, labels = sequences.to(device), types.to(device), labels.to(device)
                mu, logvar = model(sequences, types)
                all_preds.extend(mu.squeeze().cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Calculate MAE for this run
        total_mae = 0.0
        for pred, lbl, _id in zip(all_preds, all_labels, id_test):
            true_value_ms = lbl * (max_time[_id] - min_time[_id]) + min_time[_id]
            pred_value_ms = pred * (max_time[_id] - min_time[_id]) + min_time[_id]
            total_mae += abs(pred_value_ms - true_value_ms)
        mean_mae = total_mae / len(id_test)
        print(f"Run {run+1} MAE: {mean_mae:.4f}")
        mae_results.append(mean_mae)

        # ===============================================
        # EVALUATION ON NANA'S TRAILS
        # ===============================================

        print("Loading Nana's data for evaluation...")

        # Load Nana's features
        with open('all_features_3d_aligned_NANA', 'rb') as f:
            df_features_nana, start_move_nana, stop_turn_nana = pickle.load(f)


        print(f"Numver of features: {len(df_features_nana['id'].unique())}")
        df_features_nana = df_features_nana[(df_features_nana['normalized_start_movement'].notna()) & (df_features_nana['normalized_start_movement'] > 0.0005)]
        print(f"Number of features after filtering: {len(df_features_nana['id'].unique())}")
        #df_features = df_features[df_features['normalized_stop_turn'].notnull()]
        print(f"Number of start movements before filtering: {len(start_move_nana)}")
        start_move_nana = {k: v for k, v in start_move_nana.items() if v > 0.0005 and not pd.isna(v)}
        print(f"Number of start movements after filtering: {len(start_move_nana)}")
        dict_list_nana = [('start_move', start_move_nana)]


        df_annotated_nana = pd.DataFrame()
        for feature, data in dict_list_nana:
            df_temp = pd.DataFrame.from_dict(data, orient='index', columns=[feature]).reset_index()
            df_temp.columns = ['id', feature]
            if df_annotated_nana.empty:
                df_annotated_nana = df_temp
            else:
                df_annotated_nana = df_annotated_nana.merge(df_temp, on='id', how='outer')

        # Calculate time ranges for Nana data (needed for denormalization)
        max_time_nana = df_features_nana.groupby('id')['adjusted_time'].max()
        min_time_nana = df_features_nana.groupby('id')['adjusted_time'].min()

        print(f"Time ranges calculated for {len(max_time_nana)} trajectories")

        # Define a custom dataset
        class TrajectoryDataset(Dataset):
            def __init__(self, sequences, types, labels):
                self.sequences = torch.tensor(sequences, dtype=torch.float32).to(device)
                self.types = torch.tensor(types, dtype=torch.int64).to(device)
                self.labels = torch.tensor(labels, dtype=torch.float32).to(device)
                
            def __len__(self):
                return len(self.labels)
            
            def __getitem__(self, idx):
                return self.sequences[idx], self.types[idx], self.labels[idx]

        # Define the model
        class TrajectoryModel(nn.Module):
            def __init__(self, num_trajectory_types, number_of_features, sequence_length):
                super(TrajectoryModel, self).__init__()
                self.lstm = nn.LSTM(number_of_features, 50, bidirectional=True, batch_first=True)
                self.avg_pool = nn.AdaptiveAvgPool1d(1)
                self.embedding = nn.Embedding(num_trajectory_types, 2)
                self.flatten = nn.Flatten()
                self.fc1 = nn.Linear(100 + 2, 100)  # 100 from LSTM + 2 from embedding
                self.fc_mean = nn.Linear(100, 1)  # Mean prediction
                self.fc_logvar = nn.Linear(100, 1)  # Log-variance
                
            def forward(self, x, type_input):
                lstm_out, _ = self.lstm(x)
                lstm_out = lstm_out.permute(0, 2, 1)  # (batch_size, num_features, seq_len)
                avg_pool_out = self.avg_pool(lstm_out).squeeze(-1)  # (batch_size, 100)
                embedded_type = self.embedding(type_input)
                embedded_type = self.flatten(embedded_type)
                combined = torch.cat((avg_pool_out, embedded_type), dim=1)
                x = torch.relu(self.fc1(combined))
                mu = self.fc_mean(x)  # Mean prediction
                logvar = self.fc_logvar(x)  # Log-variance
                return mu, logvar

        # Scale the data - Column-wise scaling across all trajectories
        def scale_trajectories_columnwise(trajectories):
            """
            Scale each feature column separately across all trajectories
            """
            if not trajectories:
                return trajectories
            
            # Convert to numpy array for easier manipulation
            all_trajectories = np.array(trajectories, dtype=object)
            
            # Get the number of features from the first trajectory
            num_features = trajectories[0].shape[1] if len(trajectories[0]) > 0 else 0
            
            if num_features == 0:
                return trajectories
            
            # Calculate global statistics for each feature column
            all_values_per_feature = [[] for _ in range(num_features)]
            
            # Collect all values for each feature across all trajectories
            for traj in trajectories:
                if len(traj) > 0:
                    for feature_idx in range(num_features):
                        all_values_per_feature[feature_idx].extend(traj[:, feature_idx])
            
            # Calculate mean and std for each feature
            feature_means = []
            feature_stds = []
            
            for feature_idx in range(num_features):
                values = np.array(all_values_per_feature[feature_idx])
                mean_val = np.mean(values)
                std_val = np.std(values)
                
                feature_means.append(mean_val)
                feature_stds.append(std_val if std_val > 0 else 1.0)  # Avoid division by zero
            
            feature_means = np.array(feature_means)
            feature_stds = np.array(feature_stds)
            
            # Scale each trajectory using global statistics
            scaled_trajectories = []
            for traj in trajectories:
                if len(traj) == 0:
                    scaled_trajectories.append(traj)
                else:
                    scaled_traj = (traj - feature_means) / feature_stds
                    scaled_trajectories.append(scaled_traj)
            
            return scaled_trajectories, feature_means, feature_stds

        # Prepare Nana's data for evaluation
        print("Preprocessing Nana's data...")

        # Use the same features as the training data
        features = ['rotation_x', 'rotation_y','rotation_z','normalized_time','velocity','angle','distance_to_end']


        # Prepare trajectory data for Nana
        X_nana = df_features_nana.drop(['type_trajectory'], axis=1)
        grouped_nana = X_nana.groupby('id')
        trajectories_nana = [group[features].values for _, group in grouped_nana]
        ids_nana = list(grouped_nana.groups.keys())

        # Get trajectory types for Nana
        trajectory_type_nana = df_features_nana.groupby('id')['type_trajectory'].first().values
        num_trajectory_types_nana = len(np.unique(trajectory_type_nana))
        # Map Nana trajectory types to the same indices used in training

        trajectory_type = df_features_nana.groupby('id')['type_trajectory'].first().values
        unique_trajectory_types = np.unique(trajectory_type)
        trajectory_type_to_index = {t: i for i, t in enumerate(unique_trajectory_types)}
        trajectory_type_indices_nana = []
        for t in trajectory_type_nana:
            if t in trajectory_type_to_index:
                trajectory_type_indices_nana.append(trajectory_type_to_index[t])
            else:
                print(f"Warning: Trajectory type '{t}' not seen in training data. Using index 0.")
                trajectory_type_indices_nana.append(0)

        trajectory_type_indices_nana = np.array(trajectory_type_indices_nana)
        X_nana_scaled , thina_means, thina_stds = scale_trajectories_columnwise(trajectories_nana)


        def pad_sequences(sequences, maxlen, dtype='float32', padding='post', value=0):
            padded_sequences = np.full((len(sequences), maxlen, len(sequences[0][0])), value, dtype=dtype)
            for i, seq in enumerate(sequences):
                if len(seq) > maxlen:
                    padded_sequences[i] = seq[:maxlen]
                else:
                    padded_sequences[i, :len(seq)] = seq
            return padded_sequences
        ### CHANGE TO SCALED!!!
        max_sequence_length_nana = max([len(sequence) for sequence in trajectories_nana])
        X_nana_padded = pad_sequences(trajectories_nana, max_sequence_length_nana, padding='post')

        sequence_length_nana = X_nana_padded.shape[1]
        number_of_features_nana = X_nana_padded.shape[2]

        # Check MPS availability
        device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
        print("Using device:", device)
        # Pad sequences using the same max_sequence_length from training
        print(f"Padded Nana trajectories shape: {X_nana_padded.shape}")
        print(f"Trajectory types distribution in Nana:")
        unique_nana, counts_nana = np.unique(trajectory_type_nana, return_counts=True)
        for t, c in zip(unique_nana, counts_nana):
            print(f"  {t}: {c} trajectories")

        # Evaluate on Nana's data
        print("Evaluating trained model on Nana's data...")

        # Use the last trained model instance directly
        model_eval = model
        model_eval.eval()

        # Prepare labels for Nana
        y_nana = np.array([start_move_nana[_id] for _id in ids_nana])

        # Create dataset and dataloader for Nana
        nana_dataset = TrajectoryDataset(X_nana_padded, trajectory_type_indices_nana, y_nana, device)
        nana_loader = DataLoader(nana_dataset, batch_size=8, shuffle=False)

        # Evaluate the model
        predictions_nana = {}
        nana_mae_list = []

        print("Making predictions on Nana's data...")
        with torch.no_grad():
            index_offset = 0
            for sequences, types, labels in nana_loader:
                sequences, types, labels = sequences.to(device), types.to(device), labels.to(device)
                mu, logvar = model_eval(sequences, types)
                batch_size = sequences.size(0)
                batch_ids = ids_nana[index_offset:index_offset + batch_size]
                index_offset += batch_size
                for _id, pred_mu, pred_logvar in zip(batch_ids, mu.squeeze().tolist(), logvar.squeeze().tolist()):
                    predictions_nana[_id] = {"mu": pred_mu, "logvar": pred_logvar}
                    true_value_ms = start_move_nana[_id] * (max_time_nana[_id] - min_time_nana[_id]) + min_time_nana[_id]
                    pred_value_ms = pred_mu * (max_time_nana[_id] - min_time_nana[_id]) + min_time_nana[_id]
                    mae = abs(pred_value_ms - true_value_ms)
                    nana_mae_list.append(mae)
        print(f"Evaluation completed on {len(predictions_nana)} Nana trajectories")

        # Analyze Nana evaluation results
        print("Analyzing results on Nana's data...")

        # Convert predictions to DataFrame
        df_results_nana = pd.DataFrame.from_dict(predictions_nana, orient='index')
        df_results_nana['variance'] = np.exp(df_results_nana['logvar'])
        df_results_nana['true_y'] = [start_move_nana[_id] for _id in df_results_nana.index]

        # Add trajectory types
        trajectory_types_nana = df_features_nana.groupby("id")["type_trajectory"].first().to_dict()
        df_results_nana['type_trajectory'] = df_results_nana.index.map(trajectory_types_nana)
        df_results_nana['absolute_error'] = abs(df_results_nana['mu'] - df_results_nana['true_y'])
        df_results_nana['residuals'] = df_results_nana['true_y'] - df_results_nana['mu']

        # Print evaluation metrics
        print("\n=== EVALUATION ON NANA'S DATA ===")
        print(f"Number of trajectories: {len(df_results_nana)}")
        print(f"MAE: {df_results_nana['absolute_error'].mean():.4f}")
        print(f"RMSE: {np.sqrt((df_results_nana['residuals']**2).mean()):.4f}")

        print("\nMAE by trajectory type:")
        for traj_type in df_results_nana['type_trajectory'].unique():
            subset = df_results_nana[df_results_nana['type_trajectory'] == traj_type]
            print(f"  {traj_type}: {subset['absolute_error'].mean():.4f} (n={len(subset)})")

        # Filter the DataFrame to include only rows where both true_y and mu are between 0 and 1
        filtered_df = df_results_nana[
            (df_results_nana['true_y'] >= 0) & (df_results_nana['true_y'] <= 1) &
            (df_results_nana['mu'] >= 0) & (df_results_nana['mu'] <= 1)
        ]

        # Plotting
        plt.figure(figsize=(5, 5))
        plt.scatter(filtered_df['true_y'], filtered_df['mu'], alpha=0.6)
        plt.xlabel("True Values")
        plt.ylabel("Predicted Values")
        plt.title("True vs Predicted Values - Movement Onset point (Nana's Data) Trained on Thina")
        plt.plot([0, 1], [0, 1], color='red', label='Identity Line')
        plt.legend()
        plt.show()


print("\nMAE results for 10 runs:")
print(mae_results)
print(f"Mean MAE over 10 runs: {np.mean(mae_results):.4f}")