In [1]:
# Setup
import torch; torch.set_printoptions(linewidth=200)
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = AutoConfig.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_config(config)


In [2]:
sentence1 = "The cat sat on the mat"
sentence2 = "The dog ate my homework"
sentence3 = "My aunt is a teacher"

sentences = [sentence1, sentence2, sentence3]
tokenized_sentences = tokenizer(sentences, return_attention_mask=False, add_special_tokens=False)["input_ids"]
tokenized_sentences = [t for s in tokenized_sentences for t in s + [tokenizer.eos_token_id]]
tokenizer.decode(tokenized_sentences)


'The cat sat on the mat<|endoftext|>The dog ate my homework<|endoftext|>My aunt is a teacher<|endoftext|>'

In [3]:
tokenized_sentences = torch.tensor(tokenized_sentences)
attn_mask = torch.ones(tokenized_sentences.size(0), tokenized_sentences.size(0), dtype=torch.int).tril()
attn_mask


tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,

In [4]:
# store sequence length in variable for easier readability
T = tokenized_sentences.size(0)

# get indices of all EOS tokens
eos_indices = (tokenized_sentences == tokenizer.eos_token_id).nonzero().squeeze()

# from indices, get length of each sequence
reps = torch.cat([eos_indices[[0]]+1, eos_indices[1:] - eos_indices[:-1]])

# repeat each eos index n times along dimension 1 (n is the number of tokens in the sequence)
repeated_idx = torch.repeat_interleave(eos_indices, reps).view(1,-1).expand(T, -1)

# create tensor with all indices from 0 to T-1 repeated T times along dimesion 1
mask_indices = torch.arange(T).view(-1,1).expand(-1, T)


In [5]:
repeated_idx

tensor([[ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 12, 12, 18, 18, 18, 18, 18, 18],
        [ 6,  6,  6,  6,  6,  6,  6, 12, 12, 12, 12, 1

In [6]:
mask_indices

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [ 3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3],
        [ 4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4],
        [ 5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5],
        [ 6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6],
        [ 7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7],
        [ 8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8],
        [ 9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9],
        [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
        [11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 1

In [7]:
mask_indices > repeated_idx

tensor([[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],

In [10]:
torch.ones(T, T, dtype=torch.bool).tril().expand(-1, -1)

tensor([[ True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False],

In [11]:
mask = torch.ones(T, T, dtype=torch.bool).tril().expand(-1, -1)
mask.masked_fill_(mask_indices > repeated_idx, False)

tensor([[ True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False, False],