In [None]:
import torch
import torch.nn as nn

import wandb
import random
import numpy as np
from tqdm.notebook import tqdm_notebook as tqdm

import matplotlib.pyplot as plt

from utils import LinregDataset, save, set_seed
from torch.utils.data import DataLoader

from dataclasses import dataclass

## Создание модели

In [None]:
from transformer import Layer

class LoopedTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        self.emb = nn.Linear(config.n_dims, config.hidden_dim)
        self.pe = nn.Embedding(config.context, config.hidden_dim)
        self.layers = nn.Sequential(*[
            Layer(config) for i in range(config.num_layers)
        ])
        self.out = nn.Linear(config.hidden_dim, config.n_dims)
         
    def _get_mask(self, config, input_dim, device):
        if config.mask_type == 'causal':
            mask = torch.tril(torch.ones(input_dim, input_dim))
            return mask.view(1, 1, input_dim, input_dim).to(device)
        else:
            raise NotImplementedError(f'Mask type \'{config.mask_type}\' is not implemented.')
        
    def cut(self, x, out, n):
                
        def change(t, idx):
            temp = torch.zeros((x.size(0), n * 2, x.size(2)), device=t.device)

            temp[:, ::2] = t[:, idx]
            temp[:, 1::2] = t[:, idx + 1]
            
            return torch.cat([temp, t[:, [-1]]], dim = 1)
    
        indices = torch.arange(0, x.size(1) - 1, 2)
        
        if n < self.config.context // 2:
            if self.config.type == 'default':
                indices = indices[:n]
            elif self.config.type == 'random':
                indices = indices[torch.randperm(indices.size(0))[:n]]
        
            x = change(x, indices)
            out = change(out, indices)

        indices = torch.cat([indices // 2, torch.tensor([self.config.context // 2])])
        return x, out, indices
        
    def forward(self, x, b = 1, n = 1):
        
        if len(x.shape) < 3:
            x = x.unsqueeze(0)
        
        x = self.emb(x)
        x = x + self.pe(torch.arange(x.size(1), device=x.device))
        output = torch.zeros_like(x)
        
        pred_list = []
        
        # Итерация 1
        output = output + x
        mask = self._get_mask(self.config, x.size(1), x.device)
        for layer in self.layers:
            output, scores = layer(output, mask)
        
        # Обрезаем и сохраняем индексы, чтобы потом лосс по ним считать
        x, output, indices = self.cut(x, output, n)
        prediction = self.out(output)[:, ::2, 0]
        pred_list.append(prediction)
        
        mask = self._get_mask(self.config, x.size(1), x.device)
        
        # Итерации 2 - b
        for i in range(1, b):
            output = output + x
            for layer in self.layers:
                output, scores = layer(output, mask)
            prediction = self.out(output)[:, ::2, 0]
            pred_list.append(prediction)
            
        return pred_list, indices

In [None]:
@dataclass
class Config:
    n_dims:       int = 4
    num_layers:   int = 1
    attn_heads:   int = 4
    hidden_dim:   int = 32
    mlp_hidden:   int = 128
    context:      int = 128
    mask_type:    str = 'causal'
    type:         str = 'default'
    activation:   nn.Module = nn.GELU
    
def get_config(name, n_dims):
    models = {
        'default': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 129,
            type = 'default'
        ), 5e-4, 64, 64),
        'random': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 129,
            type = 'random'
        ), 5e-4, 64, 64),
    }
    return models[name]

## Слегка переделанный цикл обучения и валидации

In [None]:
def evaluate(loader, model, b, n):
        
    with torch.no_grad():
        total = 0
        for (x, y) in loader:
            
            preds, idx = model(x[:, :-1], b, n)
            preds = torch.stack(preds)
            targs = torch.stack([y[:, idx]] * b)
            
            # First by predictions, then by batches
            loss = (targs[:,:,-1] - preds[:,:,-1]).square().mean(dim=0).mean()
            
            total += loss.item() / loader.dataset.n_dims
    return total / len(loader)

