In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import glob
from sklearn.metrics.pairwise import cosine_similarity
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from scipy.stats import gaussian_kde
import matplotlib as mpl

# Set consistent styling
plt.style.use('seaborn-v0_8-whitegrid')
mpl.rcParams['font.family'] = 'Arial'
mpl.rcParams['axes.labelsize'] = 12
mpl.rcParams['xtick.labelsize'] = 10
mpl.rcParams['ytick.labelsize'] = 10
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['figure.titlesize'] = 14

def load_embedding_file(filepath):
    """Load embeddings and metadata from file"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data

def load_negative_pairs_file(filepath):
    """Load negative pair distances from file"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data

def find_embedding_files(directory, pattern):
    """Find embedding files matching a pattern"""
    return sorted(glob.glob(os.path.join(directory, pattern)))

def find_negative_pairs_files(directory, pattern):
    """Find negative pairs files matching a pattern"""
    return sorted(glob.glob(os.path.join(directory, pattern)))

def compute_representation_drift(embeddings_list, reference_idx=-1):
    """
    Compute representation drift over time
    
    Args:
        embeddings_list: List of embeddings at different epochs
        reference_idx: Index of reference embeddings (default is last epoch)
    
    Returns:
        drift_scores: Dictionary with epoch numbers and cosine similarity scores
    """
    if reference_idx == -1:
        reference_idx = len(embeddings_list) - 1
    
    reference_embeddings = embeddings_list[reference_idx]['embeddings']
    drift_scores = {}
    
    for i, embedding_data in enumerate(embeddings_list):
        embeddings = embedding_data['embeddings']
        epoch = embedding_data['training_info']['epoch']
        
        # Ensure embeddings are of the same size
        min_samples = min(embeddings.shape[0], reference_embeddings.shape[0])
        embeddings = embeddings[:min_samples]
        ref_embeddings = reference_embeddings[:min_samples]
        
        # Compute cosine similarity
        similarities = np.zeros(min_samples)
        for j in range(min_samples):
            similarities[j] = cosine_similarity(
                embeddings[j].reshape(1, -1), 
                ref_embeddings[j].reshape(1, -1)
            )[0][0]
        
        drift_scores[epoch] = {
            'mean': np.mean(similarities),
            'std': np.std(similarities),
            'min': np.min(similarities),
            'max': np.max(similarities)
        }
    
    return drift_scores

