In [1]:
from collections import OrderedDict
import re
import os

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
from tqdm.notebook import tqdm
import numpy as np


current_directory = os.getcwd()
parent_directory = os.path.dirname(current_directory)
os.chdir(parent_directory)

from eval import get_run_metrics, read_run_dir, get_model_from_run
from plot_utils import basic_plot, collect_results, relevant_model_names

%matplotlib inline
%load_ext autoreload
%autoreload 2

sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')
# Enable latex in plot
# matplotlib.rcParams['text.usetex'] = True

run_dir = "../models"

In [2]:
df = read_run_dir(run_dir)
df  # list all the runs in our run_dir

Unnamed: 0,run_id,task,model,kwargs,num_tasks,num_examples,n_dims,n_layer,n_head,run_name
0,pretrained_complete,linear_regression,Transformer,,-1,-1,20,12,8,fix_linear_regression_standard


In [3]:
task = "linear_regression"

run_id = "pretrained_complete"  # if you train more models, replace with the run_id from the table above

run_path = os.path.join(run_dir, task, run_id)
recompute_metrics = False

if recompute_metrics:
    get_run_metrics(run_path)  # these are normally precomputed at the end of training

In [4]:
from samplers import get_data_sampler
from tasks import get_task_sampler

In [5]:
model, conf = get_model_from_run(run_path)

n_dims = conf.model.n_dims
batch_size = conf.training.batch_size
data_sampler = get_data_sampler(conf.training.data, n_dims)
task_sampler = get_task_sampler(
    conf.training.task,
    n_dims,
    batch_size,
    **conf.training.task_kwargs
)

task = task_sampler()
metric = task.get_metric()

In [6]:
def R_Square_Error(ys, pred):
    y_mean = torch.mean(ys)

    SS_tot = torch.sum((ys - y_mean) ** 2)
    SS_res = torch.sum((ys - pred) ** 2)

    R_square = 1 - SS_res / SS_tot
    return R_square

### Experiment 3 - Retrieving Similar Samples

####  Part 1 - Linear regression with best prompts -- COS

In [9]:
# sampling 1000
n_pairs = 1000
n_dims = conf.model.n_dims
batch_size = conf.training.batch_size # this is 64
# n_batches = 100
n_batches = 1 # We use n_batches = 1 here to speed up the experiment speed. We used n_batches = 100 for the experiment result in the paper
prompt_length = 76

# compute similarity
def similarity(x, x_test):
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    x = x.unsqueeze(0)
    x_test = x_test.unsqueeze(0)
    return cos(x, x_test)

# record the xs for later experiment
xs_list = []
ys_list = []

actual_points_worst_prompt = [[] for _ in range(prompt_length)]
predicted_points_worst_prompt = [[] for _ in range(prompt_length)]
for batch_idx in tqdm(range(n_batches)):
    # Sample 1101 points
    xs = data_sampler.sample_xs(b_size=batch_size, n_points=n_pairs+prompt_length)
    xs_list.append(xs)
    ys= task.evaluate(xs)
    ys_list.append(ys)
    batch_errors = []
    for j in range(1, prompt_length+1):
        prompt_xs = np.zeros((batch_size, prompt_length, n_dims))
        prompt_ys = np.zeros((batch_size, prompt_length))
        for batch_idx in range(batch_size):
            # test sample
            x_test, y_test = xs[batch_idx, -j, :], ys[batch_idx, -j]
            sims = torch.tensor([similarity(xs[batch_idx, i, :], x_test) for i in range(n_pairs)])
            selected_indices = torch.topk(sims, j-1, largest=False).indices
            prompt_xs[batch_idx, :j-1, :] = xs[batch_idx, selected_indices, :]
            prompt_ys[batch_idx, :j-1] = ys[batch_idx, selected_indices]
            prompt_xs[batch_idx, j-1, :] = x_test
            prompt_ys[batch_idx, j-1] = y_test
            
        prompt_xs = torch.from_numpy(prompt_xs).float()
        prompt_ys = torch.from_numpy(prompt_ys).float()
        with torch.no_grad():
            pred = model(prompt_xs, prompt_ys)
            
        actual_points_worst_prompt[j-1].extend(prompt_ys[:, j-1])
        predicted_points_worst_prompt[j-1].extend(pred[:, j-1])




  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
