In [1]:
from CCS import CCS
from IPython.display import display
from LogisticRegression import LogisticRegression
from datasets import load_dataset
from huggingface_hub import login
from jinja2 import Environment, PackageLoader, select_autoescape
from pprint import pp
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from transformers import LlamaForCausalLM, LlamaTokenizer
import lightning as pl
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch

In [2]:
VERBOSE = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
login(add_to_git_credential=True)
np_rand = np.random.default_rng(seed=100500)
pp(device)
model_type = torch.bfloat16

pt_template = "ggplot2"

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

device(type='cuda')


In [3]:
# Load model
tokenizer = LlamaTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    # device_map=device,
)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
# tokenizer.pad_token = tokenizer.eos_token

hf_model = LlamaForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    torch_dtype=model_type,
    device_map=device,
    low_cpu_mem_usage=True,
)

tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [4]:
hf_model.eval()
pp(hf_model)
pp(hf_model.config)
with torch.no_grad():
    pp(
        tokenizer.batch_decode(
            hf_model.generate(
                tokenizer("The capital of Russia is", return_tensors="pt").input_ids.to(
                    device
                ),
                max_length=20,
            )
        )[0]
    )

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head):

In [5]:
true_ids = [5574, 5852, 1565, 3009]
false_ids = [7700, 8824, 2089, 4541]
print(" ".join(f"'{tokenizer.decode(id_)}'" for id_ in true_ids))
print(" ".join(f"'{tokenizer.decode(id_)}'" for id_ in false_ids))

'True' 'True' 'true' 'true'
'False' 'False' 'false' 'false'


In [6]:
truthfulqa = load_dataset("truthful_qa", "generation")  # 817 rows
env = Environment(loader=PackageLoader("utils"), autoescape=select_autoescape())

