In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import wandb
from tqdm.notebook import tqdm_notebook as tqdm

from dataclasses import dataclass

from src import LinregDataset, save, set_seed, train, evaluate
from src.transformer import BaseLoopedTransformer

In [None]:
@dataclass
class Config:
    n_dims:       int = 4
    num_layers:   int = 1
    attn_heads:   int = 4
    hidden_dim:   int = 32
    mlp_hidden:   int = 128
    context:      int = 128
    activation:   nn.Module = nn.GELU
    
def get_config(name, n_dims):
    models = {
        'tiny': (Config(
            n_dims = n_dims, num_layers = 1,
            attn_heads = 2, hidden_dim = 26,
            mlp_hidden = 128, context = 127
        ), 1e-3, 42, 32),
        'small': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127
        ), 5e-4, 42, 32),
        'medium': (Config(
            n_dims = n_dims, num_layers = 4,
            attn_heads = 8, hidden_dim = 128,
            mlp_hidden = 768, context = 127
        ), 5e-4, 42, 32),
    }
    return models[name]

гипотеза: чем больше b, тем больше можно выставить lr (ну или тем больше требуется итераций при текущих гиперпараметрах для достижения той же точности что и с меньшим b).

## Тренировка

In [None]:
from torch.optim import Adam
from itertools import product

n_dims = 8
mean, std = 0, 1

seeds = [42, 451, 1984]
models = ['medium', 'small', 'tiny']
bs = [1, 5, 10, 15]

runs = []
for model, bs, seed in product(models, bs, seeds):
    runs.append((model, bs, seed))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

extrp_bs = [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]

for (name, b, seed) in runs:

    config, lr, train_bsize, test_bsize = get_config(name, n_dims)
    
    # Это количество in-context примеров N
    n_points = (config.context + 1) // 2

    train_loader = DataLoader(LinregDataset(
        n_dims = n_dims, n_points = n_points,
        mean = mean, std = std, random = True,
        device = device
    ), batch_size = train_bsize)
    test_loader = DataLoader(LinregDataset(
        n_dims = n_dims, n_points = n_points,
        mean = mean, std = std, random = True,
        total = test_bsize * 25, device = device
    ), batch_size = test_bsize)

    set_seed(seed)

    model = BaseLoopedTransformer(config).to(device)
    optimizer = Adam(model.parameters(), lr=lr)

    run_name = f'{name}_{b}_{seed}'
    run = wandb.init(
        project = 'Looped Transformer',
        name = run_name,
        config = {
            'name': f'exp1_{name}_{b}',
            'experiment': 1,
            'model': name,
            'b': b,
            'train batch size': train_bsize,
            'test batch size': test_bsize,
            'lr': lr,
            'seed': seed,
            'N': n_points - 1,
        }
    )

    loss_history, eval_history = train(
        model, train_loader, test_loader, optimizer, 
        steps = 7500, b = b, run = run, log_every = 75
    )

    extrapolation = []
    for bi in extrp_bs:
        result = evaluate(test_loader, model, bi)
        extrapolation.append(result)
        print(f'b = {bi} --- loss = {result:.5f}')
        run.log({'b': bi, 'MSE': result})
    
    run.finish()

    save(
        name = run_name, 
        model = model,
        loss = loss_history, 
        eval = eval_history, 
        extr = extrapolation,
        path = './results/experiment 1'
    )

    torch.cuda.empty_cache()

## Валидация

In [None]:
import json
from itertools import product

n_dims = 8
mean, std = 0, 1

seeds = [42, 451, 1984]
models = ['medium', 'small', 'tiny']
bs = [15, 10, 5, 1]

device = 'cuda'

runs = []
for model, b in product(models, bs):
    runs.append((model, b))
    
test_seed = 4815163242 % 2**31
num_launches = 10
    
results = {}
pbar = tqdm(range(len(runs) * num_launches * len(seeds)))
for (name, b) in runs:
    config, lr, train_bsize, test_bsize = get_config(name, n_dims)
    n_points = (config.context + 1) // 2
    
    res = []
    
    set_seed(test_seed)
    for seed in seeds:
        model = torch.load(f'./results/experiment 1/models/{name}_{b}_{seed}.pt')
            
        for i in range(num_launches):
            
            std = torch.rand((1,)).item() * 2
            loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                total = 128 * 10, device = device
            ), batch_size = 128)
            
            total = 0
            with torch.no_grad():
                for (x, y) in loader:
                    
                    preds = model(x[:, :-1], b)
                    preds = torch.stack(preds)
                    targs = torch.stack([y] * b)
                    
                    loss = (targs[:,:,-1] - preds[:,:,-1]).square().mean(dim=0).mean()
                    
                    total += loss.item() / loader.dataset.n_dims
            res.append(total / len(loader))
            pbar.set_description(f'Run \'{name}_{b}\', seed {seed}...')
            pbar.update(1)
                
    results[f'{name}_{b}'] = res

In [None]:
with open('./results/experiment 1/data/evaluation.json', 'w') as f:
    f.write(json.dumps(results, indent=4))