## Imports


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

import torch
from torch import nn

from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer

In [3]:
ELK_PATH = Path("../../../elk/")
print(ELK_PATH.resolve())

modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "training",
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

print(sys.path[:3])

from reporter import Reporter
from templates import DatasetTemplates

/fsx/home-augustas/elk
['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk/elk/training', '/fsx/home-augustas/elk']


## Config

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
DATA_DIR = (
    "../../../VINC-logs/"
    "allenai/unifiedqa-v2-t5-11b-1363200/"
    "AugustasM/burns-datasets-VINC/strange-montalcini"
)
DATA_DIR = Path(DATA_DIR)

!ls {DATA_DIR}

cfg.yaml  fingerprints.yaml  lr_eval.csv  reporters	  train_lm_eval.csv
eval.csv  lm_eval.csv	     lr_models	  train_eval.csv


In [6]:
model_name = "allenai/unifiedqa-v2-t5-11b-1363200"
model_name = "allenai/unifiedqa-v2-t5-3b-1363200"

LAYER = 24

## Dataset

In [7]:
dataset = load_dataset("AugustasM/burns-datasets-VINC", split="validation[:64]")
dataset

Found cached dataset parquet (/admin/home-augustas/.cache/huggingface/datasets/AugustasM___parquet/AugustasM--burns-datasets-VINC-85ec467026b56702/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


Dataset({
    features: ['text', 'label', 'original_dataset', 'template_name'],
    num_rows: 64
})

In [8]:
dataset_template_path = "AugustasM/burns-datasets-VINC/all"

dataset_templates = DatasetTemplates(dataset_template_path)
dataset_templates.templates = {
    x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset[0]) is not None
}
print(dataset_templates.templates)

template = list(dataset_templates.templates.values())[0]
template

{'Consider the text below': <templates.Template object at 0x7f213f2c5de0>, 'Is the given text truthful': <templates.Template object at 0x7f213f2c5e10>, 'Text first': <templates.Template object at 0x7f213f2c5e40>, 'No question no choices': <templates.Template object at 0x7f213f2c4c70>, 'No question with choices': <templates.Template object at 0x7f213f2c5d80>}


<templates.Template at 0x7f213f2c5de0>

## Tokenizer

In [9]:
model_name

'allenai/unifiedqa-v2-t5-3b-1363200'

In [10]:
tokenizer = T5Tokenizer.from_pretrained(model_name, truncation_side="left")
tokenizer

