In [None]:
import torch
import torch.nn as nn

import wandb
import random
import numpy as np
from tqdm.notebook import tqdm_notebook as tqdm

import matplotlib.pyplot as plt

from utils import LinregDataset, save
from torch.utils.data import DataLoader

from dataclasses import dataclass

In [None]:
def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)

In [None]:
@dataclass
class Config:
    n_dims:       int = 4
    num_layers:   int = 1
    attn_heads:   int = 4
    hidden_dim:   int = 32
    mlp_hidden:   int = 128
    context:      int = 128
    mask_type:    str = 'causal'
    operation:    str = 'add'
    activation:   nn.Module = nn.GELU
    
def get_config(name, n_dims):
    models = {
        'none': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            operation = 'none'
        ), 5e-4, 128, 64),
        'add': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            operation = 'add'
        ), 5e-4, 128, 64),
        'mul': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            operation = 'mul'
        ), 5e-4, 128, 64),
        'linear_add': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            operation = 'linear_add'
        ), 5e-4, 128, 64),
        'linear_cat': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            operation = 'linear_cat'
        ), 5e-4, 128, 64),
        'each_layer': (Config(
            n_dims = n_dims, num_layers = 2,
            attn_heads = 4, hidden_dim = 64,
            mlp_hidden = 256, context = 127,
            operation = 'each_layer'
        ), 5e-4, 128, 64),
    }
    return models[name]

In [None]:
from transformer import Layer

class ModifiedLoopedTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        self.emb = nn.Linear(config.n_dims, config.hidden_dim)
        self.pe = nn.Embedding(config.context, config.hidden_dim)
        self.layers = nn.Sequential(*[
            Layer(config) for i in range(config.num_layers)
        ])
        
        functions = {
            'none': self.none,
            'add': self.add,
            'linear_add': self.linear_add,
            'linear_cat': self.linear_cat,
            'each_layer': self.each_layer
        }
        self.function = functions[config.operation]
        
        if config.operation == 'linear_add':
            self.linear = nn.Linear(config.hidden_dim, config.hidden_dim)
        if config.operation == 'linear_cat':
            self.linear = nn.Linear(config.hidden_dim * 2, config.hidden_dim)
        
        self.out = nn.Linear(config.hidden_dim, config.n_dims)
         
    def _get_mask(self, config, input_dim, device):
        if config.mask_type == 'causal':
            mask = torch.tril(torch.ones(input_dim, input_dim))
            return mask.view(1, 1, input_dim, input_dim).to(device)
        else:
            raise NotImplementedError(f'Mask type \'{config.mask_type}\' is not implemented.')
    
    def add(self, x, b = 1, mask = None):
        output = torch.zeros_like(x)
        pred_list = []
        for i in range(b):
            output = output + x
            for layer in self.layers:
                output, scores = layer(output, mask)
            prediction = self.out(output)[:, ::2, 0]
            pred_list.append(prediction)
        return pred_list
            
    def linear_add(self, x, b = 1, mask = None):
        output = x
        pred_list = []
        for i in range(b):
            for layer in self.layers:
                output, scores = layer(output, mask)
            output = self.linear(output + x)
            prediction = self.out(output)[:, ::2, 0]
            pred_list.append(prediction)
        return pred_list
    
    def linear_cat(self, x, b = 1, mask = None):
        output = x
        pred_list = []
        for i in range(b):
            for layer in self.layers:
                output, scores = layer(output, mask)
            output = self.linear(torch.cat([output, x], dim=-1))
            prediction = self.out(output)[:, ::2, 0]
            pred_list.append(prediction)
        return pred_list
    
    def none(self, x, b = 1, mask = None):
        output = x
        pred_list = []
        for i in range(b):
            for layer in self.layers:
                output, scores = layer(output, mask)
            prediction = self.out(output)[:, ::2, 0]
            pred_list.append(prediction)
        return pred_list
    
    def each_layer(self, x, b = 1, mask = None):
        output = torch.zeros_like(x)
        pred_list = []
        for i in range(b):
            for layer in self.layers:
                output = output + x
                output, scores = layer(output, mask)
            prediction = self.out(output)[:, ::2, 0]
            pred_list.append(prediction)
        return pred_list
        
    def forward(self, x, b = 1):
        
        if len(x.shape) < 3:
            x = x.unsqueeze(0)
        
        mask = self._get_mask(self.config, x.size(1), x.device)
        
        x = self.emb(x)
        x = x + self.pe(torch.arange(x.size(1), device=x.device))
        
        pred_list = self.function(x, b, mask)
        
        return pred_list

