In [1]:
import torch
import torch.nn as nn
import math

import wandb
import random
import numpy as np
from tqdm.notebook import tqdm_notebook as tqdm

import matplotlib.pyplot as plt

from utils import LinregDataset, save, train, set_seed
from torch.utils.data import DataLoader

from dataclasses import dataclass

In [2]:
from transformer import SinusoidalPositionEmbedding, Layer

class LoopedTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        self.emb = nn.Linear(config.n_dims, config.hidden_dim)
        
        if config.pe == 'sinus':
            self.pe = SinusoidalPositionEmbedding(config.hidden_dim, config.context)
        elif config.pe == 'learnable':
            self.pe = nn.Embedding(config.context, config.hidden_dim)
        
        self.layers = nn.Sequential(*[
            Layer(config) for i in range(config.num_layers)
        ])
        self.out = nn.Linear(config.hidden_dim, config.n_dims)
         
    def _get_mask(self, config, input_dim, device = 'cpu'):
        if config.mask_type == 'causal':
            mask = torch.tril(torch.ones(input_dim, input_dim))
            return mask.view(1, 1, input_dim, input_dim).to(device)
        if config.mask_type == 'none':
            return None
        
    def forward(self, x, b = 1):
        
        if len(x.shape) < 3:
            x = x.unsqueeze(0)
        
        # Довольно долго, но это рудимент от предыдущего кода.
        # Можно перенести маску в буфер в init.
        mask = self._get_mask(self.config, self.config.context, x.device)
        
        x = self.emb(x)
        if self.config.pe == 'sinus':
            x = x + self.pe(x)
        elif self.config.pe == 'learnable':
            x = x + self.pe(torch.arange(x.size(1), device=x.device))
        
        output = torch.zeros_like(x)
        
        pred_list = []
        for i in range(b):
            output = output + x
            for layer in self.layers:
                output, scores = layer(output, mask)
            prediction = self.out(output)[:, ::2, 0]
            pred_list.append(prediction)
            
        return pred_list

In [3]:
@dataclass
class Config:
    n_dims:       int = 4
    num_layers:   int = 1
    attn_heads:   int = 4
    hidden_dim:   int = 32
    mlp_hidden:   int = 128
    context:      int = 128
    mask_type:    str = 'causal'
    pe:           str = None
    activation:   nn.Module = nn.GELU
    
def get_config(name, n_dims):
    models = {
        'spe_cm': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            mask_type = 'causal', pe = 'sinus'
        ), 5e-4, 64, 32),
        
        'spe': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            mask_type = 'none', pe = 'sinus'
        ), 5e-4, 64, 32),
        
        'cm': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            mask_type = 'causal', pe = None
        ), 5e-4, 64, 32),
        
        'none': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            mask_type = 'none', pe = None
        ), 5e-4, 64, 32),
        
        'lpe': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            mask_type = 'none', pe = 'learnable'
        ), 5e-4, 64, 32),
        
        'lpe_cm': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            mask_type = 'causal', pe = 'learnable'
        ), 5e-4, 64, 32),
    }
    return models[name]

In [None]:
from itertools import product
from torch.optim import Adam

n_dims = 8
mean, std = 0, 1
device = 'cuda' if torch.cuda.is_available() else 'cpu'

seeds = [42, 451, 1984]
models = ['none', 'cm', 'spe', 'lpe', 'spe_cm', 'lpe_cm']
bs = [1, 5]

