In [1]:
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")

Downloading:   0%|          | 0.00/685 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/651 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/441 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/251M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [3]:
text = [
    "Human: hello, it's me\nPerson: Who are you?",
    "Human: Hi\nBot: Hey there! I found out something interesting today that I think you'll love. Can't wait to share it with you!\nHuman: Hey. What do you want?"
]


In [4]:
text

["Human: hello, it's me\nPerson: Who are you?",
 "Human: Hi\nBot: Hey there! I found out something interesting today that I think you'll love. Can't wait to share it with you!\nHuman: Hey. What do you want?"]

In [24]:
batch_contexts = list()
batch_targets = list()

batch_max_length = 0

batch = list()

for sample in text:
    parts = sample.split('\n')
    context_str = "\n".join(parts[:-1])
    response = parts[-1]
    
    tokenized_context = tokenizer(context_str).input_ids
    tokenized_response = tokenizer(response).input_ids[1:] + [tokenizer.eos_token_id]
    
    targets = [tokenizer.pad_token_id] * len(tokenized_context) + tokenized_response[1:]
    tokenized_context = tokenized_context + [tokenizer.pad_token_id] * len(tokenized_response[:-1])
    
    if len(tokenized_context) > batch_max_length:
        batch_max_length = len(tokenized_context)
        
    batch_contexts.append(tokenized_context)
    batch_targets.append(targets)

    
for sample_index in range(len(batch_contexts)):
    pad_sequence = [tokenizer.pad_token_id] * (batch_max_length - len(batch_contexts[sample_index]))
    batch_contexts[sample_index] += pad_sequence
    batch_targets[sample_index] += pad_sequence
     
batch_contexts = torch.tensor(batch_contexts)
batch_targets = torch.tensor(batch_targets)

In [25]:
tokenizer.decode(batch_contexts[0])

"</s>Human: hello, it's me<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>"

In [27]:
tokenizer.decode(batch_targets[0])

'<pad><pad><pad><pad><pad><pad><pad><pad>: Who are you?</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

In [28]:
criterion = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [29]:
attention_mask = (batch_contexts != tokenizer.pad_token_id).long()

In [30]:
model_outputs = model.forward(input_ids=batch_contexts, attention_mask=attention_mask)

In [31]:
model_outputs.logits.shape

torch.Size([2, 41, 50272])

In [32]:
loss = criterion(
    model_outputs.logits.view(-1, model_outputs.logits.size(-1)),
    batch_targets.view(-1)
)

In [33]:
loss

tensor(7.4017, grad_fn=<NllLossBackward0>)