# Visualize trajectories in parameter space

## Setup

In [None]:
from typing import Dict, List
from pathlib import Path
import json

import math
import numpy as np
import torch
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA, TruncatedSVD, IncrementalPCA

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns

from tqdm.auto import tqdm

## Load trajectories

In [None]:
experiments_root = Path("../experiments").resolve()

In [None]:
dataset = "duplicated_0.5"
experiment = "parameters_evolution"
# experiment = "parameter_evolution_128"
# experiment = "parameter_evolution_256"

model = "bert_based"
# model = "lstm_based"

experiments_dir = experiments_root / f"{dataset}/{experiment}"

params_per_epoch = 2

In [None]:
!ls $experiments_dir

In [None]:
experiments_dirs = {
    subdir.name: subdir for subdir in experiments_dir.iterdir()
}

In [None]:
experiments_dirs

In [None]:
# def load_epoch_param(experiment_path: Path, epoch: int = 0) -> torch.Tensor:

#     # Find paths to .pt files
#     pt_files = [path for path in experiment_path.iterdir() if path.name.startswith("param") and path.suffix == ".pt"]
    
#     # Sort them in increasing order of epochs
#     pt_files.sort(key=lambda s: int(s.stem.split("_")[-1]))

#     # Load tensors from pt files
#     tensors = []

#     for path in pt_files:
#         if f"epoch_{epoch}" not in path.name:
#             continue
        
#         print(path)
#         t = torch.load(path)
        
#         if t.dim() < 2:
#             t = t.unsqueeze(0)
        
# #         if "logarithmic" in experiment_path.name and "0" in path.name:
# #             t = t[:-1, :]
        
#         tensors.append(t)

#     # Stack tensors
#     trajectory = torch.cat(tensors, dim=0)

#     return trajectory

In [None]:
def load_trajectory(experiment_path: Path) -> torch.Tensor:

    # Find paths to .pt files
    pt_files = [path for path in experiment_path.iterdir() if path.name.startswith("param") and path.suffix == ".pt"]
    
    # Sort them in increasing order of epochs
    pt_files.sort(key=lambda s: int(s.stem.split("_")[-1]))

    # Load tensors from pt files
    tensors = []

    for path in pt_files:
        print(path)
        t = torch.load(path)
        
        
        if t.dim() < 2:
            t = t.unsqueeze(0)
        
#         if "logarithmic" in experiment_path.name and "0" in path.name:
#             t = t[:-1, :]
    
        # Keep only params_per_epoch tensors
        increment = max(1, math.ceil(len(t) / params_per_epoch))
        for i, ti in enumerate(t):
            if i % increment == 0:
                tensors.append(ti)
        
        
    # Stack tensors
    trajectory = torch.stack(tensors, dim=0)

    return trajectory

In [None]:
trajectories: Dict[str, torch.Tensor] = dict()

for deduplicator, experiment_dir in experiments_dirs.items():
    print(deduplicator)
    
    trajectory = load_trajectory(experiment_dir / model)
    trajectories[deduplicator] = trajectory
    
    print()

In [None]:
for k, traj in trajectories.items():
    print(k)
    print(traj.shape)

In [None]:
baseline = trajectories['dummy']

deduplicators = [k for k in trajectories.keys() if k != "dummy"]

## Compute final distance

In [None]:
final_distances: Dict[str, float] = dict()

for deduplicator in deduplicators:
    trajectory = trajectories[deduplicator]
    
    final_vector = trajectory[-1]
    final_baseline_vector = baseline[-1]

    distance_from_baseline = torch.norm(final_vector - final_baseline_vector).item()

    final_distances[deduplicator] = distance_from_baseline

In [None]:
final_distances

In [None]:
final_distances_path = Path(f"./{experiment}/{model}/final_distances.json")

final_distances_path.parent.mkdir(parents=True, exist_ok=True)
final_distances_path.write_text(json.dumps(final_distances, indent=4))

## Compute trajectory distance

In [None]:
def compute_distance(trajectory_a: torch.Tensor, trajectory_b: torch.Tensor) -> float:
    """Computes the distance between two trajectories, intended as the mean point-to-point distance
    between corresponding points along the two trajectories.

    ASSUMPTION: the two trajectories contain the same number of points 
    ASSUMPTION: the space under consideration is Euclidean, or in any case Euclidean distance in meaningful

    Args:
        trajectory_a (torch.Tensor): float tensor of shape [num_points, dim]
        trajectory_b (torch.Tensor): float tensor of shape [num_points, dim]

    Returns:
        float: mean point-to-point distance
    """

    # [num_points,]
    distances = torch.nn.functional.pairwise_distance(trajectory_a, trajectory_b, p=2)

    # float
    distance = distances.mean().item()

    return distance

In [None]:
trajectory_distances: Dict[str, float] = dict()

for deduplicator in deduplicators:
    print(deduplicator)
    trajectory = trajectories[deduplicator]

    distance_from_baseline = compute_distance(
        trajectory_a=trajectory, 
        trajectory_b=baseline
    )

    trajectory_distances[deduplicator] = distance_from_baseline

In [None]:
trajectory_distances

In [None]:
trajectory_distances_path = Path(f"./{experiment}/{model}/trajectory_distances.json")

trajectory_distances_path.write_text(json.dumps(trajectory_distances, indent=4))

In [None]:
initial_distances: Dict[str, float] = dict()

for deduplicator in deduplicators:
    trajectory = trajectories[deduplicator]

    distance_from_baseline = torch.norm(baseline[0, :] - trajectory[0, :])

    initial_distances[deduplicator] = distance_from_baseline

In [None]:
initial_distances

## Project to 2D

