In [1]:
import sys
import os
project_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
sys.path.append(project_root)

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils.data import DataLoader
from src.data import SST2Dataset
from tqdm import tqdm
torch.manual_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fe86dbb1c50>

In [3]:
model_name = "t-bank-ai/T-lite-instruct-0.1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.56s/it]


In [12]:
prompt = """You need to define if the given sentence has positive or negative sentiment.
Answer is the only number: 0 - if the sentiment is negative, 1 - if positive.
Generate the final answer bracketed with <ans> and </ans>
The sentence:
 
    """

In [15]:
sst2_ds = SST2Dataset(
    tokenizer=tokenizer,
    data_path="../../data/sst-2/test-00000-of-00001.parquet",
    prompt=prompt
)

In [16]:
print(len(sst2_ds))

1821


In [17]:
input_ids, attention_mask, label = next(iter(sst2_ds))
print(input_ids.shape, attention_mask.shape, label.shape)

torch.Size([119]) torch.Size([119]) torch.Size([])


In [18]:
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

outputs = model.generate(
    input_ids=input_ids.unsqueeze(0),
    attention_mask = attention_mask.unsqueeze(0),
    max_new_tokens=256,
    eos_token_id=terminators,
)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


In [19]:
ans = tokenizer.decode(outputs[0], skip_special_tokens=True)
ans[len(prompt):]

'no movement, no yuks, not much of anything.\n\tResponse:\n the sentiment is negative\n<ans>0</ans>'

In [20]:
pos = ans.find("Response:\n", len(prompt))
ans[pos:]

'Response:\n the sentiment is negative\n<ans>0</ans>'

In [21]:
dl = DataLoader(sst2_ds, batch_size=64)

results = []
all_labels = []

for input_ids, attention_mask, labels in tqdm(dl):
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask = attention_mask,
        max_new_tokens=50,
        eos_token_id=terminators,
        pad_token_id=tokenizer.eos_token_id,
    )
    
    ans = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    positions = [answer.find('</ans>', len(prompt)) - 1  for answer in ans]
        
    res = []
    for i, a in enumerate(ans):
        r = a[positions[i]]
        if r == '1' or r == '0':
            res.append(int(r))
        else:
            res.append(-1)
    res = torch.Tensor(res).type(torch.long)
    results.append(res)
    all_labels.append(labels)
    

100%|██████████| 29/29 [08:15<00:00, 17.10s/it]


In [None]:
results = torch.cat(results)
all_labels = torch.cat(all_labels)

torch.Size([1821])

In [33]:
accuracy = torch.mean((results == all_labels).type(torch.float))
accuracy

tensor(0.6667)

In [37]:
ratio = 1 - torch.sum((results == -1).type(torch.float)) / len(sst2_ds)
ratio

tensor(0.7155)

In [38]:
accuracy / ratio

tensor(0.9317)