# Run Inference for RNN with context

In [38]:
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, balanced_accuracy_score
import os
import torch
import json

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader

import seaborn as sns
from tqdm import tqdm
import argparse

from build.lib.navi.nn.models_using_embeddings import ResetNet50GRU
from navi.datasets.frames_embeddings import FramesWithContextDataset
from navi.datasets.single_video_with_context import SingleVideoWithContext
from navi.datasets.video_with_context import VideoWithContext
#from navi.nn.models_using_embeddings import ResetNet50GRU
from navi.transforms import ToTensor

In [39]:
def create_dfs(video_csv: str, label_map_path: str) -> dict:
    # Dict to store all dataframes
    result: dict = {}
    mapping_df = pd.read_csv(video_csv)  # df that maps a video id to its folder

    # Open json file
    with open(label_map_path, 'r') as f:
        label_map: pd.DataFrame = json.load(f)

        for index, row in mapping_df.iterrows():
            video_id = row['id']
            if video_id in label_map:
                video_data = label_map[video_id]
                if isinstance(video_data, list) and all(isinstance(item, dict) for item in video_data):
                    video_data_df = pd.DataFrame(video_data)
                else:
                    raise Exception(f'Invalid data format for video ID: {video_id}')
                result[video_id] = video_data_df
            else:
                raise Exception(f'Could not find the following id in json file: {video_id}')

    return result

In [40]:
def get_output_sub_dir(output_dir: str, key: str):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    output_sub_dir = os.path.join(output_dir, key)
    if not os.path.exists(output_sub_dir):
        os.makedirs(output_sub_dir)

    return output_sub_dir

In [41]:
from navi.nn.embeddings_fc import ResNet50FC

def get_model(device, model_path, hidden_size, n_layers, fc=True):
    if fc:
        model = ResNet50FC(
            input_size=2048,
            hidden_size=hidden_size
        )
    else:
        model = ResetNet50GRU(
            input_size=2048,
            hidden_size=hidden_size,
            n_layers=n_layers,
        )
    checkpoint = torch.load(model_path)

    model.load_state_dict(checkpoint)
    model.eval()
    model = model.to(device)
    return model


In [42]:
def get_models(weights_dir: str, nb_models: int=5, verbose=False):
    weights = []

    for i in range(nb_models):
        fold_dir = os.path.join(weights_dir, f"fold_{i}")

        best_pt_path = os.path.join(fold_dir, 'best.pt')

        if os.path.exists(best_pt_path):
            weights.append(best_pt_path)
        else:
            raise Exception(f'Missing weights for fold at: {fold_dir}')

    if verbose:
        for weight in weights:
            print(weight)

    return weights

In [43]:
def infer(model_paths: list, embeddings_dir: str, test_videos: pd.DataFrame, label_map: pd.DataFrame,
          context_size: int, batch_size: int, hidden_size: int, n_layers: int, label_df, model_name='fc') -> dict:
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    device = 'cpu'
    print(f'Device: {device}')

    #model_nb = 0
    transforms = ToTensor()

    inference_dataset = FramesWithContextDataset(
                    embeddings_dir,
                    test_videos,
                    label_map,
                    transform=transforms,
                    context_size=context_size,
                )
    # Get DataLoader
    inference_loader = DataLoader(inference_dataset, batch_size=batch_size, shuffle=False)
    
    for model_nb in range(len(model_paths)):
        print(f"Processing model {model_nb}")
        model = get_model(device, model_paths[model_nb], hidden_size, n_layers)

        for key in label_df.keys():
            label_df[key][f'proba_{model_nb}'] = None
    
        label_df: dict = run_batches(device, model, inference_loader, label_df, model_nb)

        # Run inference
        # if model_name == 'fc':
        #     label_df: dict = run_batches_fc(device, model, inference_loader, label_df, model_nb)
        # else:
        #     label_df: dict = run_batches(device, model, inference_loader, label_df, model_nb)

    return label_df

In [44]:
def run_batches_fc(device, model, inference_loader, label_df, model_nb) -> dict:
    compteur = 0
    current_df_key = list(label_df.keys())[0]
    current_df = label_df[current_df_key]

    for batch, (features, labels) in tqdm(enumerate(inference_loader), desc='inference'):
        with torch.inference_mode():
            features = features.to(device)
            features = features[:, -1, :]
            logits = model(features).squeeze(dim=-1)
            predictions = (logits.sigmoid() > 0.5).long()
            predictions = predictions.squeeze()
            proba = logits.sigmoid().squeeze()
            
            if compteur >= len(current_df):
                current_df_key = list(label_df.keys())[list(label_df.keys()).index(current_df_key) + 1]
                current_df = label_df[current_df_key]
                compteur = 0 

            # Ajoute la prédiction au DataFrame courant
            col_to_write = f'proba_{model_nb}'
            current_df.at[compteur, col_to_write] = proba[-1].item()
            compteur += 1

    return label_df


