In [None]:
!pip install transformers==4.30.0 torch accelerate bitsandbytes
!pip install ftfy names scipy scikit-learn


In [None]:
pip install -U bitsandbytes

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# --- 1. Load Teacher (GPT-J) Correctly in 8-bit ---
# Use the original base model, not the pre-quantized hivemind version
teacher_name = "EleutherAI/gpt-j-6B"

print(f"Loading Teacher: {teacher_name} with on-the-fly 8-bit quantization...")

tokenizer = AutoTokenizer.from_pretrained(teacher_name)

# Define the quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,              # Switch to 4-bit
    bnb_4bit_quant_type="nf4",      # Normalized Float 4 (better for accuracy)
    bnb_4bit_compute_dtype=torch.float16, # Compute in FP16 for speed
    bnb_4bit_use_double_quant=True, # Saves extra memory
)

teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_name,
    quantization_config=bnb_config, # Pass the config here
    device_map="auto",
)

print("Model loaded successfully!")


Loading Teacher: EleutherAI/gpt-j-6B with on-the-fly 8-bit quantization...


pytorch_model.bin:   0%|          | 0.00/24.2G [00:00<?, ?B/s]

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

Model loaded successfully!


In [3]:
'''

Load exampe events for few-shot generation, and define
function to convert these into few-shot prompts

'''


with open('/kaggle/input/seed-events/seed_events.txt') as f:
    seed_events = [line.strip() for line in f.readlines()]


## converts an example into text for few-shot
def ex2text(ex, include_gen = True, number = None):
    text = ''
    if number is not None:
        text += '{}. Event:'.format(number)
        
    text += ''.format()

    if include_gen:
        text += ' {}\n\n\n'.format(ex)
    return text


## create a few-shot prompt with the final 
## example left open for generation
##
### note: ex2text should be a function that takes 
## an example from examples and produces a string 
## template. This should accept an argument, include_gen 
## which is a bool to decide whether to leave it open for
## generation or include the gt
def few_shot_prompt(examples, ex2text, number = None, Person_list = None):
    template_str = ''
    
    i = -1
    for i, example in enumerate(examples[:-1]):
        
        if number:
            ex_str = ex2text(example, include_gen = True, number = i + 1)
        else:
            ex_str = ex2text(example, include_gen = True)
            
            
        if Person_list is not None:
            ex_str = name_PX_PY(ex_str, Person_list[i][0],Person_list[i][1] )
            
        template_str += ex_str
        
    i = i + 1
    
    if number:
        ex_str = ex2text(examples[-1], include_gen = False,number = i + 1)
    else:
        ex_str = ex2text(examples[-1], include_gen = False)
        
    if Person_list is not None:
        ex_str = name_PX_PY(ex_str, Person_list[i][0],Person_list[i][1] )
    
    template_str += ex_str
    
    return template_str


def scrub_PX_PY(s, PersonX, PersonY):
    return s.replace(PersonX, 'PersonX').replace(PersonY, 'PersonY')

def name_PX_PY(s, PersonX, PersonY):
    return s.replace('PersonX', PersonX).replace('PersonY', PersonY)

In [4]:

import torch.nn.functional as F

# --- Manual Implementation of compute_transition_scores ---
def compute_transition_scores(sequences, scores, normalize_logits=True):
    """
    Fallback implementation for older transformers versions.
    Calculates the log-probabilities of the selected tokens.
    """
    # 1. Stack the scores tuple into a single tensor
    # scores comes as a tuple of tensors (one per step)
    # Shape becomes: (batch_size, generated_len, vocab_size)
    scores_stack = torch.stack(scores, dim=1)
    
    # 2. Normalize logits to log-probs if requested (NLL requires this)
    if normalize_logits:
        scores_stack = F.log_softmax(scores_stack, dim=-1)
        
    # 3. Align sequences with scores
    # sequences contains [prompt + generated]. We only need [generated].
    gen_len = len(scores)
    
    # Take the last 'gen_len' tokens from the sequence
    generated_ids = sequences[:, -gen_len:]
    
    # 4. Gather the scores for the tokens that were actually selected
    # We use gather to pluck the specific score for each token ID
    generated_ids = generated_ids.unsqueeze(-1) # Shape: (batch, gen_len, 1)
    token_scores = torch.gather(scores_stack, 2, generated_ids) # Shape: (batch, gen_len, 1)
    
    return token_scores.squeeze(-1) # Shape: (batch, gen_len)

compute_transition_scores 

print("Manual compute_transition_scores defined.")

Manual compute_transition_scores defined.


In [5]:

