In [12]:
import json
import os

from matplotlib import pyplot as plt 
from scipy.stats import bootstrap
import numpy as np
import pandas as pd
import glob

from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

In [13]:
START_DIR = os.path.join("./", "../Latent English")

# Plotting functions

## Helpers

In [14]:
plt.rcParams.update({
    'font.size': 16
})

plt_params = {'linewidth': 2.2}


In [15]:
def plot_ci_plus_heatmap(data, heat, labels, 
                         color='blue', 
                         linestyle='-',
                         tik_step=10, 
                         method='gaussian', 
                         do_lines=True, 
                         do_colorbar=False, 
                         shift=0.5, 
                         nums=[.99, 0.18, 0.025, 0.6],
                         labelpad=10,
                         plt_params=plt_params):
    
    fig, (ax, ax2) = plt.subplots(nrows=2, sharex=True, gridspec_kw={'height_ratios': [1, 10]}, figsize=(5, 3))
    if do_colorbar:
        fig.subplots_adjust(right=0.8) 
    plot_ci(ax2, data, labels, color=color, linestyle=linestyle, tik_step=tik_step, method=method, do_lines=do_lines, plt_params=plt_params)
    
    y = np.mean(heat, axis=0)
    x = np.arange(np.shape(y)[0])+1

    extent = [x[0]-(x[1]-x[0])/2. - shift, x[-1]+(x[1]-x[0])/2. + shift, 0, 1]
    img =ax.imshow(y[np.newaxis,:], cmap="plasma", aspect="auto", extent=extent, vmin=0, vmax=14)
    ax.set_yticks([])
    #ax.set_xlim(extent[0], extent[1])
    if do_colorbar:
        cbar_ax = fig.add_axes(nums)  # Adjust these values as needed
        cbar = plt.colorbar(img, cax=cbar_ax)
        cbar.set_label('entropy', rotation=90, labelpad=labelpad)  # Adjust label and properties as needed
    plt.tight_layout()
    return fig, ax, ax2


In [16]:
def plot_ci(ax, data, label, color='blue', linestyle='-', tik_step=10, method='gaussian', do_lines=True, plt_params=plt_params):
    if do_lines:
        upper = max(round(np.shape(data)[1]/10)*10+1, np.shape(data)[1]+1)
        ax.set_xticks(np.arange(0, upper, tik_step))
        for i in range(0, upper, tik_step):
            ax.axvline(i, color='black', linestyle='--', alpha=0.2, linewidth=1)
    if method == 'gaussian':
        mean = np.mean(data, axis=0)
        std = np.std(data, axis=0)
        data_ci = {
            'x' : np.arange(np.shape(data)[1])+1,
            'y' : mean,
            'y_upper' : mean + (1.96/(np.shape(data)[0]**0.5)) * std,
            'y_lower' : mean - (1.96/(np.shape(data)[0]**0.5)) * std,
        }
    elif method == 'np':
        data_ci = {
            'x' : np.arange(np.shape(data)[1])+1,
            'y' : np.quantile(data, 0.5, axis=0),
            'y_upper' : np.quantile(data, 0.95, axis=0),
            'y_lower' : np.quantile(data, 0.05, axis=0),
        }
    elif method == 'bootstrap':
        bootstrap_ci = bootstrap((data,), np.mean, confidence_level=0.95, method='percentile')
        data_ci = {
            'x' : np.arange(np.shape(data)[1])+1,
            'y' : np.mean(data, axis=0),
            'y_upper' : bootstrap_ci.confidence_interval.high,
            'y_lower' : bootstrap_ci.confidence_interval.low,
        }

    else:
        raise ValueError('method not implemented')

    df = pd.DataFrame(data_ci)
    # Create the line plot with confidence intervals
    ax.plot(df['x'], df['y'], label=label, color=color, linestyle=linestyle, **plt_params)
    ax.fill_between(df['x'], df['y_lower'], df['y_upper'], color=color, alpha=0.3)
    ax.spines[['right', 'top']].set_visible(False)

