In [1]:
# Setup
import torch; torch.set_printoptions(linewidth=200)
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = AutoConfig.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_config(config)


In [2]:
sentence1 = "The cat sat on the mat"
sentence2 = "The dog ate my homework"
sentence3 = "My aunt is a teacher"

sentences = [sentence1, sentence2, sentence3]
tokenized_sentences = tokenizer(sentences, return_attention_mask=False, add_special_tokens=False)["input_ids"]
print(tokenized_sentences)

[[464, 3797, 3332, 319, 262, 2603], [464, 3290, 15063, 616, 26131], [3666, 25949, 318, 257, 4701]]


In [3]:
tokenized_sentences_2 = tokenizer(sentences, return_attention_mask=True, add_special_tokens=True)
print(tokenized_sentences_2)

{'input_ids': [[464, 3797, 3332, 319, 262, 2603], [464, 3290, 15063, 616, 26131], [3666, 25949, 318, 257, 4701]], 'attention_mask': [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]}


In [4]:
tokenized_sentences = [t for s in tokenized_sentences for t in s + [tokenizer.eos_token_id]]
print(tokenized_sentences)

[464, 3797, 3332, 319, 262, 2603, 50256, 464, 3290, 15063, 616, 26131, 50256, 3666, 25949, 318, 257, 4701, 50256]


In [5]:
tokenizer.decode(tokenized_sentences)

'The cat sat on the mat<|endoftext|>The dog ate my homework<|endoftext|>My aunt is a teacher<|endoftext|>'

In [6]:
tokenized_sentences = torch.tensor(tokenized_sentences)
print(tokenized_sentences)

tensor([  464,  3797,  3332,   319,   262,  2603, 50256,   464,  3290, 15063,   616, 26131, 50256,  3666, 25949,   318,   257,  4701, 50256])


In [7]:
attn_mask = torch.ones(tokenized_sentences.size(0), tokenized_sentences.size(0), dtype=torch.int).tril()
print(attn_mask)
print(attn_mask.shape)

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,

In [8]:
T = tokenized_sentences.size(0)
eos_indices = (tokenized_sentences == tokenizer.eos_token_id).nonzero().squeeze()

print(T)
print(eos_indices)

19
tensor([ 6, 12, 18])


In [9]:
reps = torch.cat([
    eos_indices[[0]]+1, 
    eos_indices[1:] - eos_indices[:-1]
    ])

print(reps)

tensor([7, 6, 6])


In [10]:
print(torch.repeat_interleave(eos_indices, reps))

print(torch.repeat_interleave(eos_indices, reps).view(1,-1))

print(torch.repeat_interleave(eos_indices, reps).view(1,-1).expand(T, -1))

repeated_idx = torch.repeat_interleave(eos_indices, reps).view(1,-1).expand(T, -1)

tensor([ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18])
tensor([[ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18]])
tensor([[ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 1

In [13]:
print(torch.arange(T))
print(torch.arange(T).view(-1,1))
print(torch.arange(T).view(-1,1).expand(-1, T))

mask_indices = torch.arange(T).view(-1,1).expand(-1, T)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18])
tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13],
        [14],
        [15],
        [16],
        [17],
        [18]])
tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [ 3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3],
        [ 4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4],
        [ 5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5],
        [ 6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6],
        [ 7,  7,  7,  7,  7,  7,  7,  7,  7,  

In [14]:
mask_indices < repeated_idx

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],

In [None]:

print(reps)
print(repeated_idx)
print(mask_indices)
print(mask)

In [None]:
mask_indices > repeated_idx

In [None]:
mask_indices > repeated_idx

In [None]:
mask.masked_fill_(mask_indices > repeated_idx, False)

In [None]:
y = x.view(3,1,1,2)

In [None]:
y.shape