In [1]:
import argparse
import random
import torch

import numpy as np
import torch.nn.functional as F

from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from my_datasets import (
  ParaphraseDetectionDataset,
  ParaphraseDetectionTestDataset,
  load_paraphrase_data
)
from evaluation_reft import model_eval_paraphrase, model_test_paraphrase
from models.gpt2 import GPT2Model

from optimizer import AdamW
import transformers

from transformers import GPT2Tokenizer
import os
import pyreft

import pyvene as pv
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import (
    Trainer,
    TrainingArguments,
    DataCollator,
    DataCollatorForSeq2Seq,
    AutoTokenizer
)
from datasets import Dataset
from dataclasses import dataclass
from typing import Dict, Optional, Sequence
from tqdm import tqdm
import os
import torch
import re
import evaluate
import numpy as np
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.utils import logging
from transformers.trainer_utils import (
    EvalPrediction,
    has_length,
    denumpify_detensorize
)
from pyreft import ReftDataCollator

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpt2 = transformers.AutoModelForCausalLM.from_pretrained('gpt2-large').to(device)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# gpt2_tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2-large')
gpt2_tokenizer.padding_side = "left"
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token
EOS_TOKEN=gpt2_tokenizer.eos_token

model_name_or_path = "gpt2-large"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

reft_model = pyreft.ReftModel.load(
    load_directory = "./reft_gpt_large_PARAPHRASE_BIGGER_AND_BETTER", model = model, from_huggingface_hub = False
)
reft_model.set_device(device)


Intervention key: layer_19_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_24_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_29_comp_block_output_unit_pos_nunit_1#0




Intervention key: layer_35_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_19_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_24_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_29_comp_block_output_unit_pos_nunit_1#0
Intervention key: layer_35_comp_block_output_unit_pos_nunit_1#0




In [17]:
print(device)

cuda


In [13]:
print('Padding side:', gpt2_tokenizer.padding_side)

Padding side: left


In [8]:
from types import SimpleNamespace

args = SimpleNamespace(
    para_train="data/quora-train.csv",
    para_dev="data/quora-dev.csv",
    para_test="data/quora-test-student.csv",
    para_dev_out="predictions/para-dev-output.csv",
    para_test_out="predictions/para-test-output.csv",
    seed=11711,
    epochs=10,
    use_gpu=False,  # change to True if you want GPU usage
    batch_size=32,
    lr=1e-5,
    model_size="gpt2-large"
)

para_train_data = load_paraphrase_data(args.para_train)
para_dev_data = load_paraphrase_data(args.para_dev)

para_train_data = ParaphraseDetectionDataset(para_train_data, args, tokenizer = gpt2_tokenizer)
para_dev_data = ParaphraseDetectionDataset(para_dev_data, args, tokenizer = gpt2_tokenizer)

para_train_dataloader = DataLoader(para_train_data, shuffle=True, batch_size=args.batch_size,
                                    collate_fn=para_train_data.collate_fn)
para_dev_dataloader = DataLoader(para_dev_data, shuffle=False, batch_size=args.batch_size,
                                collate_fn=para_dev_data.collate_fn)

Loaded 283003 train examples from data/quora-train.csv
Loaded 40429 train examples from data/quora-dev.csv