def visualize_combined_representation_drift(momentum_embeddings_list, basic_embeddings_list, output_dir):
    """
    Visualize representation drift comparison between momentum and basic encoders in a single plot
    """
    momentum_drift = compute_representation_drift(momentum_embeddings_list)
    basic_drift = compute_representation_drift(basic_embeddings_list)
    
    # Sort by epoch for plotting
    momentum_epochs = sorted(momentum_drift.keys())
    momentum_means = [momentum_drift[e]['mean'] for e in momentum_epochs]
    momentum_stds = [momentum_drift[e]['std'] for e in momentum_epochs]
    
    basic_epochs = sorted(basic_drift.keys())
    basic_means = [basic_drift[e]['mean'] for e in basic_epochs]
    basic_stds = [basic_drift[e]['std'] for e in basic_epochs]
    
    # Create plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot mean with error bands
    ax.plot(momentum_epochs, momentum_means, marker='o', label='GAN-CL with Momentum', color='#2E86C1', linewidth=2)
    ax.fill_between(momentum_epochs, 
                   [m - s for m, s in zip(momentum_means, momentum_stds)],
                   [m + s for m, s in zip(momentum_means, momentum_stds)],
                   alpha=0.2, color='#2E86C1')
    
    ax.plot(basic_epochs, basic_means, marker='x', label='GAN-CL Basic', color='#E74C3C', linewidth=2)
    ax.fill_between(basic_epochs, 
                   [m - s for m, s in zip(basic_means, basic_stds)],
                   [m + s for m, s in zip(basic_means, basic_stds)],
                   alpha=0.2, color='#E74C3C')
    
    ax.set_xlabel('Training Epoch', fontweight='bold')
    ax.set_ylabel('Cosine Similarity to Final Embeddings', fontweight='bold')
    ax.set_title('Temporal Analysis of Representation Drift', fontweight='bold')
    ax.legend(frameon=True, facecolor='white', framealpha=0.9)
    ax.grid(linestyle='--', alpha=0.7)
    
    # Add annotations
    if len(momentum_epochs) > 2 and len(basic_epochs) > 2:
        # Find the point with biggest difference
        diff_points = []
        common_epochs = set(momentum_epochs).intersection(set(basic_epochs))
        for e in common_epochs:
            m_idx = momentum_epochs.index(e)
            b_idx = basic_epochs.index(e)
            diff = momentum_means[m_idx] - basic_means[b_idx]
            diff_points.append((e, diff))
        
        max_diff_epoch = max(diff_points, key=lambda x: x[1])[0]
        m_idx = momentum_epochs.index(max_diff_epoch)
        b_idx = basic_epochs.index(max_diff_epoch)
        
        # Add annotation
        ax.annotate(f"Δ = {momentum_means[m_idx] - basic_means[b_idx]:.3f}",
                   xy=(max_diff_epoch, (momentum_means[m_idx] + basic_means[b_idx])/2),
                   xytext=(max_diff_epoch + 3, (momentum_means[m_idx] + basic_means[b_idx])/2),
                   arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2", color='black'),
                   fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    
    # Save figure
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, 'combined_representation_drift.png'), dpi=300, bbox_inches='tight')
    plt.savefig(os.path.join(output_dir, 'combined_representation_drift.pdf'), bbox_inches='tight')
    plt.close()
    
    # Create early vs late epochs comparison
    if len(momentum_epochs) >= 4 and len(basic_epochs) >= 4:
        # Select early, middle and late epochs for comparison
        early_ep_m = momentum_epochs[0]
        mid_ep_m = momentum_epochs[len(momentum_epochs)//2]
        late_ep_m = momentum_epochs[-1]
        
        early_ep_b = basic_epochs[0]
        mid_ep_b = basic_epochs[len(basic_epochs)//2]
        late_ep_b = basic_epochs[-1]
        
        # Extract embeddings for selected epochs
        early_emb_m = momentum_embeddings_list[momentum_epochs.index(early_ep_m)]['embeddings']
        mid_emb_m = momentum_embeddings_list[momentum_epochs.index(mid_ep_m)]['embeddings']
        late_emb_m = momentum_embeddings_list[momentum_epochs.index(late_ep_m)]['embeddings']
        
        early_emb_b = basic_embeddings_list[basic_epochs.index(early_ep_b)]['embeddings']
        mid_emb_b = basic_embeddings_list[basic_epochs.index(mid_ep_b)]['embeddings']
        late_emb_b = basic_embeddings_list[basic_epochs.index(late_ep_b)]['embeddings']
        
        # Create similarity matrix for epoch pairs
        fig, axes = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
        
        # Calculate similarity matrices
        def calculate_similarity_matrix(early, mid, late):
            # Truncate to same size
            min_samples = min(early.shape[0], mid.shape[0], late.shape[0])
            early = early[:min_samples]
            mid = mid[:min_samples]
            late = late[:min_samples]
            
            matrix = np.zeros((3, 3))
            
            # Calculate pairwise similarities
            for i, emb_i in enumerate([early, mid, late]):
                for j, emb_j in enumerate([early, mid, late]):
                    # Calculate cosine similarity for each pair
                    similarities = np.zeros(min_samples)
                    for k in range(min_samples):
                        similarities[k] = cosine_similarity(
                            emb_i[k].reshape(1, -1),
                            emb_j[k].reshape(1, -1)
                        )[0][0]
                    matrix[i, j] = np.mean(similarities)
            
            return matrix
        
        momentum_matrix = calculate_similarity_matrix(early_emb_m, mid_emb_m, late_emb_m)
        basic_matrix = calculate_similarity_matrix(early_emb_b, mid_emb_b, late_emb_b)
        
        # Plot heatmaps
        epoch_labels = ['Early', 'Middle', 'Late']
        sns.heatmap(momentum_matrix, annot=True, fmt=".3f", cmap="Blues", 
                   xticklabels=epoch_labels, yticklabels=epoch_labels, 
                   vmin=0, vmax=1, ax=axes[0])
        axes[0].set_title("GAN-CL with Momentum Encoder", fontweight='bold')
        axes[0].set_xlabel("Epoch", fontweight='bold')
        axes[0].set_ylabel("Epoch", fontweight='bold')
        
        sns.heatmap(basic_matrix, annot=True, fmt=".3f", cmap="Reds", 
                   xticklabels=epoch_labels, yticklabels=epoch_labels, 
                   vmin=0, vmax=1, ax=axes[1])
        axes[1].set_title("GAN-CL Basic", fontweight='bold')
        axes[1].set_xlabel("Epoch", fontweight='bold')
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'epoch_comparison_heatmap.png'), dpi=300, bbox_inches='tight')
        plt.savefig(os.path.join(output_dir, 'epoch_comparison_heatmap.pdf'), bbox_inches='tight')
        plt.close()