worst = []

for point_idx in range(prompt_length):
    actual = torch.tensor(actual_points_worst_prompt[point_idx])
    predicted = torch.tensor(predicted_points_worst_prompt[point_idx])
    R_square = R_Square_Error(actual, predicted)
    worst.append(R_square)


with open('./data/exp_3_worst.txt', 'w') as f:
    for value in worst:
        f.write(f"{value}\n")

In [None]:

actual_points_best_prompt = [[] for _ in range(prompt_length)]
predicted_points_best_prompt = [[] for _ in range(prompt_length)]
for batch_idx in tqdm(range(n_batches)):
    xs = xs_list[batch_idx]
    ys = ys_list[batch_idx]
    for j in range(1, prompt_length+1):
        prompt_xs = np.zeros((batch_size, prompt_length, n_dims))
        prompt_ys = np.zeros((batch_size, prompt_length))
        for batch_idx in range(batch_size):
            # test sample
            x_test, y_test = xs[batch_idx, -j, :], ys[batch_idx, -j]
            sims = torch.tensor([similarity(xs[batch_idx, i, :], x_test) for i in range(n_pairs)])
            selected_indices = torch.topk(sims, j-1, largest=True).indices
            prompt_xs[batch_idx, :j-1, :] = xs[batch_idx, selected_indices, :]
            prompt_ys[batch_idx, :j-1] = ys[batch_idx, selected_indices]
            prompt_xs[batch_idx, j-1, :] = x_test
            prompt_ys[batch_idx, j-1] = y_test
            
        prompt_xs = torch.from_numpy(prompt_xs).float()
        prompt_ys = torch.from_numpy(prompt_ys).float()
        with torch.no_grad():
            pred = model(prompt_xs, prompt_ys)
            
        actual_points_best_prompt[j-1].extend(prompt_ys[:, j-1])
        predicted_points_best_prompt[j-1].extend(pred[:, j-1])




  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
best = []

for point_idx in range(prompt_length):
    actual = torch.tensor(actual_points_best_prompt[point_idx])
    predicted = torch.tensor(predicted_points_best_prompt[point_idx])
    R_square = R_Square_Error(actual, predicted)
    best.append(R_square)
    

with open('./data/exp_3_best.txt', 'w') as f:
    for value in best:
        f.write(f"{value}\n")

In [None]:
with open('./data/base_line.txt', 'r') as file:
    base_line = [float(line.strip()) for line in file]
with open('./data/exp_1_w_1.txt', 'r') as file:
    w_1 = [float(line.strip()) for line in file]

In [None]:
prompt_length = 76
plt.figure(figsize=(10, 5), facecolor='none')
plt.plot(range(prompt_length), best, label="High Similarity", linewidth=3, color="darkred")
plt.plot(range(prompt_length), w_1, label="Normal Similarity", linewidth=3)
plt.plot(range(prompt_length), worst, label="Low Similarity", linewidth=3)
plt.plot(range(prompt_length), base_line, label="Least Squares", linewidth=3, color="grey")
x = [0, 25, 50, 75]
plt.xticks(x, fontsize=28)
plt.yticks(fontsize=28)
plt.axhline(1, ls="--", color="darkgrey")
plt.axvline(x=50, color='darkgrey', linestyle='--')  # Grey vertical line
plt.legend(loc='lower right', fontsize=17)
plt.xlabel('In-context Examples', fontsize=28)
plt.ylabel('R-Squared', fontsize=28)
plt.grid(color='lightgray', linestyle='-', linewidth=0.5)
# save fig
# plt.savefig("../pretrain_complete/test_4_3.pdf", bbox_inches = "tight", transparent=True)
plt.show()