In [1]:
from __future__ import annotations
from typing import Any, Optional
from tactic_gen.lm_example import LmExample
from tactic_gen.train_fid import get_model, get_tokenizer, get_datasets
from tactic_gen.fid_data import FidDataset
from tactic_gen.fid_prime_model import FiDT5 
from transformers import T5ForConditionalGeneration
from transformers import T5Tokenizer, CodeLlamaTokenizer
import torch
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
MODEL_NAME = "google-t5/t5-small"

In [3]:
t5 = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
t5.config.n_passages = 8
fid_model = FiDT5(t5.config)
fid_model.load_t5(t5.state_dict())
fid_model.cuda()

FiDT5(
  (shared): Embedding(32128, 512)
  (encoder): EncoderWrapper(
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 512)
      (block): ModuleList(
        (0): CheckpointWrapper(
          (module): T5Block(
            (layer): ModuleList(
              (0): T5LayerSelfAttention(
                (SelfAttention): T5Attention(
                  (q): Linear(in_features=512, out_features=512, bias=False)
                  (k): Linear(in_features=512, out_features=512, bias=False)
                  (v): Linear(in_features=512, out_features=512, bias=False)
                  (o): Linear(in_features=512, out_features=512, bias=False)
                  (relative_attention_bias): Embedding(32, 8)
                )
                (layer_norm): T5LayerNorm()
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (1): T5LayerFF(
                (DenseReluDense): T5DenseActDense(
                  (wi): Linear(in_features=512, out_features=2048, b

In [4]:
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
test_example = LmExample("hi", "there")

In [6]:
dataset = FidDataset(None, tokenizer, 448, 64, 8)

In [7]:
test_batch = dataset.collate([test_example])

In [8]:
t5.config.d_model

512

In [9]:
outputs = fid_model.generate(
                    test_batch["input_ids"].cuda(),
                    test_batch["attention_mask"].cuda(),
                    64,
                    return_dict_in_generate=True,
                    output_scores=True,
                )

torch.Size([8, 448])
shape before view torch.Size([1, 448, 8, 512])
proj shape torch.Size([512, 4096])
state shape torch.Size([1, 448, 4096])
torch.Size([1, 448, 512])
BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[0.0368, 0.0000, 0.0193,  ..., 0.0738, 0.0000, 0.0000],
         [0.0659, 0.0367, 0.1574,  ..., 0.0651, 0.0000, 0.0000],
         [0.0447, 0.0423, 0.1159,  ..., 0.0229, 0.0000, 0.0000],
         ...,
         [0.0986, 0.0945, 0.1009,  ..., 0.0735, 0.0000, 0.0013],
         [0.1088, 0.0517, 0.0000,  ..., 0.0878, 0.0000, 0.0000],
         [0.1229, 0.0000, 0.1172,  ..., 0.0753, 0.0000, 0.0000]]],
       device='cuda:0'), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)


RuntimeError: The size of tensor a (56) must match the size of tensor b (448) at non-singleton dimension 3

In [12]:
outputs = fid_model(
                    test_batch["input_ids"].cuda(),
                    test_batch["attention_mask"].cuda(),
                    test_batch["labels"].cuda(),
                )

orig shape torch.Size([1, 8, 448])
reshaping torch.Size([8, 448])
torch.Size([8, 448])
shape before view torch.Size([1, 448, 8, 512])
proj shape torch.Size([512, 4096])
state shape torch.Size([1, 448, 4096])
torch.Size([1, 448, 512])
BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0320, 0.0668],
         [0.0000, 0.0021, 0.0468,  ..., 0.0222, 0.0287, 0.0000],
         [0.1294, 0.0965, 0.0000,  ..., 0.0213, 0.0249, 0.0000],
         ...,
         [0.0000, 0.0864, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.2473, 0.0881, 0.0372,  ..., 0.0000, 0.0000, 0.0000],
         [0.2165, 0.0441, 0.0000,  ..., 0.0097, 0.0000, 0.0000]]],
       device='cuda:0', grad_fn=<ReluBackward0>), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)


RuntimeError: output with shape [1, 8, 64, 448] doesn't match the broadcast shape [8, 8, 64, 448]