# 1. Define the replacement function with NLL calculation
def complete_gpt3(prompt, l, model_name, num_log_probs=None, n=1, top_p=0.9, **kwargs):
    """
    Replaces OpenAI API call with local teacher_model generation.
    Calculates REAL log-probabilities so NLL filtering works correctly.
    """
    # Encode the few-shot prompt
    inputs = tokenizer(prompt, return_tensors="pt").to(teacher_model.device)
    input_length = inputs.input_ids.shape[1]
    
    with torch.no_grad():
        # Generate 'n' sequences with scores enabled
        outputs = teacher_model.generate(
            **inputs,
            max_new_tokens=l,
            do_sample=True,
            top_p=top_p,
            num_return_sequences=n,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True, # REQUIRED for scores
            output_scores=True            # REQUIRED for scores
        )
    
    # Calculate the log-probabilities for the generated tokens
    # transition_scores[i] = log P(token_i | context)
    # These are negative numbers (e.g. -0.5). Summing them gives the NLL.
    transition_scores = compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=True
    )

    choices = []
    
    # Iterate over every generated sequence in the batch
    for idx, sequence in enumerate(outputs.sequences):
        # 1. Decode the generated text (excluding prompt)
        generated_tokens = sequence[input_length:]
        full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
        
        # Stop at newline (the notebook expects 1 event per line)
        if '\n' in full_text:
            # If there's a newline, we technically only want the text before it.
            # For simplicity in NLL calculation, we will just use the full scores 
            # returned for the tokens generated so far.
            generated_text = full_text.split('\n')[0]
        else:
            generated_text = full_text

        # 2. Get the scores for these specific tokens
        # transition_scores is shape (batch_size, generated_seq_len)
        # We grab the row for this specific sequence
        # Move to CPU and convert to list
        current_scores = transition_scores[idx].cpu().numpy().tolist()
        
        # 3. Construct the Mock OpenAI Response
        # We construct 'text_offset' as a simple range so the original code's "max index" logic works.
        # The original code uses: end_ind = choice['logprobs']['text_offset'].index(max(...))
        dummy_offsets = list(range(len(current_scores)))
        
        choice_obj = {
            "text": generated_text, 
            "finish_reason": "stop",
            "logprobs": {
                # This dummy offset list ensures the NLL summation loop runs for the full length
                "text_offset": dummy_offsets, 
                # CRITICAL: These are now the REAL log probabilities
                "token_logprobs": current_scores 
            }
        }
        choices.append(choice_obj)

    return {"choices": choices}

print("Advanced Monkey patch applied! 'complete_gpt3' now calculates real NLL.")

Advanced Monkey patch applied! 'complete_gpt3' now calculates real NLL.


In [6]:
engine = 'davinci' # gpt-3 model version to use

length= 12 # maximum token length of generations. 12 tokens is up to 12 words, but likely less
top_p = 0.9 # top_p to use for nucleus sampling
n_gen = 100 # number of generations per-batch
presence_penalty = 0.5 # discourage tokens in the prompt
frequency_penalty = 0.5 # discourage tokens in the prompt based on frequency
PersonX = 'PersonX' # name to use for PersonX. Default is just "PersonX"
PersonY = 'PersonY' # name to use for PersonY. Default is just "PersonY"



n_examples_in = 10# how many few-shot examples to use per generation batch
n_batches = 20 # how many generation batches to run



batches = []

anon_heads = []

In [7]:
import json
import time
import random
import torch  # Required for GPU memory management
import gc     # Required for garbage collection
import os
# Define output file path
output_filename = '/kaggle/working/generated_events.jsonl'

if os.path.exists(output_filename):
    os.remove(output_filename)
# Use a set instead of a list for anon_heads to save memory and make lookups faster
anon_heads = set() 

