In [272]:
from transformers import AutoConfig, AutoModelForCausalLM, GPT2Tokenizer
from transformers import GPT2Tokenizer

vocab_size = 10
config = AutoConfig.from_pretrained("gpt2", vocab_size=vocab_size, n_ctx=11, n_head=3, n_layer=1)
model = AutoModelForCausalLM.from_config(config)

In [273]:
def model_size(model):
    return sum(t.numel() for t in model.parameters())

print(f'Model size: {model_size(model)/1000**2:.1f}M parameters')

Model size: 7.9M parameters


In [274]:
model_ckpt = 'sortingLLM'

In [263]:
model.save_pretrained("models/" + model_ckpt, push_to_hub=True)

pytorch_model.bin:  14%|█▎        | 4.29M/31.5M [00:04<00:24, 1.13MB/s]

KeyboardInterrupt: 

pytorch_model.bin:  14%|█▍        | 4.37M/31.5M [00:20<00:24, 1.13MB/s]

In [275]:
class NumberTokenizer:
  def __init__(self, numbers_qty=10):
    self.numbers_qty = numbers_qty
    self.pad_token = '-1'
    self.encoder = {str(v):i for i,v in enumerate(range(-1, numbers_qty-2))}
    self.decoder = {i:str(v) for i,v in enumerate(range(-1, numbers_qty-2))}
    self.pad_token_id = self.encoder[self.pad_token]

  def decode(self, token_ids):
    return ' '.join(self.decoder[t] for t in token_ids)

  def __call__(self, text):
    return [self.encoder[t] for t in text.split()]

In [276]:
tokenizer = NumberTokenizer(vocab_size)
tokenizer("0 1 2 3 4")

[1, 2, 3, 4, 5]

In [277]:
import pickle
import torch
from torch.utils.data import Dataset

class SortDataset(Dataset):
    """ 
    Dataset for the Sort problem. E.g. for problem length 6:
    Input:  0 0 2 1 0 1 -> Output:  0 0 0 1 1 2
    Which will feed into the transformer concatenated as:
    input:   0 0 2 1 0 1 0 0 0 1 1
    output:  I I I I I 0 0 0 1 1 2
    where I is "ignore", as the transformer is reading the input sequence
    """

    def __init__(self, split, length=6, num_digits=3):
        assert split in {'train', 'test'}
        self.split = split
        self.length = length
        self.num_digits = num_digits
    
    def __len__(self):
        return 10000 # ...
    
    def get_vocab_size(self):
        return self.num_digits
    
    def get_block_size(self):
        # the length of the sequence that will feed into transformer, 
        # containing concatenated input and the output, but -1 because
        # the transformer starts making predictions at the last input element
        return self.length * 2 - 1

    def __getitem__(self, idx):
        
        # use rejection sampling to generate an input example from the desired split
        while True:
            # generate some random integers
            inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)
            # half of the time let's try to boost the number of examples that 
            # have a large number of repeats, as this is what the model seems to struggle
            # with later in training, and they are kind of rate
            if torch.rand(1).item() < 0.5:
                if inp.unique().nelement() > self.length // 2:
                    # too many unqiue digits, re-sample
                    continue
            # figure out if this generated example is train or test based on its hash
            h = hash(pickle.dumps(inp.tolist()))
            inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test
            if inp_split == self.split:
                break # ok
        
        # solve the task: i.e. sort
        sol = torch.sort(inp)[0]

        # concatenate the problem specification and the solution
        cat = torch.cat((inp, sol), dim=0)

        # the inputs to the transformer will be the offset sequence
        x = cat[:-1].clone()
        y = cat[1:].clone()
        # we only want to predict at output locations, mask out the loss at the input locations
        y[:self.length-1] = int(tokenizer.pad_token)
        x, y = ' '.join(map(str, x.tolist())), ' '.join(map(str, y.tolist()))
        tokenized_input = tokenizer(x)
        tokenized_output = tokenizer(y)
        return torch.tensor(tokenized_input), torch.tensor(tokenized_output)

In [278]:
train_dataset = SortDataset('train')
test_dataset = SortDataset('test')

In [279]:
print(test_dataset[0])

(tensor([1, 2, 1, 2, 3, 3, 1, 1, 2, 2, 3]), tensor([0, 0, 0, 0, 0, 1, 1, 2, 2, 3, 3]))


In [281]:
import torch
import torch.nn.functional as F
# from datasets import load_dataset
from accelerate import Accelerator

accelerator = Accelerator()

batch_size = 50
input_size = 11
input_to_sort_size = 6
num_epochs = 10

optimizer = torch.optim.Adam(model.parameters())
dataset = train_dataset
data = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=batch_size)

# Dont take in account the list of input to sort
attention_mask = torch.ones((batch_size, input_size), dtype=torch.long)
attention_mask[:, :input_to_sort_size-1] = 0

model, optimizer, data = accelerator.prepare(model, optimizer, data)
attention_mask = attention_mask.to(accelerator.device)

model.train()
for epoch in range(num_epochs):
  for source, targets in data:
    optimizer.zero_grad()
    loss = torch.nn.functional.cross_entropy(model(source).logits.flatten(end_dim=1), targets.flatten(end_dim=1), ignore_index=tokenizer.pad_token_id)
    accelerator.backward(loss)
    optimizer.step()

In [286]:
def generate_solution(input, solution_length=6):
  model.eval()
  input = torch.tensor(tokenizer(input))
  input = input.to(accelerator.device)
  solution = []
  for i in range(solution_length):
    output = model(input)
    predicted = output.logits[-1].argmax()
    input = torch.cat((input, predicted.unsqueeze(0)), dim=0)
    solution.append(predicted.cpu().item())
  return tokenizer.decode(solution)

In [287]:
test_sample = '0 2 1 1 0'
generate_solution(test_sample)

'0 0 0 1 1 2'