In [1]:
# Training the brain, using text alone.

In [2]:
from visual_transformer import *

In [3]:
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from torch.utils.data import Dataset

In [4]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # -- penguins.farm version
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # -- penguins.army version

vocab_size = 10000


In [18]:
device

device(type='cuda', index=0)

In [5]:
brain = DefaultAgentBrain().to(device)

In [6]:
a = torch.randint(0, 2, (4,))
a

tensor([1, 0, 1, 1])

In [7]:
vals = torch.LongTensor([[0, 1, 2, 3, 4], [0, 1, 3, 2, 5]])
vals

tensor([[0, 1, 2, 3, 4],
        [0, 1, 3, 2, 5]])

In [8]:
vals[a]

tensor([[0, 1, 3, 2, 5],
        [0, 1, 2, 3, 4],
        [0, 1, 3, 2, 5],
        [0, 1, 3, 2, 5]])

In [9]:
vals = vals.to(device)
vals

tensor([[0, 1, 2, 3, 4],
        [0, 1, 3, 2, 5]], device='cuda:0')

In [10]:
v = vals.to('cpu')
v

tensor([[0, 1, 2, 3, 4],
        [0, 1, 3, 2, 5]])

In [11]:
vals

tensor([[0, 1, 2, 3, 4],
        [0, 1, 3, 2, 5]], device='cuda:0')

In [12]:
def get_batch(batchsize, device = device):
    v = vals.to(device)
    inds = torch.randint(0, 2, (batchsize,), device=device)
    return v[inds]

get_batch(10)

tensor([[0, 1, 3, 2, 5],
        [0, 1, 3, 2, 5],
        [0, 1, 3, 2, 5],
        [0, 1, 3, 2, 5],
        [0, 1, 3, 2, 5],
        [0, 1, 3, 2, 5],
        [0, 1, 3, 2, 5],
        [0, 1, 3, 2, 5],
        [0, 1, 3, 2, 5],
        [0, 1, 3, 2, 5]], device='cuda:0')

In [13]:
criterion = nn.CrossEntropyLoss(ignore_index=0)

def get_loss(res, inputs):
    return torch.sum(criterion(res[:, :, :-1], inputs[:, 1:]))

optimizer = optim.Adam(brain.parameters(), lr=0.0001*256/256, eps=1e-9)#, #betas=(0.9, 0.98), eps=1e-9)

In [19]:
from tqdm import tqdm

criterion = nn.CrossEntropyLoss(ignore_index=0)

batches = 30000
batchsize = 256

train_loss = 0

for batch in range(batches):
    brain.train()
    inputs = get_batch(batchsize)
    src_attention_mask, src_key_padding_mask = brain.get_masks(inputs, use_masks=True)
    text_encoding = brain.get_text_encoding(inputs, src_attention_mask, src_key_padding_mask)
    img_context = torch.randn((batchsize, 256, 768), device=inputs.device) # easier for pretraining
    res = brain.get_text_decoding(text_encoding, src_attention_mask, src_key_padding_mask, img_context, return_full=True)
    loss = get_loss(res, inputs)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    train_loss += loss.item()
    if batch % 100 == 99:
        avg_loss = train_loss / 100
        train_loss = 0
        print(f"Average Training Loss at batch {batch + 1}: {avg_loss}")
        torch.save(brain.text_enc.state_dict(), 'brain_checkpoints/text_encoder_weights_DEBUG_POSITION.pth')
        torch.save(brain.text_dec.state_dict(), 'brain_checkpoints/text_decoder_weights_DEBUG_POSITION.pth')


Average Training Loss at batch 100: 0.17730373039841651
Average Training Loss at batch 200: 0.1762099702656269
Average Training Loss at batch 300: 0.17520385324954987
Average Training Loss at batch 400: 0.17463001534342765
Average Training Loss at batch 500: 0.1745309340953827
Average Training Loss at batch 600: 0.1748275390267372


KeyboardInterrupt: 

In [21]:
brain.generate()

(tensor([[0, 1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 4, 4, 4, 4,
          4, 4, 4, 4, 5, 4, 4, 5]], device='cuda:0'),
 tensor([[-5.8365e-04, -7.1559e-01, -6.7329e-04, -1.9073e-04, -3.6812e-04,
          -1.3475e-03, -1.6108e-03, -5.1584e-03, -4.7136e-02, -1.4511e-02,
          -9.2657e-02, -9.3384e-03, -4.6642e-02, -9.1591e-03, -2.2130e-02,
          -1.7346e-02, -2.5759e-01, -1.1941e-01, -3.5244e+00, -1.2056e-01,
          -1.7401e-02, -2.7031e-02, -3.2639e-02, -9.7146e-02, -2.6613e-01,
          -2.6722e-02, -1.4102e-01, -1.0009e+00, -2.4345e-02, -1.9376e-01,
          -7.1617e-01]], device='cuda:0', grad_fn=<CopySlices>),
 tensor([[0.0098, 0.6986, 0.0106, 0.0034, 0.0064, 0.0189, 0.0222, 0.0610, 0.3074,
          0.1642, 0.4694, 0.0974, 0.3087, 0.0885, 0.1943, 0.2058, 0.8567, 0.6369,
          0.3949, 0.6102, 0.1684, 0.2205, 0.3073, 0.5889, 1.0236, 0.2639, 0.7292,
          1.4098, 0.2774, 1.1117, 1.3544]], device='cuda:0',
        grad_fn=<CopySlices>))

In [32]:
seed1 = vals[:1, :-1]
seed2 = vals[1:, :-1]

In [33]:
seed1

tensor([[0, 1, 2, 3]], device='cuda:0')

In [34]:
seed2

tensor([[0, 1, 3, 2]], device='cuda:0')

In [36]:
brain.extend(seed1, is_terminated = torch.tensor([False]).to(device))

(tensor([[0, 1, 2, 3, 4]], device='cuda:0'),
 tensor([-0.0004], device='cuda:0', grad_fn=<SqueezeBackward1>),
 tensor([0.0062], device='cuda:0', grad_fn=<NegBackward0>),
 tensor([False], device='cuda:0'))

In [37]:
# Ok, so it did learn to memorize these two things.
# It knows about position; it correctly uses positional encoding. Good.
# Done here. Successful test