In [8]:
import torch

from retnet import GPTR
import run_model

2023-11-25 09:59:06.735850: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-25 09:59:06.735876: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-25 09:59:06.736643: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [87]:
class GPTRConfig(object):
    def __init__(self, vocab_size,
                 context_window,
                 embedding_dim=64,
                 nlayers=6,
                 nheads=4,
                 nhidden=None,
                 nclasses=None):
        self.vocab_size = vocab_size
        self.context_window = context_window
        self.embedding_dim = embedding_dim
        self.nlayers = nlayers
        self.nheads = nheads
        self.nhidden = nhidden
        if self.nhidden is None:
            self.nhidden = 4*self.embedding_dim
        self.nclasses = nclasses

    def from_model_config(model_config):
        vocab_size = model_config.vocab_size
        context_window = model_config.max_position_embeddings
        embedding_dim = model_config.hidden_size
        nlayers = model_config.num_hidden_layers
        nheads = model_config.num_attention_heads
        nhidden = model_config.intermediate_size
        nclasses = model_config.num_labels
        return GPTRConfig(vocab_size,
                          context_window,
                          embedding_dim,
                          nlayers,
                          nheads,
                          nhidden,
                          nclasses)

class GPTRAutoregressive(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.model = GPTR(config.vocab_size,
                          config.context_window,
                          config.embedding_dim,
                          config.nlayers,
                          config.nheads,
                          config.nhidden)
    
    def forward(self, x):
        return self.model(x)
    

class GPTRClassifier(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.nclasses is not None
        self.model = GPTR(config.vocab_size,
                          config.context_window,
                          config.embedding_dim,
                          config.nlayers,
                          config.nheads,
                          config.nhidden)
        self.classifier = torch.nn.Linear(config.embedding_dim,
                                          config.nclasses)

    def get_which_index_from_mask(self, mask):
        """
        From a mask of shape [..., D], get from where we should call our classifier
        """
        # mask : [..., D]
        # return : [...]
        out = torch.max(1-mask, dim=-1)
        index_seq = (out.indices*out.values + mask.shape[-1]*(1-out.values))-1
        index_batch = torch.arange(index_seq.shape[0]).to(index_seq)
        return index_seq, index_batch
    
    def forward(self, input_ids, attention_mask):
        x = input_ids['input_ids']
        mask = attention_mask['attention_mask']
        index_seq, index_batch = self.get_which_index_from_mask(mask)
        print(index_seq)
        print(index_batch)
        x = self.model.decode(x)[index_batch, index_seq, :]
        x = self.classifier(x)
        return x

In [88]:
x 

tensor([[13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,  1, 15,  2, 19,  2, 11,  2,
         16,  2, 15,  2, 19,  2, 19,  2,  4,  2, 18,  2, 17,  2,  9,  2,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,

In [89]:
config = GPTRConfig(vocab_size=100,
                    context_window=30,
                    nlayers=4,
                    nclasses=10)
model = GPTRClassifier(config)

In [90]:
task = run_model.TASKS['listops-tiny']
config, model_config = task.config_getter()
dataset = task.dataset_fn(config)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=run_model.transformers_collator)

In [91]:
data = next(iter(dataloader))
x = data[0]['input_ids']
model.forward(x).shape

torch.Size([2, 10])

In [96]:
torch.argmax(1-data[0]['attention_mask'], axis=-1)

tensor([34, 70])

In [92]:
model.forward(data[0])

tensor([33, 69])
tensor([0, 1])


tensor([[ 1.2323,  0.5281, -1.0114, -0.1157,  0.1694,  0.4346, -0.4739, -1.2117,
          1.0016,  0.2805],
        [ 0.9401,  0.1419, -0.5465, -0.5943, -0.4775,  0.5544, -0.7533, -1.1351,
          0.7218,  0.1537]], grad_fn=<AddmmBackward0>)

In [99]:
mask = torch.tril(torch.ones(5, 5, dtype=torch.long))
out = torch.max(1-mask, dim=-1)
(out.indices*out.values + mask.shape[-1]*(1-out.values))-1

tensor([0, 1, 2, 3, 4])

In [100]:
model.get_which_index_from_mask(mask)

(tensor([0, 1, 2, 3, 4]), tensor([0, 1, 2, 3, 4]))

In [24]:
x = data[0]['input_ids']
model.forward(x).shape

torch.Size([1, 10])

In [25]:
model(x)

tensor([[ 0.2973,  0.0945,  0.1017, -0.0455,  0.0804, -0.6638, -0.4700,  0.9078,
          0.1955, -0.6841]], grad_fn=<AddmmBackward0>)