T5Tokenizer(name_or_path='allenai/unifiedqa-v2-t5-3b-1363200', vocab_size=32100, model_max_length=512, is_fast=False, padding_side='right', truncation_side='left', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<ex

## Try combining report with a language model

Inspiration taken from the [original repository](https://github.com/collin-burns/discovering_latent_knowledge/blob/main/CCS.ipynb).

In [21]:
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
model.eval();

In [11]:
# class MyRewardModel(nn.Module):
#     def __init__(
#             self, language_model_name, reporter_path,
#             layer=-1, device="cpu", hidden_state_name="decoder_hidden_states",
#         ):
#         super().__init__()

#         # Load the language model and the reporter
#         self.language_model = T5ForConditionalGeneration.from_pretrained(
#             language_model_name
#         ).to(device)
#         self.language_model.eval()

#         self.reporter = Reporter.load(reporter_path).to(device)
#         self.reporter.eval()

#         self.layer = layer # which layer to extract
#         self.hidden_state_name = hidden_state_name

    
#     def forward(self, pos_inputs, neg_inputs):
#         # Get the hidden states
#         pos_hidden_states = self.language_model(
#             **pos_inputs, output_hidden_states=True,
#         )[self.hidden_state_name][self.layer]
#         neg_hidden_states = self.language_model(
#             **neg_inputs, output_hidden_states=True,
#         )[self.hidden_state_name][self.layer]
        
#         # Find the index of the last non-padding token
#         pos_last_token_index = torch.sum(pos_inputs["attention_mask"], dim=1) - 1
#         neg_last_token_index = torch.sum(neg_inputs["attention_mask"], dim=1) - 1

#         # Get the last token's output
#         pos_last_tokens = pos_hidden_states[range(len(pos_last_token_index)), pos_last_token_index]
#         neg_last_tokens = neg_hidden_states[range(len(neg_last_token_index)), neg_last_token_index]

#         # Get the logits for the two classes
#         pos_logits = self.reporter(pos_last_tokens)
#         neg_logits = self.reporter(neg_last_tokens)

#         # Return the difference in logits which will later be
#         # passed through a sigmoid function
#         return pos_logits - neg_logits


In [42]:
class MyRewardModel(nn.Module):
    def __init__(
            self, language_model, reporter_path,
            layer=-1, device="cpu", hidden_state_name="decoder_hidden_states",
        ):
        super().__init__()

        # Load the language model and the reporter
        # self.language_model = T5ForConditionalGeneration.from_pretrained(
        #     language_model
        # ).to(device)
        # self.language_model.eval()
        self.language_model = language_model

        self.reporter = Reporter.load(reporter_path).to(device)
        self.reporter.eval()

        self.layer = layer # which layer to extract
        self.hidden_state_name = hidden_state_name

    
    def forward(self, pos_inputs, neg_inputs):
        '''
            NOTE: only works for a single input at a time for now
        '''

        # Get the hidden states
        pos_hidden_states = self.language_model(
            **pos_inputs, output_hidden_states=True,
        )[self.hidden_state_name][self.layer]
        neg_hidden_states = self.language_model(
            **neg_inputs, output_hidden_states=True,
        )[self.hidden_state_name][self.layer]

        # Get the last token's output
        # Shape B x T x H -> B x H
        pos_last_tokens = pos_hidden_states[:, -1, :]
        neg_last_tokens = neg_hidden_states[:, -1, :]

        # Get the logits for the two classes
        # Shape B x H -> B
        pos_logits = self.reporter(pos_last_tokens)
        neg_logits = self.reporter(neg_last_tokens)

        # Return the difference in logits which will later be
        # passed through a sigmoid function
        return pos_logits - neg_logits


In [43]:
reporter_path = DATA_DIR / "reporters" / f"layer_{LAYER}.pt"
reporter_path.resolve()

PosixPath('/fsx/home-augustas/VINC-logs/allenai/unifiedqa-v2-t5-11b-1363200/AugustasM/burns-datasets-VINC/strange-montalcini/reporters/layer_24.pt')

In [44]:
%%time

reward_model = MyRewardModel(model, reporter_path, layer=LAYER, device=device)

CPU times: user 79.3 ms, sys: 0 ns, total: 79.3 ms
Wall time: 8.24 ms


In [45]:
item = dataset[0]
item_copy = item.copy()

# Get the positive and negative examples
item_copy["label"] = 1
pos_q, pos_a = template.apply(item_copy)

item_copy["label"] = 0
neg_q, neg_a = template.apply(item_copy)

# Tokenize the inputs
pos_inputs = tokenization_function(pos_q, pos_a).to(device)
neg_inputs = tokenization_function(neg_q, neg_a).to(device)

In [46]:
with torch.no_grad():
    outputs = reward_model(pos_inputs, neg_inputs)

outputs

tensor([0.2552], device='cuda:0')

### Sift through the data

In [47]:
tokenization_function = lambda q, a: tokenizer(
    q, text_target=a.strip(),
    add_special_tokens=True, return_tensors="pt",
)

In [48]:
%%time

predictions = []
labels = []
for idx, item in enumerate(dataset):
    labels.append(item["label"])
    item_copy = item.copy()

    # Get the positive and negative examples
    item_copy["label"] = 1
    pos_q, pos_a = template.apply(item_copy)

    item_copy["label"] = 0
    neg_q, neg_a = template.apply(item_copy)

    # Tokenize the inputs
    pos_inputs = tokenization_function(pos_q, pos_a).to(device)
    neg_inputs = tokenization_function(neg_q, neg_a).to(device)

    with torch.no_grad():
        prediction = reward_model(pos_inputs, neg_inputs)
    
    print(prediction.item())
    predictions.append(prediction.item())

# predictions = torch.tensor(predictions)
predictions, labels

0.2552218437194824
0.3841745853424072
0.3256509304046631


0.28431081771850586
0.37639713287353516
0.37061095237731934
0.28977108001708984
0.36952924728393555
0.3606541156768799
0.255687952041626
0.3672020435333252
0.35164308547973633
0.2697887420654297
0.3173236846923828
0.3689241409301758
0.40074586868286133
0.28209733963012695
0.30278587341308594
0.22183847427368164
0.30509042739868164
0.21025776863098145
0.3518798351287842
0.1934523582458496
0.30237603187561035
0.26209402084350586
0.35160207748413086
0.23949551582336426
0.3721284866333008
0.2751791477203369
0.32799458503723145
0.36185717582702637
0.2610764503479004
0.31461048126220703
0.34275054931640625
0.271038293838501
0.30406808853149414
0.2901172637939453
0.35800886154174805
0.2638258934020996
0.3245968818664551
0.3276844024658203
0.3547680377960205
0.43056797981262207
0.4248487949371338
0.3674638271331787
0.35486602783203125
0.4332923889160156
0.31902360916137695
0.22350382804870605
0.22410941123962402
0.34844207763671875
0.3510615825653076
0.2791781425476074
0.28778505325317383
0.32

([0.2552218437194824,
  0.3841745853424072,
  0.3256509304046631,
  0.28431081771850586,
  0.37639713287353516,
  0.37061095237731934,
  0.28977108001708984,
  0.36952924728393555,
  0.3606541156768799,
  0.255687952041626,
  0.3672020435333252,
  0.35164308547973633,
  0.2697887420654297,
  0.3173236846923828,
  0.3689241409301758,
  0.40074586868286133,
  0.28209733963012695,
  0.30278587341308594,
  0.22183847427368164,
  0.30509042739868164,
  0.21025776863098145,
  0.3518798351287842,
  0.1934523582458496,
  0.30237603187561035,
  0.26209402084350586,
  0.35160207748413086,
  0.23949551582336426,
  0.3721284866333008,
  0.2751791477203369,
  0.32799458503723145,
  0.36185717582702637,
  0.2610764503479004,
  0.31461048126220703,
  0.34275054931640625,
  0.271038293838501,
  0.30406808853149414,
  0.2901172637939453,
  0.35800886154174805,
  0.2638258934020996,
  0.3245968818664551,
  0.3276844024658203,
  0.3547680377960205,
  0.43056797981262207,
  0.4248487949371338,
  0.3674638

In [55]:
def get_prompt(template, entry, label):
    entry_copy = entry.copy()
    entry_copy["label"] = label

    return "".join(template.apply(entry_copy))


print(get_prompt(template, dataset[0], label=0))
print(get_prompt(template, dataset[0], label=1))

Consider the text below in triple backticks: 
``` 
Consider the following example: ''' Chicken Genome Sheds New Light on Human DNA A new study states that 60 of the genes in chicken have close relations to human DNA. This may not comfort those who frequently eat the creature, but may ponder this the next time they order a batch of chicken wings. ''' Choice 1: Science and technology. Choice 2: Business.Between choice 1 and choice 2, what is the topic of this example? Business 
``` 

Question: is the given text truthful? Answer either yes or no. 
Answer:no
Consider the text below in triple backticks: 
``` 
Consider the following example: ''' Chicken Genome Sheds New Light on Human DNA A new study states that 60 of the genes in chicken have close relations to human DNA. This may not comfort those who frequently eat the creature, but may ponder this the next time they order a batch of chicken wings. ''' Choice 1: Science and technology. Choice 2: Business.Between choice 1 and choice 2, wha

In [35]:
# Get input texts
pos_input_texts = [get_prompt(template, item, label=1) for item in dataset]
neg_input_texts = [get_prompt(template, item, label=0) for item in dataset]

# Get the inputs (input_ids and attention_masks) for the positive and negative texts
pos_inputs = tokenizer(pos_input_texts, return_tensors="pt", padding=True)
neg_inputs = tokenizer(neg_input_texts, return_tensors="pt", padding=True)

pos_inputs["input_ids"].shape, neg_inputs["input_ids"].shape

(torch.Size([32, 199]), torch.Size([32, 200]))

In [36]:
pos_inputs["attention_mask"].sum(dim=1), neg_inputs["attention_mask"].sum(dim=1)

(tensor([136, 109, 133, 112, 142, 124,  96, 122,  98, 113, 135, 115, 161, 113,
         138,  89, 147, 108, 155, 199, 125,  89, 116, 125, 119, 116, 101, 132,
         111, 113, 114, 126]),
 tensor([137, 110, 134, 113, 143, 125,  97, 123,  99, 114, 136, 116, 162, 114,
         139,  90, 148, 109, 156, 200, 126,  90, 117, 126, 120, 117, 102, 133,
         112, 114, 115, 127]))

In [56]:
language_model = T5ForConditionalGeneration.from_pretrained("gpt2")
reporter = Reporter.load(reporter_path)

my_reward_model = MyRewardModel(language_model, reporter)

In [57]:
with torch.no_grad():
    credences = my_reward_model(pos_inputs, neg_inputs)

outputs = torch.sigmoid(credences)
outputs

tensor([0.5005, 0.5015, 0.4993, 0.5002])

In [58]:
# Compare ground-truth with predictions
dataset["label"], (outputs > 0.5).int().tolist()

([1, 1, 0, 1], [1, 1, 0, 1])