In [97]:
GPU = 2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(GPU)

import os
from models import TransformerModel
from tasks import get_task_sampler
from samplers import get_data_sampler
import torch
import numpy as np

import time



In [128]:
def eval_batch(model, data_sampler, task_sampler, percent_clamped_correct=0, window_len=41, b_size=1, last_pt_clamped=True, smoothing=0):
    task = task_sampler()
    if torch.cuda.is_available() and model.name.split("_")[0] in ["gpt2", "lstm"]:
        device = "cuda"
    else:
        device = "cpu"

    # sample 100 x points
    # keep sampling a batch of 100 x points until we have % 0.5 ys>=percent_clamped
    # then, depending on the context length, select percent_clamped *context_window (x,y) pairs, and randomly select the rest from the remaining points

    num_clamped_needed = int(percent_clamped_correct*window_len)
    num_other_needed = int(window_len - num_clamped_needed)

    print("needed, other", (num_clamped_needed, num_other_needed))

    done = False
    while not done: 
        xs = data_sampler.sample_xs(n_points=100, b_size=b_size)
        time.sleep(10)
        ys = task.evaluate(xs, noise=False, separate_noise=False)
        
        print(f'clamped {torch.where(ys[0]==0.5)[0].size(0)}, other {torch.where(ys[0]!=0.5)[0].size(0)}')
        done = torch.where(ys[0]==0.5)[0].size(0) >= num_clamped_needed and torch.where(ys[0]!=0.5)[0].size(0) >= num_other_needed # use 1 bc dont want batch dim


    # select percent_clamped_correct * window_len points
    clamped_indices = torch.where(ys[0]==0.5)[0]
    indices = torch.multinomial(torch.ones(len(clamped_indices)), num_clamped_needed, replacement=False)  # select indices
    clamped_indices = clamped_indices[indices]

    clamped_xs = xs[:, clamped_indices]
    clamped_ys = ys[:, clamped_indices]

    # select the rest randomly
    remaining_indices = torch.where(ys[0]!=0.5)[0]
    indices = torch.multinomial(torch.ones(len(remaining_indices)), num_other_needed, replacement=False)  # select indices
    remaining_indices = remaining_indices[indices]

    remaining_xs = xs[:, remaining_indices]
    remaining_ys = ys[:, remaining_indices]

    # save last point
    if last_pt_clamped:
        last_x, last_y = clamped_xs[:, -1], clamped_ys[:, -1]
        clamped_xs = clamped_xs[:, :-1]
    else:
        last_x, last_y = remaining_xs[:, -1], remaining_ys[:, -1]
        remaining_xs = remaining_xs[:, :-1]

    # combine clamped & remaining
    xs = torch.cat([clamped_xs, remaining_xs], dim=1)
    ys = torch.cat([clamped_ys, remaining_ys], dim=1)

    # shuffle order of xs and ys(but together)
    perm = torch.randperm(xs.shape[1])
    xs = xs[:, perm]
    ys = ys[:, perm]

    print(ys.shape)
    print(last_y[:, None].shape)

    # add back last point
    xs = torch.cat((xs, last_x[:, :, None]), dim=1)
    ys = torch.cat((ys, last_y[:, None]), dim=1)

    pred = model(xs.to(device), ys.to(device)).detach()

    perturbations = np.arange(-1 * smoothing, smoothing + 0.002, 0.002)
    predictions = torch.zeros(len(perturbations), xs.shape[0], xs.shape[1])
    predictions = pred.cpu() # (64, 41)

    predictions = predictions[:, window_len-1]


    return ys[:,  window_len-1], predictions

def build_model():
    model = TransformerModel(
        n_dims=1,
        n_positions=41,
        n_embd=512,
        n_layer=24,
        n_head=16,
    )
    return model


In [129]:

model = build_model()
torch.cuda.set_device(0)
model.cuda()

# ckpt_path = '/home/riadoshi/alignment/Alignment/models/finetune/go_time/'
ckpt_path = '/home/riadoshi/alignment/prev/ckpts/ckpt/'
base_model = os.path.join(ckpt_path, "state.pt")
state = torch.load(base_model, map_location='cuda:0')
model.load_state_dict(state["model_state_dict"])
b_size = 1

task_sampler = get_task_sampler(
        "clamped_chebyshev", 1, b_size
)

data_sampler = get_data_sampler('gaussian', 1)



In [130]:
# trial one: fixed context window = 41, varying percentage clamped. 
# last point clamped

# last point not clamped
last_pt_clamped=True
percent_clamped = [0.1]
num_per_pc = 10
mse = []

for pc in percent_clamped:
    pc_mse = 0
    for _ in range(num_per_pc):
        y, pred = eval_batch(model, data_sampler, task_sampler, percent_clamped_correct=pc, last_pt_clamped=last_pt_clamped)
        pc_mse+= np.sqrt((pred-y)**2)
        
    mse.append(pc_mse/num_per_pc)
    

needed, other (4, 37)
clamped 6, other 94
torch.Size([1, 40])
torch.Size([1, 1])
needed, other (4, 37)
clamped 0, other 100
clamped 0, other 100


KeyboardInterrupt: 