In [None]:
import os
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# Set the style for the plots
plt.style.use('seaborn')
sns.set_context("notebook", font_scale=1.5)

# Specify the path to your run directory
run_dir = "/home/jakub/projects/double_descend/double_descent_torch/runs/run_20241006_215645"

# Function to load results
def load_results(k):
    results_path = os.path.join(run_dir, "results", f"resnet18k_{k}_results.pt")
    return torch.load(results_path)

# Load all results
k_values = range(1, 11)  # Adjust this range based on your actual k values
all_results = {k: load_results(k) for k in k_values}

# Create a DataFrame for easier plotting
data = []
for k, results in all_results.items():
    for epoch, (train_loss, train_acc, test_loss, test_acc) in enumerate(zip(
        results['train_loss'], results['train_acc'], 
        results['test_loss'], results['test_acc']
    )):
        data.append({
            'k': k,
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'test_loss': test_loss,
            'test_acc': test_acc
        })

df = pd.DataFrame(data)

# Plot training and test accuracy for all k values
plt.figure(figsize=(12, 8))
for k in k_values:
    plt.plot(df[df['k'] == k]['epoch'], df[df['k'] == k]['train_acc'], label=f'Train k={k}')
    plt.plot(df[df['k'] == k]['epoch'], df[df['k'] == k]['test_acc'], label=f'Test k={k}', linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Test Accuracy for Different k Values')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

# Plot final test accuracy vs k
final_test_acc = [df[(df['k'] == k) & (df['epoch'] == df['epoch'].max())]['test_acc'].values[0] for k in k_values]
plt.figure(figsize=(10, 6))
plt.plot(k_values, final_test_acc, marker='o')
plt.xlabel('k')
plt.ylabel('Final Test Accuracy')
plt.title('Final Test Accuracy vs k')
plt.grid(True)
plt.show()

# Create heatmap of test accuracy
pivot_df = df.pivot(index='epoch', columns='k', values='test_acc')
plt.figure(figsize=(12, 8))
sns.heatmap(pivot_df, cmap='viridis')
plt.xlabel('k')
plt.ylabel('Epoch')
plt.title('Test Accuracy Heatmap')
plt.show()

# Plot training curves for specific k values
k_to_plot = [1, 5, 10]  # Adjust these values as needed
plt.figure(figsize=(12, 8))
for k in k_to_plot:
    plt.plot(df[df['k'] == k]['epoch'], df[df['k'] == k]['train_loss'], label=f'Train Loss k={k}')
    plt.plot(df[df['k'] == k]['epoch'], df[df['k'] == k]['test_loss'], label=f'Test Loss k={k}', linestyle='--')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss for Selected k Values')
plt.legend()
plt.show()