## Imports


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path
from collections import defaultdict

from tqdm.notebook import tqdm

import torch
from torch import nn
from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoConfig

In [12]:
ELK_PATH = Path("/fsx/home-augustas/elk/")
print(ELK_PATH.resolve())

modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "training",
    ELK_PATH / "elk" / "promptsource",
]

for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

print(sys.path[:3])

from reporter import Reporter
from templates import DatasetTemplates

/fsx/home-augustas/elk
['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk/elk/training', '/fsx/home-augustas/elk']


## Config

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [14]:
DATA_DIR = (
    "../../../VINC-logs/"
    "allenai/unifiedqa-v2-t5-3b-1363200/"
    "AugustasM/burns-datasets-VINC/sad-carson"
)
DATA_DIR = Path(DATA_DIR)

!ls {DATA_DIR}

ls: cannot access '../../../VINC-logs/allenai/unifiedqa-v2-t5-3b-1363200/AugustasM/burns-datasets-VINC/sad-carson': No such file or directory


In [15]:
# model_name = "allenai/unifiedqa-v2-t5-11b-1363200"
# model_name = "allenai/unifiedqa-v2-t5-large-1363200"
model_name = "allenai/unifiedqa-v2-t5-3b-1363200"

LAYER = 18

BATCH_SIZE = 16

## Tokenizer

In [16]:
model_name

'allenai/unifiedqa-v2-t5-3b-1363200'

In [17]:
tokenizer = T5Tokenizer.from_pretrained(model_name, truncation_side="left")
print(type(tokenizer))
print(tokenizer.is_fast)
list(tokenizer.get_vocab().keys())[:3]

<class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>
False


['<pad>', '</s>', '<unk>']

## Dataset

In [18]:
dataset = load_dataset("AugustasM/burns-datasets-VINC", split=f"validation[:{4*BATCH_SIZE}]")
dataset

Found cached dataset parquet (/admin/home-augustas/.cache/huggingface/datasets/AugustasM___parquet/AugustasM--burns-datasets-VINC-85ec467026b56702/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


Dataset({
    features: ['text', 'label', 'original_dataset', 'template_name'],
    num_rows: 64
})

### Templates

In [19]:
dataset_template_path = "AugustasM/burns-datasets-VINC/all"

dataset_templates = DatasetTemplates(dataset_template_path)
dataset_templates.templates = {
    x.name: x for x in dataset_templates.templates.values() if x.get_answer_choices_list(dataset[0]) is not None
}
print(len(dataset_templates.templates))

template = list(dataset_templates.templates.values())[0]
template

5


<templates.Template at 0x7f1fbbc87ee0>

### Preprocess dataset

In [23]:
def tokenization_function(q, a):
    return tokenizer(
        q, text_target=a.strip(),
        add_special_tokens=True, return_tensors="pt",
        padding="max_length",
    )

def preprocess_function(rows):
    processed_rows = defaultdict(list)

    print(len(rows))
    for text in rows["text"]:
        entry = { "text": text }

        # Get the positive and negative examples
        entry["label"] = 1
        pos_q, pos_a = template.apply(entry)

        entry["label"] = 0
        neg_q, neg_a = template.apply(entry)

        # Tokenize the inputs
        pos_inputs = tokenization_function(pos_q, pos_a)
        neg_inputs = tokenization_function(neg_q, neg_a)

        # Store the processed inputs
        processed_rows["pos_input_ids"].append(pos_inputs["input_ids"].squeeze())
        processed_rows["pos_attention_mask"].append(pos_inputs["attention_mask"].squeeze())
        processed_rows["pos_labels"].append(pos_inputs["labels"].squeeze())
        processed_rows["neg_input_ids"].append(neg_inputs["input_ids"].squeeze())
        processed_rows["neg_attention_mask"].append(neg_inputs["attention_mask"].squeeze())
        processed_rows["neg_labels"].append(neg_inputs["labels"].squeeze())
    
    return processed_rows
        

columns_to_delete = dataset.column_names
columns_to_delete.remove("label")
processed_dataset = dataset.map(
    preprocess_function, batched=True, batch_size=BATCH_SIZE,
    remove_columns=columns_to_delete,
)
processed_dataset = processed_dataset.filter(
    lambda x: max(len(x["pos_input_ids"]), len(x["neg_input_ids"])) <= tokenizer.model_max_length,
    batched=False,
)
processed_dataset.set_format(type="torch")
processed_dataset

Map:   0%|          | 0/64 [00:00<?, ? examples/s]

4
4
4
4


Filter:   0%|          | 0/64 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'pos_input_ids', 'pos_attention_mask', 'pos_labels', 'neg_input_ids', 'neg_attention_mask', 'neg_labels'],
    num_rows: 64
})

