In [32]:
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from retnet.modeling_retnet import RetNetForCausalLM
from retnet.configuration_retnet import load_config_from_json
from transformers import AutoTokenizer

In [2]:
with open("./data/shakespeare.txt", "rb") as f:
    text = pickle.load(f)

----

In [8]:

config = load_config_from_json('configs/retnet-base/config.json')
model = RetNetForCausalLM(config)

In [9]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.model_max_length = 1000000
tokenizer.pad_token = tokenizer.eos_token

context_inputs = tokenizer("I have a request for you, my sire.", return_tensors='pt')

In [10]:
# parallel forward
# our custom generate function
generated = model.custom_generate(context_inputs['input_ids'], parallel_compute_prompt=True, max_new_tokens=20)

In [11]:
generated = model.generate(**context_inputs, max_new_tokens=20)

In [12]:
generated

tensor([[   40,   423,   257,  2581,   329,   345,    11,   616,   264,   557,
            13, 37243, 15841, 46194, 40791, 45726, 42678, 24962, 14478, 47854,
         35026, 19192, 20486, 35408, 34015, 33585,  7729, 43324,  7525, 31710,
         26604]])

In [13]:
tokenizer.batch_decode(generated)

['I have a request for you, my sire. Chern Mak Roku.? Bahamas CapitalismurstPass Thro videot Mosul tajected moms mitigation instructions nervously primarilyBoston¶']

---

In [14]:
data = tokenizer.encode(text, return_tensors='pt', return_attention_mask=False)[0]

In [15]:
# Split the data into training and validation sets
n = int(0.9 * len(data))
data_train = data[:n]
data_val = data[n:]

In [16]:
BATCH_SIZE = 1
BLOCK_SIZE = 200

In [31]:
# Data loading function to get input (x) and target (y) batches
def get_batch(split_type, batch_size, data_train, data_val, block_size):
    data = data_train if split_type == 'train' else data_val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [35]:
x_train_test, y_train_test = get_batch("train", BATCH_SIZE, data_train, data_val, BLOCK_SIZE)

In [52]:
output_test = model(x_train_test)["logits"]

In [47]:
model(x_train_test, y_train_test)

RetNetCausalLMOutputWithPast(loss=None, logits=tensor([[[ 0.5378,  0.4868, -1.5332,  ...,  0.3570, -0.1741, -0.4460],
         [ 0.5664,  0.4663, -1.5862,  ...,  0.4134, -0.1759, -0.2898],
         [ 1.0328, -0.3031,  2.1131,  ..., -0.6309, -0.1988, -0.5183],
         ...,
         [-0.6413, -0.2282, -0.8098,  ...,  0.0620,  0.0032,  1.5376],
         [-0.2352, -0.5176, -2.5079,  ..., -0.3055,  1.5851,  1.1990],
         [-0.3700,  1.0973,  0.2156,  ..., -1.1637, -2.1866,  0.0649]]],
       grad_fn=<UnsafeViewBackward0>), past_key_values=(None, None, None, None, None, None), hidden_states=None, retentions=None, attentions=None)

In [44]:
b

'past_key_values'

In [19]:
loss_iter = 16

In [20]:

config = load_config_from_json('configs/retnet-base/config.json')
model = RetNetForCausalLM(config)

In [33]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [34]:
NB_EPOCHS = 10

In [61]:
output_test

tensor([[[ 0.5397,  0.4509, -1.5759,  ...,  0.3701, -0.1832, -0.3811],
         [ 0.5523,  0.4654, -1.5932,  ...,  0.3871, -0.1746, -0.3381],
         [ 1.0607, -0.2741,  2.1151,  ..., -0.6813, -0.2821, -0.4561],
         ...,
         [-0.7678, -0.1694, -0.9424,  ...,  0.1731,  0.0495,  1.6391],
         [-0.1331, -0.4586, -2.6028,  ..., -0.1229,  1.5787,  1.1152],
         [-0.2104,  1.1515,  0.1176,  ..., -1.2531, -2.2347,  0.1857]]],
       grad_fn=<UnsafeViewBackward0>)

In [57]:
output_test.contiguous().view(-1).shape

torch.Size([10051400])

In [None]:
model.train()

for epoch in range(NB_EPOCHS):
    optimizer.zero_grad()
    X, Y = get_batch("train", BATCH_SIZE, data_train, data_val, BLOCK_SIZE)
    output = model(X)['logits']
    print(output)
    loss = criterion(output.contiguous().view(-1), Y[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")