In [45]:
def run_batches(device, model, inference_loader, label_df, model_nb) -> dict:
    compteur = 0
    current_df_key = list(label_df.keys())[0]
    current_df = label_df[current_df_key]

    for batch, (features, labels) in tqdm(enumerate(inference_loader), desc='inference'):
        with torch.inference_mode():
            features = features.to(device)
            logits = model(features).squeeze(dim=-1)
            predictions = (logits.sigmoid() > 0.5).long()
            predictions = predictions.squeeze()
            proba = logits.sigmoid().squeeze()
            
            if compteur >= len(current_df):
                current_df_key = list(label_df.keys())[list(label_df.keys()).index(current_df_key) + 1]
                current_df = label_df[current_df_key]
                compteur = 0 

            # Ajoute la prédiction au DataFrame courant
            col_to_write = f'proba_{model_nb}'
            current_df.at[compteur, col_to_write] = proba[-1].item()
            compteur += 1

    return label_df


In [46]:
weights_dir="/home/rob/Documents/Github/navi_lstm/output/snapshot/v25_fn_resnet/grid_0"
#weights_dir="/home/rob/Documents/Github/navi_lstm/notebooks/"

k=5
embeddings_dir='/home/rob/Documents/lstm_backup/embeddings_resnet_test' 
video_csv='/home/rob/Documents/lstm_backup/maps/mapping_test.csv'
label_map_path='/home/rob/Documents/lstm_backup/labels/ground_truth_testset.json' 
context_size=5
batch_size=1
hidden_size=32
n_layers=1

In [47]:
output_dir = '/home/rob/Documents/Github/navi_lstm/output/inference'
output_sub_dir: dict = {}
dataframe_dict = create_dfs(video_csv, label_map_path)
test_videos = pd.read_csv(video_csv)

print('Pre Processing Dataframes')
for key, dataframe in dataframe_dict.items():
    try:
        # pre_process_df(dataframe, k)
        # dataframe_dict[key] = dataframe
        # stores the path to save output
        output_sub_dir[key] = get_output_sub_dir(output_dir, key)
    except Exception as e:
        raise Exception(f'Error during Pre Processing of Dataframes for key: {key} | \nException: {e}')

video_ids = sorted(dataframe_dict.keys())

print('Collecting Models')
model_paths = get_models(weights_dir, nb_models=k)
with open(label_map_path, 'rb') as file:
    label_map = json.load(file)

results_df = infer(model_paths=model_paths,
                          embeddings_dir=embeddings_dir, test_videos=test_videos, label_map=label_map,
                          context_size=context_size, batch_size=batch_size, hidden_size=hidden_size,
                          n_layers=n_layers, label_df=dataframe_dict)

#verify_matching_keys(dataframe_dict, df_dict_one_model)

Pre Processing Dataframes
Collecting Models
Device: cpu
Loading embeddings...
Loading targets...
Indexing frames...
Dataset loaded.
Processing model 0


inference: 24124it [00:05, 4563.74it/s]


Processing model 1


inference: 24124it [00:05, 4042.29it/s]


Processing model 2


inference: 24124it [00:06, 3931.19it/s]


Processing model 3


inference: 24124it [00:04, 5013.81it/s]


Processing model 4


inference: 24124it [00:04, 5124.24it/s]


In [48]:
results_df.keys()

dict_keys(['id_x1', 'id_x2', 'id_x3', 'id_x4', 'id_x5'])

In [49]:
results_df['id_x1']

Unnamed: 0,frame,label,proba_0,proba_1,proba_2,proba_3,proba_4
0,0,0,0.695349,0.604026,0.693878,0.78162,0.740069
1,1,0,0.588238,0.511964,0.534744,0.683259,0.580529
2,2,0,0.687131,0.534376,0.54728,0.753236,0.654674
3,3,0,0.630827,0.477583,0.550627,0.702151,0.681208
4,4,0,0.587383,0.4571,0.498581,0.709815,0.602283
...,...,...,...,...,...,...,...
1365,1365,0,0.635369,0.607195,0.676168,0.814261,0.776786
1366,1366,0,0.538782,0.545471,0.646625,0.603083,0.719442
1367,1367,0,0.684252,0.669105,0.774884,0.732205,0.788397
1368,1368,0,0.652685,0.518505,0.777622,0.758464,0.867524


