In [20]:
import torch
import losslandscape as land
import matplotlib.pyplot as plt

from models import TransformerModelLooped, TransformerModelLoopedPyHessianWrapper, TransformerModelPyHessianWrapper
from curriculum import CurriculumSimple
from train import train_without_config, validate_model
from scripts.tasks import get_task_sampler

n_dims = 10
train_steps = 15000
log_every = 500
device = "cuda:0"

In [21]:
task_sampler = get_task_sampler(
    task_name="linear_regression",
    batch_size=1,
    n_points=31,
    n_dims=10,
    n_dims_truncated=10,
    device=device,
)

real_task = task_sampler()
xs, ys = real_task.xs.float(), real_task.ys.float()

In [22]:
## criterion function for loss plot calculation 
def criterion_fn(y_pred_list, ys_in):
    y_pred_arr = torch.cat(y_pred_list, dim=0)  # [B * K, n]
    y_star_arr = torch.cat([ys_in] * len(y_pred_list), dim=0)  # [B * K, n]
    return (y_star_arr - y_pred_arr).square().mean()

In [23]:
model_loop_b5 = TransformerModelLoopedPyHessianWrapper(
    n_dims=n_dims,
    n_positions=101,
    n_embd=128,
    n_layer=1,
    n_head=4,
    pred_type="regression",
    default_n_loops=5
).to(device)

model_b5_config = {
    "curriculum" : CurriculumSimple(n_dims, 31, 5, [5000, n_dims, 0], [5000, 31, 0], [1000, 5, 0]),
    "log_steps" : train_steps // log_every,
    "params" : [],
    "losses" : [],
    "metrics" : None
}

def callback_b5_fn(model, loss):
    model_b5_config["params"].append(land.ParamList(land.get_params(model)))
    model_b5_config["losses"].append(loss)


model_b5_config["metrics"] = train_without_config(
    model_loop_b5, model_b5_config["curriculum"], model_n_dims=n_dims,
    log_every_steps=log_every, train_steps=train_steps, family="gpt2_loop",
    do_wandb_log=False, seed=None, task_name="linear_regression", callback=callback_b5_fn)

number of parameters: 0.20M


  0%|          | 0/15000 [00:00<?, ?it/s]

0


loss 4.830509662628174:  85%|████████▍ | 12677/15000 [07:41<01:24, 27.48it/s] 


KeyboardInterrupt: 

In [None]:
model_b5_config["losses"] = torch.tensor(model_b5_config["losses"]).to(device)

loss_landscape = land.LossLandscapePlotting(
    model=model_loop_b5,
    criterion= criterion_fn,
    device=device,
    data=(torch.concatenate([xs, ys.unsqueeze(-1)], dim=-1), ys),
    parameters_history=model_b5_config["params"],
    loss_history=model_b5_config["losses"],
    mean_theta0=True
)

In [None]:
trace = loss_landscape.compute_trace(every_ith=1)
ralpha, rbeta, surface = loss_landscape.compute_landscape(trace, arange=(-1, 1), brange=(-1, 1), grid_density=20, coef=1)

In [None]:
loss_landscape.plot(
    trace=trace, # [trace[0], trace[-1]],
    ralpha=ralpha, rbeta=rbeta, surface=surface,
    colormap="magma", k=0.5
)

In [None]:
loss_landscape.plot_contour(
    trace=trace,
    ralpha=ralpha, rbeta=rbeta, surface=surface,
    colormap="magma", k=1
)

In [None]:
model_l2 = TransformerModelPyHessianWrapper(
    n_dims=n_dims,
    n_positions=101,
    n_embd=128,
    n_layer=2,
    n_head=4,
    pred_type="regression"
).to(device)

model_l2_config = {
    "curriculum" : CurriculumSimple(n_dims, 31, 0, [5000, n_dims, 0], [5000, 31, 0], [1000, 0, 0]),
    "log_steps" : train_steps // log_every,
    "params" : [],
    "losses" : [],
    "metrics" : None
}

def callback_l2_fn(model, loss):
    model_l2_config["params"].append(land.ParamList(land.get_params(model)))
    model_l2_config["losses"].append(loss)

model_l2_config["metrics"] = train_without_config(
    model_l2, model_l2_config["curriculum"], model_n_dims=n_dims,
    log_every_steps=log_every, train_steps=train_steps, family="gpt2",
    do_wandb_log=False, seed=None, task_name="linear_regression", callback=callback_l2_fn)

In [None]:
model_l2_config["losses"] = torch.tensor(model_l2_config["losses"]).to(device)

def criterion_fn(y_pred, ys):
    return (ys - y_pred).square().mean()


from scripts.tasks import get_task_sampler

task_sampler = get_task_sampler(
    task_name="noisy_linear_regression",
    batch_size=1,
    n_points=31,
    n_dims=10,
    n_dims_truncated=10,
    device=device,
)

real_task = task_sampler()
xs, ys = real_task.xs.float(), real_task.ys.float()

loss_landscape = land.LossLandscapePlotting(
    model=model_l2,
    criterion=criterion_fn,
    device=device,
    data=(torch.concatenate([xs, ys.unsqueeze(-1)], dim=-1), ys),
    parameters_history=model_l2_config["params"],
    loss_history=model_l2_config["losses"],
    mean_theta0=True,
)
  

In [None]:
trace = loss_landscape.compute_trace(every_ith=1)
ralpha, rbeta, surface = loss_landscape.compute_landscape(trace, arange=(-1, 1), brange=(-1, 1), grid_density=20, coef=1)

In [None]:
loss_landscape.plot(
    trace=trace,  # [trace[0], trace[-1]],
    ralpha=ralpha, rbeta=rbeta, surface=surface,
    colormap="magma", k=0.5
)

In [None]:
loss_landscape.plot_contour(
    trace=trace,
    ralpha=ralpha, rbeta=rbeta, surface=surface,
    colormap="magma", k=1
)