for name in models:
    for b in bs:
        for seed in seeds:
        
            config, lr, train_bsize, test_bsize = get_config(name, n_dims)
            n_points = (config.context + 1) // 2
            
            train_loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                device = device
            ), batch_size = train_bsize)
            test_loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                total = test_bsize * 10, device = device
            ), batch_size = test_bsize)

            set_seed(seed)
            
            model = LoopedTransformer(config).to(device)
            optimizer = Adam(model.parameters(), lr=lr)
            
            if 'spe' in name:
                pe = 'spe'
            elif 'lpe' in name:
                pe = 'lpe'
            else:
                pe = 'none'
            cm = 'cm' in name
                
            run_name = f'{name}_{b}_{seed}'
            run = wandb.init(
                project = 'Looped Transformer',
                name = run_name,
                config = {
                    'name': f'exp3_{name}_{b}',
                    'experiment': 3,
                    'model': name,
                    'b': b,
                    'train batch size': train_bsize,
                    'test batch size': test_bsize,
                    'lr': lr,
                    'seed': seed,
                    'N': n_points - 1,
                    'pe': pe,
                    'cm': cm
                }
            )
            
            loss_history, eval_history = train(
                model, train_loader, test_loader, optimizer, 
                b = b, steps = 7500, run = run, log_every = 75
            )
                
            run.finish()
            
            save(3, run_name, loss_history, eval_history)
            torch.save(model, './results/experiment 3/models/' + run_name + '.pt')
            
            torch.cuda.empty_cache()

In [14]:
import json
from itertools import product

n_dims = 8
mean, std = 0, 1

seeds = [42, 451, 1984]
models = ['none', 'cm', 'spe', 'lpe', 'spe_cm', 'lpe_cm']
bs = [1, 5]

device = 'cuda'

runs = []
for model, b in product(models, bs):
    runs.append((model, b))
    
test_seed = 4815163242 % 2**31
num_launches = 10
    
results = {}
pbar = tqdm(range(len(runs) * num_launches * 3))
for (name, b) in runs:
    config, lr, train_bsize, test_bsize = get_config(name, n_dims)
    n_points = (config.context + 1) // 2
    
    set_seed(test_seed)
    res = []
    for seed in seeds:
        model = torch.load(f'./results/experiment 3/models/{name}_{b}_{seed}.pt')
            
        for i in range(num_launches):
            
            std = torch.rand((1,)).item() * 2
            loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                total = 128 * 10, device = device
            ), batch_size = 128)
            
            total = 0
            with torch.no_grad():
                for (x, y) in loader:
                    
                    preds = model(x[:, :-1], b)
                    preds = torch.stack(preds)
                    targs = torch.stack([y] * b)
                    
                    # First by predictions, then by batches
                    loss = (targs[:,:,-1] - preds[:,:,-1]).square().mean(dim=0).mean()
                    
                    total += loss.item() / loader.dataset.n_dims
            res.append(total / len(loader))
            pbar.set_description(f'Run \'{name}_{b}\', seed {seed}...')
            pbar.update(1)
                
    results[f'{name}_{b}'] = res

  0%|          | 0/360 [00:00<?, ?it/s]

In [16]:
results

{'none_1': [0.0007617591239977628,
  0.16063800901174546,
  0.06604306474328041,
  1.0238654613494873,
  1.6386443734169007,
  0.6003895491361618,
  2.2944544196128844,
  3.3825491666793823,
  0.3290462464094162,
  1.2762592792510987,
  0.6742267310619354,
  3.141448736190796,
  0.9908895373344422,
  0.35089159607887266,
  0.2803729087114334,
  0.07493693009018898,
  0.09332494884729385,
  2.6483384370803833,
  1.0674304366111755,
  0.022384471818804742,
  2.101412773132324,
  3.437103033065796,
  1.5022632718086242,
  0.03170825019478798,
  3.2148009061813356,
  2.2701428174972533,
  0.03521009273827076,
  0.6284147381782532,
  3.0248604297637938,
  0.36026417911052705],
 'none_5': [0.015313539281487465,
  0.16762558221817017,
  0.06613628529012203,
  1.0227354764938354,
  1.6741648197174073,
  0.6005726456642151,
  2.2859718799591064,
  3.3874714851379393,
  0.3365749180316925,
  1.2562395572662353,
  0.6721282243728638,
  3.0953964471817015,
  0.9927003383636475,
  0.368436565995216

In [15]:
with open('./results/experiment 3/data/evaluation.json', 'w') as f:
    f.write(json.dumps(results, indent=4))