In [None]:
import sys
sys.path.append('..')
from config import load_config
from model import CustomModel

import torch.nn as nn
import torch
import torch.nn.functional as F

2024-11-17 14:18:39,982 - modelscope - INFO - PyTorch version 2.2.0 Found.
2024-11-17 14:18:39,983 - modelscope - INFO - Loading ast index from /Users/xy/.cache/modelscope/ast_indexer
2024-11-17 14:18:40,009 - modelscope - INFO - Loading done! Current index file version is 1.14.0, with md5 b6a37aa50898b7ca29cb870cc35ad7a7 and a total number of 976 components indexed


Test HF model

In [1]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

model1 = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer1 = GPT2Tokenizer.from_pretrained('gpt2')



model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [11]:
criterian = nn.NLLLoss(reduction='none')
log_softmax = nn.LogSoftmax(dim=1)

In [15]:
input_str = 'Hello, world!'

input_encoded = tokenizer1(input_str, return_tensors='pt')
with torch.no_grad():
    outputs = model1(**input_encoded, labels=input_encoded['input_ids'])

print('input_ids: ', input_encoded['input_ids'])
print('input tokens:', tokenizer1.convert_ids_to_tokens(input_encoded['input_ids'][0]))

print('outputs.logits: ', outputs.logits.shape)
print('outputs.loss: ', outputs.loss)

input_ids:  tensor([[15496,    11,   995,     0]])
input tokens: ['Hello', ',', 'Ġworld', '!']
outputs.logits:  torch.Size([1, 4, 50257])
outputs.loss:  tensor(4.1283)


In [12]:
logits = outputs.logits
logits = torch.permute(logits, (0, 2, 1)) # B,L,V -> B,V,L # Reason: nn.NLLLoss requires class to be of dimension 1
print('permuted logits: ', logits.shape)

permuted logits:  torch.Size([1, 50257, 4])


In [38]:
targets = input_encoded['input_ids']
shift_logits = logits[:, :, :-1]
shift_targets = targets[:, 1:]

print('shift_logits: ', shift_logits.shape)
print('shift_targets: ', shift_targets)

shift_logits:  torch.Size([1, 50257, 3])
shift_targets:  tensor([[ 11, 995,   0]])


In [39]:
nlls = criterian(log_softmax(shift_logits), shift_targets).squeeze()
print('nlls: ', nlls)
print('mean nlls: ', nlls.mean()) # ==> This is exactly the ouputs.loss

nlls:  tensor([2.3432, 8.0267, 2.0149])
mean nlls:  tensor(4.1282)


In [None]:
# manually double check the computed NLLs

print(log_softmax(shift_logits)[0, :, :].shape)
max_prob_next_ids = log_softmax(shift_logits)[0, :, :].argmax(dim=0)
print('most probable next token ids: ', max_prob_next_ids)
print('most probable next tokens: ', tokenizer1.convert_ids_to_tokens(max_prob_next_ids))

print()
log_probs = log_softmax(shift_logits)[0, :, :]
print('log_probs.shape: ', log_probs.shape)

nll0 = - log_probs[shift_targets[0, 0], 0]
nll1 = - log_probs[shift_targets[0, 1], 1]
nll2 = - log_probs[shift_targets[0, 2], 2]
print('nll0: ', nll0)
print('nll1: ', nll1)
print('nll2: ', nll2)
# These are exactly the same as the nlls computed by the above code

torch.Size([50257, 3])
most probable next token ids:  tensor([ 11, 314,  13])
most probable next tokens:  [',', 'ĠI', '.']

log_probs.shape:  torch.Size([50257, 3])
nll0:  tensor(2.3432)
nll1:  tensor(8.0267)
nll2:  tensor(2.0149)


In [19]:
# what if we do not shift logits or targets?
nlls_ns = criterian(log_softmax(logits), targets).squeeze()
print('nlls_ns: ', nlls_ns)

nlls_ns:  tensor([10.0143,  9.0923,  9.8088,  7.7748])


Test custom model

In [None]:
model = Model("/data1/model/mistral-7b-base/")

In [None]:
logits, nlls = model.forward("Hello, world!")

In [None]:
logits, nlls, token_ids = model.forward("Hello, world!", return_tokens=True)
print(logits.shape, nlls.shape, token_ids.shape)

print(nlls)

In [None]:
logits, nlls, token_ids = model.forward("Hello, world!", return_tokens=True)

# Test if the above obtained nlls are equal to the ones obtained by the following code
criterian = nn.NLLLoss(reduction='none')
log_softmax = nn.LogSoftmax(dim=1)

logits = torch.unsqueeze(logits, 0)
# token_ids = torch.unsqueeze(token_ids, 0)
print(logits.shape, token_ids.shape)

In [None]:
logits = torch.permute(logits, (0, 2, 1)) # B,L,V -> B,V,L # Reason: nn.NLLLoss requires class to be of dimension 1
shift_logits = logits[:, :, :-1]
shift_token_ids = token_ids[..., 1:]

nlls2 = criterian(log_softmax(shift_logits), 
                  shift_token_ids)
print(nlls2)
print(nlls)
# Inconsistent

In [None]:
# Try to fix nlls
logits, nlls, token_ids = model.forward("Hello, world!", return_tokens=True)

probs = F.softmax(logits, dim=-1) # shape: [L, V]
nlls_cands = torch.zeros(probs.shape[0], dtype=torch.float32)
for i in range(probs.shape[0]-1):
    nlls_cands[i] = -torch.log(probs[i, token_ids[0, i+1]])

print(nlls_cands) # Now, it is fixed