## this code plots the averaged number of adjective per layer

In [None]:
import os
os.chdir('..')
import numpy as np
import json
import matplotlib.pyplot as plt
from collections import defaultdict

## data loading and parameter setup

In [None]:
target_names = ['beats-esc50-frozen', 'beats-esc50-finetuned', 'ast-esc50']
datas = []
for name in target_names:
    in_json = f'summaries/calibration_{name}_esc50_esc50_top5_processed.json'
    with open(in_json, 'r') as f:
        data = json.load(f)
        datas.append(data)

## get the averaged number of adjective per layer

In [None]:
model_layers = []
for model_data in datas:
    layers = defaultdict(list)
    for k, item in model_data.items():
        if 'fc' in k:
            layer_id = '11'
        else:
            layer_id = k.split('_')[0].strip('layer')
        
        adjs = item['adj_after_rbf_llmf']
        layers[layer_id].append(len(adjs))

    layers = dict({int(k): sum(v)/len(v) for k, v in layers.items()})
    layers = dict(sorted(layers.items()))
    model_layers.append(layers)

# beats-frozen
a = np.array(list(model_layers[0].values()))
# beats-finetuned
b = np.array(list(model_layers[1].values()))
# ast
c = np.array(list(model_layers[2].values()))

## plotting

In [None]:
x = range(1, 13)

# Fit linear regression models
coeff_a = np.polyfit(x, a, 1)
coeff_b = np.polyfit(x, b, 1)
coeff_c = np.polyfit(x, c, 1)

# Predict y values
y_pred_a = np.polyval(coeff_a, x)
y_pred_b = np.polyval(coeff_b, x)
y_pred_c = np.polyval(coeff_c, x)

# Plot the data points and regression lines
plt.rcParams.update({'font.size': 13})
plt.scatter(x, c, label='AST', marker='^')
plt.scatter(x, b, label='BEATs-finetuned', marker='o')
plt.scatter(x, a, label='BEATs-frozen', marker='s')
plt.plot(x, y_pred_c, linewidth=2)
plt.plot(x, y_pred_b, linewidth=2)
plt.plot(x, y_pred_a, linewidth=2)

plt.xlabel('transformer layer', fontsize=20)
plt.ylabel('avg. number of adjectives', fontsize=20)
plt.legend()
plt.grid(True)
plt.savefig('adj_num_per_layer_regression.jpg', format='jpg', dpi=1000)
plt.show()