In [6]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small", model_max_length=128)
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")




In [7]:
print(tokenizer("yes"))
print(tokenizer("ja"))
print(tokenizer("nee"))
print(tokenizer("no"))

print(tokenizer("nej"))
print(tokenizer("og"))


{'input_ids': [36339, 1], 'attention_mask': [1, 1]}
{'input_ids': [432, 1], 'attention_mask': [1, 1]}
{'input_ids': [448, 265, 1], 'attention_mask': [1, 1, 1]}
{'input_ids': [375, 1], 'attention_mask': [1, 1]}
{'input_ids': [3810, 1], 'attention_mask': [1, 1]}
{'input_ids': [373, 1], 'attention_mask': [1, 1]}


In [8]:
input_text = ["yes yes yes yes yes yes yes yes yes", "yes yes"]

# Tokenize the input text
inputs = tokenizer(input_text, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
print(inputs.input_ids.shape)


torch.Size([2, 128])


In [11]:
# Prepare decoder input ids (usually the start token)
decoder_input_ids = tokenizer(["",""], return_tensors="pt").input_ids
print(decoder_input_ids.shape)

# Forward pass with decoder input ids
outputs = model.base_model(input_ids=inputs["input_ids"], decoder_input_ids=decoder_input_ids)


torch.Size([2, 1])


In [12]:
logits = outputs.logits
print(logits.shape)

torch.Size([2, 1, 250112])


In [13]:
print('logit yes', outputs.logits[0,0,36339])
print('logit no',  outputs.logits[0,0,375])

logit yes tensor(-36.5104, grad_fn=<SelectBackward0>)
logit no tensor(-35.1915, grad_fn=<SelectBackward0>)


In [14]:
import torch
def compute_rank_loss(logits_pos, logits_neg):
    r_pos = torch.sigmoid(logits_pos)
    r_neg = torch.sigmoid(logits_neg)
    diff = torch.sigmoid(r_pos - r_neg)
    return torch.log(1e-8 + torch.exp(diff))


In [15]:
from torch.nn import CrossEntropyLoss

ce = CrossEntropyLoss()


In [26]:
positive_string = ["og", "og og"]
negative_string = ["nej", "nej javlar"]


#training step

input_pos = tokenizer(positive_string, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
input_neg = tokenizer(negative_string, padding='max_length', max_length=128, truncation=True, return_tensors='pt')

# print(input_pos)

decoder_input = tokenizer(["",""], return_tensors="pt")

decoder_input_ids

print(decoder_input_ids)

outputs_pos = model.base_model(input_ids=input_pos.input_ids, 
                               attention_mask=input_pos.attention_mask,
                               decoder_input_ids=decoder_input.input_ids,
                               decoder_attention_mask=decoder_input.attention_mask)
outputs_neg = model.base_model(input_ids=input_neg.input_ids, 
                               attention_mask=input_neg.attention_mask,
                               decoder_input_ids=decoder_input.input_ids,
                               decoder_attention_mask=decoder_input.attention_mask)

#logits = [yes,no]
logits_pos = torch.stack((outputs_pos.logits[:,-1,36399], outputs_pos.logits[:,-1,375]), dim=1)
logits_neg = torch.stack((outputs_neg.logits[:,-1,36399], outputs_neg.logits[:,-1,375]), dim=1)

print(logits_pos)

target_pos = torch.tensor([1,0],dtype=torch.float).unsqueeze(0).repeat(logits_pos.shape[0],1)
target_neg = torch.tensor([0,1],dtype=torch.float).unsqueeze(0).repeat(logits_pos.shape[0],1)

loss_nll = ce(logits_pos, target_pos) + ce(logits_neg, target_neg)
loss_bpr = -compute_rank_loss(logits_pos[0], logits_neg[0]).mean(dim=0)

lamb=0.5
loss = (1-lamb)*loss_nll + lamb*loss_bpr

loss.backward()




tensor([[1],
        [1]])
tensor([[-70.1002, -60.0054],
        [-73.3915, -61.9784]], grad_fn=<StackBackward0>)