## Get average

In [50]:
def average_prediction(frame_df, k_range):
    prob_cols = [f'proba_{k}' for k in k_range]
    for col in prob_cols:
        if col not in frame_df.columns:
            raise ValueError(f"Missing column {col} in DataFrame")

    # Average probability for each row (axis=1 -> column)
    frame_df['proba_avg'] = frame_df[prob_cols].mean(axis=1)
    # Average prediction
    frame_df['pred_avg'] = np.where(frame_df['proba_avg'] > 0.5, 1, 0)

In [51]:
for key, dataframe in results_df.items():
    average_prediction(dataframe, range(5))

In [52]:
results_df['id_x1']

Unnamed: 0,frame,label,proba_0,proba_1,proba_2,proba_3,proba_4,proba_avg,pred_avg
0,0,0,0.695349,0.604026,0.693878,0.78162,0.740069,0.702988,1
1,1,0,0.588238,0.511964,0.534744,0.683259,0.580529,0.579747,1
2,2,0,0.687131,0.534376,0.54728,0.753236,0.654674,0.63534,1
3,3,0,0.630827,0.477583,0.550627,0.702151,0.681208,0.608479,1
4,4,0,0.587383,0.4571,0.498581,0.709815,0.602283,0.571032,1
...,...,...,...,...,...,...,...,...,...
1365,1365,0,0.635369,0.607195,0.676168,0.814261,0.776786,0.701956,1
1366,1366,0,0.538782,0.545471,0.646625,0.603083,0.719442,0.610681,1
1367,1367,0,0.684252,0.669105,0.774884,0.732205,0.788397,0.729769,1
1368,1368,0,0.652685,0.518505,0.777622,0.758464,0.867524,0.71496,1


In [53]:
for key, df in results_df.items():
    output_file_name = f"{key}_output.csv"
    df.to_csv(output_file_name, index=False)

In [54]:
def plot_labels_and_predictions(df, plot_title, output_dir):
    df = df.sort_values(by='frame')

    true_labels = df['label'].tolist()
    model_predictions = df['pred_0'].tolist()

    fig, axs = plt.subplots(2, 1, figsize=(15, 4))
    tick_interval = 250
    frame_numbers = df['frame'].tolist()
    xticks = np.arange(min(frame_numbers), max(frame_numbers) + tick_interval, tick_interval)
    xticklabels = np.arange(min(frame_numbers), max(frame_numbers) + tick_interval, tick_interval).astype(int)

    for i, label in enumerate(true_labels):
        color = 'white' if label == 0 else 'dimgray'
        axs[0].axhspan(0, 1, xmin=i / len(frame_numbers), xmax=(i + 1) / len(frame_numbers), facecolor=color)

    axs[0].set_yticks([])
    axs[0].set_title('True Labels', fontsize=14, fontweight='bold')
    axs[0].set_xticks(xticks / max(frame_numbers))
    axs[0].set_xticklabels(xticklabels)

    for i, label in enumerate(model_predictions):
        color = 'white' if label == 0 else 'darkgray'
        axs[1].axhspan(0, 1, xmin=i / len(frame_numbers), xmax=(i + 1) / len(frame_numbers), facecolor=color)

    axs[1].set_yticks([])
    axs[1].set_title(plot_title, fontsize=14, fontweight='bold')
    axs[1].set_xticks(xticks / max(frame_numbers))
    axs[1].set_xticklabels(xticklabels)

    plt.xlabel('Frame Number')
    plt.tight_layout()

    label_rectangle_path = os.path.join(output_dir, 'label_vis.png')
    plt.savefig(label_rectangle_path)

    plt.show()


In [55]:
# def evaluate_models(df):
#     # Initialize DataFrame
#     results_df = pd.DataFrame(
#         columns=['average', 'balanced_accuracy', 'accuracy', 'precision', 'recall', 'f1_score', 'correct_no', 'incorrect_no', 'correct_spur',
#                  'incorrect_spur'])
# 
#     # Extract true labels from df and filter None values
#     true_labels_k = [label for label in df['label'].tolist() if label is not None]
# 
#     for col in df.columns:
#         if col.startswith('pred_'):
#             # Filter None values
#             model_labels_k = [label for label in df[col].tolist() if label is not None]
# 
#             if len(true_labels_k) != len(model_labels_k):
#                 print(f"Skipping evaluation for {col} due to inconsistent label lengths after filtering None values.")
#                 continue
# 
#             model_name = col.split('_')[1]
#             results_df.loc[len(results_df)] = calculate_metrics(true_labels_k, model_labels_k, model_name)
# 
#     return results_df

