In [None]:

import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.load.*weights_only=False.*")

import time

import SJLT

import torch
from _dattri.utlis import compute_pairwise_distance_metrics, compute_pairwise_inner_product_rank_correlation
import matplotlib.pyplot as plt

import torch._dynamo
torch._dynamo.config.suppress_errors = True

# First, check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print("Using device:", device)

Using device: cuda


In [3]:
# projection_dims = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
projection_dims = [256, 512, 1024, 2048, 4096, 8192, 16384]
# sparsity_levels = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
sparsity_levels = [0, 0.1, 0.3, 0.5, 0.7, 0.9]

def test_random_projection_quality(projection_dims, sparsity_levels=sparsity_levels, mode="SJLT", activation_fn="relu", c=20, blow_up=1):
    # Store results for each sparsity level
    relative_errors_dist = {sparsity: [] for sparsity in sparsity_levels}
    spearman_correlations = {sparsity: [] for sparsity in sparsity_levels}
    times = {sparsity: [] for sparsity in sparsity_levels}
    overhead_times = {sparsity: [] for sparsity in sparsity_levels}
    computation_times = {sparsity: [] for sparsity in sparsity_levels}
    memory_usage = {sparsity: [] for sparsity in sparsity_levels}  # For memory usage

    for sparsity in sparsity_levels:
        print(f"\tSparsity: {sparsity}")
        # load the data from torch.save(grad_t, f"grad_t_{ckpt_idx}_{train_batch_idx}.pt") with ckpt_idx = 0 and train_batch_idx = 0 to 9
        for i in range(10):
            grad_t = torch.load(f"result/grad/{activation_fn}/grad_t_0_{i}.pt", map_location=device)
            if i == 0:
                batch_vec = grad_t
            else:
                batch_vec = torch.cat([batch_vec, grad_t], dim=0)

        batch_size, original_dim = batch_vec.size()

        # Sparsify the vectors by randomly setting a fraction (sparsity) of elements to zero
        num_elements_to_drop = int(sparsity * original_dim)
        if num_elements_to_drop > 0:
            for i in range(batch_size):
                indices_to_drop = torch.randperm(original_dim)[:num_elements_to_drop]
                batch_vec[i, indices_to_drop] = 0

        for proj_dim in projection_dims:
            torch.cuda.reset_peak_memory_stats()
            print(f"\t\tProjection dimension: {proj_dim}")

            # Start timing for overhead
            torch.cuda.synchronize()
            overhead_start_time = time.time()

            if mode in ["SJLT", "SJLT_batch"]:
                rand_indices = torch.randint(proj_dim * blow_up, (original_dim, c), device=device)
                rand_signs = torch.randint(0, 2, (original_dim, c), device=device) * 2 - 1
            elif mode == "SJLT_reverse":
                rand_indices = torch.randint(proj_dim, (original_dim, c), device=device)
                rand_signs = torch.randint(0, 2, (original_dim, c), device=device) * 2 - 1
                pos_indices, neg_indices = SJLT.backward_SJLT_indices(original_dim, proj_dim, c, device, rand_indices, rand_signs)
            else:
                proj_matrix = torch.randn(proj_dim, original_dim, device=batch_vec.device) / (proj_dim ** 0.5)

            # End timing for overhead
            torch.cuda.synchronize()
            overhead_end_time = time.time()

            # Start timing for the actual projection
            torch.cuda.synchronize()
            computation_start_time = time.time()

            if mode == "SJLT":
                batch_vec_p = SJLT.SJLT(batch_vec, proj_dim, rand_indices=rand_indices, rand_signs=rand_signs, c=20, blow_up=blow_up)
            elif mode == "SJLT_batch":
                batch_vec_p = SJLT.SJLT_batch(batch_vec, proj_dim, rand_indices=rand_indices, rand_signs=rand_signs, c=20, blow_up=blow_up)
            elif mode == "SJLT_reverse":
                batch_vec_p = SJLT.SJLT_reverse(batch_vec, proj_dim, pos_indices=pos_indices, neg_indices=neg_indices, c=20)
            else:
                batch_vec_p = batch_vec @ proj_matrix.T

            # End timing for the actual projection
            torch.cuda.synchronize()
            computation_end_time = time.time()

            # Calculate relative error for the current projection dimension
            relative_error = compute_pairwise_distance_metrics(batch_vec, batch_vec_p)
            relative_errors_dist[sparsity].append(relative_error)

            # Calculate Spearman rank correlation for inner products
            spearman_corr = compute_pairwise_inner_product_rank_correlation(batch_vec, batch_vec_p)
            spearman_correlations[sparsity].append(spearman_corr)

            # Record times
            overhead_time = overhead_end_time - overhead_start_time
            computation_time = computation_end_time - computation_start_time
            total_time = overhead_time + computation_time

            overhead_times[sparsity].append(overhead_time)
            computation_times[sparsity].append(computation_time)
            times[sparsity].append(total_time)

            peak_mem = torch.cuda.max_memory_allocated()
            memory_usage[sparsity].append(peak_mem / 1024**3)  # Store peak memory in GB

    # Updated plotting for six graphs
    fig, axes = plt.subplots(3, 2, figsize=(14, 18))  # Create a 3x2 grid for subplots

    # Flatten axes for easier indexing
    axes = axes.flatten()

    # Relative Errors
    ax = axes[0]
    for sparsity in sparsity_levels:
        ax.plot(projection_dims, relative_errors_dist[sparsity], marker='o', label=f'Sparsity: {sparsity}')
    ax.set_xticks(projection_dims)
    ax.set_title("Relative Error Between Pairwise Distances vs. Projection Dimension")
    ax.set_xlabel("Projection Dimension")
    ax.set_ylabel("Relative Error (Mean)")
    ax.legend()

    # Spearman Correlations
    ax = axes[1]
    for sparsity in sparsity_levels:
        ax.plot(projection_dims, spearman_correlations[sparsity], marker='o', label=f'Sparsity: {sparsity}')
    ax.set_xticks(projection_dims)
    ax.set_title("Spearman Rank Correlation of Pairwise Inner Products vs. Projection Dimension")
    ax.set_xlabel("Projection Dimension")
    ax.set_ylabel("Spearman Rank Correlation")
    ax.legend()

    # Overhead Times
    ax = axes[2]
    for sparsity in sparsity_levels:
        ax.plot(projection_dims, overhead_times[sparsity], marker='o', label=f'Sparsity: {sparsity}')
    ax.set_xticks(projection_dims)
    ax.set_title("Overhead Time vs. Projection Dimension")
    ax.set_xlabel("Projection Dimension")
    ax.set_ylabel("Overhead Time (seconds)")
    ax.legend()

    # Computation Times
    ax = axes[3]
    for sparsity in sparsity_levels:
        ax.plot(projection_dims, computation_times[sparsity], marker='o', label=f'Sparsity: {sparsity}')
    ax.set_xticks(projection_dims)
    ax.set_title("Computation Time vs. Projection Dimension")
    ax.set_xlabel("Projection Dimension")
    ax.set_ylabel("Computation Time (seconds)")
    ax.legend()

    # Total Computation Times
    ax = axes[4]
    for sparsity in sparsity_levels:
        ax.plot(projection_dims, times[sparsity], marker='o', label=f'Sparsity: {sparsity}')
    ax.set_xticks(projection_dims)
    ax.set_title("Total Time vs. Projection Dimension")
    ax.set_xlabel("Projection Dimension")
    ax.set_ylabel("Total Time (seconds)")
    ax.legend()

    # Memory Usage
    ax = axes[5]
    for sparsity in sparsity_levels:
        ax.plot(projection_dims, memory_usage[sparsity], marker='o', label=f'Sparsity: {sparsity}')
    ax.set_xticks(projection_dims)
    ax.set_title("Peak GPU Memory Usage vs. Projection Dimension")
    ax.set_xlabel("Projection Dimension")
    ax.set_ylabel("Peak Memory Usage (GB)")
    ax.legend()

    plt.tight_layout()
    plt.show()

In [None]:
for activation_fn in ["relu", "tanh", "sigmoid", "leaky_relu", "linear"]:
    print(f"Activation function: {activation_fn}")
    test_random_projection_quality(projection_dims, mode="SJLT", activation_fn=activation_fn, c=20, blow_up=1)
    test_random_projection_quality(projection_dims, mode="SJLT_batch", activation_fn=activation_fn, c=20, blow_up=1)
    test_random_projection_quality(projection_dims, mode="SJLT_reverse", activation_fn=activation_fn, c=20, blow_up=1)

Activation function: relu
	Sparsity: 0
		Projection dimension: 256
		Projection dimension: 512
