In [92]:
import torch
import sys
sys.path.append('../scripts')
from tqdm import tqdm
import numpy as np



In [31]:
baseline_run_ids = {
    0: "1118045041-openml_baseline_test_id_{0}-0aa4",
    1: "1118045041-openml_baseline_test_id_{1}-f03d",
    2: "1118045041-openml_baseline_test_id_{2}-6fdc",
    3: "1118045041-openml_baseline_test_id_{3}-5443",
    4: "1118045041-openml_baseline_test_id_{4}-a369",
    5: "1118045041-openml_baseline_test_id_{5}-d0ff",
    6: "1118045041-openml_baseline_test_id_{6}-8613",
    7: "1118045041-openml_baseline_test_id_{7}-8c0e",
    8: "1118045041-openml_baseline_test_id_{8}-9e12",
    9: "1118045041-openml_baseline_test_id_{9}-0770",
}

loop_run_ids = {
    0: "1118045041-openml_loop_test_id_{0}-8fd3",
    1: "1118045041-openml_loop_test_id_{1}-ee8d",
    2: "1118045041-openml_loop_test_id_{2}-179a",
    3: "1118045041-openml_loop_test_id_{3}-ec09",
    4: "1118045041-openml_loop_test_id_{4}-6198",
    5: "1118045041-openml_loop_test_id_{5}-9c58",
    6: "1118045041-openml_loop_test_id_{6}-408d",
    7: "1118045041-openml_loop_test_id_{7}-72e2",
    8: "1118045041-openml_loop_test_id_{8}-1d55",
    9: "1118045041-openml_loop_test_id_{9}-e971",
}


# Now calculate the correct error

In [140]:
# Post-training evaluation
import random
NUM_POINTS = 41


def post_train_eval(openml_datasets_test, test_dataset_id, n_dims, model, device=torch.device('cuda:0'), which_model='gpt2'):
    with torch.no_grad():
        X, y = openml_datasets_test[test_dataset_id]['X'], openml_datasets_test[test_dataset_id]['y']
        test_loss_list = []
        for idx_n in tqdm(range(min(X.shape[0], 5000))):
            in_context_list = list(range(0, X.shape[0]))
            in_context_list.remove(idx_n)
            batch_ids = random.sample(in_context_list, NUM_POINTS - 1)
            xs, ys = X[batch_ids], y[batch_ids]
            xs, ys = torch.tensor(xs).to(device), torch.tensor(ys).to(device)
            xs_test, ys_test = X[[idx_n]], y[[idx_n]]  # [1, d], [1]
            xs_test, ys_test = torch.tensor(xs_test).to(device), torch.tensor(ys_test).to(device)
            xs = torch.cat([xs, xs_test], dim=0)
            ys = torch.cat([ys, ys_test], dim=0)

            xs = xs.reshape(1, NUM_POINTS, -1)
            B, n, d_x = xs.shape
            xs = torch.cat(
                [
                    torch.zeros(B, n, n_dims - d_x, device=device),
                    xs,
                ],
                axis=2,
            )  # xs.shape should be [B, n, d] now
            ys = ys.view(B, n)
            xs, ys = xs.float(), ys.float()

            _, acc = train_step(which_model, model, xs, ys)
            test_loss_list.append(acc)
        return test_loss_list

In [89]:
def train_step(which_model, model, xs, ys):
    if which_model in ['gpt2', 'gpt2_tying']:  # , 'gpt2_tying'
        B, n = ys.shape

        y_pred = model(xs, ys)  # [B, n]
        pred = y_pred.view(B * n, -1).data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        acc = pred.eq(ys.view(B * n).data.view_as(pred)).cpu().view(B, n)[:, -1].sum().item() / (B)
        loss = 0
    elif which_model in ['gpt2_loop']:  # , 'gpt2_neumann', 'gpt2_neumann_input'
        n_loops = 30  # K
        B, n = ys.shape
        horizon_start = 0
        y_pred_list = model(xs, ys, horizon_start, n_loops)
        y_pred_arr = y_pred_list[-1]  # torch.cat(y_pred_list, dim=0)  # [B * K, n]
        # y_star_arr = torch.cat([ys] * len(y_pred_list), dim=0)  # [B * K, n]
        # BK, n = y_star_arr.shape
        # pred = y_pred_arr.view(BK * n, -1).data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        pred = y_pred_arr.view(B * n, -1).data.max(1, keepdim=True)[1]  # get the index of the max log-probability
        # acc = pred.eq(y_star_arr.view(BK * n).data.view_as(pred)).cpu().view(-1, B, n)[-1, :, -1].sum().item() / (B)
        acc = pred.eq(ys.view(B * n).data.view_as(pred)).cpu().view(-1, B, n)[-1, :, -1].sum().item() / (B)
        loss = 0
    return loss, acc  # , total_norm, norm_dict

In [86]:
root = '../data/'
import pickle
openml_datasets_train = pickle.load(open(root + 'openml_train2.npy', 'rb'))
openml_datasets_test = pickle.load(open(root + 'openml_test2.npy', 'rb'))
dataset_id = list(openml_datasets_train.keys())

