In [None]:
import numpy as np
import torch
from numpy import linalg as LA

def eigenvalues_to_angles(eigenvalues):
    angles = []
    for eig in eigenvalues:
        angle = np.arctan2(np.imag(eig), np.real(eig))
        angles.append(angle)
    return angles

def to_degree_angles(angles):
    degree_angles = set()
    for angle in angles:
        angle = np.degrees(angle)
        degree_angles.add(abs(angle))
    return degree_angles

In [None]:
R_nli = torch.load("./saved_models_nli/rotation_matrix.bin")
w_nli, v_nli = LA.eig(R[:512, :512].detach().numpy())

R_arthmetic = torch.load("./saved_models_arithmetic/rotation_matrix.bin")
w_arthmetic, v_arthmetic = LA.eig(R_arthmetic.detach().numpy())

# equality task
R_equality = torch.load("./saved_models_equality/rotation_matrix.bin")
w_equality, v_equality = LA.eig(R_equality.cpu().detach().numpy())

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "DejaVu Serif"
font = {'family' : 'DejaVu Serif',
        'size'   : 12}
plt.rc('font', **font)
params = {'mathtext.default': 'regular' }          
plt.rcParams.update(params)

with plt.rc_context({
    'axes.edgecolor':'black', 'xtick.color':'black', 
    'ytick.color':'black', 'figure.facecolor':'white'
}):

    fig, axes = plt.subplots(2, 1, figsize=(5, 3.8))

    # Create the distribution plot
    _ = sns.histplot(
        to_degree_angles(eigenvalues_to_angles(w_equality)), 
        legend=False,
        ax=axes[0]
    )

    # Add a title and labels
    # ax.set_xlabel("Basis Vector Rotation Degree(s)", fontsize=14)
    axes[0].set_ylabel("Frequency", fontsize=14)

    axes[0].spines["top"].set_linewidth(2)
    axes[0].spines["bottom"].set_linewidth(2)
    axes[0].spines["left"].set_linewidth(2)
    axes[0].spines["right"].set_linewidth(2)
    axes[0].spines["top"].set_linewidth(2)
    axes[0].spines["bottom"].set_linewidth(2)
    axes[0].spines["left"].set_linewidth(2)
    axes[0].spines["right"].set_linewidth(2)
    axes[0].xaxis.grid(color='grey', linestyle='-.', linewidth=1, alpha=0.5)
    axes[0].yaxis.grid(color='grey', linestyle='-.', linewidth=1, alpha=0.5)
    
    axes[0].set_facecolor("white")
    
    axes[0].legend(labels=['Hierarchical Equality'],loc="lower right")
    
    # Create the distribution plot
    _ = sns.histplot(
        to_degree_angles(eigenvalues_to_angles(w_arthmetic)), 
        legend=False,
        ax=axes[1]
    )

    # Add a title and labels
    # ax.set_xlabel("Basis Vector Rotation Degree(s)", fontsize=14)
#     axes[1].set_ylabel("Frequency", fontsize=14)

#     axes[1].spines["top"].set_linewidth(2)
#     axes[1].spines["bottom"].set_linewidth(2)
#     axes[1].spines["left"].set_linewidth(2)
#     axes[1].spines["right"].set_linewidth(2)
#     axes[1].spines["top"].set_linewidth(2)
#     axes[1].spines["bottom"].set_linewidth(2)
#     axes[1].spines["left"].set_linewidth(2)
#     axes[1].spines["right"].set_linewidth(2)
#     axes[1].xaxis.grid(color='grey', linestyle='-.', linewidth=1, alpha=0.5)
#     axes[1].yaxis.grid(color='grey', linestyle='-.', linewidth=1, alpha=0.5)
    
#     axes[1].set_facecolor("white")
    
#     axes[1].legend(labels=['Arithmetic'],loc="lower right")
    
#     # Create the distribution plot
#     _ = sns.histplot(
#         to_degree_angles(eigenvalues_to_angles(w_nli)), 
#         legend=False,
#         ax=axes[2]
#     )

    # Add a title and labels
    # ax.set_xlabel("Basis Vector Rotation Degree(s)", fontsize=14)
    axes[1].set_ylabel("Frequency", fontsize=14)

    axes[1].spines["top"].set_linewidth(2)
    axes[1].spines["bottom"].set_linewidth(2)
    axes[1].spines["left"].set_linewidth(2)
    axes[1].spines["right"].set_linewidth(2)
    axes[1].spines["top"].set_linewidth(2)
    axes[1].spines["bottom"].set_linewidth(2)
    axes[1].spines["left"].set_linewidth(2)
    axes[1].spines["right"].set_linewidth(2)
    axes[1].xaxis.grid(color='grey', linestyle='-.', linewidth=1, alpha=0.5)
    axes[1].yaxis.grid(color='grey', linestyle='-.', linewidth=1, alpha=0.5)
    
    axes[1].set_facecolor("white")
    
    axes[1].legend(labels=['MoNLI'],loc="lower right")
    
    axes[1].set_xlabel("Basis Vector Rotation Degree(s)", fontsize=14)
    
    # Show the plot
    plt.tight_layout()
    # plt.show()
    
    plt.savefig(f"./fig/rotation-degree.png",dpi=200, bbox_inches='tight')