In [None]:
from itertools import product
from utils import train, evaluate
from torch.optim import Adam

n_dims = 12
mean, std = 0, 2
device = 'cuda' if torch.cuda.is_available() else 'cpu'

seeds = [42, 451, 1984, 31415, 271828]
models = ['none', 'each_layer']# + ['add', 'linear_add', 'linear_cat']
bs = [10]

runs = []
for model, b, seed in product(models, bs, seeds):
    runs.append((model, b, seed))

for name in models:
    for b in bs:
        for seed in seeds:
            config, lr, train_bsize, test_bsize = get_config(name, n_dims)
            n_points = (config.context + 1) // 2
            
            train_loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                device = device
            ), batch_size = train_bsize)
            test_loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                total = test_bsize * 10, device = device
            ), batch_size = test_bsize)

            set_seed(seed)
            
            model = ModifiedLoopedTransformer(config).to(device)
            optimizer = Adam(model.parameters(), lr=lr)
                
            run_name = f'{name}_{b}_{seed}'
            run = wandb.init(
                project = 'Looped Transformer',
                name = run_name,
                config = {
                    'name': f'exp4_{name}_{b}',
                    'experiment': 4,
                    'model': name,
                    'b': b,
                    'train batch size': train_bsize,
                    'test batch size': test_bsize,
                    'lr': lr,
                    'seed': seed,
                    'N': n_points - 1,
                }
            )
            
            loss_history, eval_history = train(
                model, train_loader, test_loader, optimizer, 
                b = b, steps = 7500, run = run, log_every = 75
            )
            
            extrapolation = []
            for bi in [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]:
                result = evaluate(test_loader, model, bi)
                extrapolation.append(result)
                print(f'b = {bi} --- loss = {result:.5f}')
                run.log({'b': bi, 'MSE': result})
                
            run.finish()
            
            save(4, run_name, loss_history, eval_history, extrapolation)
            torch.save(model, './results/experiment 4/models/' + run_name + '.pt')
            
            torch.cuda.empty_cache()

In [None]:
import json
from itertools import product

seeds = [42, 451, 1984, 31415, 271828]
models = ['none', 'each_layer', 'add', 'linear_add', 'linear_cat']
bs = [10]

test_seed = 4815163242 % 2**16
num_launches = 15

runs = [('none', 10), ('each_layer', 10), ('add', 10), ('linear_add', 10), ('linear_cat', 10)]
    
results = {}
pbar = tqdm(range(len(runs) * num_launches * 5))
for (name, b) in runs:
    config, lr, train_bsize, test_bsize = get_config(name, n_dims)
    n_points = (config.context + 1) // 2
    
    set_seed(test_seed)
    res = []
    for seed in seeds:
        model = torch.load(f'./results/experiment 4/models/{name}_{b}_{seed}.pt')
            
        for i in range(num_launches):
            
            std = torch.rand((1,)).item() * 2
            loader = DataLoader(LinregDataset(
                n_dims = n_dims, n_points = n_points,
                mean = mean, std = std, random = True,
                total = 128 * 25, device = device
            ), batch_size = 128)
            
            total = 0
            with torch.no_grad():
                for (x, y) in loader:
                    
                    preds = model(x[:, :-1], b)
                    preds = torch.stack(preds)
                    targs = torch.stack([y] * b)
                    
                    # First by predictions, then by batches
                    loss = (targs[:,:,-1] - preds[:,:,-1]).square().mean(dim=0).mean()
                    
                    total += loss.item() / loader.dataset.n_dims
            res.append(total / len(loader))
            pbar.set_description(f'Run \'{name}_{b}\', seed {seed}...')
            pbar.update(1)
                
    results[f'{name}_{b}'] = res

In [None]:
with open('./results/experiment 4/data/evaluation.json', 'w') as f:
    f.write(json.dumps(results, indent=4))