In [None]:
from sklearn.metrics import f1_score, accuracy_score
def model_eval_paraphrase_intervenable(dataloader, model, device, tokenizer, TQDM_DISABLE = False):
    model.eval()  # Turn off dropout and other randomness.
    y_true, y_pred, sent_ids = [], [], []
    
    yes_token_id = tokenizer.encode("yes", add_special_tokens=False)[0]
    # step, batch = next(enumerate(tqdm(dataloader, desc='eval', disable=TQDM_DISABLE)))
    # Decode batch with tokenizer
    # base_unit_location = batch["input_ids"].shape[-1] - 1
    # print(batch)
    for step, batch in enumerate(tqdm(dataloader, desc='eval', disable=TQDM_DISABLE)):
        b_ids = batch['token_ids'].to(device)
        b_mask = batch['attention_mask'].to(device)
        b_sent_ids = batch['sent_ids']
        labels = batch['labels'].flatten()
        
        
        # Compute the actual length (number of non-padded tokens) for each example.
        # Assuming that b_mask contains 1s for tokens and 0s for padding.
        lengths = b_mask.sum(dim=1)  # shape: [batch_size]
        
        # For each example, the base unit is the last non-padded token.
        # Create a nested list (one per sample) in the expected format.
        total_length = b_ids.shape[1]
        unit_locations_batch = [[[total_length - 1]] for _ in range(b_ids.shape[0])]
        
        # Prepare the input dictionary for the base prompt.
        prompt_batch = {"input_ids": b_ids, "attention_mask": b_mask}

        _, reft_response = model.generate(
            prompt_batch,
            unit_locations={"sources->base": (None, unit_locations_batch)},
            intervene_on_prompt=True,
            max_new_tokens=512,
            do_sample=True, 
            eos_token_id=tokenizer.eos_token_id,
            early_stopping=True
        )
        
        # Process each generated output.
        first_generated = reft_response[:, total_length]
        pred_batch = (first_generated == yes_token_id).long()
        true_batch = (labels.cpu() == yes_token_id).long()
        
        y_pred.extend(pred_batch.cpu().numpy().tolist())
        y_true.extend(true_batch.numpy().tolist())
        sent_ids.extend(b_sent_ids)
        break
    
    f1 = f1_score(y_true, y_pred, average='macro')
    acc = accuracy_score(y_true, y_pred)
    
    return acc, f1, y_pred, y_true, sent_ids

In [None]:
def make_data_collator(tokenizer, model) -> ReftDataCollator:
    data_collator_fn = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100,
        padding="longest",
        max_length=2048,
    )
    return ReftDataCollator(data_collator=data_collator_fn)


def make_dataloader(dataset: Dataset, batch_size: int, collate_fn: DataCollatorForSeq2Seq, shuffle: bool) -> DataLoader:
    return DataLoader(dataset, shuffle=shuffle, batch_size=batch_size, collate_fn=collate_fn)

def extract_answer(generation, trigger_tokens=""):
    """
    Extract the predicted answer (assumed to be the first token after any trigger text).
    """
    if trigger_tokens and trigger_tokens in generation:
        generation = generation.split(trigger_tokens)[-1]
    # Take the first token as the answer.
    answer = generation.strip().split()[0].lower()
    # Ensure the answer is either "yes" or "no"
    if answer not in ["yes", "no"]:
        if answer.startswith("y"):
            answer = "yes"
        elif answer.startswith("n"):
            answer = "no"
    return answer