def train(
    model,
    train_loader,
    test_loader,
    optimizer,
    b = None,
    steps = None,
    run = None,
    log_every = 1,
    n = None
):
    
    train_loader.dataset.length = steps * train_loader.batch_size
    
    if n is None:
        n = (model.config.context - 1) // 2
    
    evaluate_every = steps // 5
    
    loss_history = []
    eval_history = []
    
    val_loss = evaluate(test_loader, model, b, n)
    eval_history.append(val_loss)
    
    if run is not None:
        run.log({'Eval Loss': val_loss}, commit = False, step = 1)
    
    pbar, step = tqdm(range(steps)), 0
    for (x, y) in train_loader:
        optimizer.zero_grad()

        preds, idx = model(x[:, :-1], b, n)
        preds = torch.stack(preds)
        targs = torch.stack([y[:, idx]] * b)
        
        # First by inputs, then by predictions, then by batches
        loss = (targs - preds).square().mean(dim=2).mean(dim=0).mean()
        loss = loss / train_loader.dataset.n_dims
        
        loss.backward()
        optimizer.step()
        
        loss_history.append(loss.item())
        
        pbar.set_description(f'Train loss: {loss.item():.5f}')
        pbar.update(1)
        
        step += 1
        
        if step % evaluate_every == 0:
            val_loss = evaluate(test_loader, model, b, n)
            eval_history.append(val_loss)
            if run is not None:
                run.log({'Eval Loss': val_loss}, commit = False, step = step)
        if run is not None and (step % log_every == 0 or step == 1):
            run.log({'Train Loss': loss.item()}, step = step)
    
    return loss_history, eval_history

## Тренировка

In [None]:
from torch.optim import Adam
from itertools import product

n_dims = 12
mean, std = 0, 2

seeds = [42, 451, 1984]
models = ['default', 'random'][1:]
bs = [10]
ns = [64, 48, 32, 24][:1]

runs = []
for model, b, n, seed in product(models, bs, ns, seeds):
    runs.append((model, b, n, seed))

device = 'cuda' if torch.cuda.is_available() else 'cpu'

extrp_bs = [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]

for (name, b, n, seed) in runs:

    config, lr, train_bsize, test_bsize = get_config(name, n_dims)
    n_points = 64 + 1

    train_loader = DataLoader(LinregDataset(
        n_dims = n_dims, n_points = n_points,
        mean = mean, std = std, random = True,
        device = device
    ), batch_size = train_bsize)
    test_loader = DataLoader(LinregDataset(
        n_dims = n_dims, n_points = n_points,
        mean = mean, std = std, random = True,
        total = test_bsize * 25, device = device
    ), batch_size = test_bsize)

    set_seed(seed)

    model = LoopedTransformer(config).to(device)
    optimizer = Adam(model.parameters(), lr=lr)

    run_name = f'{name}_{b}_{n}_{seed}'
    run = wandb.init(
        project = 'Looped Transformer',
        name = run_name,
        config = {
            'name': f'exp5_{name}_{b}_{n}',
            'experiment': 5,
            'model': name,
            'b': b,
            'n': n,
            'train batch size': train_bsize,
            'test batch size': test_bsize,
            'lr': lr,
            'seed': seed,
            'N': n_points - 1,
        }
    )

    loss_history, eval_history = train(
        model, train_loader, test_loader, optimizer, 
        b = b, steps = 7500, run = run, log_every = 75, n = n
    )

    extrapolation = []
    for bi in extrp_bs:
        result = evaluate(test_loader, model, bi, n)
        extrapolation.append(result)
        print(f'b = {bi} --- loss = {result:.5f}')
        run.log({'b': bi, 'MSE': result})
    
    run.finish()

    save(
        name = run_name, 
        model = model,
        loss = loss_history, 
        eval = eval_history, 
        extr = extrapolation,
        path = './results/experiment 5'
        )

    torch.cuda.empty_cache()

## Валидация

In [None]:
import json
from itertools import product

seeds = [42, 451, 1984]
models = ['default', 'random']
bs = [10]
ns = [64, 48, 32, 24]

runs = []
for model, b, n in product(models, bs, ns):
    runs.append((model, b, n))
    
test_seed = 4815163242 % 2**16
num_launches = 10
    
results = {}
pbar = tqdm(range(len(runs) * num_launches * 3))
for (name, b, n) in runs:
    config, lr, train_bsize, test_bsize = get_config(name, n_dims)
    n_points = 64 + 1
    
    res = []
    
    set_seed(test_seed)
    for seed in seeds:
        model = torch.load(f'./results/experiment 5/models/{name}_{b}_{n}_{seed}.pt')
            
        for i in range(num_launches):
            
            std = torch.rand((1,)).item() * 2
            loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                total = 128 * 10, device = device
            ), batch_size = 128)
            
            total = 0
            with torch.no_grad():
                for (x, y) in loader:
                    
                    preds, idx = model(x[:, :-1], b, n)
                    preds = torch.stack(preds)
                    targs = torch.stack([y[:, idx]] * b)
                    
                    # First by predictions, then by batches
                    loss = (targs[:,:,-1] - preds[:,:,-1]).square().mean(dim=0).mean()
                    
                    total += loss.item() / loader.dataset.n_dims
            res.append(total / len(loader))
            pbar.set_description(f'Run \'{name}_{b}_{n}\', seed {seed}...')
            pbar.update(1)
                
    results[f'{name}_{b}_{n}'] = res

In [None]:
with open('./results/experiment 5/data/evaluation.json', 'w') as f:
    f.write(json.dumps(results, indent=4))