In [1]:
from transformers import AutoTokenizer,AutoModelForCausalLM

cache_dir = "/U_20240603_ZSH_SMIL/LLM/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9"
tokenizer = AutoTokenizer.from_pretrained(
                    cache_dir, device_map="auto",
                    token_type_ids=None)
model = AutoModelForCausalLM.from_pretrained(
                        cache_dir, device_map="auto",
                        max_memory={0: '80GIB'})


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [01:05<00:00, 32.56s/it]


In [2]:
from transformers import StoppingCriteria
from transformers import StoppingCriteriaList
import torch
class StoppingCriteriaSub(StoppingCriteria):
    """Stop generations when they match a particular text or token."""
    def __init__(self, stops, tokenizer, match_on='text', initial_length=None):
        super().__init__()
        self.stops = stops
        self.initial_length = initial_length
        self.tokenizer = tokenizer
        self.match_on = match_on
        if self.match_on == 'tokens':
            self.stops = [torch.tensor(self.tokenizer.encode(i)).to('cuda') for i in self.stops]
            print(self.stops)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        del scores  # `scores` arg is required by StoppingCriteria but unused by us.
        for stop in self.stops:
            if self.match_on == 'text':
                generation = self.tokenizer.decode(input_ids[0][self.initial_length:], skip_special_tokens=False)
                match = stop in generation
            elif self.match_on == 'tokens':
                # Can be dangerous due to tokenizer ambiguities.
                match = stop in input_ids[0][-len(stop):]
            else:
                raise
            if match:
                return True
        return False


In [7]:
inputs = tokenizer(["Answer the following question as briefly as possible.\nQuestion: Which group recorded the 1976 album 'Rastaman Vibration'?\nAnswer: wailers\n\nQuestion: Which car company produces the Meriva model?\nAnswer: vauxhall\n\nQuestion: Who directed the first two Beatles' films 'A Hard Day's Night' and 'Help! '?\nAnswer: richard lester\n\nQuestion: Which of the 'Classic' horse races, run at Epsom for three year old fillies on the Friday after the derby, is named after the estate then owned by the Earl of Derby?\nAnswer: oaks\n\nQuestion: In which country is the most northerly point on mainland Africa?\nAnswer: tunisia\n\nQuestion: Who is the host of the BBC television show QI?\nAnswer:"], return_tensors="pt").to("cuda")

In [64]:
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
    stops=['Question:'],
    initial_length=len(inputs['input_ids'][0]),
    tokenizer=tokenizer)])

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        # num_beams=4,
        
        # num_return_sequences=4,
        
        max_new_tokens=50,
        return_dict_in_generate=True,
        output_scores=True,
        output_hidden_states=True,
        top_k=20,
        temperature=10.0,
        do_sample=False,
        stopping_criteria=stopping_criteria,
        pad_token_id=tokenizer.pad_token_id,
    )
    
transition_scores = model.compute_transition_scores(
    outputs.sequences, outputs.scores, normalize_logits=True
    # outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
)
print(transition_scores)


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


tensor([[-0.6817, -0.0199, -0.0930, -0.0012, -0.0251, -0.0437, -0.0224, -0.0008]],
       device='cuda:0')


In [56]:
import numpy as np
output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1)
length_penalty = model.generation_config.length_penalty
reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty)
print(np.allclose(outputs.sequences_scores.cpu(), reconstructed_scores))
print(reconstructed_scores)
print(outputs.sequences_scores)

False
tensor([-12.5141,  -9.4003,  -9.4026,  -9.4805], dtype=torch.float64)
tensor([-0.1110, -0.2870, -0.3838, -0.4125], device='cuda:0')