In [184]:
from models import TransformerModel

n_dims = 20
n_positions = 101
n_embd = 256
n_layer = 12
n_head = 8
device = torch.device('cuda:0')

model = TransformerModel(n_dims, n_positions, n_embd, n_layer, n_head, pred_type='classification')
step = -1
# baseline_loss_list = {}

# for i in baseline_run_ids.keys():
for i in [1]:
    run_id = baseline_run_ids[i]
    state_dict = torch.load("../results2/openml_baseline/" + run_id + "/state.pt", map_location='cpu')['model_state_dict']
    model.load_state_dict(state_dict, strict=True)
    model = model.to(device)
    baseline_loss_list[i] = []
    for j in range(5):
        seed = 4242 + j
        torch.manual_seed(seed)
        random.seed(seed)
        test_acc = post_train_eval(openml_datasets_test, dataset_id[i], n_dims, model, device)
        print(i, np.mean(test_acc))
        baseline_loss_list[i].append(np.mean(test_acc))

number of parameters: 9.48M


100%|███████████████████████████████████████████████████████████████████████████████| 410/410 [00:01<00:00, 237.77it/s]


1 0.5048780487804878


100%|███████████████████████████████████████████████████████████████████████████████| 410/410 [00:01<00:00, 240.01it/s]


1 0.5024390243902439


100%|███████████████████████████████████████████████████████████████████████████████| 410/410 [00:01<00:00, 239.86it/s]


1 0.5219512195121951


100%|███████████████████████████████████████████████████████████████████████████████| 410/410 [00:01<00:00, 239.99it/s]


1 0.5146341463414634


100%|███████████████████████████████████████████████████████████████████████████████| 410/410 [00:01<00:00, 238.94it/s]

1 0.5121951219512195





In [187]:
from models import TransformerModelLooped

n_dims = 20
n_positions = 101
n_embd = 256
n_layer = 1
n_head = 8
device = torch.device('cuda:0')

model = TransformerModelLooped(n_dims, n_positions, n_embd, n_layer, n_head, pred_type='classification')
step = -1
# loop_loss_list = {}

# for i in baseline_run_ids.keys():
for i in [2]:    
    run_id = loop_run_ids[i]
    state_dict = torch.load("../results2/openml_loop/" + run_id + "/state.pt", map_location='cpu')['model_state_dict']
    model.load_state_dict(state_dict, strict=True)
    model = model.to(device)
    

    loop_loss_list[i] = []
    for j in range(5):
        seed = 4242 + j
        torch.manual_seed(seed)
        random.seed(seed)
        test_acc = post_train_eval(openml_datasets_test, dataset_id[i], n_dims, model, device, which_model='gpt2_loop')
        print(i, np.mean(test_acc))
        loop_loss_list[i].append(np.mean(test_acc))


number of parameters: 0.79M


  5%|████▏                                                                             | 8/156 [00:00<00:02, 73.12it/s]

0


100%|████████████████████████████████████████████████████████████████████████████████| 156/156 [00:02<00:00, 74.77it/s]


2 0.7371794871794872


100%|████████████████████████████████████████████████████████████████████████████████| 156/156 [00:02<00:00, 75.58it/s]


2 0.7115384615384616


100%|████████████████████████████████████████████████████████████████████████████████| 156/156 [00:02<00:00, 75.21it/s]


2 0.7243589743589743


100%|████████████████████████████████████████████████████████████████████████████████| 156/156 [00:02<00:00, 75.08it/s]


2 0.7051282051282052


100%|████████████████████████████████████████████████████████████████████████████████| 156/156 [00:02<00:00, 75.38it/s]

2 0.7243589743589743





In [185]:
for i in baseline_loss_list.keys():
    data_list = baseline_loss_list[i]
    print(i, np.mean(data_list), np.std(data_list))

0 0.6267942583732058 0.008006316043388285
1 0.511219512195122 0.0070013171192230906
2 0.6564102564102564 0.006537204504606111
3 0.39495798319327735 0.009703365868733209
4 0.4058536585365854 0.003650397450511142
5 0.46341463414634154 0.004081268422117445
6 0.48319327731092443 0.00547832134890979
7 0.6680851063829787 0.007036982919629349
8 0.5326424870466322 0.0038773651676414036
9 0.65 0.004545454545454542


In [188]:
for i in loop_loss_list.keys():
    data_list = loop_loss_list[i]
    print(i, np.mean(data_list), np.std(data_list))

0 0.662200956937799 0.008342390322565884
1 0.504390243902439 0.007619755781372353
2 0.7205128205128206 0.011176663957796594
3 0.4011204481792717 0.009767840769839053
4 0.415609756097561 0.004779492181040352
5 0.462439024390244 0.004523716339266201
6 0.5680672268907563 0.01489079424089863
7 0.756838905775076 0.00607902735562309
8 0.5098445595854922 0.005283958045173854
9 0.6484848484848484 0.007725787141807243