def visualize_combined_negative_pairs(momentum_files, basic_files, output_dir):
    """
    Compare negative pair distributions between momentum and basic approaches in combined visualizations
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Load all data
    momentum_data = [load_negative_pairs_file(f) for f in momentum_files]
    basic_data = [load_negative_pairs_file(f) for f in basic_files]
    
    momentum_epochs = [d['epoch'] for d in momentum_data]
    basic_epochs = [d['epoch'] for d in basic_data]
    
    # Sort data by epoch
    momentum_sorted_idx = np.argsort(momentum_epochs)
    basic_sorted_idx = np.argsort(basic_epochs)
    
    momentum_epochs = [momentum_epochs[i] for i in momentum_sorted_idx]
    basic_epochs = [basic_epochs[i] for i in basic_sorted_idx]
    
    momentum_distances = [momentum_data[i]['negative_distances'] for i in momentum_sorted_idx]
    basic_distances = [basic_data[i]['negative_distances'] for i in basic_sorted_idx]
    
    # Calculate statistics
    m_mean_distances = [np.mean(d) for d in momentum_distances]
    m_std_distances = [np.std(d) for d in momentum_distances]
    m_min_distances = [np.min(d) for d in momentum_distances]
    m_max_distances = [np.max(d) for d in momentum_distances]
    
    b_mean_distances = [np.mean(d) for d in basic_distances]
    b_std_distances = [np.std(d) for d in basic_distances]
    b_min_distances = [np.min(d) for d in basic_distances]
    b_max_distances = [np.max(d) for d in basic_distances]
    
    # Create combined statistics plot
    fig, ax = plt.subplots(figsize=(12, 6))

    # Define consistent colors
    momentum_color = '#2E86C1'  # Blue
    basic_color = '#E74C3C'      # Red

    # Plot means with error bands
    ax.plot(momentum_epochs, m_mean_distances, marker='o', label='GAN-CL with Momentum (Mean)', 
           color=momentum_color, linewidth=2.5)
    ax.fill_between(momentum_epochs, 
                    [m - s for m, s in zip(m_mean_distances, m_std_distances)],
                    [m + s for m, s in zip(m_mean_distances, m_std_distances)],
                    alpha=0.2, color=momentum_color)

    ax.plot(basic_epochs, b_mean_distances, marker='x', label='GAN-CL Basic (Mean)', 
           color=basic_color, linewidth=2.5)
    ax.fill_between(basic_epochs, 
                    [m - s for m, s in zip(b_mean_distances, b_std_distances)],
                    [m + s for m, s in zip(b_mean_distances, b_std_distances)],
                    alpha=0.2, color=basic_color)

    # Plot min/max as dashed lines with the same color scheme
    ax.plot(momentum_epochs, m_min_distances, linestyle='--', marker='.', label='Min (Momentum)', 
           color=momentum_color, alpha=0.7, linewidth=1.5)
    ax.plot(momentum_epochs, m_max_distances, linestyle='-.', marker='+', label='Max (Momentum)', 
           color=momentum_color, alpha=0.7, linewidth=1.5)

    ax.plot(basic_epochs, b_min_distances, linestyle='--', marker='.', label='Min (Basic)', 
           color=basic_color, alpha=0.7, linewidth=1.5)
    ax.plot(basic_epochs, b_max_distances, linestyle='-.', marker='+', label='Max (Basic)', 
           color=basic_color, alpha=0.7, linewidth=1.5)

    # Add shaded region showing the range between min and max
    ax.fill_between(momentum_epochs, m_min_distances, m_max_distances, 
                   color=momentum_color, alpha=0.1)
    ax.fill_between(basic_epochs, b_min_distances, b_max_distances, 
                   color=basic_color, alpha=0.1)

    # Add annotations for key differences
    if len(momentum_epochs) > 1 and len(basic_epochs) > 1:
        # Find common epochs
        common_epochs = sorted(set(momentum_epochs).intersection(set(basic_epochs)))
        if common_epochs:
            # Find epoch with biggest difference in means
            diff_means = []
            for e in common_epochs:
                m_idx = momentum_epochs.index(e)
                b_idx = basic_epochs.index(e)
                diff = m_mean_distances[m_idx] - b_mean_distances[b_idx]
                diff_means.append((e, diff))

            max_diff_epoch, max_diff = max(diff_means, key=lambda x: abs(x[1]))
            m_idx = momentum_epochs.index(max_diff_epoch)
            b_idx = basic_epochs.index(max_diff_epoch)

            # Add annotation for maximum difference
            ax.annotate(f"Δ = {max_diff:.3f}",
                       xy=(max_diff_epoch, (m_mean_distances[m_idx] + b_mean_distances[b_idx])/2),
                       xytext=(max_diff_epoch + 2, (m_mean_distances[m_idx] + b_mean_distances[b_idx])/2),
                       arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2", color='black'),
                       fontsize=10, fontweight='bold')

    ax.set_xlabel('Training Epoch', fontweight='bold')
    ax.set_ylabel('Negative Pair Distance', fontweight='bold')
    ax.set_title('Evolution of Negative Pair Distances', fontweight='bold')
    ax.legend(frameon=True, facecolor='white', framealpha=0.9, loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3)
    ax.grid(linestyle='--', alpha=0.7)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'combined_negative_pair_statistics.png'), dpi=300, bbox_inches='tight')
    plt.savefig(os.path.join(output_dir, 'combined_negative_pair_statistics.pdf'), bbox_inches='tight')
    plt.close()
    
    # Create distribution comparison at specific epochs
    # Find common epochs
    common_epochs = sorted(set(momentum_epochs).intersection(set(basic_epochs)))
    
    if len(common_epochs) > 0:
        # Select a subset of common epochs for comparison
        if len(common_epochs) > 3:
            # Early, middle, late epochs
            epoch_indices = [0, len(common_epochs)//2, -1]
            comparison_epochs = [common_epochs[i] for i in epoch_indices]
        else:
            comparison_epochs = common_epochs
        
        # Create multi-panel figure for epoch comparisons
        fig, axes = plt.subplots(1, len(comparison_epochs), figsize=(15, 5), sharey=True)
        if len(comparison_epochs) == 1:
            axes = [axes]  # Make it iterable if only one subplot
        
        for i, epoch in enumerate(comparison_epochs):
            m_idx = momentum_epochs.index(epoch)
            b_idx = basic_epochs.index(epoch)
            
            m_dist = momentum_distances[m_idx]
            b_dist = basic_distances[b_idx]
            
            # Remove NaN values
            m_dist = m_dist[~np.isnan(m_dist)]
            b_dist = b_dist[~np.isnan(b_dist)]
            
            if len(m_dist) > 0 and len(b_dist) > 0:
                # Compute KDEs
                x_min = min(np.min(m_dist), np.min(b_dist))
                x_max = max(np.max(m_dist), np.max(b_dist))
                x = np.linspace(x_min, x_max, 1000)
                
                # Use more robust KDE estimation
                try:
                    m_kde = gaussian_kde(m_dist)
                    b_kde = gaussian_kde(b_dist)
                    
                    # Plot KDEs
                    axes[i].plot(x, m_kde(x), label='GAN-CL with Momentum', color='#2E86C1', linewidth=2)
                    axes[i].fill_between(x, 0, m_kde(x), alpha=0.2, color='#2E86C1')
                    
                    axes[i].plot(x, b_kde(x), label='GAN-CL Basic', color='#E74C3C', linewidth=2)
                    axes[i].fill_between(x, 0, b_kde(x), alpha=0.2, color='#E74C3C')
                except:
                    # Fallback to histograms if KDE fails
                    axes[i].hist(m_dist, bins=30, alpha=0.5, label='GAN-CL with Momentum', color='#2E86C1', density=True)
                    axes[i].hist(b_dist, bins=30, alpha=0.5, label='GAN-CL Basic', color='#E74C3C', density=True)
                
                # Add statistics annotation
                axes[i].annotate(f"Mean (Momentum): {np.mean(m_dist):.3f}\nMean (Basic): {np.mean(b_dist):.3f}",
                                xy=(0.05, 0.95), xycoords='axes fraction',
                                bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8),
                                ha='left', va='top', fontsize=8)
                
                axes[i].set_title(f'Epoch {epoch}', fontweight='bold')
                axes[i].set_xlabel('Negative Pair Distance', fontweight='bold')
                axes[i].grid(linestyle='--', alpha=0.4)
                
                # Only add y-label to first subplot
                if i == 0:
                    axes[i].set_ylabel('Density', fontweight='bold')
                
                # Only add legend to last subplot
                if i == len(comparison_epochs) - 1:
                    axes[i].legend(frameon=True, facecolor='white', framealpha=0.9)
        
        plt.suptitle('Comparison of Negative Pair Distributions at Different Training Stages', fontweight='bold', fontsize=14)
        plt.tight_layout()
        
        plt.savefig(os.path.join(output_dir, 'negative_pair_comparison_by_epoch.png'), dpi=300, bbox_inches='tight')
        plt.savefig(os.path.join(output_dir, 'negative_pair_comparison_by_epoch.pdf'), bbox_inches='tight')
        plt.close()
        
        # Create 3D visualization of distribution evolution
        # Replace the existing 3D visualization code with this:
        try:
            from mpl_toolkits.mplot3d import Axes3D

            # For better 3D visualization, select a subset of epochs
            if len(momentum_epochs) > 4:
                m_selected_idx = np.linspace(0, len(momentum_epochs)-1, 4, dtype=int)
                m_selected_epochs = [momentum_epochs[i] for i in m_selected_idx]
                m_selected_dists = [momentum_distances[i] for i in m_selected_idx]
            else:
                m_selected_epochs = momentum_epochs
                m_selected_dists = momentum_distances

            if len(basic_epochs) > 4:
                b_selected_idx = np.linspace(0, len(basic_epochs)-1, 4, dtype=int)
                b_selected_epochs = [basic_epochs[i] for i in b_selected_idx]
                b_selected_dists = [basic_distances[i] for i in b_selected_idx]
            else:
                b_selected_epochs = basic_epochs
                b_selected_dists = basic_distances

            # Create figure with a single 3D plot
            fig = plt.figure(figsize=(12, 10))
            ax = fig.add_subplot(111, projection='3d')

            # Setup colors
            m_color = '#2E86C1'  # Blue for momentum
            b_color = '#E74C3C'  # Red for basic

            # Define offset to separate the two models on the y-axis
            offset = 5

            # Plot distributions for momentum encoder
            for i, (epoch, dist) in enumerate(zip(m_selected_epochs, m_selected_dists)):
                dist = dist[~np.isnan(dist)]
                if len(dist) > 0:
                    try:
                        kde = gaussian_kde(dist)
                        x = np.linspace(0, 2, 100)  # Assumed range for distances
                        y = kde(x)

                        # Plot KDE as a line at this epoch position
                        ax.plot(x, [epoch] * len(x), y, color=m_color, linewidth=2, alpha=0.8)

                        # Add a filled area under the curve
                        x_grid, z_grid = np.meshgrid(x, np.linspace(0, max(y), 20))
                        y_grid = np.ones_like(x_grid) * epoch
                        z_grid = np.minimum(z_grid, kde(x_grid[0]))

                        ax.plot_surface(x_grid, y_grid, z_grid, alpha=0.2, color=m_color, shade=False)

                        # Add label at the end of the line
                        ax.text(2.0, epoch, max(y), f'E{epoch} Mom', color=m_color, fontsize=8)
                    except:
                        continue

            # Plot distributions for basic encoder (with offset)
            for i, (epoch, dist) in enumerate(zip(b_selected_epochs, b_selected_dists)):
                dist = dist[~np.isnan(dist)]
                if len(dist) > 0:
                    try:
                        kde = gaussian_kde(dist)
                        x = np.linspace(0, 2, 100)  # Assumed range for distances
                        y = kde(x)

                        # Add offset to epoch for visualization separation
                        epoch_pos = epoch + offset

                        # Plot KDE as a line at this epoch position
                        ax.plot(x, [epoch_pos] * len(x), y, color=b_color, linewidth=2, alpha=0.8)

                        # Add a filled area under the curve
                        x_grid, z_grid = np.meshgrid(x, np.linspace(0, max(y), 20))
                        y_grid = np.ones_like(x_grid) * epoch_pos
                        z_grid = np.minimum(z_grid, kde(x_grid[0]))

                        ax.plot_surface(x_grid, y_grid, z_grid, alpha=0.2, color=b_color, shade=False)

                        # Add label at the end of the line
                        ax.text(2.0, epoch_pos, max(y), f'E{epoch} Basic', color=b_color, fontsize=8)
                    except:
                        continue

            # Add legend indicators
            ax.plot([0], [0], [0], color=m_color, linewidth=4, label='GAN-CL with Momentum')
            ax.plot([0], [0], [0], color=b_color, linewidth=4, label='GAN-CL Basic')

            # Set labels and title
            ax.set_xlabel('Negative Pair Distance', fontweight='bold')
            ax.set_ylabel('Training Stage', fontweight='bold')
            ax.set_zlabel('Density', fontweight='bold')
            ax.set_title('Evolution of Negative Pair Distributions', fontweight='bold', fontsize=14)

            # Create custom y-ticks to label both models
            max_epoch = max(max(momentum_epochs), max(basic_epochs))
            ax.set_yticks(np.arange(0, max_epoch + offset + 1, step=max(max_epoch//4, 1)))
            ax.set_yticklabels([f'Epoch {e}' if e <= max_epoch else f'Epoch {e-offset}' 
                               for e in np.arange(0, max_epoch + offset + 1, step=max(max_epoch//4, 1))])

            # Adjust view angle for better visualization
            ax.view_init(30, 45)
            ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=2)

            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, 'combined_3d_negative_pair_evolution.png'), dpi=300, bbox_inches='tight')
            plt.savefig(os.path.join(output_dir, 'combined_3d_negative_pair_evolution.pdf'), bbox_inches='tight')
            plt.close()
        except Exception as e:
            print(f"Could not create 3D visualization: {e}")

def find_latest_file(directory, pattern):
    """Find the latest file matching the pattern in the directory based on modification time."""
    files = glob.glob(os.path.join(directory, pattern))
    if not files:
        return None
    # Sort files by modification time (newest last)
    latest_file = max(files, key=os.path.getmtime)
    return latest_file

def find_embedding_files(directory, pattern):
    """Find all embedding files matching the pattern in the directory."""
    return glob.glob(os.path.join(directory, pattern))

def find_negative_pairs_files(directory, pattern):
    """Find all negative pairs files matching the pattern in the directory."""
    return glob.glob(os.path.join(directory, pattern))

def load_embedding_file(filepath):
    """Load embeddings from a pickle file."""
    import pickle
    with open(filepath, 'rb') as f:
        return pickle.load(f)            

def find_latest_files(directory, pattern, num_files=4):
    """Find the N most recent files matching the pattern in the directory based on modification time.
    
    Args:
        directory: Directory to search in
        pattern: File pattern to match
        num_files: Number of most recent files to return (default: 4)
    
    Returns:
        List of file paths sorted by modification time (newest first)
    """
    files = glob.glob(os.path.join(directory, pattern))
    if not files:
        return []
    
    # Sort files by modification time (newest first)
    latest_files = sorted(files, key=os.path.getmtime, reverse=True)
    
    # Return at most num_files
    return latest_files[:min(num_files, len(latest_files))]    
    
def main():
    # Directory paths
    momentum_dir = './embeddings'  # GAN-CL with momentum encoder
    basic_dir = './embeddings_basic'  # GAN-CL without momentum encoder
    output_dir = './analysis_results'
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # For representation drift, to use all files or a specific number
    # Option 1: Find all files (original behavior)
    # momentum_embedding_files = find_embedding_files(momentum_dir, 'epoch_*_embeddings_*.pkl')
    # basic_embedding_files = find_embedding_files(basic_dir, 'epoch_*_embeddings_*.pkl')
    
    # Option 2: Find only N most recent files 
    num_files_to_use = 5  
    momentum_embedding_files = find_latest_files(momentum_dir, 'epoch_*_embeddings_*.pkl', num_files_to_use)
    basic_embedding_files = find_latest_files(basic_dir, 'epoch_*_embeddings_*.pkl', num_files_to_use)
    
    print(f"Loading {len(momentum_embedding_files)} latest momentum embedding files")
    print(f"Loading {len(basic_embedding_files)} latest basic embedding files")
    
    # Load all selected embeddings
    momentum_embeddings_list = [load_embedding_file(f) for f in momentum_embedding_files]
    basic_embeddings_list = [load_embedding_file(f) for f in basic_embedding_files]
    
    # Visualize representation drift
    print("Analyzing representation drift...")
    visualize_combined_representation_drift(momentum_embeddings_list, basic_embeddings_list, output_dir)
    
    # Find negative pairs files - same approach with N most recent files
    momentum_negative_files = find_latest_files(momentum_dir, 'negative_pairs_epoch_*.pkl', num_files_to_use)
    basic_negative_files = find_latest_files(basic_dir, 'basic_negative_pairs_epoch_*.pkl', num_files_to_use)
    
    print(f"Loading {len(momentum_negative_files)} latest momentum negative pairs files")
    print(f"Loading {len(basic_negative_files)} latest basic negative pairs files")
    
    # Compare negative pairs
    print("Analyzing negative pairs...")
    visualize_combined_negative_pairs(momentum_negative_files, basic_negative_files, output_dir)
    
    print(f"Analysis complete. Results saved to {output_dir}")

if __name__ == "__main__":
    main()

Loading 5 latest momentum embedding files
Loading 5 latest basic embedding files
Analyzing representation drift...
Loading 5 latest momentum negative pairs files
Loading 5 latest basic negative pairs files
Analyzing negative pairs...


  plt.tight_layout()


Analysis complete. Results saved to ./analysis_results