In [None]:
#input_ids = test_batch["input_ids"].view(test_batch["input_ids"].size(0), -1).cuda()
input_ids = test_batch["input_ids"].cuda()
attention_mask = test_batch["attention_mask"].cuda()
labels = test_batch["labels"].cuda()
model(input_ids, attention_mask, labels=labels)
#model.encoder(input_ids, attention_mask)

Seq2SeqLMOutput(loss=tensor(1.5929, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[[-54.2444, -31.2982, -34.3318,  ..., -68.7229, -69.2365, -69.2920],
         [-57.3117, -29.2449, -21.7290,  ..., -69.4969, -69.8299, -69.8837],
         [-57.6321, -29.4005, -24.8219,  ..., -68.5074, -68.8125, -68.7309],
         ...,
         [-55.4589, -30.7832, -35.0631,  ..., -70.0534, -70.5770, -70.6369],
         [-55.4020, -30.7915, -35.0198,  ..., -69.9774, -70.5003, -70.5602],
         [-55.3844, -30.8054, -34.9945,  ..., -69.9370, -70.4597, -70.5195]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[-3.4396e-01, -1.7363e+00, -7.6735e-01,  ..., -1.8445e+00,
           -1.2066e+00, -2.9358e-01],
          [-3.8958e-01,  5.2143e-01, -3.1301e-01,  ..., -1.0632e+00,
            7.7356e-01, -1.0901e+00],
          [ 7.4578e-01,  4.7222e-01,  1.0388e+00,  ...,  2.3293e-01,
           -1.3219e+00, -1.1404e+00],
          ...,
          [-3.4396e-01,

In [None]:
input_ids.shape

torch.Size([1, 1, 448])

In [None]:
input_ids = test_batch["input_ids"].view(1 * 1, 448).cuda()
attention_mask = test_batch["attention_mask"].view(1 * 1, 448).cuda()
result = model.encoder.encoder(input_ids, attention_mask) 

In [None]:
result

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[ 0.1670, -0.0549, -0.0678,  ..., -0.0046, -0.1266, -0.0658],
         [ 0.1413, -0.0544,  0.1294,  ...,  0.0837, -0.0265,  0.0722],
         [-0.1379,  0.1142,  0.0339,  ...,  0.0827,  0.1169, -0.0465],
         ...,
         [-0.0466,  0.0035,  0.0482,  ..., -0.1274, -0.0180, -0.0025],
         [-0.0466,  0.0035,  0.0482,  ..., -0.1274, -0.0180, -0.0025],
         [-0.0466,  0.0035,  0.0482,  ..., -0.1274, -0.0180, -0.0025]]],
       device='cuda:0', grad_fn=<MulBackward0>), past_key_values=None, hidden_states=None, attentions=None, cross_attentions=None)

In [None]:
# (loss, decoder_logits, decoder hidden states, encoder state)
result[2][1][1].shape

torch.Size([1, 8, 64, 64])

In [None]:

inputs = test_batch["input_ids"].cuda()
attn_mask = test_batch["attention_mask"].cuda()
with torch.no_grad():
    transformers.set_seed(1)
    outputs = model.generate(
        inputs, 
        attn_mask,
        64,
        do_sample=True,
        temperature=0.1,
        return_dict_in_generate=True,
        output_scores=True,
        #num_beams=10,
        length_penalty=0,
        num_return_sequences=10,
    )
    #tokens = model.generate(test_batch["input_ids"].cuda(), test_batch["attention_mask"].cuda(), 64)

In [None]:
outputs

