In [1]:
###Consts
STEPS = 100_000
SEED = 42
DEVICE = 'cuda'
MODEL_N_DIMS = 10
LOG_EVERY_STEPS = 100
SAVE_EVERY_STEPS = 1000
KEEP_EVERY_STEPS = 1000

In [6]:
### Import libraries and scripts
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import wandb

import matplotlib.pyplot as plt
import numpy as np

import os
import tqdm

from tasks import get_task_sampler
from models import TransformerModel
from main_utils import gen_dataloader
from curriculum import CurriculumSimple

In [3]:
model = TransformerModel(
    n_dims=4,
    n_positions=201,
    n_embd=256,
    n_layer=12,
    n_head=8,
    pred_type="regression",
).cuda()

number of parameters: 9.48M


In [4]:
task_sampler = get_task_sampler(
    task_name="linear_regression",
    batch_size=2,
    n_points=7,
    n_dims=4,
    n_dims_truncated=1,
    device="cuda",
)

real_task = task_sampler()
xs, ys = real_task.xs.float(), real_task.ys.float()
n_loops = 1; horizon_start = max(0, n_loops - 20)
xs, ys, xs.shape, ys.shape

(tensor([[[-1.3078,  0.0000,  0.0000,  0.0000],
          [ 0.0933,  0.0000,  0.0000,  0.0000],
          [-0.4446,  0.0000,  0.0000,  0.0000],
          [ 0.9205,  0.0000,  0.0000,  0.0000],
          [ 0.2219,  0.0000,  0.0000,  0.0000],
          [-0.8618,  0.0000,  0.0000,  0.0000],
          [-1.5656,  0.0000,  0.0000,  0.0000]],
 
         [[ 0.1971,  0.0000,  0.0000,  0.0000],
          [ 0.6093,  0.0000,  0.0000,  0.0000],
          [-1.1747,  0.0000,  0.0000,  0.0000],
          [-1.5403,  0.0000,  0.0000,  0.0000],
          [ 0.6959,  0.0000,  0.0000,  0.0000],
          [-0.9754,  0.0000,  0.0000,  0.0000],
          [-0.3829,  0.0000,  0.0000,  0.0000]]], device='cuda:0'),
 tensor([[ 0.5135, -0.0366,  0.1746, -0.3614, -0.0871,  0.3384,  0.6147],
         [-0.3464, -1.0705,  2.0641,  2.7065, -1.2228,  1.7138,  0.6729]],
        device='cuda:0'),
 torch.Size([2, 7, 4]),
 torch.Size([2, 7]))

In [5]:
y_pred = model(xs, ys, add_inputs_embeds=False) 
print(y_pred)
loss = (ys - y_pred).square().mean()
loss

