In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from synth import CrossEntropyDifferential

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda")

gpt2 = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
gpt2.config.pad_token_id = gpt2.config.eos_token_id
cross_entropy_differential = CrossEntropyDifferential(gpt2, gpt2_tokenizer, device)

phi = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3.5-mini-instruct",
    device_map="cuda",
    torch_dtype="auto",
    trust_remote_code=True,
    attn_implementation="eager"
    )
phi_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.19it/s]


In [8]:
Y = """ Pizza is the best thing in the world, all the mozzarella and the tomato sauce on my margherita are amazing. 
My Italian friend Angelo told me that his favourite pizza is the quattroformaggi rossa with salsiccia. 
When I was a kid, I used to watch the pizza maker create his pizzas, he was from Romania but a very nice gentleman I have to say."""

message = [
    {
        "role": "system", 
        # "content": "Write 5 words that summarize the following text"
        # "content": "Write a 5 word introduction of the following text"
        "content": "Write a 5 word TL;DR: for this"
    },
    {
        "role": "user", 
        "content": Y
    }
]

pipe = pipeline(
    "text-generation",
    model=phi,
    tokenizer=phi_tokenizer,
)

generation_args = {
    "num_return_sequences": 120,
    "max_new_tokens": 30,
    "return_full_text": False,
    "do_sample": True,
    "temperature": 1
}

phi_output = pipe(message, **generation_args)

In [9]:
results = []
for i in phi_output:    
    results.append(
        (
            cross_entropy_differential(i['generated_text'], Y, diff=True).item(),
            i['generated_text']
        )
    )
best_result = min(results, key=lambda x: x[0])
results.sort()
results

[(-0.8807568550109863,
  ' TL;DR: Love pizza, particularly margherita; Angelo loves quattroformaggi rossa with salsiccia;'),
 (-0.8153045177459717,
  ' TL;DR: Enjoys pizza (margherita, quattroformaggi rossa with salsiccia), learnt'),
 (-0.7691218852996826,
  ' TL;DR: Incredible pizzas including margherita, quattroformaggi rossa with salsiccia, and fond'),
 (-0.7545592784881592,
  ' Pizza love, margherita, quattroformaggi rossa, Romanian pizza maker, Italian friend Angelo.'),
 (-0.7197329998016357,
  ' TL;DR: Pizza, mozzarella, margherita, quattroformaggi rossa, salsiccia;'),
 (-0.7011477947235107,
  ' TL;DR: Delights in pizza (mozzarella margherita & quattroformaggi rossa with salsiccia'),
 (-0.6679937839508057,
  ' TL;DR: Love pizza; margherita preferred, quadruformaggi rossa with salsiccia, fond memories'),
 (-0.6652042865753174,
  " TL;DR: Pizza, Romanian pizza maker's friend Angelo prefers quattroformaggi rossa, appreciates"),
 (-0.6636302471160889,
  ' TL;DR: Love pizzas; margheri