In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import math

import wandb
from tqdm.notebook import tqdm_notebook as tqdm

from dataclasses import dataclass

from src.linreg import LinregDataset
from src.utils import save, set_seed
from src.training import train, evaluate
from src.transformer import Transformer

In [10]:
@dataclass
class Config:
    n_dims:       int = 4
    num_layers:   int = 1
    attn_heads:   int | None = None
    head_dim:     int | None = None
    hidden_dim:   int | None = None
    dhat:         int | None = None
    mlp_hidden:   int = 128
    context:      int = 129
    method:       str = 'softmax'
    activation:   nn.Module = nn.GELU
    elu_alpha:    float = 1.0
    
    def _calculate_dhat(self):
        if self.method in ['based', 'rebased']:
            low = math.floor(math.sqrt(self.head_dim))
            top = math.ceil(math.sqrt(self.head_dim))
            lowdiff = math.fabs(self.hidden_dim - low ** 2 * self.attn_heads)
            topdiff = math.fabs(self.hidden_dim - top ** 2 * self.attn_heads)
            if lowdiff < topdiff:
                dhat = low * self.attn_heads
            else:
                dhat = top * self.attn_heads
        else:
            dhat = self.hidden_dim
        return dhat
    
    def __post_init__(self):
        if self.hidden_dim is not None and self.attn_heads is not None:
            self.head_dim = self.hidden_dim // self.attn_heads
        elif self.attn_heads is not None and self.head_dim is not None:
            self.hidden_dim = self.attn_heads * self.head_dim
        else:
            raise ValueError('You should provide either (hidden_dim, attn_heads) or (attn_heads, head_dim).')
        
        if self.dhat is None:
            self.dhat = self._calculate_dhat()
    
def get_config(name, n_dims):
    models = {
        'softmax': (Config(
            n_dims = n_dims, num_layers = 1,
            attn_heads = 4, hidden_dim = 96,
            mlp_hidden = 256, context = 257,
            method = 'softmax',
        ), 1e-3, 64, 32),
        'based': (Config(
            n_dims = n_dims, num_layers = 1,
            attn_heads = 4, hidden_dim = 96,
            mlp_hidden = 256, context = 257,
            method = 'based'
        ), 1e-3, 64, 32),
        'rebased': (Config(
            n_dims = n_dims, num_layers = 1,
            attn_heads = 4, hidden_dim = 96,
            mlp_hidden = 256, context = 257,
            method = 'rebased'
        ), 1e-3, 64, 32),
        'learnable': (Config(
            n_dims = n_dims, num_layers = 1,
            attn_heads = 4, hidden_dim = 96,
            mlp_hidden = 256, context = 257,
            method = 'learnable'
        ), 1e-3, 64, 32),
        'elu': (Config(
            n_dims = n_dims, num_layers = 1,
            attn_heads = 4, hidden_dim = 96,
            mlp_hidden = 256, context = 257,
            method = 'learnable'
        ), 1e-3, 64, 32),
        'squared': (Config(
            n_dims = n_dims, num_layers = 1,
            attn_heads = 4, hidden_dim = 96,
            mlp_hidden = 256, context = 257,
            method = 'learnable'
        ), 1e-3, 64, 32),
    }
    return models[name]

## Тренировка

In [11]:
from torch.optim import AdamW
from itertools import product

N = 4
n_dims = 8
mean, std = 0, 1

seeds = [42, 451, 1984][:1]
models = ['softmax', 'based', 'rebased', 'learnable', 'elu', 'squared']
bs = [1, 5]

runs = []
for model, bs, seed in product(models, bs, seeds):
    runs.append((model, bs, seed))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

extrp_bs = [1, 5, 10, 15]

