In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
import glob
from sklearn.metrics.pairwise import cosine_similarity
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from scipy.stats import gaussian_kde

def load_embedding_file(filepath):
    """Load embeddings and metadata from file"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data

def load_negative_pairs_file(filepath):
    """Load negative pair distances from file"""
    with open(filepath, 'rb') as f:
        data = pickle.load(f)
    return data

def find_embedding_files(directory, pattern):
    """Find embedding files matching a pattern"""
    return sorted(glob.glob(os.path.join(directory, pattern)))

def find_negative_pairs_files(directory, pattern):
    """Find negative pairs files matching a pattern"""
    return sorted(glob.glob(os.path.join(directory, pattern)))

def compute_representation_drift(embeddings_list, reference_idx=-1):
    """
    Compute representation drift over time
    
    Args:
        embeddings_list: List of embeddings at different epochs
        reference_idx: Index of reference embeddings (default is last epoch)
    
    Returns:
        drift_scores: Dictionary with epoch numbers and cosine similarity scores
    """
    if reference_idx == -1:
        reference_idx = len(embeddings_list) - 1
    
    reference_embeddings = embeddings_list[reference_idx]['embeddings']
    drift_scores = {}
    
    for i, embedding_data in enumerate(embeddings_list):
        embeddings = embedding_data['embeddings']
        epoch = embedding_data['training_info']['epoch']
        
        # Ensure embeddings are of the same size
        min_samples = min(embeddings.shape[0], reference_embeddings.shape[0])
        embeddings = embeddings[:min_samples]
        ref_embeddings = reference_embeddings[:min_samples]
        
        # Compute cosine similarity
        similarities = np.zeros(min_samples)
        for j in range(min_samples):
            similarities[j] = cosine_similarity(
                embeddings[j].reshape(1, -1), 
                ref_embeddings[j].reshape(1, -1)
            )[0][0]
        
        drift_scores[epoch] = {
            'mean': np.mean(similarities),
            'std': np.std(similarities),
            'min': np.min(similarities),
            'max': np.max(similarities)
        }
    
    return drift_scores

def visualize_representation_drift(momentum_embeddings_list, basic_embeddings_list, output_dir):
    """
    Visualize representation drift comparison between momentum and basic encoders
    """
    momentum_drift = compute_representation_drift(momentum_embeddings_list)
    basic_drift = compute_representation_drift(basic_embeddings_list)
    
    # Sort by epoch for plotting
    momentum_epochs = sorted(momentum_drift.keys())
    momentum_means = [momentum_drift[e]['mean'] for e in momentum_epochs]
    momentum_stds = [momentum_drift[e]['std'] for e in momentum_epochs]
    
    basic_epochs = sorted(basic_drift.keys())
    basic_means = [basic_drift[e]['mean'] for e in basic_epochs]
    basic_stds = [basic_drift[e]['std'] for e in basic_epochs]
    
    # Create plot
    plt.figure(figsize=(12, 6))
    
    # Plot mean with error bands
    plt.plot(momentum_epochs, momentum_means, marker='o', label='GAN-CL with Momentum', color='blue')
    plt.fill_between(momentum_epochs, 
                    [m - s for m, s in zip(momentum_means, momentum_stds)],
                    [m + s for m, s in zip(momentum_means, momentum_stds)],
                    alpha=0.2, color='blue')
    
    plt.plot(basic_epochs, basic_means, marker='x', label='GAN-CL Basic', color='red')
    plt.fill_between(basic_epochs, 
                    [m - s for m, s in zip(basic_means, basic_stds)],
                    [m + s for m, s in zip(basic_means, basic_stds)],
                    alpha=0.2, color='red')
    
    plt.xlabel('Training Epoch')
    plt.ylabel('Cosine Similarity to Final Embeddings')
    plt.title('Temporal Analysis of Representation Drift')
    plt.legend()
    plt.grid(linestyle='--', alpha=0.7)
    plt.tight_layout()
    
    # Save figure
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, 'representation_drift.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, 'representation_drift.pdf'))
    plt.close()
    
    # Create temporal evolution visualization
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))
    
    # Create heatmap data
    momentum_data = np.zeros((len(momentum_epochs), len(momentum_epochs)))
    basic_data = np.zeros((len(basic_epochs), len(basic_epochs)))
    
    # Compute pairwise similarities between epochs
    for i, epoch_i in enumerate(momentum_epochs):
        for j, epoch_j in enumerate(momentum_epochs):
            if i <= j:  # Only compute upper triangle
                embeddings_i = momentum_embeddings_list[i]['embeddings']
                embeddings_j = momentum_embeddings_list[j]['embeddings']
                
                # Ensure same size
                min_samples = min(embeddings_i.shape[0], embeddings_j.shape[0])
                embeddings_i = embeddings_i[:min_samples]
                embeddings_j = embeddings_j[:min_samples]
                
                # Compute average similarity
                similarities = np.zeros(min_samples)
                for k in range(min_samples):
                    similarities[k] = cosine_similarity(
                        embeddings_i[k].reshape(1, -1), 
                        embeddings_j[k].reshape(1, -1)
                    )[0][0]
                
                momentum_data[i, j] = np.mean(similarities)
                momentum_data[j, i] = momentum_data[i, j]  # Symmetric
    
    # Same for basic
    for i, epoch_i in enumerate(basic_epochs):
        for j, epoch_j in enumerate(basic_epochs):
            if i <= j:
                embeddings_i = basic_embeddings_list[i]['embeddings']
                embeddings_j = basic_embeddings_list[j]['embeddings']
                
                min_samples = min(embeddings_i.shape[0], embeddings_j.shape[0])
                embeddings_i = embeddings_i[:min_samples]
                embeddings_j = embeddings_j[:min_samples]
                
                similarities = np.zeros(min_samples)
                for k in range(min_samples):
                    similarities[k] = cosine_similarity(
                        embeddings_i[k].reshape(1, -1), 
                        embeddings_j[k].reshape(1, -1)
                    )[0][0]
                
                basic_data[i, j] = np.mean(similarities)
                basic_data[j, i] = basic_data[i, j]
    
    # Create heatmaps
    sns.heatmap(momentum_data, annot=True, fmt=".2f", cmap="Blues", 
                xticklabels=momentum_epochs, yticklabels=momentum_epochs, 
                vmin=0, vmax=1, ax=axes[0])
    axes[0].set_title("GAN-CL with Momentum: Embedding Similarity Between Epochs")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Epoch")
    
    sns.heatmap(basic_data, annot=True, fmt=".2f", cmap="Reds", 
                xticklabels=basic_epochs, yticklabels=basic_epochs, 
                vmin=0, vmax=1, ax=axes[1])
    axes[1].set_title("GAN-CL Basic: Embedding Similarity Between Epochs")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Epoch")
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'epoch_similarity_heatmap.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, 'epoch_similarity_heatmap.pdf'))
    plt.close()

def analyze_negative_pairs(negative_pairs_files, title, color, output_dir, filename_prefix):
    """
    Analyze the evolution of negative pair distances over epochs
    """
    epochs = []
    mean_distances = []
    std_distances = []
    min_distances = []
    max_distances = []
    distance_distributions = []
    
    for file_path in negative_pairs_files:
        data = load_negative_pairs_file(file_path)
        negative_distances = data['negative_distances']
        epoch = data['epoch']
        
        epochs.append(epoch)
        mean_distances.append(np.mean(negative_distances))
        std_distances.append(np.std(negative_distances))
        min_distances.append(np.min(negative_distances))
        max_distances.append(np.max(negative_distances))
        distance_distributions.append(negative_distances)
    
    # Sort by epoch
    sorted_indices = np.argsort(epochs)
    epochs = [epochs[i] for i in sorted_indices]
    mean_distances = [mean_distances[i] for i in sorted_indices]
    std_distances = [std_distances[i] for i in sorted_indices]
    min_distances = [min_distances[i] for i in sorted_indices]
    max_distances = [max_distances[i] for i in sorted_indices]
    distance_distributions = [distance_distributions[i] for i in sorted_indices]
    
    # Create statistics plot
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, mean_distances, marker='o', label='Mean Distance', color=color)
    plt.fill_between(epochs, 
                    [m - s for m, s in zip(mean_distances, std_distances)],
                    [m + s for m, s in zip(mean_distances, std_distances)],
                    alpha=0.2, color=color)
    plt.plot(epochs, min_distances, linestyle='--', marker='x', label='Min Distance', color=color, alpha=0.5)
    plt.plot(epochs, max_distances, linestyle='--', marker='+', label='Max Distance', color=color, alpha=0.5)
    
    plt.xlabel('Training Epoch')
    plt.ylabel('Negative Pair Distance')
    plt.title(f'Evolution of Negative Pair Distances: {title}')
    plt.legend()
    plt.grid(linestyle='--', alpha=0.7)
    plt.tight_layout()
    
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, f'{filename_prefix}_negative_pair_statistics.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, f'{filename_prefix}_negative_pair_statistics.pdf'))
    plt.close()
    
    # Create distribution evolution plot
    plt.figure(figsize=(12, 8))
    
    # Select a subset of epochs for clearer visualization
    if len(epochs) > 5:
        indices = np.linspace(0, len(epochs)-1, 5, dtype=int)
        selected_epochs = [epochs[i] for i in indices]
        selected_distributions = [distance_distributions[i] for i in indices]
    else:
        selected_epochs = epochs
        selected_distributions = distance_distributions
    
    # Plot KDE for selected epochs
    for i, (epoch, distribution) in enumerate(zip(selected_epochs, selected_distributions)):
        # Create KDE
        distribution = distribution[~np.isnan(distribution)]  # Remove NaN values
        if len(distribution) > 0:
            kde = gaussian_kde(distribution)
            x = np.linspace(0, max(distribution), 1000)
            y = kde(x)
            
            # Plot with color gradient
            alpha = 0.4 + 0.6 * (i / (len(selected_epochs) - 1 or 1))
            plt.plot(x, y, label=f'Epoch {epoch}', color=color, alpha=alpha)
            plt.fill_between(x, 0, y, alpha=0.1, color=color)
    
    plt.xlabel('Negative Pair Distance')
    plt.ylabel('Density')
    plt.title(f'Evolution of Negative Pair Distance Distribution: {title}')
    plt.legend()
    plt.grid(linestyle='--', alpha=0.4)
    plt.tight_layout()
    
    plt.savefig(os.path.join(output_dir, f'{filename_prefix}_negative_pair_distribution.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, f'{filename_prefix}_negative_pair_distribution.pdf'))
    plt.close()

def compare_negative_pairs(momentum_files, basic_files, output_dir):
    """
    Compare negative pair distributions between momentum and basic approaches
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Analyze each model type individually
    analyze_negative_pairs(momentum_files, 'GAN-CL with Momentum', 'blue', output_dir, 'momentum')
    analyze_negative_pairs(basic_files, 'GAN-CL Basic', 'red', output_dir, 'basic')
    
    # Direct comparison at specific epochs
    # Find matching epochs between both approaches
    momentum_data = [load_negative_pairs_file(f) for f in momentum_files]
    basic_data = [load_negative_pairs_file(f) for f in basic_files]
    
    momentum_epochs = [d['epoch'] for d in momentum_data]
    basic_epochs = [d['epoch'] for d in basic_data]
    
    # Find common epochs
    common_epochs = sorted(set(momentum_epochs).intersection(set(basic_epochs)))
    
    if len(common_epochs) > 0:
        # Select a subset of common epochs for comparison
        if len(common_epochs) > 3:
            comparison_epochs = [common_epochs[0], common_epochs[len(common_epochs)//2], common_epochs[-1]]
        else:
            comparison_epochs = common_epochs
        
        # Create comparison plots
        for epoch in comparison_epochs:
            momentum_idx = momentum_epochs.index(epoch)
            basic_idx = basic_epochs.index(epoch)
            
            momentum_dist = momentum_data[momentum_idx]['negative_distances']
            basic_dist = basic_data[basic_idx]['negative_distances']
            
            plt.figure(figsize=(10, 6))
            
            # Plot histograms
            plt.hist(momentum_dist, bins=50, alpha=0.5, label='GAN-CL with Momentum', color='blue')
            plt.hist(basic_dist, bins=50, alpha=0.5, label='GAN-CL Basic', color='red')
            
            plt.xlabel('Negative Pair Distance')
            plt.ylabel('Frequency')
            plt.title(f'Comparison of Negative Pair Distributions at Epoch {epoch}')
            plt.legend()
            plt.grid(linestyle='--', alpha=0.4)
            plt.tight_layout()
            
            plt.savefig(os.path.join(output_dir, f'negative_pair_comparison_epoch_{epoch}.png'), dpi=300)
            plt.savefig(os.path.join(output_dir, f'negative_pair_comparison_epoch_{epoch}.pdf'))
            plt.close()
            
            # Also plot KDE comparison
            plt.figure(figsize=(10, 6))
            
            # Create KDEs
            momentum_dist = momentum_dist[~np.isnan(momentum_dist)]
            basic_dist = basic_dist[~np.isnan(basic_dist)]
            
            if len(momentum_dist) > 0 and len(basic_dist) > 0:
                momentum_kde = gaussian_kde(momentum_dist)
                basic_kde = gaussian_kde(basic_dist)
                
                x_min = min(np.min(momentum_dist), np.min(basic_dist))
                x_max = max(np.max(momentum_dist), np.max(basic_dist))
                x = np.linspace(x_min, x_max, 1000)
                
                plt.plot(x, momentum_kde(x), label='GAN-CL with Momentum', color='blue')
                plt.fill_between(x, 0, momentum_kde(x), alpha=0.2, color='blue')
                
                plt.plot(x, basic_kde(x), label='GAN-CL Basic', color='red')
                plt.fill_between(x, 0, basic_kde(x), alpha=0.2, color='red')
                
                plt.xlabel('Negative Pair Distance')
                plt.ylabel('Density')
                plt.title(f'Comparison of Negative Pair Distributions at Epoch {epoch}')
                plt.legend()
                plt.grid(linestyle='--', alpha=0.4)
                plt.tight_layout()
                
                plt.savefig(os.path.join(output_dir, f'negative_pair_kde_comparison_epoch_{epoch}.png'), dpi=300)
                plt.savefig(os.path.join(output_dir, f'negative_pair_kde_comparison_epoch_{epoch}.pdf'))
                plt.close()

def main():
    # Directory paths
    momentum_dir = './embeddings'  # GAN-CL with momentum encoder
    basic_dir = './embeddings_basic'  # GAN-CL without momentum encoder
    output_dir = './analysis_results'
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Find embedding files
    momentum_embedding_files = find_embedding_files(momentum_dir, 'epoch_*_embeddings_*.pkl')
    basic_embedding_files = find_embedding_files(basic_dir, 'epoch_*_embeddings_*.pkl')
    
    # Load all embeddings
    momentum_embeddings_list = [load_embedding_file(f) for f in momentum_embedding_files]
    basic_embeddings_list = [load_embedding_file(f) for f in basic_embedding_files]
    
    # Visualize representation drift
    print("Analyzing representation drift...")
    visualize_representation_drift(momentum_embeddings_list, basic_embeddings_list, output_dir)
    
    # Find negative pairs files
    momentum_negative_files = find_negative_pairs_files(momentum_dir, 'negative_pairs_epoch_*.pkl')
    basic_negative_files = find_negative_pairs_files(basic_dir, 'basic_negative_pairs_epoch_*.pkl')
    
    # Compare negative pairs
    print("Analyzing negative pairs...")
    compare_negative_pairs(momentum_negative_files, basic_negative_files, output_dir)
    
    print(f"Analysis complete. Results saved to {output_dir}")

if __name__ == "__main__":
    main()

Analyzing representation drift...
Analyzing negative pairs...
Analysis complete. Results saved to ./analysis_results