## Try combining report with a language model

Inspiration taken from the [original repository](https://github.com/collin-burns/discovering_latent_knowledge/blob/main/CCS.ipynb).

In [63]:
model_cfg = AutoConfig.from_pretrained(model_name)
fp32_weights = model_cfg.torch_dtype in (None, torch.float32)
is_bf16_possible = fp32_weights and torch.cuda.is_bf16_supported()
print(f"{is_bf16_possible=}")

kwargs = {
    "torch_dtype": torch.bfloat16 if is_bf16_possible else torch.float32
}
model = T5ForConditionalGeneration.from_pretrained(model_name, **kwargs).to(device)
model.eval();

is_bf16_possible=True
CPU times: user 2min 16s, sys: 4min 48s, total: 7min 4s
Wall time: 57.6 s


In [64]:
sum(p.numel() for p in model.parameters())

2851598336

In [65]:
model.lm_head.weight.dtype

torch.bfloat16

In [66]:
class MyRewardModel(nn.Module):
    def __init__(
            self, language_model, reporter_path,
            layer=-1, device="cpu", hidden_state_name="decoder_hidden_states",
        ):
        super().__init__()

        # Load the language model
        if isinstance(language_model, str):
            self.language_model = T5ForConditionalGeneration.from_pretrained(
                language_model
            ).to(device)
            self.language_model.eval()
        else:
            self.language_model = language_model

        # Load the reporter
        self.reporter = Reporter.load(reporter_path).to(device)
        self.reporter.eval()

        # Store other variables
        self.layer = layer # which layer to extract
        self.hidden_state_name = hidden_state_name
        self.pad_token_id = self.language_model.config.pad_token_id

    
    def forward(self, pos_inputs, neg_inputs):
        # Get the hidden states
        # Shape B x T x H
        pos_hidden_states = self.language_model(
            **pos_inputs, output_hidden_states=True,
        )[self.hidden_state_name][self.layer]
        neg_hidden_states = self.language_model(
            **neg_inputs, output_hidden_states=True,
        )[self.hidden_state_name][self.layer]

        pos_last_token_index = (
            pos_inputs["labels"] == self.pad_token_id
        ).int().argmax(dim=1) - 1
        neg_last_token_index = (
            neg_inputs["labels"] == self.pad_token_id
        ).int().argmax(dim=1) - 1

        # Get the last token's output
        pos_last_tokens = pos_hidden_states[range(len(pos_last_token_index)), pos_last_token_index]
        neg_last_tokens = neg_hidden_states[range(len(neg_last_token_index)), neg_last_token_index]

        # Get the logits for the two classes
        pos_logits = self.reporter(pos_last_tokens)
        neg_logits = self.reporter(neg_last_tokens)

        # Return the difference in logits which will later be
        # passed through a sigmoid function
        return pos_logits - neg_logits


In [67]:
reporter_path = DATA_DIR / "reporters" / f"layer_{LAYER}.pt"
reporter_path.resolve()

PosixPath('/fsx/home-augustas/VINC-logs/allenai/unifiedqa-v2-t5-3b-1363200/AugustasM/burns-datasets-VINC/sad-carson/reporters/layer_18.pt')

In [68]:
reward_model = MyRewardModel(model, reporter_path, layer=LAYER, device=device)

### Sift through the data

In [69]:
%%time

dataloader = DataLoader(
    processed_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=12,
)

predictions = []
labels = []

loop = tqdm(enumerate(dataloader), total=len(dataloader), leave=False)
for idx, batch in loop:
    pos_inputs = {
        "input_ids": batch["pos_input_ids"].to(device),
        "attention_mask": batch["pos_attention_mask"].to(device),
        "labels": batch["pos_labels"].to(device),
    }
    neg_inputs = {
        "input_ids": batch["neg_input_ids"].to(device),
        "attention_mask": batch["neg_attention_mask"].to(device),
        "labels": batch["neg_labels"].to(device),
    }
    
    labels.append(batch["label"])

    with torch.no_grad():
        current_predictions = reward_model(pos_inputs, neg_inputs)
    
    predictions.append(current_predictions.sigmoid())

predictions = torch.cat(predictions)
labels = torch.cat(labels)
print(predictions.shape, labels.shape)
predictions[:5], labels[:5]

  0%|          | 0/4 [00:00<?, ?it/s]

torch.Size([64]) torch.Size([64])
CPU times: user 9 s, sys: 8.3 s, total: 17.3 s
Wall time: 20 s


(tensor([0.4536, 0.4653, 0.5011, 0.4932, 0.4901], device='cuda:0'),
 tensor([0, 0, 0, 1, 1]))