In [None]:
tsne_seed = 0

In [None]:
num_points_per_deduplicator = baseline.shape[0]
tot_num_points = len(trajectories) * baseline.shape[0]

dim = baseline.shape[1]

X = torch.empty(tot_num_points, dim)

In [None]:
deduplicator_names: List[str] = ["dummy"] + deduplicators

In [None]:
inc = num_points_per_deduplicator
start = 0

for l in deduplicator_names:
    traj = trajectories[l]
    
    end = start + inc
    
    X[start:end, :] = traj
    
    start += inc

In [None]:
X.shape

In [None]:
# def cuPCA(A: torch.Tensor, n_components: int, random_state=0):
#     np.random.seed(random_state)
#     torch.manual_seed(random_state)
    
#     A = A.cuda()
    
#     U, S, V = torch.pca_lowrank(A, q=n_components)
    
#     proj = torch.matmul(A, V[:, :n_components])
    
#     proj = proj.cpu()
    
#     return proj

In [None]:
# X_ = PCA(n_components=min(50, X.shape[0]), random_state=tsne_seed).fit_transform(X)

In [None]:
# X_baseline = X[:baseline.shape[0], :].clone()
# X_baseline = X_baseline.to('cuda:4')

# n_components = min(50, X_baseline.shape[0])
# U, S, V = torch.pca_lowrank(X_baseline, q=n_components)
# V = V.cpu()

# X_ = torch.matmul(X_, V)

In [None]:
X_ = TruncatedSVD(n_components=min(50, X.shape[0]), random_state=tsne_seed).fit_transform(X)

In [None]:
# X_ = IncrementalPCA(n_components=min(50, X.shape[0]), batch_size=min(50, X.shape[0])).fit_transform(X)

In [None]:
X_path = f"./{experiment}/{model}/X_.pt"

In [None]:
torch.save(X_, X_path)

In [None]:
X_ = torch.load(X_path)

In [None]:
X_2d = dict()
perplexities = [5, 30, 50, 100]

In [None]:
X_2d['pca'] = PCA(n_components=2, random_state=tsne_seed).fit_transform(X_)

for perplexity in perplexities:
    tsne = TSNE(perplexity=perplexity, random_state=tsne_seed, init='pca', learning_rate='auto')
    X_2d[perplexity] = tsne.fit_transform(X_)

## Visualize

In [None]:
sns.set_style("dark")
# sns.reset_orig()

# colormap = cm.jet
colormap = cm.tab10

plt_params = {
    'quiver': {
        'width': 0.002,
        'headwidth': 3,
        'headlength': 5,
    },
    'scatter': {
        's_star': 2**8,
        's_quad': 2**7,
    }
}

tick_params = {
    'axis':'both',       
    'which':'both',      
    'bottom':False,      
    'top':False,       
    'bottom':False,       
    'left':False,       
    'right':False,         
    'labelbottom':False, 
    'labeltop':False, 
    'labelleft':False,
    'labelright':False, 
}

In [None]:
def draw_trajectories(points_2d, labels, path=None):
    
    fig, ax = plt.subplots(figsize=(16, 9))
    
    colors = [colormap(x) for x in np.linspace(0, 1, num=len(labels))]

    num_points = points_2d.shape[0]
    inc = num_points // len(labels)

    start = 0

    for i, label in enumerate(labels):
        end = start + inc

        xx = points_2d[start:end-1, 0]
        yy = points_2d[start:end-1, 1]

        xx_ = points_2d[start+1:end, 0]
        yy_ = points_2d[start+1:end, 1]

        uu = xx_ - xx
        vv = yy_ - yy

        start += inc

        ax.scatter(xx[0], yy[0], marker="D", color=colors[i], edgecolors='black', s=plt_params['scatter']['s_quad'])

        ax.quiver(xx, yy, uu, vv, label=label, color=colors[i], 
                    angles='xy', scale_units='xy', scale=1, 
                    width=plt_params['quiver']['width'], 
                    headwidth=plt_params['quiver']['headwidth'], 
                    headlength=plt_params['quiver']['headlength'],
                 )

        ax.scatter(xx_[-1], yy_[-1], marker="*", color=colors[i], edgecolors='black', s=plt_params['scatter']['s_star'])
        ax.tick_params(**tick_params)

    plt.legend()
    
    fig.tight_layout()
    
    if path:
        plt.savefig(path)
    
    plt.show()

### PCA only

In [None]:
points_2d = X_2d['pca']


draw_trajectories(X_2d['pca'], deduplicator_names, f"./{experiment}/{model}/pca.png")

### T-sne

In [None]:
perplexity = perplexities[0]

print("T-sne with perplexity:", perplexity)

points_2d = X_2d[perplexity]

draw_trajectories(X_2d[perplexity], deduplicator_names, f"./{experiment}/{model}/tsne_{perplexity}.png")

In [None]:
perplexity = perplexities[1]

print("T-sne with perplexity:", perplexity)

points_2d = X_2d[perplexity]

draw_trajectories(X_2d[perplexity], deduplicator_names, f"./{experiment}/{model}/tsne_{perplexity}.png")

In [None]:
perplexity = perplexities[2]

print("T-sne with perplexity:", perplexity)

points_2d = X_2d[perplexity]

draw_trajectories(X_2d[perplexity], deduplicator_names, f"./{experiment}/{model}/tsne_{perplexity}.png")

In [None]:
perplexity = perplexities[3]

print("T-sne with perplexity:", perplexity)

points_2d = X_2d[perplexity]

draw_trajectories(X_2d[perplexity], deduplicator_names, f"./{experiment}/{model}/tsne_{perplexity}.png")