def evaluate_paraphrase_detection(
    model,
    tokenizer,
    eval_dataset,
    batch_size=4,
    generation_args=None,
    data_collator = None,
    device=None,
    trigger_tokens=""
):
    """
    Evaluate a fine-tuned GPT2-Large model for paraphrase detection using generation.
    This function sets up the intervenable generation call by including:
      - the "base" argument with input_ids and attention_mask,
      - "unit_locations" computed from intervention locations (if provided) or a dummy value,
      - intervention on prompt enabled.
    It then decodes the outputs, extracts the answer, and computes accuracy.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    print(eval_dataset[0])
    data_collator = data_collator if data_collator is not None else \
        make_data_collator(tokenizer, model.model)

    dataloader = make_dataloader(eval_dataset, batch_size, data_collator, shuffle=False)
    print(next(iter(dataloader)))
    return
    correct_count = 0
    total_count = 0
    predictions = []
    gold_labels = []

    # Set default generation parameters; these can be updated via generation_args.
    default_generation_args = {
        "max_length": 512,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.pad_token_id,
        "do_sample": False,  # Greedy decoding by default
        "num_beams": 1       # Default to greedy decoding
    }
    if generation_args:
        default_generation_args.update(generation_args)
    eval_iterator = tqdm(dataloader, position=0, leave=True)
    for steps, inputs in enumerate(eval_iterator):
        print(inputs)
        # Move inputs to device.
        input_ids = batch["token_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        batch_labels_ids = batch["labels"].to(device)

        # Compute intervention locations following the sample compute_metrics function.
        # If your collate_fn provides "intervention_locations", use them;
        # otherwise, use a dummy value.
        if "intervention_locations" in batch:
            intervention_locations = batch["intervention_locations"]
            if intervention_locations.dim() == 3:
                intervention_locations = intervention_locations.permute(1, 0, 2)
            # Adjust intervention locations by computing left padding from the BOS token.
            left_padding = (batch["token_ids"] == tokenizer.bos_token_id).nonzero(as_tuple=True)[1]
            if left_padding.numel() > 0:
                left_padding = left_padding.reshape(1, -1, 1).to(device)
                intervention_locations = intervention_locations + left_padding
                intervention_locations = intervention_locations - 1  # Offset for sink padding.
            else:
                print("Warning: No BOS token found, skipping left padding adjustment.")
            # If using beam search, repeat intervention locations accordingly.
            num_beams = default_generation_args.get("num_beams", 1)
            if num_beams > 1:
                intervention_locations = intervention_locations.repeat_interleave(num_beams, dim=1).tolist()
            else:
                intervention_locations = intervention_locations.tolist()
        else:
            print('ALERT ALERT')
            intervention_locations = 0  # Dummy intervention_locations for non-intervenable batches.

        # Construct the generation arguments including the "base" input and intervention locations.
        gen_args = {
            "base": {"input_ids": input_ids, "attention_mask": attention_mask},
            "unit_locations": {"sources->base": (None, intervention_locations)},
            "intervene_on_prompt": True,
            "eos_token_id": tokenizer.eos_token_id,
            "early_stopping": True,
            "max_length": default_generation_args.get("max_length", 50),
            "pad_token_id": tokenizer.pad_token_id,
            "do_sample": default_generation_args.get("do_sample", False),
        }
        # Optionally add extra generation parameters.
        for key in ["temperature", "top_p", "top_k", "num_beams"]:
            if key in default_generation_args:
                gen_args[key] = default_generation_args[key]

        with torch.no_grad():
            generation_output = model.generate(**gen_args)
            if isinstance(generation_output, (tuple, list)):
                outputs = generation_output[1]
            else:
                outputs = generation_output


        # Decode generated outputs and gold labels.
        decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(batch_labels_ids, skip_special_tokens=True)

        # Process each output: extract the predicted answer and compare with gold.
        for pred_text, label_text in zip(decoded_outputs, decoded_labels):
            pred_answer = extract_answer(pred_text, trigger_tokens)
            label_answer = extract_answer(label_text, trigger_tokens)
            predictions.append(pred_answer)
            gold_labels.append(label_answer)
            if pred_answer == label_answer:
                correct_count += 1
            total_count += 1

    accuracy = correct_count / total_count if total_count > 0 else 0.0
    print(f"Paraphrase Detection Accuracy: {accuracy:.3f}")
    return predictions, gold_labels, accuracy

In [10]:
expected_tokens = gpt2_tokenizer.encode("<|assistant|>:")
print("Expected tokens for '<|assistant|>:'", expected_tokens)

Expected tokens for '<|assistant|>:' [27, 91, 562, 10167, 91, 31175]


In [None]:
evaluate_paraphrase_detection(
    model = reft_model,
    tokenizer = gpt2_tokenizer,
    eval_dataset = para_dev_data,
    batch_size = 4
)

In [12]:
model_eval_paraphrase_intervenable(para_dev_dataloader, reft_model, 'cuda', gpt2_tokenizer, TQDM_DISABLE = False)

eval:   0%|          | 0/1264 [00:00<?, ?it/s]


KeyError: 0