In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
#transition matrix for visible states
P = torch.zeros(2,2,2)

p_G, p_B = 0.8, 0.3
#p_G, p_B = 0, 1
P[0,:,:] = torch.Tensor([[1-p_G,p_G],[p_G,1-p_G]])
P[1,:,:] = torch.Tensor([[1-p_B,p_B],[p_B,1-p_B]])

p_G, p_B = 0.8, 0.3
#p_G, p_B = 0, 1
P[0,:,:] = torch.Tensor([[1-0.2,0.2],[0.3,0.7]])
P[1,:,:] = torch.Tensor([[1-p_B,p_B],[p_B,1-p_B]])

In [3]:
#transition matrix for hidden error process
#b, g = 0., 0.7 # real
b, g = 0.1, 0.7 # real 2
b, g = 0.1, 0.5 # real 3
b, g = 0, 1 # real 4
#b, g = 0.2, 0.2
#b, g = 0, 1
P_hidden = torch.Tensor([[1-b,b],[g,1-g]])

In [4]:
def get_next_symbols(P, P_hidden, data, data_hidden):
    M = P[data_hidden.to(int),data.to(int)]
    s = torch.multinomial(M,1).flatten()

    M_hidden = P_hidden[data_hidden.to(int)]
    s_hidden = torch.multinomial(M_hidden,1).flatten()


    return s, s_hidden

In [5]:
def get_batch(P, P_hidden, seq_length, batch_size):
    #alpha = 0.5
    alpha = 0 # real 4
    data = torch.zeros(batch_size, seq_length+1)
    data[:,0] = torch.bernoulli(alpha*torch.ones((batch_size,)))
    data_hidden = torch.zeros(batch_size, seq_length+1)
    data_hidden[:,0] = torch.bernoulli(alpha*torch.ones((batch_size,)))
    for i in range(seq_length):
        data[:,i+1], data_hidden[:,i+1] = get_next_symbols(P, P_hidden, data[:,i], data_hidden[:,i])
    x = data[:,:seq_length].to(int)
    y = data[:,1:].to(int)
    return x, y

In [6]:
get_batch(P, P_hidden, 50, 2)

(tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1,
          1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0,
          0, 1],
         [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          1, 1]]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1,
          1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0,
          1, 1],
         [0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
          1, 1]]))

In [7]:
from models import GPTBase

### Training

In [8]:
from types import SimpleNamespace

config = SimpleNamespace(
    n_embd = 8,#n_embd = 32,
    sequence_length = 1024,
    bias = False,
    dropout = True
)

In [9]:
train_args = SimpleNamespace(
    iterations = 8000,
    beta1 = 0.9,
    beta2 = 0.95,
    lr = 2e-3,
    batch_size = 16,
    weight_decay = 1e-3,
    scheduler = 'cos',
    warmup_percent = 0.02,
)

In [10]:
@torch.no_grad()
def eval(model, P, P_hidden, sequence_length, batch_size, max_num_batches=20):
    assert model.training == False

    loss_list_val, acc_list = [], []

    for _ in range(max_num_batches): 
        x, y = get_batch(P, P_hidden, sequence_length, batch_size)
        outputs, _ = model(x, y)
        outputs = outputs.squeeze(-1)
        val_loss = F.binary_cross_entropy_with_logits(outputs.view(-1), y.float().view(-1))
        loss_list_val.append(val_loss)
        acc_list.append(((outputs > 0) == y.to(bool)).float().mean())

    val_acc = torch.stack(acc_list).mean().item()
    val_loss = torch.stack(loss_list_val).mean().item()
    val_perplexity = 2.71828 ** val_loss

    return val_acc, val_loss, val_perplexity

In [None]:
import time

itr, best_val_loss, text_table = 0, float('inf'), None

stats = {'train_loss': [], 'val_loss': [], 'norm': [], 'val_acc': []}

model = GPTBase(config)
model.train()

opt = optim.AdamW(model.parameters(), lr=train_args.lr, betas=(train_args.beta1, train_args.beta2),
                  weight_decay=train_args.weight_decay)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=opt, max_lr=train_args.lr, total_steps=train_args.iterations, 
                                               pct_start=train_args.warmup_percent, anneal_strategy=train_args.scheduler, 
                                               cycle_momentum=False, div_factor=1e2, final_div_factor=.05)
#scheduler = None

t0 = time.time()
print('Starting pumping')
while itr < train_args.iterations:
    x, y = get_batch(P, P_hidden, config.sequence_length, train_args.batch_size)
    outputs, contrib_norm = model(x)
    outputs = outputs.squeeze(-1)
    
    loss = F.binary_cross_entropy_with_logits(outputs.view(-1), y.float().view(-1))
    loss.backward()

    opt.step()
    if scheduler is not None:
        scheduler.step()
    opt.zero_grad(set_to_none=True)
    itr += 1

    stats['norm'].append(contrib_norm)
    stats['train_loss'].append(loss.detach().cpu().item())

    if itr % 1 == 0 or itr == train_args.iterations: # from here it's only evaluation code, all the training is above
        t1 = time.time()
        dt = t1 - t0

        model.eval()
        train_loss = loss.detach().cpu().item()
        current_lr = scheduler.get_last_lr()[0] if scheduler is not None else train_args.lr
        val_acc, val_loss, val_perplexity = eval(model, P, P_hidden, config.sequence_length, train_args.batch_size,
                                                 max_num_batches=10)

        print_string = f"{itr} [train] loss={train_loss:.5f} [val] loss={val_loss:.5f}, pp={val_perplexity:.2f}, acc={val_acc:3f}"
        print_string += f" [time per itr] {dt*1000/1:.2f}ms"
        if scheduler is not None:
            print_string += f" [lr] {current_lr:.5f}"
        print_string += f""
        print(print_string)

        stats['val_loss'].append(val_loss)
        stats['val_acc'].append(val_acc)
        torch.save(stats, './stats_real_6.pt')

        model.train()
        t0 = time.time()

Starting pumping
1 [train] loss=0.69283 [val] loss=0.69282, pp=2.00, acc=0.571716 [time per itr] 298.61ms [lr] 0.00002
2 [train] loss=0.69281 [val] loss=0.69282, pp=2.00, acc=0.576080 [time per itr] 227.01ms [lr] 0.00002
3 [train] loss=0.69281 [val] loss=0.69282, pp=2.00, acc=0.576221 [time per itr] 248.45ms [lr] 0.00002
4 [train] loss=0.69282 [val] loss=0.69282, pp=2.00, acc=0.577307 [time per itr] 220.38ms [lr] 0.00002
5 [train] loss=0.69279 [val] loss=0.69282, pp=2.00, acc=0.580170 [time per itr] 217.22ms [lr] 0.00002
6 [train] loss=0.69282 [val] loss=0.69281, pp=2.00, acc=0.586243 [time per itr] 222.26ms [lr] 0.00003
7 [train] loss=0.69282 [val] loss=0.69282, pp=2.00, acc=0.583759 [time per itr] 252.52ms [lr] 0.00003
8 [train] loss=0.69283 [val] loss=0.69281, pp=2.00, acc=0.588171 [time per itr] 246.55ms [lr] 0.00003
9 [train] loss=0.69280 [val] loss=0.69281, pp=2.00, acc=0.593042 [time per itr] 220.07ms [lr] 0.00004
10 [train] loss=0.69281 [val] loss=0.69279, pp=2.00, acc=0.603754