# Procrustes Distance Analysis for Barlow Twins Training

This notebook analyzes Procrustes distances during Barlow Twins training, recreating Figure 4 from the paper.

**Based on:** "A Theoretical Characterization of Optimal Data Augmentations in Self-Supervised Learning" (arXiv:2411.01767v3)

## Important Notes
- Run `scripts/run_experiment.py` first to generate the results files.
- The experiment now uses mini-batch training (batch_size=256) for stability.
- Checkpoint system available with `--resume` flag to resume interrupted experiments.
- Only RBF kernels used for numerical stability.

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sys.path.insert(0, os.path.join(os.getcwd(), '..'))

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("Libraries imported")

## Load Experimental Results

In [None]:
# Load Procrustes distances from experiment
procrustes_to_target = np.load('../results/procrustes_to_target.npy')
procrustes_to_random = np.load('../results/procrustes_to_random.npy')

print(f"Loaded {len(procrustes_to_target)} iterations of training data")
print(f"Final distance to target: {procrustes_to_target[-1]:.4f}")
print(f"Final distance to random: {procrustes_to_random[-1]:.4f}")

## Plot Procrustes Distance Over Training

In [None]:
plt.figure(figsize=(14, 7))

plt.plot(procrustes_to_target, linewidth=2.5, label='Distance to Target', color='#2E86AB')
plt.plot(procrustes_to_random, linewidth=2.5, label='Distance to Random', color='#A23B72', linestyle='--')

plt.xlabel('Training Iteration', fontsize=14, fontweight='bold')
plt.ylabel('Average Procrustes Distance', fontsize=14, fontweight='bold')
plt.title('Procrustes Distance During Barlow Twins Training\n(Recreating Figure 4 from Paper)', 
          fontsize=16, fontweight='bold')
plt.legend(fontsize=12, loc='upper right')
plt.grid(True, alpha=0.3)
plt.tight_layout()

plt.savefig('../results/plots/procrustes_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print("Plot saved to ../results/plots/procrustes_analysis.png")