tensor([[-0.6317,  0.0810, -0.3501,  0.4283,  0.1921, -0.5659, -0.6622],
        [ 0.1001,  0.3541, -0.6169, -0.6579,  0.4289, -0.5863, -0.2742]],
       device='cuda:0', grad_fn=<SelectBackward0>)


  y = torch.nn.functional.scaled_dot_product_attention(


tensor(2.4573, device='cuda:0', grad_fn=<MeanBackward0>)

In [75]:

def train_step(curriculum, model, xs, ys, optimizer, ctx, scaler, n_loop_window, family = 'gpt2'):
    """n_loop_window: T"""
    if family == 'gpt2':
        with ctx:
            #y_pred = model(xs, ys, add_inputs_embeds=args.training.add_inputs_embeds)  # [B, n]
            y_pred = model(xs, ys, add_inputs_embeds=False)  # [B, n]
            # list of [B, n], length K + 1, get rid of the 0-th one
            loss = (ys - y_pred).square().mean()  # auto on both K and n (number of in context samples)
            
    elif family == 'gpt2_loop':
        n_loops = curriculum.n_loops  # K
        with ctx:
            horizon_start = max(0, n_loops - n_loop_window)
            y_pred_list = model(xs, ys, horizon_start, n_loops)
            # list of [B, n], length K
            y_pred_arr = torch.cat(y_pred_list, dim=0)  # [B * K, n]
            y_star_arr = torch.cat([ys] * len(y_pred_list), dim=0)  # [B * K, n]
            loss = (y_star_arr - y_pred_arr).square().mean()  # auto on both K and n (number of in context samples)
            y_pred = y_pred_list[-1]  # [B, n]
 
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    optimizer.zero_grad(set_to_none=True)
    return loss.detach(), y_pred.detach()

def train_loop(model, 
               lr=0.0001, 
               use_fixed_dataset = False, 
               task_name = "linear_regression", 
               batch_size=64, 
               model_n_dims=MODEL_N_DIMS,
               train_size = 10000, test_size=1000,
               family = "gpt2",
               experiment_name = "linear_regression_gpt_2",):
    state_path = "state.pt"
    # TORCH 2.0 ZONE ###############################
    torch.set_float32_matmul_precision('highest')
    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
    dtype = 'float16'  # 'bfloat16', 'float32'
    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
    ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype, cache_enabled=False)

    ################################################
    
    wandb.init(
        dir="results",
        project="looped_transformers",
        #config=args.__dict__,
        notes="args.wandb.notes",
        name=experiment_name,
        mode="disabled",
        resume=True,
    )

    torch.manual_seed(SEED)
    # model = torch.compile(model)

    model.to(DEVICE)
    model.train()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0)
    scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
    curriculum = CurriculumSimple(dims_start = 5,
        points_start = 11,
        loops_start= 1,
        dims_schedule = 0,
        points_schedule = 0,
        loops_schedule = 0)

    # Here the model load the pretrained model
    # args, model, optimizer, curriculum, state_path, starting_step = load_pretrained_model(
    #     args, model, optimizer, curriculum, device)

    if use_fixed_dataset:
        from main_utils import gen_dataloader
        task_sampler = get_task_sampler(
            task_name=task_name,
            batch_size=batch_size,
            n_points=curriculum.n_points,
            n_dims=model_n_dims,
            n_dims_truncated=curriculum.n_dims_truncated,
            device=DEVICE,
            sparsity=False,
        )
        train_loader = gen_dataloader(task_sampler, train_size, batch_size)
        train_iter = iter(train_loader)
        test_loader = gen_dataloader(task_sampler, test_size, batch_size)

    pbar = tqdm(range(0, STEPS))
    for i in pbar:
        if use_fixed_dataset:
            try:
                batch = next(train_iter)
                xs, ys = batch['x'].to(DEVICE), batch['y'].to(DEVICE)
            except StopIteration:
                train_iter = iter(train_loader)
        else:
            task_sampler = get_task_sampler(
                task_name=task_name,
                batch_size=batch_size,
                n_points=curriculum.n_points,
                n_dims= model_n_dims,
                n_dims_truncated= curriculum.n_dims_truncated,
                device=DEVICE,
                sparsity=False,
            )

            real_task = task_sampler()
            xs, ys = real_task.xs.float(), real_task.ys.float()

        loss, output, total_norm, grad_norm_dict = train_step(curriculum, model, xs, ys, optimizer, ctx, scaler)
        train_loss = loss
        
        # EVALUATION ======================================
        point_wise_tags = list(range(curriculum.n_points))  # [0, 1, 2, ..., n-1]
        if i % LOG_EVERY_STEPS == 0:
            point_wise_loss = (output - ys).square().mean(dim=0)  # [n,]
            if use_fixed_dataset:
                # eval
                with torch.no_grad():
                    for batch in test_loader:
                        xs, ys = batch['x'].to(DEVICE), batch['y'].to(DEVICE)
                        if family == 'gpt2':
                            output = model(xs, ys)  # [B,]
                        elif family == 'gpt2_loop':
                            n_loops = curriculum.n_loops  # K
                            y_pred_list = model(xs, ys, 0, n_loops)
                            output = y_pred_list[-1]  # [B, n]
                        else:
                            raise NotImplementedError
                        point_wise_loss = (output - ys).square().mean(dim=0)
                        loss = point_wise_loss.mean()
            wandb.log(
                {
                    "overall_loss": loss,
                    "overall_train_loss": train_loss,
                    "loop_times": curriculum.n_loops,
                    "grad_norm/layerwise": grad_norm_dict,
                    "grad_norm": total_norm,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.detach().cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                    "lr": optimizer.param_groups[0]['lr'],
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss}")
        if i % SAVE_EVERY_STEPS == 0:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
                "val_loss": loss,
            }
            torch.save(training_state, state_path)
        if (KEEP_EVERY_STEPS > 0
                and i % KEEP_EVERY_STEPS == 0
                and i > 0 ) or (i == STEPS - 1):
            torch.save({'model': model.state_dict()},
                       os.path.join(args.out_dir, f"model_{i}.pt"))

NameError: name 'MODEL_N_DIMS' is not defined