In [None]:
import torch
import pickle
import argparse
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

import matplotlib
matplotlib.rcParams.update({'font.size': 13})


def idx2pos(idx: np.array) -> np.array:
    '''Transform indexes to positions'''
    pos = list()
    for i in range(0, idx.shape[0], 2):
        pos.append(int(idx[i].split()[0]) - 1)  # -1: since the CLS token was removed
    return np.array(pos)


def lexical_replacement_distance(X: torch.tensor, pos: int) -> float:
    '''Return the cosine distance between a word and its replacement'''
    if pos >= X.size(0):  # Bounds check
        raise IndexError(f"Index {pos} is out of bounds for matrix size {X.size(0)}")
    return X[pos, pos].item()


def context_replacement_distances(X: torch.tensor, pos: int, n_words: int = 0) -> np.array:
    '''Compute context distances'''
    if n_words == 0:
        n_words = X.shape[0]
    cd = np.concatenate([X.diagonal()[:pos], X.diagonal()[pos + 1:]])
    return cd[max(pos - n_words, 0):min(pos + n_words, X.shape[0])]


def get_baselines(model, baseline_filename: str = 'random', n_layers: int = 4, n_words: int = 0):
    '''Compute baselines'''
    idx_filename = f'{model}/target_index/{baseline_filename}.npy'
    try:
        idx_target = idx2pos(np.load(idx_filename))
    except Exception as e:
        print(f"Error loading target index file: {e}")
        return [], []

    wd_baseline, cd_baseline = [], []
    for layer in range(1, n_layers + 1):
        path = f'{model}/cosine_distances/{layer}/{baseline_filename}.pkl'
        try:
            with open(path, mode='rb') as f:
                cd_matrix = pickle.load(f)

            cd_values = [
                np.mean(context_replacement_distances(cd_matrix[j], idx_target[j], n_words))
                for j in range(len(cd_matrix)) if j < len(idx_target)
            ]
            cd_baseline.append(np.mean(cd_values))

            wd_values = [
                lexical_replacement_distance(cd_matrix[j], idx_target[j])
                for j in range(len(cd_matrix)) if j < len(idx_target)
            ]
            wd_baseline.append(np.mean(wd_values))
        except Exception as e:
            print(f"Error processing layer {layer} for {baseline_filename}: {e}")
            cd_baseline.append(0)
            wd_baseline.append(0)

    return wd_baseline, cd_baseline


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Self-embedding Distance Plots')
    parser.add_argument('-l', '--layers', type=int, default=4, help='Number of layers to process')
    parser.add_argument('-c', '--context_words', type=int, default=5, help='Number of context words')
    parser.add_argument('-m', '--model', type=str, default='bert', help='Model name (e.g., bert, xlmr)')
    args = parser.parse_args()

    # Parameters
    n_layers = args.layers
    n_words = args.context_words
    model = args.model
    layers = range(1, n_layers + 1)

    # Baselines
    wd_baseline, cd_baseline = get_baselines(model, 'random', n_layers, n_words)

    # Define filenames and POS
    filenames = ['antonyms', 'synonyms', 'hypernyms', 'random', 'unknown']
    pos_tags = ['v', 'r', 'a', 'n']

    # Colors for plotting
    colors = {'antonyms': '#D55E00', 'synonyms': '#56B4E9', 'hypernyms': '#009E73', 'random': '#000000',
              'unknown': '#F0E442'}

    # Plot
    fig, axs = plt.subplots(1, len(pos_tags), figsize=(16, 4))
    metric = 'Self-embedding Distance'

    for filename in filenames:
        for i, pos in enumerate(pos_tags):
            values = []
            for layer in layers:
                rd_filename = f'{model}/metrics/{layer}/rd_{filename}_{pos}.npy'
                try:
                    rd_matrix = np.load(rd_filename, allow_pickle=True)
                    distances = [
                        lexical_replacement_distance(rd_matrix[j], j)
                        for j in range(len(rd_matrix)) if j < rd_matrix.shape[0]
                    ]
                    values.append(np.mean(distances))
                except FileNotFoundError:
                    print(f"File not found: {rd_filename}")
                except Exception as e:
                    print(f"Error processing {filename} {pos} for layer {layer}: {e}")
                    continue

            if values:
                axs[i].plot(layers, values, color=colors[filename], label=filename)
                axs[i].set_title(f'POS: {pos}')
                axs[i].set_xlabel('Layers')
                axs[i].set_ylabel(metric)

    plt.legend(loc='upper right')
    plt.tight_layout()
    plt.savefig(f'{model}_sed_plot.png')
    plt.show()
