In [1]:
import torch
import torch.nn as nn

import wandb
import random
import numpy as np
from tqdm.notebook import tqdm_notebook as tqdm

import matplotlib.pyplot as plt

from transformer import BaseLoopedTransformer
from utils import LinregDataset, save, train, set_seed
from torch.utils.data import DataLoader

from dataclasses import dataclass

In [3]:
@dataclass
class Config:
    n_dims:       int = 4
    num_layers:   int = 1
    attn_heads:   int = 4
    hidden_dim:   int = 32
    mlp_hidden:   int = 128
    context:      int = 128
    activation:   nn.Module = nn.GELU
    
def get_config(name, n_dims):
    models = {
        'L1': (Config(
            n_dims = n_dims, num_layers = 1,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127
        ), 1e-3, 64, 32),
        
        'L2': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127
        ), 5e-4, 64, 32),
        
        'L4': (Config(
            n_dims = n_dims, num_layers = 4,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127
        ), 5e-4, 64, 32),
        
        'L8': (Config(
            n_dims = n_dims, num_layers = 8,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127
        ), 5e-4, 64, 32),
    }
    return models[name]

In [None]:
from torch.optim import Adam

n_dims = 8
mean, std = 0, 1

seeds = [42, 451, 1984]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

models = {
    'L1': [1, 2, 4, 8],
    'L2': [1, 2, 4],
    'L4': [1, 2],
    'L8': [1]
}

for name, bs in models.items():
    for b in bs:
        for seed in seeds:
            
            config, lr, train_bsize, test_bsize = get_config(name, n_dims)
            n_points = (config.context + 1) // 2
            
            train_loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                device = device
            ), batch_size = train_bsize)
            test_loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                total = test_bsize * 10, device = device
            ), batch_size = test_bsize)
    
            set_seed(seed)
            
            model = BaseLoopedTransformer(config).to(device)
            optimizer = Adam(model.parameters(), lr=lr)
            
            run_name = f'{name}_{b}_{seed}'
            run = wandb.init(
                project = 'Looped Transformer',
                name = run_name,
                config = {
                    'name': f'exp2_{name}_{b}',
                    'experiment': 2,
                    'model': name,
                    'b': b,
                    'train batch size': train_bsize,
                    'test batch size': test_bsize,
                    'lr': lr,
                    'seed': seed,
                    'N': n_points - 1,
                }
            )
            
            loss_history, eval_history = train(
                model, train_loader, test_loader, optimizer, 
                b = b, steps = 7500, run = run, log_every = 75
            )
                
            run.finish()
            
            save(2, run_name, loss_history, eval_history)
            torch.save(model, './results/experiment 2/models/' + run_name + '.pt')
            
            torch.cuda.empty_cache()

In [7]:
import json

n_dims = 8
mean, std = 0, 1

seeds = [42, 451, 1984]
models = {
    'L1': [1, 2, 4, 8],
    'L2': [1, 2, 4],
    'L4': [1, 2],
    'L8': [1]
}

device = 'cuda'
    
test_seed = 4815163242 % 2**31
num_launches = 10

results = {}
pbar = tqdm(range(10 * 3 * num_launches))
for name, bs in models.items():
    for b in bs:
        config, lr, train_bsize, test_bsize = get_config(name, n_dims)
        n_points = (config.context + 1) // 2
        
        res = []
        
        set_seed(test_seed)
        for seed in seeds:
            model = torch.load(f'./results/experiment 2/models/{name}_{b}_{seed}.pt')
                
            for i in range(num_launches):
                
                std = torch.rand((1,)).item() * 2
                loader = DataLoader(LinregDataset(
                    n_dims = n_dims, n_points = n_points,
                    mean = mean, std = std, random = True,
                    total = 128 * 10, device = device
                ), batch_size = 128)
                
                total = 0
                with torch.no_grad():
                    for (x, y) in loader:
                        
                        preds = model(x[:, :-1], b)
                        preds = torch.stack(preds)
                        targs = torch.stack([y] * b)
                        
                        # First by predictions, then by batches
                        loss = (targs[:,:,-1] - preds[:,:,-1]).square().mean(dim=0).mean()
                        
                        total += loss.item() / loader.dataset.n_dims
                res.append(total / len(loader))
                pbar.set_description(f'Run \'{name}_{b}\', seed {seed}...')
                pbar.update(1)
                    
        results[f'{name}_{b}'] = res

  0%|          | 0/300 [00:00<?, ?it/s]

In [9]:
with open('./results/experiment 2/data/evaluation.json', 'w') as f:
    f.write(json.dumps(results, indent=4))