In [1]:

import os
from models import TransformerModel
from tasks import get_task_sampler
from samplers import get_data_sampler
import torch
import numpy as np



In [19]:
def eval_batch(model, task_sampler, xs, include_noise=True, ground_truth_loss=False, smoothing=0):
    task = task_sampler()
    if torch.cuda.is_available() and model.name.split("_")[0] in ["gpt2", "lstm"]:
        device = "cuda"
    else:
        device = "cpu"
    perturbations = np.arange(-1 * smoothing, smoothing + 0.002, 0.002)
    predictions = torch.zeros(len(perturbations), xs.shape[0], xs.shape[1])
    if ground_truth_loss:
        ys, noise = task.evaluate(xs, noise=include_noise, separate_noise=True)
        ys = ys + noise
    else:
        ys = task.evaluate(xs, noise=include_noise, separate_noise=False)
    for i in range(len(perturbations)):
        cur_xs = xs + perturbations[i]
        pred = model(cur_xs.to(device), ys.to(device)).detach()
        predictions[i] = pred.cpu()
    predictions = predictions.mean(dim=0)
    if ground_truth_loss:
        metrics = task.get_metric()(predictions, ys - noise)
    else: 
        metrics = task.get_metric()(predictions, ys)

    # hinge metric
    # return mean squared error only where ys = 0.5
    clamped_error = metrics[torch.where(ys == 0.5)]


    return metrics, clamped_error

def build_model():
    model = TransformerModel(
        n_dims=1,
        n_positions=41,
        n_embd=512,
        n_layer=24,
        n_head=16,
    )
    return model


In [20]:

model = build_model()
torch.cuda.set_device(6)
model.cuda()

ckpt_path = '/home/riadoshi/alignment/Alignment/models/finetune/go_time/'
base_model = os.path.join(ckpt_path, "state.pt")
state = torch.load(base_model)
model.load_state_dict(state["model_state_dict"])

task_sampler = get_task_sampler(
        "clamped_chebyshev", 1, 64
)

def gen_standard(data_sampler, n_points, b_size):
    xs = data_sampler.sample_xs(n_points, b_size)
    return xs, None

data_sampler = get_data_sampler('gaussian', 1)
xs, xs_p = gen_standard(data_sampler, n_points=41, b_size=64)
metrics = eval_batch(model, task_sampler, xs, include_noise=False, ground_truth_loss=False, smoothing=0)


In [21]:
other_metrics, clamped_mse = metrics

In [27]:
sum(clamped_mse).mean()

tensor(569.1140)

In [28]:
sum(other_metrics).mean()

tensor(37.0464)