for (name, b, seed) in runs:

    config, lr, train_bsize, test_bsize = get_config(name, n_dims)
    if config.context != N * 2 + 1:
        config.context = N * 2 + 1

    train_loader = DataLoader(LinregDataset(
        n_dims = n_dims, n_points = N + 1,
        xmean = mean, xstd = std, device = device
    ), batch_size = 1)
    test_loader = DataLoader(LinregDataset(
        n_dims = n_dims, n_points = N + 1,
        xmean = mean, xstd = std, device = device,
        total = test_bsize * 5
    ), batch_size = 1)

    set_seed(seed)

    model = Transformer(config).to(device)
    optimizer = AdamW(model.parameters(), lr=lr)

    run_name = f'{name}_{b}_{seed}'
    print(run_name)
    # run = wandb.init(
    #     project = 'Linear Transformer',
    #     name = run_name,
    #     config = {
    #         'name': f'{name}_{b}',
    #         'model': name,
    #         'b': b,
    #         'train batch size': train_bsize,
    #         'test batch size': test_bsize,
    #         'lr': lr,
    #         'seed': seed,
    #         'N': N,
    #     }
    # )

    loss_history, eval_history = train(
        model, train_loader, test_loader, optimizer, 
        steps = 75, b = b, run = None, log_every = 1
    )

    extrapolation = []
    for bi in extrp_bs:
        result = evaluate(test_loader, model, bi)
        extrapolation.append(result)
        print(f'b = {bi} --- loss = {result:.5f}')
        # run.log({'b': bi, 'MSE': result})
    
    # run.finish()

    save(
        name = run_name, 
        model = model,
        loss = loss_history, 
        eval = eval_history, 
        extr = extrapolation,
        path = './results/'
    )

    torch.cuda.empty_cache()

softmax_1_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 1.10198
b = 5 --- loss = 1.60882
b = 10 --- loss = 3.98357
b = 15 --- loss = 6.25741
softmax_5_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 1.08826
b = 5 --- loss = 1.13193
b = 10 --- loss = 1.89257
b = 15 --- loss = 2.42076
based_1_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 1.05669
b = 5 --- loss = 1.73280
b = 10 --- loss = 2.97090
b = 15 --- loss = 5.78900
based_5_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 1.04130
b = 5 --- loss = 1.11826
b = 10 --- loss = 1.27768
b = 15 --- loss = 1.99170
rebased_1_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 0.91004
b = 5 --- loss = 1.49189
b = 10 --- loss = 3.64890
b = 15 --- loss = 7.27225
rebased_5_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 0.85395
b = 5 --- loss = 1.02606
b = 10 --- loss = 1.46918
b = 15 --- loss = 2.28624
learnable_1_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 0.93074
b = 5 --- loss = 1.29131
b = 10 --- loss = 1.87351
b = 15 --- loss = 3.07824
learnable_5_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 0.93117
b = 5 --- loss = 1.24895
b = 10 --- loss = 1.78055
b = 15 --- loss = 2.87314
elu_1_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 0.93074
b = 5 --- loss = 1.29131
b = 10 --- loss = 1.87351
b = 15 --- loss = 3.07824
elu_5_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 0.93117
b = 5 --- loss = 1.24895
b = 10 --- loss = 1.78055
b = 15 --- loss = 2.87314
squared_1_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 0.93074
b = 5 --- loss = 1.29131
b = 10 --- loss = 1.87351
b = 15 --- loss = 3.07824
squared_5_42


  0%|          | 0/75 [00:00<?, ?it/s]

b = 1 --- loss = 0.93117
b = 5 --- loss = 1.24895
b = 10 --- loss = 1.78055
b = 15 --- loss = 2.87314


## Валидация

In [None]:
import json
from itertools import product

n_dims = 8
mean, std = 0, 1

seeds = [42, 451, 1984]
models = ['medium', 'small', 'tiny']
bs = [15, 10, 5, 1]

device = 'cuda'

runs = []
for model, b in product(models, bs):
    runs.append((model, b))
    
test_seed = 4815163242 % 2**31
num_launches = 10
    
results = {}
pbar = tqdm(range(len(runs) * num_launches * len(seeds)))
for (name, b) in runs:
    config, lr, train_bsize, test_bsize = get_config(name, n_dims)
    n_points = (config.context + 1) // 2
    
    res = []
    
    set_seed(test_seed)
    for seed in seeds:
        model = torch.load(f'./results/experiment 1/models/{name}_{b}_{seed}.pt')
            
        for i in range(num_launches):
            
            std = torch.rand((1,)).item() * 2
            loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                total = 128 * 10, device = device
            ), batch_size = 128)
            
            total = 0
            with torch.no_grad():
                for (x, y) in loader:
                    
                    preds = model(x[:, :-1], b)
                    preds = torch.stack(preds)
                    targs = torch.stack([y] * b)
                    
                    loss = (targs[:,:,-1] - preds[:,:,-1]).square().mean(dim=0).mean()
                    
                    total += loss.item() / loader.dataset.n_dims
            res.append(total / len(loader))
            pbar.set_description(f'Run \'{name}_{b}\', seed {seed}...')
            pbar.update(1)
                
    results[f'{name}_{b}'] = res