In [56]:
def evaluate_models(df):
    # Initialize DataFrame
    results_df = pd.DataFrame(columns=['model', 'balanced_accuracy', 'accuracy', 'precision', 'recall', 'f1_score', 'correct_no', 'incorrect_no', 'correct_spur', 'incorrect_spur'])

    # Extract true labels from df and filter None values
    true_labels_k = [label for label in df['label'].tolist() if label is not None]
    model_labels_k = [label for label in df['pred_avg'].tolist() if label is not None]

    if len(true_labels_k) != len(model_labels_k):
        print("Skipping evaluation due to inconsistent label lengths after filtering None values.")
    else:
        results_df.loc[len(results_df)] = calculate_metrics(true_labels_k, model_labels_k, 'pred_avg')

    return results_df

In [57]:
def calculate_metrics(true_labels, pred_labels, label_name):
    # Calculate metrics
    accuracy = round(accuracy_score(true_labels, pred_labels) * 100, 2)  # TP + TN / n
    balanced_accuracy = round(balanced_accuracy_score(true_labels, pred_labels) * 100, 2)  

    precision = round(precision_score(true_labels, pred_labels, pos_label=1, average='binary') * 100,
                      2)  # TP / TP + FP
    recall = round(recall_score(true_labels, pred_labels, pos_label=1, average='binary') * 100, 2)  # TP / TP + FN
    f1 = round(f1_score(true_labels, pred_labels, pos_label=1, average='binary') * 100,
               2)  # 2 * [(recall * precision) / (recall + precision)]

    # Calculate percentages
    correct_no = round(
        sum([1 for true, pred in zip(true_labels, pred_labels) if true == pred and true == 0]) / true_labels.count(
            0) * 100, 2)
    incorrect_no = round(
        sum([1 for true, pred in zip(true_labels, pred_labels) if true != pred and true == 0]) / true_labels.count(
            0) * 100, 2)
    correct_spur = round(
        sum([1 for true, pred in zip(true_labels, pred_labels) if true == pred and true == 1]) / true_labels.count(
            1) * 100, 2)
    incorrect_spur = round(
        sum([1 for true, pred in zip(true_labels, pred_labels) if true != pred and true == 1]) / true_labels.count(
            1) * 100, 2)

    return [label_name, balanced_accuracy, accuracy, precision, recall, f1, correct_no, incorrect_no, correct_spur, incorrect_spur]


In [58]:
df_id_x1 = results_df['id_x1']  

results_frame_df = evaluate_models(df_id_x1)

In [59]:
results_frame_df

Unnamed: 0,model,balanced_accuracy,accuracy,precision,recall,f1_score,correct_no,incorrect_no,correct_spur,incorrect_spur
0,pred_avg,61.97,52.48,41.95,93.53,57.92,30.42,69.58,93.53,6.47


In [60]:
df_id_x = results_df['id_x2']  
results_frame_df = evaluate_models(df_id_x)
results_frame_df

Unnamed: 0,model,balanced_accuracy,accuracy,precision,recall,f1_score,correct_no,incorrect_no,correct_spur,incorrect_spur
0,pred_avg,77.76,81.15,45.28,72.73,55.81,82.8,17.2,72.73,27.27


In [61]:
df_id_x = results_df['id_x3']  
results_frame_df = evaluate_models(df_id_x)
results_frame_df

Unnamed: 0,model,balanced_accuracy,accuracy,precision,recall,f1_score,correct_no,incorrect_no,correct_spur,incorrect_spur
0,pred_avg,83.6,87.26,42.35,79.03,55.15,88.17,11.83,79.03,20.97


In [62]:
df_id_x = results_df['id_x4']  
results_frame_df = evaluate_models(df_id_x)
results_frame_df

Unnamed: 0,model,balanced_accuracy,accuracy,precision,recall,f1_score,correct_no,incorrect_no,correct_spur,incorrect_spur
0,pred_avg,80.59,78.54,46.13,83.88,59.53,77.3,22.7,83.88,16.12


In [63]:
df_id_x = results_df['id_x5']  
results_frame_df = evaluate_models(df_id_x)
results_frame_df

Unnamed: 0,model,balanced_accuracy,accuracy,precision,recall,f1_score,correct_no,incorrect_no,correct_spur,incorrect_spur
0,pred_avg,90.41,91.91,80.05,87.57,83.64,93.25,6.75,87.57,12.43


In [64]:
plot_title = "Predictions for id_x1"
output_dir = "fold_0" 
plot_labels_and_predictions(df_id_x1, plot_title, output_dir)

KeyError: 'pred_0'