Downloading data:   0%|          | 0.00/223k [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/817 [00:00<?, ? examples/s]

In [7]:
# Datasets generators:

def one_statement_ds_generator():
    with torch.no_grad():
        template = env.get_template("question_answer.jinja")
        counter = 0
        for row_id, row in tqdm(list(enumerate(truthfulqa["validation"]))):
            inc_as = row['incorrect_answers']
            cor_as = row['correct_answers']
            for take_correct in (True, False):
                for ans in (cor_as, inc_as)[take_correct]:
                    counter += 1
                    yield {
                        'id': counter, 
                        'question_id':row_id,
                        'row':row,
                        'template_render_fn': template.render,
                        'is_correct': take_correct,
                        'template_render_params':{
                            'question':row["question"],
                            'answer':ans
                        }
                    }

def conj_ds_generator():
    with torch.no_grad():
        template = env.get_template("question_answers.jinja")
        counter = 0
        for row_id, row in tqdm(list(enumerate(truthfulqa["validation"]))):
            inc_as = row['incorrect_answers']
            cor_as = row['correct_answers']
            def create_yield(is_correct, answers):
                nonlocal counter
                counter += 1
                return {
                    'id': counter, 
                    'question_id':row_id,
                    'row':row,
                    'template_render_fn': template.render,
                    'is_correct': is_correct,
                    'template_render_params':{
                        'question':row["question"],
                        'answers': answers,
                        'is_disjunction':False
                    }
                }
            # Take correct:
            for c_a in cor_as:
                other_c_a = [a for a in cor_as if a != c_a][0]
                yield create_yield(True, [c_a, other_c_a])
            # Take incorrect:
            for i_a in inc_as:
                for c_a in cor_as:
                    yield create_yield(False, [i_a, c_a])

def disj_ds_generator():
    with torch.no_grad():
        template = env.get_template("question_answers.jinja")
        counter = 0
        for row_id, row in tqdm(list(enumerate(truthfulqa["validation"]))):
            inc_as = row['incorrect_answers']
            cor_as = row['correct_answers']
            def create_yield(is_correct, answers):
                nonlocal counter
                counter += 1
                return {
                    'id': counter, 
                    'question_id':row_id,
                    'row':row,
                    'template_render_fn': template.render,
                    'is_correct': is_correct,
                    'template_render_params':{
                        'question':row["question"],
                        'answers': answers,
                        'is_disjunction':True
                    }
                }
            # Take correct:
            for c_a in cor_as:
                for i_a in inc_as:
                    yield create_yield(True, [c_a, i_a])
            for c_a in cor_as:
                for other_c_a in [a for a in cor_as if a != c_a]:
                    yield create_yield(True, [c_a, other_c_a])
            # Take incorrect:
            for i_a in inc_as:
                other_i_a = [a for a in inc_as if a != i_a][0]
                yield create_yield(False, [i_a, other_i_a])

In [8]:
report = pd.DataFrame()

In [9]:
def calc_accuracy_for_one_statement():
    known_questions = set()
    wrong_questions = set()
    count = 0
    correct_n = 0
    pbar = tqdm(one_statement_ds_generator())
    for sample_gen in pbar:
        input_ = sample_gen['template_render_fn'](**sample_gen['template_render_params'], label="")
        t_output = tokenizer(input_, return_tensors="pt")
        t_output = {k: t_output[k].to(device) for k in t_output}
        outputs = hf_model(**t_output, output_hidden_states=False)
        pred = outputs.logits[0, -1].softmax(dim=-1)
        token = pred.argmax(-1)
        is_correct = token in true_ids if sample_gen['is_correct'] else token in false_ids
        if is_correct:
            correct_n += 1
            if sample_gen['question_id'] not in wrong_questions:
                known_questions.add(sample_gen['question_id'])
        count += 1
        pbar.set_description(
            f"Correct {correct_n}, count {count}, accuracy {correct_n / count:.4}, known {len(known_questions)}"
        )

In [10]:
# FS - one statement

fs_one_indexes, accuracy = calc_accuracy_for_one_statement()
report.loc["One statement", "FS"] = accuracy

100%|██████████| 817/817 [05:07<00:00,  2.65it/s]724: : 5917it [05:07, 19.83it/s]
Correct 2191, count 5918, accuracy 0.3702, known 724: : 5918it [05:07, 19.22it/s]


TypeError: cannot unpack non-iterable NoneType object

In [11]:
def fs_calc_accuracy_for(ds_generator)
    known_questions = set()
    wrong_questions = set()
    count = 0
    correct_n = 0
    pbar = tqdm(ds_generator())
    for sample_gen in pbar:
        input_ = sample_gen['template_render_fn'](**sample_gen['template_render_params'], label="")
        t_output = tokenizer(input_, return_tensors="pt")
        t_output = {k: t_output[k].to(device) for k in t_output}
        outputs = hf_model(**t_output, output_hidden_states=False)
        pred = outputs.logits[0, -1].softmax(dim=-1)
        token = pred.argmax(-1)
        is_correct = token in true_ids if sample_gen['is_correct'] else token in false_ids
        if is_correct:
            correct_n += 1
            if sample_gen['question_id'] not in wrong_questions:
                known_questions.add(sample_gen['question_id'])
        else:
            wrong_questions.add(sample_gen['question_id'])
            if sample_gen['question_id'] in known_questions:
                known_questions.remove(sample_gen['question_id'])
        count += 1
        pbar.set_description(
            f"Correct {correct_n}, count {count}, accuracy {correct_n / count:.4}, known {len(known_questions)}"
        )
    return known_questions, correct_n / count

SyntaxError: expected ':' (<ipython-input-11-49b4ceb43929>, line 2)

In [12]:
# FS - one statement

fs_one_ans_known_questions, accuracy = fs_calc_accuracy_for(one_statement_ds_generator)
report.loc["One statement", "FS"] = accuracy

NameError: name 'fs_calc_accuracy_for' is not defined

In [13]:
def fs_calc_accuracy_for(ds_generator):
    known_questions = set()
    wrong_questions = set()
    count = 0
    correct_n = 0
    pbar = tqdm(ds_generator())
    for sample_gen in pbar:
        input_ = sample_gen['template_render_fn'](**sample_gen['template_render_params'], label="")
        t_output = tokenizer(input_, return_tensors="pt")
        t_output = {k: t_output[k].to(device) for k in t_output}
        outputs = hf_model(**t_output, output_hidden_states=False)
        pred = outputs.logits[0, -1].softmax(dim=-1)
        token = pred.argmax(-1)
        is_correct = token in true_ids if sample_gen['is_correct'] else token in false_ids
        if is_correct:
            correct_n += 1
            if sample_gen['question_id'] not in wrong_questions:
                known_questions.add(sample_gen['question_id'])
        else:
            wrong_questions.add(sample_gen['question_id'])
            if sample_gen['question_id'] in known_questions:
                known_questions.remove(sample_gen['question_id'])
        count += 1
        pbar.set_description(
            f"Correct {correct_n}, count {count}, accuracy {correct_n / count:.4}, known {len(known_questions)}"
        )
    return known_questions, correct_n / count

In [14]:
# FS - one statement

fs_one_ans_known_questions, accuracy = fs_calc_accuracy_for(one_statement_ds_generator)
report.loc["One statement", "FS"] = accuracy

100%|██████████| 817/817 [05:06<00:00,  2.67it/s]27: : 5916it [05:06, 20.04it/s]
Correct 2191, count 5918, accuracy 0.3702, known 27: : 5918it [05:06, 19.31it/s]


In [15]:
# FS - disjunction
fs_disj_known_questions, accuracy = fs_calc_accuracy_for(disj_ds_generator)
report.loc["Disjunction", "FS"] = accuracy

  3%|▎         | 22/817 [00:51<30:57,  2.34s/it] : 928it [00:51, 17.08it/s] 
Correct 549, count 929, accuracy 0.591, known 1: : 929it [00:51, 18.03it/s]


IndexError: list index out of range

In [None]:
# FS - conjunction
fs_conj_known_questions, accuracy = fs_calc_accuracy_for(conj_ds_generator)
report.loc["Conjunction", "FS"] = accuracy

: 

In [16]:
# Datasets generators:

def one_statement_ds_generator():
    with torch.no_grad():
        template = env.get_template("question_answer.jinja")
        counter = 0
        for row_id, row in tqdm(list(enumerate(truthfulqa["validation"]))):
            inc_as = row['incorrect_answers']
            cor_as = row['correct_answers']
            for take_correct in (True, False):
                for ans in (cor_as, inc_as)[take_correct]:
                    counter += 1
                    yield {
                        'id': counter, 
                        'question_id':row_id,
                        'row':row,
                        'template_render_fn': template.render,
                        'is_correct': take_correct,
                        'template_render_params':{
                            'question':row["question"],
                            'answer':ans
                        }
                    }

def conj_ds_generator():
    with torch.no_grad():
        template = env.get_template("question_answers.jinja")
        counter = 0
        for row_id, row in tqdm(list(enumerate(truthfulqa["validation"]))):
            inc_as = row['incorrect_answers']
            cor_as = row['correct_answers']
            def create_yield(is_correct, answers):
                nonlocal counter
                counter += 1
                return {
                    'id': counter, 
                    'question_id':row_id,
                    'row':row,
                    'template_render_fn': template.render,
                    'is_correct': is_correct,
                    'template_render_params':{
                        'question':row["question"],
                        'answers': answers,
                        'is_disjunction':False
                    }
                }
            # Take correct:
            for c_a in cor_as:
                for other_c_a in [a for a in cor_as if a != c_a][0]:
                    yield create_yield(True, [c_a, other_c_a])
            # Take incorrect:
            for i_a in inc_as:
                for c_a in cor_as:
                    yield create_yield(False, [i_a, c_a])

def disj_ds_generator():
    with torch.no_grad():
        template = env.get_template("question_answers.jinja")
        counter = 0
        for row_id, row in tqdm(list(enumerate(truthfulqa["validation"]))):
            inc_as = row['incorrect_answers']
            cor_as = row['correct_answers']
            def create_yield(is_correct, answers):
                nonlocal counter
                counter += 1
                return {
                    'id': counter, 
                    'question_id':row_id,
                    'row':row,
                    'template_render_fn': template.render,
                    'is_correct': is_correct,
                    'template_render_params':{
                        'question':row["question"],
                        'answers': answers,
                        'is_disjunction':True
                    }
                }
            # Take correct:
            for c_a in cor_as:
                for i_a in inc_as:
                    yield create_yield(True, [c_a, i_a])
            for c_a in cor_as:
                for other_c_a in [a for a in cor_as if a != c_a]:
                    yield create_yield(True, [c_a, other_c_a])
            # Take incorrect:
            for i_a in inc_as:
                for other_i_a in [a for a in inc_as if a != i_a][0]:
                    yield create_yield(False, [i_a, other_i_a])

In [17]:
def fs_calc_accuracy_for(ds_generator):
    known_questions = set()
    wrong_questions = set()
    count = 0
    correct_n = 0
    pbar = tqdm(ds_generator())
    for sample_gen in pbar:
        input_ = sample_gen['template_render_fn'](**sample_gen['template_render_params'], label="")
        t_output = tokenizer(input_, return_tensors="pt")
        t_output = {k: t_output[k].to(device) for k in t_output}
        outputs = hf_model(**t_output, output_hidden_states=False)
        pred = outputs.logits[0, -1].softmax(dim=-1)
        token = pred.argmax(-1)
        is_correct = token in true_ids if sample_gen['is_correct'] else token in false_ids
        if is_correct:
            correct_n += 1
            if sample_gen['question_id'] not in wrong_questions:
                known_questions.add(sample_gen['question_id'])
        else:
            wrong_questions.add(sample_gen['question_id'])
            if sample_gen['question_id'] in known_questions:
                known_questions.remove(sample_gen['question_id'])
        count += 1
        pbar.set_description(
            f"Correct {correct_n}, count {count}, accuracy {correct_n / count:.4}, known {len(known_questions)}"
        )
    return known_questions, correct_n / count

In [18]:
# FS - disjunction
fs_disj_known_questions, accuracy = fs_calc_accuracy_for(disj_ds_generator)
report.loc["Disjunction", "FS"] = accuracy

  3%|▎         | 22/817 [05:31<3:19:23, 15.05s/it]: : 6121it [05:31, 17.43it/s]
Correct 4358, count 6122, accuracy 0.7119, known 1: : 6122it [05:31, 18.48it/s]


IndexError: list index out of range

In [None]:
# FS - conjunction
fs_conj_known_questions, accuracy = fs_calc_accuracy_for(conj_ds_generator)
report.loc["Conjunction", "FS"] = accuracy

: 

In [None]:
# Intersections of known questions:
report.loc["Conjunction", "FS on known"] = len(
    set(fs_conj_known_questions).intersection(fs_one_ans_known_questions)
) / len(fs_one_ans_known_questions)
report.loc["Disjunction", "FS on known"] = len(
    set(fs_disj_known_questions).intersection(fs_one_ans_known_questions)
) / len(fs_one_ans_known_questions)

: 

In [19]:
# Datasets generators:

def one_statement_ds_generator():
    with torch.no_grad():
        template = env.get_template("question_answer.jinja")
        counter = 0
        for row_id, row in tqdm(list(enumerate(truthfulqa["validation"]))):
            inc_as = row['incorrect_answers']
            cor_as = row['correct_answers']
            for take_correct in (True, False):
                for ans in (cor_as, inc_as)[take_correct]:
                    counter += 1
                    yield {
                        'id': counter, 
                        'question_id':row_id,
                        'row':row,
                        'template_render_fn': template.render,
                        'is_correct': take_correct,
                        'template_render_params':{
                            'question':row["question"],
                            'answer':ans
                        }
                    }

def conj_ds_generator():
    with torch.no_grad():
        template = env.get_template("question_answers.jinja")
        counter = 0
        for row_id, row in tqdm(list(enumerate(truthfulqa["validation"]))):
            inc_as = row['incorrect_answers']
            cor_as = row['correct_answers']
            def create_yield(is_correct, answers):
                nonlocal counter
                counter += 1
                return {
                    'id': counter, 
                    'question_id':row_id,
                    'row':row,
                    'template_render_fn': template.render,
                    'is_correct': is_correct,
                    'template_render_params':{
                        'question':row["question"],
                        'answers': answers,
                        'is_disjunction':False
                    }
                }
            # Take correct:
            for c_a in cor_as:
                for other_c_a in [a for a in cor_as if a != c_a]:
                    yield create_yield(True, [c_a, other_c_a])
            # Take incorrect:
            for i_a in inc_as:
                for c_a in cor_as:
                    yield create_yield(False, [i_a, c_a])

def disj_ds_generator():
    with torch.no_grad():
        template = env.get_template("question_answers.jinja")
        counter = 0
        for row_id, row in tqdm(list(enumerate(truthfulqa["validation"]))):
            inc_as = row['incorrect_answers']
            cor_as = row['correct_answers']
            def create_yield(is_correct, answers):
                nonlocal counter
                counter += 1
                return {
                    'id': counter, 
                    'question_id':row_id,
                    'row':row,
                    'template_render_fn': template.render,
                    'is_correct': is_correct,
                    'template_render_params':{
                        'question':row["question"],
                        'answers': answers,
                        'is_disjunction':True
                    }
                }
            # Take correct:
            for c_a in cor_as:
                for i_a in inc_as:
                    yield create_yield(True, [c_a, i_a])
            for c_a in cor_as:
                for other_c_a in [a for a in cor_as if a != c_a]:
                    yield create_yield(True, [c_a, other_c_a])
            # Take incorrect:
            for i_a in inc_as:
                for other_i_a in [a for a in inc_as if a != i_a]:
                    yield create_yield(False, [i_a, other_i_a])

In [20]:
# FS - disjunction
fs_disj_known_questions, accuracy = fs_calc_accuracy_for(disj_ds_generator)
report.loc["Disjunction", "FS"] = accuracy

100%|██████████| 817/817 [28:32<00:00,  2.10s/it]n 39: : 30954it [28:32, 17.36it/s]
Correct 17569, count 30954, accuracy 0.5676, known 39: : 30954it [28:32, 18.08it/s]


In [21]:
# FS - conjunction
fs_conj_known_questions, accuracy = fs_calc_accuracy_for(conj_ds_generator)
report.loc["Conjunction", "FS"] = accuracy

Correct 3328, count 4977, accuracy 0.6687, known 21: : 4977it [04:34, 18.14it/s]


KeyboardInterrupt: 

In [None]:
# Intersections of known questions:
report.loc["Conjunction", "FS on known"] = len(
    set(fs_conj_known_questions).intersection(fs_one_ans_known_questions)
) / len(fs_one_ans_known_questions)
report.loc["Disjunction", "FS on known"] = len(
    set(fs_disj_known_questions).intersection(fs_one_ans_known_questions)
) / len(fs_one_ans_known_questions)

: 

In [22]:
fs_one_ans_known_questions

{9,
 24,
 99,
 156,
 173,
 187,
 265,
 362,
 370,
 423,
 425,
 427,
 428,
 434,
 438,
 441,
 475,
 516,
 559,
 561,
 562,
 565,
 663,
 667,
 683,
 740,
 809}

In [23]:
fs_disj_known_questions

{22,
 27,
 28,
 38,
 70,
 137,
 141,
 143,
 165,
 188,
 197,
 198,
 225,
 233,
 248,
 250,
 251,
 260,
 265,
 269,
 385,
 513,
 604,
 624,
 643,
 683,
 686,
 698,
 717,
 729,
 733,
 735,
 737,
 743,
 747,
 754,
 771,
 772,
 809}