# Random forests on embeddings

This script reads all embedding csvs in the folder_path, computes random forests. 

Fix
- These have a random 0.8:0.2 training split, this is currently not the same random split as the fullt rained resnets, so fix this.
- With both cases, could maybe do a more comprehensive sweep of the random splits, e.g 5 fold cross-val to get error bars

In [1]:
import os
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from datetime import datetime

In [None]:
# Path to the folder containing the CSV files
folder_path = '/home/ben/reef-audio-representation-learning/code/simclr-pytorch-reefs/evaluation/embeddings/raw_embeddings'

In [2]:
# Function to calculate metrics
def calculate_metrics(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    return accuracy, precision, recall, f1

# Initialize an empty DataFrame to store metrics
columns = ['Filename', 'Test Accuracy', 'Test Precision', 'Test Recall', 'Test F1',
           'Train Accuracy', 'Train Precision', 'Train Recall', 'Train F1']
results_df = pd.DataFrame(columns=columns)

# Path to the folder containing the CSV files
folder_path = '/home/ben/reef-audio-representation-learning/code/simclr-pytorch-reefs/evaluation/embeddings/raw_embeddings'

# Loop through each file in the folder
for filename in os.listdir(folder_path):
    # Check if the file is a CSV file
    if filename.endswith('.csv'):
        # Full path to the file
        filepath = os.path.join(folder_path, filename)
        
        # Read the CSV file into a DataFrame
        df = pd.read_csv(filepath)
        
        # Extract features and labels
        X = df.drop(columns=['Label'])
        y = df['Label']
        
        # Split the data into training and testing sets (80:20 ratio)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0, stratify=y)
        
        # Initialize and train the Random Forest Classifier
        clf = RandomForestClassifier(random_state=0)
        clf.fit(X_train, y_train)
        
        # Make predictions on test set
        y_pred_test = clf.predict(X_test)
        # Make predictions on training set
        y_pred_train = clf.predict(X_train)
        
        # Calculate metrics for test set
        accuracy_test, precision_test, recall_test, f1_test = calculate_metrics(y_test, y_pred_test)
        # Calculate metrics for training set
        accuracy_train, precision_train, recall_train, f1_train = calculate_metrics(y_train, y_pred_train)
        
        # Create a DataFrame for the new row and concatenate it to the existing DataFrame
        new_row_df = pd.DataFrame({
            'Filename': [filename],
            'Test Accuracy': [accuracy_test],
            'Test Precision': [precision_test],
            'Test Recall': [recall_test],
            'Test F1': [f1_test],
            'Train Accuracy': [accuracy_train],
            'Train Precision': [precision_train],
            'Train Recall': [recall_train],
            'Train F1': [f1_train]
        })
        results_df = pd.concat([results_df, new_row_df], ignore_index=True)
        
        # Print metrics
        print(f"Results for {filename}:")
        print("--- Test Metrics ---")
        print(f"Accuracy: {accuracy_test}")
        print(f"Precision: {precision_test}")
        print(f"Recall: {recall_test}")
        print(f"F1 Score: {f1_test}")
        print("--- Training Metrics ---")
        print(f"Accuracy: {accuracy_train}")
        print(f"Precision: {precision_train}")
        print(f"Recall: {recall_train}")
        print(f"F1 Score: {f1_train}")
        print("-" * 40)
        

# Generate a timestamp
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")

# Save the DataFrame to a CSV file with a timestamp in the filename
results_df.to_csv(f"RF_results/RF_results-{current_time}.csv", index=False)


Results for ImageNet-kenya-embeddings.csv:
--- Test Metrics ---
Accuracy: 0.7058823529411765
Precision: 0.6900452488687784
Recall: 0.7058823529411765
F1 Score: 0.6954248366013073
--- Training Metrics ---
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1 Score: 1.0
----------------------------------------
Results for ImageNet-australia-embeddings.csv:
--- Test Metrics ---
Accuracy: 0.7425
Precision: 0.7435289096432427
Recall: 0.7425
F1 Score: 0.74222772803774
--- Training Metrics ---
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1 Score: 1.0
----------------------------------------
Results for ImageNet-florida-embeddings.csv:
--- Test Metrics ---
Accuracy: 0.9099009900990099
Precision: 0.9090399177126222
Recall: 0.9099009900990099
F1 Score: 0.9092349986300611
--- Training Metrics ---
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1 Score: 1.0
----------------------------------------
Results for ImageNet-french_polynesia-embeddings.csv:
--- Test Metrics ---
Accuracy: 0.967706013363029
Precision: 0.9677

In [3]:
results_df

Unnamed: 0,Filename,Test Accuracy,Test Precision,Test Recall,Test F1,Train Accuracy,Train Precision,Train Recall,Train F1
0,ImageNet-kenya-embeddings.csv,0.705882,0.690045,0.705882,0.695425,1.0,1.0,1.0,1.0
1,ImageNet-australia-embeddings.csv,0.7425,0.743529,0.7425,0.742228,1.0,1.0,1.0,1.0
2,ImageNet-florida-embeddings.csv,0.909901,0.90904,0.909901,0.909235,1.0,1.0,1.0,1.0
3,ImageNet-french_polynesia-embeddings.csv,0.967706,0.967708,0.967706,0.967706,1.0,1.0,1.0,1.0
4,ImageNet-indonesia-embeddings.csv,0.972569,0.973675,0.972569,0.972995,1.0,1.0,1.0,1.0
5,ReefCLR-indonesia_embeddings.csv,0.967581,0.96988,0.967581,0.968399,1.0,1.0,1.0,1.0
6,ImageNet-bermuda-embeddings.csv,0.505682,0.436883,0.505682,0.457595,0.96162,0.959761,0.96162,0.959583
7,ReefCLR-australia_embeddings.csv,0.6625,0.662555,0.6625,0.662472,1.0,1.0,1.0,1.0
8,ReefCLR-bermuda_embeddings.csv,0.616477,0.556576,0.616477,0.584562,0.958778,0.957328,0.958778,0.957746
9,ReefCLR-kenya_embeddings.csv,0.860606,0.85999,0.860606,0.855597,1.0,1.0,1.0,1.0
