In [90]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel

device="cuda" if torch.cuda.is_available() else "cpu"

# model_name="distilbert/distilgpt2"
# model_name="sshleifer/tiny-gpt2"
model_name="EleutherAI/gpt-neo-125M"
# model_name= "google-t5/t5-small"

# model_name="zatochu/GPT2-Medium-Alpaca-355m-ggml"
tokenizer = AutoTokenizer.from_pretrained(model_name)





# tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
# t5_model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")





encoded_input = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to(device)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)





# eos = tokenizer(tokenizer.eos_token, return_tensors="pt", return_token_type_ids=False).input_ids.flatten()[0].item()
# eos = torch.tensor([eos]).to(device).reshape(1,1)







print("testing generate")
gen_outputs = model.generate(**encoded_input)
print(type(gen_outputs))
print(tokenizer.batch_decode(gen_outputs))
print(gen_outputs)







Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


testing generate
<class 'torch.Tensor'>
['Hugging Face is an open-source company that provides a wide range of services to the public. We are a company that is dedicated to providing']
tensor([[48098,  2667, 15399,   318,   281,  1280,    12, 10459,  1664,   326,
          3769,   257,  3094,  2837,   286,  2594,   284,   262,  1171,    13,
           775,   389,   257,  1664,   326,   318,  7256,   284,  4955]])


In [91]:
def exact_mode_algo(model, encoded_input, eos:int):
    def get_next_log_probs(y, model):
        with torch.no_grad():
            outputs = model(input_ids=y)
            logits = outputs.logits
        
        next_token_logits = logits[:, -1, :]
        return torch.nn.functional.log_softmax(next_token_logits, dim=-1)[0]

    #y is the prompt with additions, p is the probability of y, gamma is the current best probability, eos is the eos string
    def DFS(  y:str, p:float, gamma:float,model, eos=1, depth=0, max_depth=-1):


        
        #If we reached max depth without finishing
        if(depth>max_depth):
            return (y,gamma*2)

        #if y is finished, return the node
        if(y[0,-1]==eos):
            print(f"p is {p}, y is {tokenizer.batch_decode(y)}, gamma is {gamma}", flush=True)
            return (y,p)
        
        #else, search for the best way to complete y
        
        #exclude the pad token
        log_probs=get_next_log_probs(y, model)

        arange=torch.arange(len(log_probs)).to(log_probs.device)
        
        best_y=None
        for idx, log_prob in enumerate(log_probs):
            newP = p + log_prob 
            #if we're doing better than the best one so far
            if newP > gamma:

                # print(f"gamma is {gamma}, newP is {newP}")
                #do a DFS
                appended_y=torch.concat((y, arange[idx].reshape(1,1)), axis=1)
                new_y, new_gamma = DFS( y=appended_y, p=newP, gamma=gamma, model=model, eos=eos, depth=depth+1, max_depth=max_depth)
                if new_gamma > gamma:
                    best_y=new_y
                    gamma=new_gamma

        return best_y, gamma

    y=encoded_input.input_ids
    ended_y=torch.concat((y, eos), axis=1)
    start_gamma=get_next_log_probs(y=y, model=model)[eos]



    best_y, gamma = DFS(y=y, gamma=start_gamma,p=0, model= model, eos=eos, depth=0, max_depth=100)
    if best_y is None: 
        return ended_y, gamma
    return best_y, gamma



In [92]:
def get_next_log_probs(y, model):
    with torch.no_grad():
        outputs = model(input_ids=y)
        logits = outputs.logits
    
    next_token_logits = logits[:, -1, :]
    return torch.nn.functional.log_softmax(next_token_logits, dim=-1)[0]

next_log_probs=get_next_log_probs(encoded_input.input_ids, model)

a=torch.topk(next_log_probs, 3)

for index, value in zip(a.indices, a.values):
    print(f"p: {value},index: {index} token: {tokenizer.decode([index])}")

tokenizer.decode(tokenizer.encode("that"))

#so, start_Gamma should be -1.0048828125

p: -1.0048828125,index: 326 token:  that
p: -2.41796875,index: 11 token: ,
p: -3.244140625,index: 351 token:  with


'that'

In [93]:

encoded_input = tokenizer("Hugging Face is an open-source company that creates", return_tensors="pt").to(device)

eos = torch.tensor([tokenizer.encode(".")]).to(device).reshape(1,1)
print(f"eos is {eos}")
# print(f"eos is {eos}")


# start_gamma=get_next_log_probs(y=encoded_input.input_ids, model=model)[eos]
# print(f"1. start_gamma is {start_gamma}")

best_y, gamma = exact_mode_algo(model, encoded_input, eos)
print(f"gamma is {gamma.item()}, best_y is {tokenizer.batch_decode(best_y)}")

# tokenizer.decode(best_y)



# get_decode_log_prob(x,gen_outputs, t5_model, tokenizer)


# print("testing blank string")
# get_decode_log_prob(x,ended_y, t5_model, tokenizer)


# 48098,  2667, 15399,   318,   281,  1280,    12, 10459,  1664,   
    # 326, 3769,   257,  3094,  2837,   286,  2594,   284,   262,  1171,    13, 775,   389,   257,  1664,   326,   318,  7256,   284,  4955

eos is tensor([[13]])
p is -8.734375, y is ['Hugging Face is an open-source company that creates open-source software.'], gamma is tensor([[-9.1406]], dtype=torch.float16)
p is -8.5546875, y is ['Hugging Face is an open-source company that creates open source software.'], gamma is -8.734375
gamma is -8.5546875, best_y is ['Hugging Face is an open-source company that creates open source software.']


In [94]:
tokenizer.batch_decode(best_y)

['Hugging Face is an open-source company that creates open source software.']