In [None]:
# Imports
import sys
import os
sys.path.append(os.path.abspath('..'))

from loader import load_log_file
from plt_style import set_style
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
from sklearn.model_selection import train_test_split

from tqdm import tqdm


In [None]:
# CONFIG

FILES = [
    ("discrete", 28, "Parent"),
    ("aim", 21, "AIM"),
    # ("aim", 15, "Critic nudge"),
    # ("continuous", 25, "Basic"),
    ("none", 5, "No Communication")
]

set_style()

In [None]:
# Load data
dfs = [load_log_file(comm_type, seed) for comm_type, seed, _ in FILES]

df_plot = pd.concat([
    df.assign(Source=f"{label}: {comm_type} {seed}") for df, (comm_type, seed, label) in zip(dfs, FILES)
])

In [None]:
fig, ax = plt.subplots(figsize=(9, 5))

sns.lineplot(data=df_plot, x="timestep", y="reward_mean", hue="Source", ax=ax)

plt.title("Reward Mean Comparison")
plt.ylabel("Reward (0.0 - 0.15)")
plt.grid(True, alpha=0.2)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(9, 5))

sns.lineplot(data=df_plot, x="timestep", y="communication_perplexity_mean", hue="Source", ax=ax)

plt.title("Perplexity Comparison")
plt.ylabel("Perplexity")
plt.grid(True, alpha=0.2)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(9, 5))

sns.lineplot(data=df_plot, x="timestep", y="communication_entropy_mean", hue="Source", ax=ax)

plt.title("Communication Entropy Comparison")
plt.ylabel("Communication Entropy")
plt.grid(True, alpha=0.2)
plt.show()

In [None]:
from loader import load_comm_file


COMMUNICATION_TYPE = "aim"
SEED = 1
VOCAB_SIZE = 64

comm_logs = [load_comm_file(COMMUNICATION_TYPE, SEED, i) for i in range(VOCAB_SIZE)]

comms = []
for i, df in enumerate(comm_logs):
    df['vocab_index'] = i
    comms.append(df)

comm_df = pd.concat(comms, axis=0, ignore_index=True)

In [None]:
lengths = comm_df['vocab_index'].value_counts().sort_index().values

vocab_indices = [f"Token {i}" for i in range(len(lengths))]

df = pd.DataFrame({
    'length': lengths,
    'token': vocab_indices
})

df["length"] /= df["length"].max()

df = df.sort_values(by='length', ascending=False)

plt.figure(figsize=(10, 6))

ax = sns.barplot(
    x='length', 
    y='token', 
    data=df, 
)

plt.title('Usage Frequency per Token', fontsize=14)
plt.xlabel('Uses Normalised', fontsize=12)
plt.ylabel('Token Index', fontsize=12)

plt.tick_params(axis='y', labelsize=5)

plt.tight_layout()

plt.show()

In [None]:
x = comm_df.drop('vocab_index', axis=1)
y = comm_df['vocab_index']

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)

n_trees = 100
rf = RandomForestClassifier(n_estimators=0, warm_start=True, random_state=42, n_jobs=-1)

with tqdm(total=n_trees, desc="Building Trees") as pbar:
    step = 10
    for i in range(0, n_trees, step):
        rf.n_estimators += step
        rf.fit(x_train, y_train)
        pbar.update(step)

y_test_class = (y_test == 56)
x_test_class = x_test.loc[y_test_class].sample(n=2000)
y_test_class = y_test.loc[x_test_class.index]

result = permutation_importance(rf, x_test_class, y_test_class, n_repeats=5, n_jobs=-1)

class_5_importance = pd.Series(result.importances_mean, index=x.columns).sort_values(ascending=False)

top_features_df = class_5_importance.head(20).reset_index()
top_features_df.columns = ['feature_name', 'importance_score']

print(top_features_df)

plt.figure(figsize=(10, 6))

ax = sns.barplot(
    x='importance_score', 
    y='feature_name', 
    data=top_features_df,
)

plt.title(f'Feature Importance for Class 56 (Vocab Index)', fontsize=14)
plt.xlabel('Importance Score (Accuracy Drop)', fontsize=12)
plt.ylabel('Feature Name', fontsize=12)

plt.tick_params(axis='y', labelsize=5)

plt.tight_layout()
plt.show()