for i in range(n_batches):
    # sleep to prevent timeout from api
    time.sleep(0.05)
    
    # randomize which events are used for generation
    random.shuffle(seed_events)
    examples_in = seed_events[:n_examples_in]

    # use names defined above
    names = (PersonX, PersonY)
    
    # define prompt based on examples and names
    prompt = few_shot_prompt(examples_in + [''], ex2text, number=True)
    prompt = name_PX_PY(prompt, names[0], names[1])

    ## generate using the prompt
    print('='*20 + f' Batch {i+1}/{n_batches} ' + '='*20)
    # print(prompt) # Optional: Comment out to reduce clutter in logs
    
    result = complete_gpt3(prompt, length, engine, top_p=top_p, num_log_probs=1, n=n_gen, stop='\n\n', echo=False,
                           frequency_penalty=frequency_penalty, presence_penalty=presence_penalty)

    ### Process the output
    outputs = []
    for choice in result['choices']:
        try:
            out = choice['text']
        except:
            out = ''
            
        # Handle logprobs logic (assuming your monkey patch provides this structure)
        try:
            end_ind = choice['logprobs']['text_offset'].index(max(choice['logprobs']['text_offset']))
            nll = sum(choice['logprobs']['token_logprobs'][:end_ind + 1])
        except:
            nll = 0.0 # Fallback if monkey patch doesn't provide perfect logprobs

        text = choice['text']
        anon_text = scrub_PX_PY(out, names[0], names[1]).strip()

        ## Validations
        if choice['finish_reason'] != 'stop' or not (anon_text.startswith('PersonX')):
            continue

        outputs.append({
            'text': text,
            'anon_text': anon_text,
            'result': choice,
            'nll': nll,
            'prompt': prompt
        })
        
        # Add to set for tracking uniqueness
        anon_heads.add(anon_text)

    # --- CRITICAL CHANGE 1: Write immediately to file ---
    # We open in 'a' (append) mode so we add to the file batch by batch
    batch_data = {'prompt': prompt, 'events': outputs}
    with open(output_filename, 'a') as f:
        f.write(json.dumps(batch_data) + '\n')

    print(f'Batch saved. Total unique events so far: {len(anon_heads)}')

    # --- CRITICAL CHANGE 2: Clear Memory ---
    # Delete large variables to free up RAM
    del result
    del outputs
    del batch_data
    
    # Force Python's garbage collector to run
    gc.collect()
    
    # Clear PyTorch's cache to free up GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

Batch saved. Total unique events so far: 80
Batch saved. Total unique events so far: 169
Batch saved. Total unique events so far: 253
Batch saved. Total unique events so far: 346
Batch saved. Total unique events so far: 434
Batch saved. Total unique events so far: 518
Batch saved. Total unique events so far: 588
Batch saved. Total unique events so far: 677
Batch saved. Total unique events so far: 754
Batch saved. Total unique events so far: 830
Batch saved. Total unique events so far: 911
Batch saved. Total unique events so far: 976
Batch saved. Total unique events so far: 1051
Batch saved. Total unique events so far: 1128
Batch saved. Total unique events so far: 1195


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.57 GiB. GPU 0 has a total capacity of 15.89 GiB of which 2.05 GiB is free. Process 4743 has 13.84 GiB memory in use. Of the allocated memory 10.75 GiB is allocated by PyTorch, and 2.80 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [8]:
import json

generated_batches = []
# Use 'generated_events.jsonl' directly as saved in the previous step
with open('/kaggle/working/generated_events.jsonl', 'r') as f:
    for line in f: # Iterate line by line to save RAM
        generated_batches.append(json.loads(line))

# Extract events
all_events_list = []
for batch in generated_batches:
    for d in batch['events']:
        all_events_list.append((d['anon_text'], d['nll']))

# Deduplicate (Keep the last occurrence)
all_events_dict = {v[0]: v[1] for v in all_events_list}
# Convert back to list of tuples
all_events = list(all_events_dict.items())

# --- SAFETY CHECK FOR LOCAL MODEL ---
# Check if NLLs are actually useful (not all 0.0)
nll_values = [v[1] for v in all_events]
if sum(nll_values) == 0:
    print("Warning: All NLL values are 0.0 (likely from Monkey Patch). Skipping truncation.")
    # Do NOT truncate. Keep everything.
else:
    # Sort Ascending (Lowest NLL is best)
    all_events.sort(key=lambda v: v[1]) 
    
    # Keep top 80% (The ones with lowest NLL)
    cutoff = int(0.8 * len(all_events))
    all_events = all_events[:cutoff]
    print(f"Truncated bottom 20%. Remaining events: {len(all_events)}")

# Prepare for next step
# The next step likely expects a list of strings
final_event_list = [v[0] for v in all_events]

Truncated bottom 20%. Remaining events: 956


In [9]:
''' 

Finally, remove any events with strange formatting or degenerate properties

''' 
import os
todo_events = []

output_filename = '/kaggle/working/todo_events.txt'

if os.path.exists(output_filename):
    os.remove(output_filename)
for event,_ in all_events:

    if any(not (c.isalnum() or c in '\',".-â€™$ ') for c in event) or (len(event) < len(PersonX)):
        continue
    todo_events += [event]
    

# write these to a file for use in inference generation
with open('/kaggle/working/todo_events.txt','w') as f:
    for event in todo_events:
        f.write('{}\n'.format(event))