In [17]:
def read_json(path_name: str):
    with open(path_name, "r") as f:
        json_file = json.load(f)
    return json_file


## Plotting

In [18]:
def plot_probs(
        source_lang, 
        target_lang,
        backbone_lang,
        model_name,
        out_dir,
        
        latent_token_probs, 
        entropy,
        out_token_probs,
        **kwargs
):
    fig, ax, ax2 = plot_ci_plus_heatmap(
        latent_token_probs, 
        entropy, 
        backbone_lang, 
        color='tab:orange', 
        tik_step=8, 
        do_colorbar=True,
        nums=[.99, 0.18, 0.025, 0.6]
    )
    
    if target_lang != 'en':
        plot_ci(ax2, out_token_probs, target_lang, color='tab:blue', do_lines=False)
    
    # _prob_max_1 = np.mean(latent_token_probs, axis=0)
    # _prob_max_2 = np.mean(out_token_probs, axis=0)
    # _prob_max = np.mean(np.maximum(_prob_max_1, _prob_max_2), axis=0)
    # _prob_max = max(1.25*_prob_max, 0.4)
    # ylim_max = min(_prob_max, 1)
        
    ax2.set_xlabel('layer')
    ax2.set_ylabel('probability')
    ax2.set_xlim(0, round(np.shape(out_token_probs)[1]/10)*10+1)
    ax2.set_ylim(0, 1)
    ax2.legend(loc='upper left')
    
    save_dir = os.path.join(out_dir, model_name, "translation", "prob")
    file_name = f"{source_lang}-{backbone_lang}-{target_lang}.pdf"
    save_path = os.path.join(save_dir, file_name)
    
    os.makedirs(save_dir, exist_ok=True)    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

In [19]:
def plot_energy(
        source_lang, 
        target_lang,
        backbone_lang,
        model_name,
        out_dir,
        
        energy,
        out_token_probs,
        **kwargs
):
    fig, ax = plt.subplots(figsize=(5,3))
    plot_ci(ax, energy, 'energy', color='tab:green', do_lines=True, tik_step=5)
    
    ax.set_xlabel('layer')
    ax.set_ylabel('energy')
    ax.set_xlim(0, round(np.shape(out_token_probs)[1]/10)*10+1)
    
    save_dir = os.path.join(out_dir, model_name, "translation", "energy")
    file_name = f"{source_lang}-{target_lang}.pdf"
    save_path = os.path.join(save_dir, file_name)
    
    os.makedirs(save_dir, exist_ok=True)    
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()
    

# Visualize

In [20]:
def metadata_from_path(path: str):
    path = path.lower()
    config = {}
    if "llama3" in path:
        config['model_name'] = "llama3"
    elif "llama2" in path:
        config['model_name'] = "llama2"
    else:
        raise ValueError('model name not found')
    
    base_name = os.path.basename(path)
    base_name = base_name[:base_name.find(".json")]
    base_name = base_name.split(" ")
    config['source_lang'] = base_name[0]
    config['target_lang'] = base_name[2]
    
    for b in path.split(os.path.sep):
        if 'latent' in b:
            config['backbone_lang'] = b.split(" ")[1]
            break
    if 'backbone_lang' not in config:
        raise ValueError('backbone language not found')
    return config

In [21]:
paths = glob.glob(os.path.join(START_DIR, "./**", "*.json"), recursive=True)
paths = [str(p) for p in paths]

In [22]:
for path in tqdm(paths):
    metadata = metadata_from_path(path)
    data = read_json(path)
    if len(data["entropy"]) < 1:
        continue
    plot_probs(**metadata, **data, out_dir="../visuals/")
    plot_energy(**metadata, **data, out_dir="../visuals/")

100%|██████████| 371/371 [00:41<00:00,  9.04it/s]
