In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from utils_accelerate import *

tokenizer = T5Tokenizer.from_pretrained('t5-small')
# input = "predict tail: barack obama | position_held |"
# input = "translate English to German: How are you doing?"

# model = T5ForConditionalGeneration.from_pretrained('models/codex_m_accelerate_1gpu.pt')
checkpoint_location = 'models/codex_m_accelerate_1gpu/115000.pt'
model = load_accelerator_model(checkpoint_location, only_model=True)

In [51]:
input = "predict tail: united states of america | member of |"
input_ids = tokenizer(input, return_tensors="pt").input_ids  # Batch size 1
model.cpu()


T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dro

In [52]:
# outputs = model.sample(input_ids)
from transformers import (
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    BeamSearchScorer,
)

In [109]:
import torch
encoder_input_str = ["predict tail: united states of america | member of |", "predict tail: united states of america | member of |"]
encoder_input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
num_beams = 10
input_ids = torch.ones((len(encoder_input_str) * num_beams, 1), device=model.device, dtype=torch.long)
input_ids = input_ids * model.config.decoder_start_token_id
model_kwargs = {
    "encoder_outputs": model.get_encoder()(encoder_input_ids.repeat_interleave(num_beams, dim=0), return_dict=True)
}
beam_scorer = BeamSearchScorer(
    batch_size=2,
    max_length=model.config.max_length,
    num_beams=num_beams,
    device=model.device,
    num_beam_hyps_to_keep=num_beams
)
logits_processor = LogitsProcessorList([
    MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
])

In [110]:
input_ids

tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0]])

In [111]:
outputs = model.beam_search(input_ids, beam_scorer, logits_processor=logits_processor, **model_kwargs)

In [113]:
outputs

tensor([[    0,  1038,  2137,    21, 20532,    11,   606,     1,     0,     0,
             0,     0,     0,     0,     0],
        [    0,  5102,    21,     8, 30693,    13,  5368,  7749,     1,     0,
             0,     0,     0,     0,     0],
        [    0,     3, 15974,    18,  5379,  3286,  1456,  8494,     1,     0,
             0,     0,     0,     0,     0],
        [    0,  1249, 12088,  1729,  3614,  3193,     1,     0,     0,     0,
             0,     0,     0,     0,     0],
        [    0,  5102,    21,  1456,  8494,    11,   606,     1,     0,     0,
             0,     0,     0,     0,     0],
        [    0,  2665,    63,    30,   539, 22902,     1,     0,     0,     0,
             0,     0,     0,     0,     0],
        [    0,  1038,     3,    17, 18329,  7021,     1,     0,     0,     0,
             0,     0,     0,     0,     0],
        [    0,  1038,     3, 20844,   827,  3193,     1,     0,     0,     0,
             0,     0,     0,     0,     0],
        

In [114]:
print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True))

Generated: ['international bank for reconstruction and development', 'organisation for the prohibition of chemical weapons', 'asia-pacific economic cooperation', 'multilateral investment guarantee agency', 'organisation for economic cooperation and development', 'treaty on open skies', 'international telecommunication union', 'international atomic energy agency', 'organisation for economic cooperation de la francophonie', 'treaty on open skies 2 days in the council of europe', 'international bank for reconstruction and development', 'organisation for the prohibition of chemical weapons', 'asia-pacific economic cooperation', 'multilateral investment guarantee agency', 'organisation for economic cooperation and development', 'treaty on open skies', 'international telecommunication union', 'international atomic energy agency', 'organisation for economic cooperation de la francophonie', 'treaty on open skies 2 days in the council of europe']


In [115]:
x = tokenizer.batch_decode(outputs, skip_special_tokens=True)

In [117]:
len(x)

20

In [118]:
outputs.shape

torch.Size([20, 15])

In [71]:
outputs

tensor([[    0,  1038,  2137,    21, 20532,    11,   606,     1]])

Generated: ['international bank for reconstruction and development']


In [3]:
print(input)
print(''.join(tokenizer.convert_ids_to_tokens(outputs[0])))

predict tail: united states of america | member of |
<pad>▁international▁bank▁for▁reconstruction▁and▁development</s>


In [4]:
from dataset import T5_Dataset

In [5]:
valid_dataset = T5_Dataset('test', dataset_name='codex-m')

100%|██████████| 20622/20622 [00:00<00:00, 840393.08it/s]


In [6]:
from eval_accelerate import removePadding, eval

In [7]:
class Args:
    batch_size = 200
args=Args()

In [8]:
acc = eval(model, valid_dataset, args)

100%|██████████| 104/104 [00:59<00:00,  1.74batches/s]


In [9]:
acc

0.10876733585491223

In [8]:
actual = tokenizer("international development association", return_tensors="pt").input_ids[0].numpy()

In [9]:
actual

array([1038,  606, 6028,    1])

In [10]:
predicted = outputs[0][1:].numpy()

In [11]:
predicted

array([ 1038,  2137,    21, 20532,    11,   606,     1])

In [30]:
actual == predicted

array([ True,  True,  True,  True])

In [25]:
actual.numpy()

array([1038,  606, 6028,    1])