SampleEncoderDecoderOutput(sequences=tensor([[   0,  879, 1737, 8976,  276,    5,    1],
        [   0,  879, 1737, 8976,  276,    5,    1],
        [   0,  879, 1737, 8976,  276,    5,    1],
        [   0,  879, 1737, 8976,  276,    5,    1],
        [   0,  879, 1737, 8976,  276,    5,    1],
        [   0,  879, 1737, 8976,  276,    5,    1],
        [   0,  879, 1737, 8976,  276,    5,    1],
        [   0,  879, 1737, 8976,  276,    5,    1],
        [   0,  879, 1737, 8976,  276,    5,    1],
        [   0,  879, 1737, 8976,  276,    5,    1]], device='cuda:0'), scores=(tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        ...,
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
        [-inf, -inf, -inf,  ..., -inf, -inf, -inf]], device='cuda:0'), tensor([[     -inf,      -inf, -157.7079,  ...,      -inf,      -inf,
  

In [None]:
tokenizer.batch_decode(outputs.sequences)

['<pad> generalize dependent P.</s>',
 '<pad> generalize dependent P.</s>',
 '<pad> generalize dependent P.</s>',
 '<pad> generalize dependent P.</s>',
 '<pad> generalize dependent P.</s>',
 '<pad> generalize dependent P.</s>',
 '<pad> generalize dependent P.</s>',
 '<pad> generalize dependent P.</s>',
 '<pad> generalize dependent P.</s>',
 '<pad> generalize dependent P.</s>']

In [None]:
outputs.sequences.shape

torch.Size([10, 7])

In [None]:
len(outputs.scores)

6

In [None]:
outputs.scores[0].shape

torch.Size([10, 32128])

In [None]:
outputs.scores[0][outputs.sequences[0]]

../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [182,0,0], thread: [64,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [182,0,0], thread: [65,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [182,0,0], thread: [66,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [182,0,0], thread: [67,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [182,0,0], thread: [68,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [182,0,0], thread: [69,0,0] Assertion 

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
model.compute_transition_scores(outputs.sequences, outputs.scores)

tensor([[-135.2088,   49.1707,   20.9431,  -10.5498,  -59.5263,  -16.1576],
        [-135.2088,   49.1707,   20.9431,  -10.5498,  -59.5263,  -16.1576],
        [-135.2088,   49.1707,   20.9431,  -10.5498,  -59.5263,  -16.1576],
        [-135.2088,   49.1707,   20.9431,  -10.5498,  -59.5263,  -16.1576],
        [-135.2088,   49.1707,   20.9431,  -10.5498,  -59.5263,  -16.1576],
        [-135.2088,   49.1707,   20.9431,  -10.5498,  -59.5263,  -16.1576],
        [-135.2088,   49.1707,   20.9431,  -10.5498,  -59.5263,  -16.1576],
        [-135.2088,   49.1707,   20.9431,  -10.5498,  -59.5263,  -16.1576],
        [-135.2088,   49.1707,   20.9431,  -10.5498,  -59.5263,  -16.1576],
        [-135.2088,   49.1707,   20.9431,  -10.5498,  -59.5263,  -16.1576]],
       device='cuda:0')

In [None]:
the_bools = ~((outputs.sequences == tokenizer.pad_token_id) + (outputs.sequences == tokenizer.eos_token_id))
#torch.where(the_bools, 1, 0).sum(axis=1)

TypeError: where() received an invalid combination of arguments - got (int, int), but expected one of:
 * (Tensor condition, Tensor other)
      didn't match because some of the arguments have invalid types: (!int!, !int!)
 * (Tensor condition, Number other)
      didn't match because some of the arguments have invalid types: (!int!, !int!)


In [None]:
model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices).sum(axis=1)

tensor([-1.8814, -2.6457, -2.8369, -3.1251, -3.9642, -4.0253, -4.5893, -4.9442,
        -5.1175, -5.3550], device='cuda:0')

In [None]:
if "beam_indices" in outputs.keys():
    seq_scores = output.sequences_scores 

odict_keys(['sequences', 'sequences_scores', 'scores', 'beam_indices'])

In [None]:
model.compute_transition_scores(outputs.sequences, outputs.scores)

tensor([[-1.0924e+00, -5.3725e-04, -2.0845e-01, -3.4957e-01, -4.6140e-02,
         -1.5571e-03, -3.8826e-02, -1.6582e-01, -7.4209e-01, -2.8463e-04,
         -4.6652e+01, -4.4954e+01, -3.9539e+01, -4.7255e+01, -4.5628e+01,
         -3.8009e+01, -3.0419e+01, -3.6212e+01, -4.0046e+01, -4.7286e+01,
         -4.5628e+01, -3.9329e+01, -4.7004e+01, -4.0645e+01, -3.8275e+01,
         -3.0197e+01, -3.6068e+01, -3.9653e+01, -4.6780e+01, -4.5773e+01,
         -3.8774e+01, -4.1612e+01, -4.1023e+01, -3.8042e+01, -2.9946e+01,
         -3.6146e+01, -3.9438e+01, -4.6120e+01, -4.5308e+01, -3.8505e+01,
         -4.1189e+01, -4.1466e+01, -3.7899e+01, -2.9849e+01, -3.6151e+01,
         -3.9236e+01, -4.1129e+01, -4.4634e+01, -3.8401e+01, -4.0940e+01,
         -4.1558e+01, -3.7693e+01, -2.9866e+01, -3.6126e+01, -3.8958e+01,
         -4.0932e+01, -4.3907e+01, -3.8285e+01, -4.0571e+01, -4.2507e+01,
         -3.7361e+01, -3.0210e+01, -3.6002e+01],
        [-1.0924e+00, -1.1111e+01, -1.9473e+01, -1.3342e+01, -1

In [None]:
outputs

SampleEncoderDecoderOutput(sequences=tensor([[    0,    20,  7593,    41,     9,     3,     2,    58,   305,   137,
             1,     0,     0,     0,     0],
        [    0,  1581, 23966,   834,   994,     9,  9208,  4416,     1,     0,
             0,     0,     0,     0,     0],
        [    0,    20,  7593,    41,     9,     3,     2,    58,   431,  3670,
             3,     7, 10296,     5,     1],
        [    0,    20,  7593,    41,     9,     3,     2,    58,   305,  3670,
           653, 22112, 10696,     5,     1],
        [    0,    20,  7593,    41,     9,     3,     2,    58,   305,   137,
             1,     0,     0,     0,     0],
        [    0,     3,     7, 10296,     5,     1,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [    0,     3,     7, 10296,     5,     1,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [    0,    20,  7593,    41,     9,     3,     2,    58,   305,  3670,
             3, 

In [None]:
outputs.keys()

odict_keys(['sequences', 'scores'])

In [None]:
outputs.scores

(tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         ...,
         [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         [-inf, -inf, -inf,  ..., -inf, -inf, -inf]], device='cuda:0'),
 tensor([[    -inf,     -inf,     -inf,  ...,     -inf,     -inf,     -inf],
         [    -inf,     -inf, -38.5619,  ...,     -inf,     -inf,     -inf],
         [    -inf,     -inf,     -inf,  ...,     -inf,     -inf,     -inf],
         ...,
         [    -inf,     -inf,     -inf,  ...,     -inf,     -inf,     -inf],
         [    -inf,     -inf, -38.5619,  ...,     -inf,     -inf,     -inf],
         [    -inf,     -inf,     -inf,  ...,     -inf,     -inf,     -inf]],
        device='cuda:0'),
 tensor([[    -inf,     -inf, -44.0177,  ...,     -inf,     -inf,     -inf],
         [    -inf,     -inf, -45.5157,  ...,     -inf,     -inf,    

In [None]:
tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)

['simpl.',
 'destruct (a? 5).',
 'induction a.',
 'destruct a.',
 'lia.',
 'case a.',
 'destruct (a? 5); simpl.',
 'destruct (a? 5); auto.',
 'destruct (a? 5); trivial.',
 'destruct (a? 5); reflexivity.']

In [None]:
tokenizer.tokenize("hi <= ☂ there")

['▁hi', '▁', '<', '=', '▁', '☂', '▁there']

In [None]:
class SampleResult:
    def __init__(
        self, tactics: list[str], scores: list[float], num_tokens: list[int]
    ) -> None:
        self.tactics = tactics
        self.scores = scores
        self.num_tokens = num_tokens

    def to_json(self) -> Any:
        return {
            "tactics": self.tactics,
            "scores": self.scores,
            "num_tokens": self.num_tokens,
        }

    @classmethod
    def from_json(cls, json_data: Any) -> SampleResult:
        tactics = json_data["tactics"]
        scores = json_data["scores"]
        num_tokens = json_data["num_tokens"]
        return cls(tactics, scores, num_tokens)


In [None]:

def fuzzy_starts_with(s1: str, s2: str) -> bool:
    """some nonempty prefix of s1 matches some nonempty suffix of s2"""
    if len(s2) == 0:
        return False
    if s1.startswith(s2):
        return True
    return fuzzy_starts_with(s1, s2[1:])


def should_stop_now(
    input_ids: torch.Tensor,
    tokenizer: CodeLlamaTokenizer | T5Tokenizer,
    stop_strings: list[str],
) -> bool:
    """input ids is a one dimensional tensor"""
    consider_len = 1
    cur_candidate = tokenizer.decode(input_ids[(-1 * consider_len) :])
    any_matched = True
    while any_matched:
        any_matched = False
        for stop_string in stop_strings:
            if stop_string in cur_candidate:
                return True
            any_matched |= fuzzy_starts_with(cur_candidate, stop_string)
        consider_len += 1
        cur_candidate = tokenizer.decode(input_ids[(-1 * consider_len) :])
    return False


In [None]:
def prepare_batches_fid(
    input_ids: torch.LongTensor,
    attention_mask: torch.Tensor,
    beam_scores: torch.Tensor,
    batch_size: int,
) -> tuple[tuple[torch.LongTensor], tuple[torch.Tensor], tuple[torch.Tensor]]:
    split_input_ids = torch.split(input_ids, batch_size)
    split_attention_mask = torch.split(attention_mask, batch_size)
    split_scores = torch.split(beam_scores, batch_size)
    return split_input_ids, split_attention_mask, split_scores, 


In [None]:
class CompletedCandidate:
    def __init__(self, indices: torch.LongTensor, score: torch.Tensor) -> None:
        self.indices = indices
        self.score = score

    def __lt__(self, other: CompletedCandidate) -> bool:
        return float(self.score) < float(other.score)

    def __le__(self, other: CompletedCandidate) -> bool:
        return float(self.score) <= float(other.score)


def fidt5_beam_sample(
    input_ids: torch.LongTensor,
    attention_mask: torch.Tensor,
    tokenizer: T5Tokenizer,
    beam_width: int,
    n_recs: int,
    stop_strings: list[str],
    batch_size: int = 2,
    max_seq_len: int = 512,
) -> SampleResult: 

    beam_scores = torch.zeros((input_ids.shape[0],), dtype=torch.float32).to("cuda")
    completed_heap: list[CompletedCandidate] = []

    while True:
        batched_input_ids, batched_attention_masks, batched_scores = prepare_batches_fid(
            input_ids,
            attention_mask, 
            beam_scores,
            batch_size,
        )

        next_scores_list: list[torch.Tensor] = []
        next_input_id_list: list[torch.Tensor] = []

        for input_ids_batch, attention_mask_batch, past_batch in zip(
            batched_input_ids, batched_attention_masks, batched_scores 
        ):
            with torch.no_grad():
                output_batch = model(**batch_inputs)



SyntaxError: expected ':' (2590927809.py, line 40)

In [None]:
tokenizer.batch_decode(tokens)

['<pad> simpl.</s>']

In [None]:
model.device

device(type='cuda', index=0)

In [None]:
#test_batch["input_ids"].shape
#test_batch["attention_mask"]
test_batch["return_dict"] = False
model(**test_batch)


(tensor(4.9911, grad_fn=<NllLossBackward0>),
 tensor([[[-20.5744, -12.8595, -14.6405,  ..., -45.1446, -45.2421, -45.2430],
          [-25.0881, -15.1740,  -5.2820,  ..., -45.1663, -45.2774, -45.3447],
          [-21.0880,  -8.3345,  -7.0697,  ..., -37.5994, -37.6230, -37.6437],
          ...,
          [-14.8440,  -8.7909,  -9.5105,  ..., -37.0198, -37.2101, -37.2129],
          [-15.4429, -10.7294, -12.1775,  ..., -38.6211, -38.7931, -38.8090],
          [-14.3118,  -8.4187,  -9.7325,  ..., -36.4805, -36.5963, -36.5941]]],
        grad_fn=<UnsafeViewBackward0>),
 ((tensor([[[[-4.3993e-01, -2.3827e+00, -1.1007e+00,  ..., -2.4550e+00,
              -1.7357e+00, -5.3955e-01],
             [-5.7156e-01,  1.8579e-01, -6.1753e-03,  ..., -1.7960e+00,
               6.6674e-01, -1.0925e+00],
             [-4.8580e-01,  9.8874e-01, -2.8671e-01,  ..., -1.6572e+00,
               7.7291e-01, -7.1863e-01],
             ...,
             [-7.9248e-01, -1.6456e+00, -6.5908e-01,  ..., -1.3585e+00,
 