In [None]:
!pip install transformers==4.30.0 torch accelerate bitsandbytes
!pip install ftfy names scipy scikit-learn


In [None]:
pip install -U bitsandbytes

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# --- 1. Load Teacher (GPT-J) ---
teacher_name = "EleutherAI/gpt-j-6B"

print(f"Loading Teacher: {teacher_name}...")

tokenizer = AutoTokenizer.from_pretrained(teacher_name)

# CRITICAL FIX FOR BATCHING:
# GPT-J doesn't have a pad token by default. We must set it.
tokenizer.pad_token = tokenizer.eos_token
# For generation, we MUST pad on the LEFT so the model generates continuations correctly
tokenizer.padding_side = "left" 

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_name,
    quantization_config=bnb_config,
    device_map="auto",
)

print("Model loaded successfully with Padding Configured!")


Loading Teacher: EleutherAI/gpt-j-6B...


pytorch_model.bin:   0%|          | 0.00/24.2G [00:00<?, ?B/s]

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

Model loaded successfully with Padding Configured!


In [3]:

import torch.nn.functional as F

# --- Manual Implementation of compute_transition_scores ---
def compute_transition_scores(sequences, scores, normalize_logits=True):
    """
    Fallback implementation for older transformers versions.
    Calculates the log-probabilities of the selected tokens.
    """
    # 1. Stack the scores tuple into a single tensor
    # scores comes as a tuple of tensors (one per step)
    # Shape becomes: (batch_size, generated_len, vocab_size)
    scores_stack = torch.stack(scores, dim=1)
    
    # 2. Normalize logits to log-probs if requested (NLL requires this)
    if normalize_logits:
        scores_stack = F.log_softmax(scores_stack, dim=-1)
        
    # 3. Align sequences with scores
    # sequences contains [prompt + generated]. We only need [generated].
    gen_len = len(scores)
    
    # Take the last 'gen_len' tokens from the sequence
    generated_ids = sequences[:, -gen_len:]
    
    # 4. Gather the scores for the tokens that were actually selected
    # We use gather to pluck the specific score for each token ID
    generated_ids = generated_ids.unsqueeze(-1) # Shape: (batch, gen_len, 1)
    token_scores = torch.gather(scores_stack, 2, generated_ids) # Shape: (batch, gen_len, 1)
    
    return token_scores.squeeze(-1) # Shape: (batch, gen_len)

compute_transition_scores 

print("Manual compute_transition_scores defined.")

Manual compute_transition_scores defined.


In [5]:
# 1. Define the replacement function with NLL calculation
def complete_gpt3(prompt, l, model_name, num_log_probs=None, n=1, top_p=0.9, **kwargs):
    """
    Replaces OpenAI API call with local teacher_model generation.
    Handles BATCH inputs (lists of strings) correctly.
    """
    # CRITICAL: Enable padding for batch processing
    inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(teacher_model.device)
    input_length = inputs.input_ids.shape[1]
    
    with torch.no_grad():
        outputs = teacher_model.generate(
            **inputs,
            max_new_tokens=l,
            do_sample=True,
            top_p=top_p,
            num_return_sequences=n,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
            output_scores=True
        )
    
    transition_scores = compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=True
    )

    choices = []
    
    # Iterate over generated sequences
    # Note: If batch_size=7 and n=10, we have 70 sequences.
    # We must be careful to treat them as independent generations for the list.
    for idx, sequence in enumerate(outputs.sequences):
        generated_tokens = sequence[input_length:]
        full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
        
        if '\n' in full_text:
            generated_text = full_text.split('\n')[0]
        else:
            generated_text = full_text

        current_scores = transition_scores[idx].cpu().numpy().tolist()
        dummy_offsets = list(range(len(current_scores)))
        
        choice_obj = {
            "text": generated_text, 
            "finish_reason": "stop",
            "logprobs": {
                "text_offset": dummy_offsets, 
                "token_logprobs": current_scores 
            }
        }
        choices.append(choice_obj)

    return {"choices": choices}

print("Advanced Monkey patch applied with BATCH SUPPORT.")

Advanced Monkey patch applied with BATCH SUPPORT.


In [7]:
def get_key(source, target):
    return '{}'.format(json.dumps({'source':source, 'target':target}))
def few_shot_prompt(examples, ex2text, number = None, Person_list = None):
    template_str = ''
    
    i = -1
    for i, example in enumerate(examples[:-1]):
        
        if number:
            ex_str = ex2text(example, include_gen = True, number = i + 1)
        else:
            ex_str = ex2text(example, include_gen = True)
            
            
        if Person_list is not None:
            ex_str = name_PX_PY(ex_str, Person_list[i][0],Person_list[i][1] )
            
        template_str += ex_str
        
    i = i + 1
    
    if number:
        ex_str = ex2text(examples[-1], include_gen = False,number = i + 1)
    else:
        ex_str = ex2text(examples[-1], include_gen = False)
        
    if Person_list is not None:
        ex_str = name_PX_PY(ex_str, Person_list[i][0],Person_list[i][1] )
    
    template_str += ex_str
    
    return template_str
def scrub_PX_PY(s, PersonX, PersonY):
    return s.replace(PersonX, 'PersonX').replace(PersonY, 'PersonY')

def name_PX_PY(s, PersonX, PersonY):
    return s.replace('PersonX', PersonX).replace('PersonY', PersonY)
def gpt3(prompt, max_len, model_name, temp=0, num_log_probs=100, echo=False, n=1, **kwargs):
    """
    Replaces the original gpt3() function to use your local teacher_model.
    """
    print('calling local model (patch)')
    
    # 1. Encode Input
    inputs = tokenizer(prompt, return_tensors="pt").to(teacher_model.device)
    input_length = inputs.input_ids.shape[1]
    
    # 2. Setup Generation Arguments
    # The original gpt3 function uses temp=0 by default (Greedy Search)
    do_sample = True
    if temp == 0:
        do_sample = False
        temp = 1.0 # Temperature must be > 0 even if unused, to prevent errors in some configs
    
    # 3. Generate
    with torch.no_grad():
        outputs = teacher_model.generate(
            **inputs,
            max_new_tokens=max_len,
            do_sample=do_sample,
            temperature=temp,
            num_return_sequences=n if n else 1,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
            output_scores=True # Crucial for NLL calculation
        )

    # 4. Calculate Scores (Logprobs)
    # transition_scores contains the log-probability of each generated token
    transition_scores = compute_transition_scores(
        outputs.sequences, outputs.scores, normalize_logits=True
    )

    # 5. Format Response to mimic OpenAI API
    choices = []
    
    for idx, sequence in enumerate(outputs.sequences):
        # A. Decode text
        generated_tokens = sequence[input_length:]
        full_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
        
        # B. Handle Stop Token ('\n') processing manually
        if '\n' in full_text:
            text_output = full_text.split('\n')[0]
            # Calculate how many tokens we actually used (approximate for logprobs)
            # This is a heuristic; exact token alignment after split is complex but 
            # this is sufficient for NLL filtering.
            stop_index = full_text.find('\n')
            # We keep scores for the whole sequence; the notebook will slice them anyway
        else:
            text_output = full_text

        # C. Get Logprobs
        # Convert tensor to list of floats
        current_scores = transition_scores[idx].cpu().numpy().tolist()
        
        # D. Construct Response Object
        choice_obj = {
            "text": text_output,
            "finish_reason": "stop",
            "logprobs": {
                # Create dummy offsets so the downstream code's `.index(max(...))` logic works
                "text_offset": list(range(len(current_scores))),
                "token_logprobs": current_scores
            }
        }
        choices.append(choice_obj)

    return {"choices": choices}

In [9]:
'''

Template Definition

'''




'''

ADD xWant

'''

'''

ADD xWant

'''

##### DEFINE THIS ######
relation_to_test = 'xWant'
########################

examples_string = """PersonX is at a party	xWant	to drink beer and dance
PersonX bleeds a lot	xWant	to see a doctor
PersonX works as a cashier	xWant	to be a store manager
PersonX gets dirty	xWant	to clean up
PersonX stays up all night studying	xWant	to sleep all day
PersonX gets PersonY's autograph	xWant	to have a relationship with PersonY
PersonX ends a friendship	xWant	to meet new people
PersonX makes his own costume	xWant	to go to a costume party
PersonX calls PersonY	xWant	to have a long chat
PersonX tells PersonY a secret	xWant	to get PersonY's advice
PersonX mows the lawn	xWant	to get a new lawnmower
PersonX makes a huge mistake	xWant	to apologize
PersonX sees PersonY's point	xWant	to agree with PersonY
PersonX wants a tattoo	xWant	to find a tattoo design
PersonX tells PersonY something	xWant	to get PersonY's opinion
PersonX leaves PersonY's bike	xWant	to keep the bike safe
PersonX visits some friends	xWant	to catch up with friends
PersonX sits next to PersonY	xWant	to become better friends with PersonY
PersonX gives PersonY indication	xWant	to date PersonY
PersonX buys some popcorn	xWant	eat popcorn from bag in car"""
    
lines = [v.split('\t') for v in  examples_string.split('\n')]

examples = [{'head':v[0], 'relation':v[1], 'tail':v[2]} for v in lines]

assert(examples[0]['relation'] == relation_to_test)



def ex2text(ex, include_gen = True, number=None):

    text = ''
    
    if number is not None:
        text += 'Situation {}: '.format(number)
    
    text += '{}.\n\nWants: As a result, PersonX wants'.format(ex['head'],ex['head'])

    if include_gen:
        text += ' {}.\n\n'.format(ex['tail'])
    return text



xWant_dict = {'relation':'xWant', 
              'examples':examples[:7],
             'function':ex2text,
             'prefix':'How does this situation affect each character\'s wants?\n\n\n',
             'length':8}

'''

ADD xAttr

'''

##### DEFINE THIS ######
relation_to_test = 'xAttr'
########################

xAttr_examples_string = """PersonX bullies PersonY	xAttr	dominant
PersonX moves to another city	xAttr	adventurous
PersonX changes PersonY's mind	xAttr	persuasive
PersonX writes a story	xAttr	creative
PersonX covers PersonY's expenses	xAttr	generous
PersonX takes time off	xAttr	lazy
PersonX advises PersonY	xAttr	wise
PersonX bursts into tears	xAttr	sensitive
PersonX deals with problems	xAttr	determined
PersonX follows PersonY	xAttr	creepy"""
    
lines = [v.split('\t') for v in  xAttr_examples_string.split('\n')]

xAttr_examples = [{'head':v[0], 'relation':'xAttr', 'tail':v[2]} for v in lines]



def ex2text(ex, include_gen = True, number=None):

    text = ''
    
    if number is not None:
        text += '{}. '.format(number)
    
    text += 'When {}, people think PersonX is'.format(ex['head'])

    if include_gen:
        text += ' {}.\n\n'.format(ex['tail'])
    return text




xAttr_dict = {'relation':'xAttr', 
              'examples':xAttr_examples[:10],
             'function':ex2text,
             'prefix':'Next, we will discuss how people are seen in different situations. Examples:\n\n\n',
             'length':5}



'''

ADD xIntent

'''

##### DEFINE THIS ######
relation_to_test = 'xIntent'
########################



examples_string = """PersonX gets the newspaper	xIntent	to read the newspaper
PersonX works all night	xIntent	to meet a deadline
PersonX destroys PersonY	xIntent	to punish PersonY
PersonX clears her mind	xIntent	to be ready for a new task
PersonX wants to start a business	xIntent	to be self sufficient
PersonX ensures PersonY's safety	xIntent	to be helpful
PersonX buys lottery tickets	xIntent	to become rich
PersonX chases the ball	xIntent	to return it to some kids
PersonX calls customer service	xIntent	to solve a problem
PersonX hops around	xIntent	to exercise"""
    
    
lines = [v.split('\t') for v in  examples_string.split('\n')]

examples = [{'head':v[0], 'relation':v[1], 'tail':v[2]} for v in lines]


def ex2text(ex, include_gen = True, number=None):
    text = ''
    
    if number is not None:
        text += '{}. '.format(number)
    #text += '\n\nEvent: {}\n\nCharacter: PersonX\n\nIntention: "{}" because PersonX is trying'.format(ex['head'],ex['head'])
    text += '\n\nEvent: {}.\n\nIntention: "{}" because PersonX wants'.format(ex['head'],ex['head'])

    if include_gen:
        text += ' {}.\n\n'.format(ex['tail'])
    return text



xIntent_dict = {'relation':relation_to_test, 
              'examples':examples[:5],
             'function':ex2text,
               'prefix':'For each event, what do the characters want?\n\n\n'}


'''

ADD xEffect

'''

'''

ADD xEffect

'''

##### DEFINE THIS ######
relation_to_test = 'xEffect'
########################



examples_string = """PersonX recently divorced	xEffect	is single
PersonX lifts weights	xEffect	gains muscle mass
PersonX takes PersonY to a bar	xEffect	gets drunk
PersonX decides to hire a tutor	xEffect	learns to read
PersonX buys PersonY a gift	xEffect	is thanked by PersonY
PersonX hears bad news	xEffect	faints
PersonX buys something	xEffect	gets change
PersonX does a lot of work	xEffect	gets better at their job
PersonX attends the concert	xEffect	enjoys the music
PersonX performs PersonX's duties	xEffect	receives the expected salary
PersonX gets calls	xEffect	has a conversation
PersonX continues PersonY's search	xEffect	gets lost
PersonX sees logs	xEffect	calms down"""
    
lines = [v.split('\t') for v in  examples_string.split('\n')]

examples = [{'head':v[0], 'relation':v[1], 'tail':v[2]} for v in lines]

assert(examples[0]['relation'] == relation_to_test)



def ex2text(ex, include_gen = True):
    
    text = 'Because {},{}'.format(ex['head'], 
                                 ' she feels')
    if include_gen:
        text += ' {}\n#\n'.format(ex['tail'])
    return text


def ex2text(ex, include_gen = True, number=None):

    
    text = ''
    
    if number is not None:
        text += '{}.\n\n'.format(number)
    
    text += 'event: {}\n\noutcome: PersonX'.format(ex['head'],ex['head'])

    if include_gen:
        text += ' {}\n\n'.format(ex['tail'])
    return text



xEffect_dict = {'relation':relation_to_test, 
              'examples':examples[:10],
             'function':ex2text,
              'prefix':'What is the result?\n\n\n'}








'''

ADD xReact


'''

##### DEFINE THIS ######
relation_to_test = 'xReact'
########################



examples_string = """PersonX lives with PersonY's family	xReact	loved
PersonX expects to win	xReact	excited
PersonX comes home late	xReact	tired
PersonX sees dolphins	xReact	joyful
PersonX causes PersonY anxiety	xReact	guilty
PersonX goes broke	xReact	embarrassed
PersonX has a drink	xReact	refreshed
PersonX has a heart condition	xReact	scared about their health
PersonX shaves PersonY's hair	xReact	helpful
PersonX loses all of PersonY's money	xReact	horrible
PersonX finishes dinner	xReact	relaxed and content
PersonX decides to go see a movie	xReact	immersed in a different reality
PersonX comes home early	xReact	happy to be not at work
PersonX finds peace	xReact	in a position of grace
PersonX says anything else	xReact	glad that people listened to him"""
    
lines = [v.split('\t') for v in  examples_string.split('\n')]

examples = [{'head':v[0], 'relation':v[1], 'tail':v[2]} for v in lines]

assert(examples[0]['relation'] == relation_to_test)

##### DEFINE THIS ######
relation_to_test = 'xReact'
########################



examples_string = """PersonX lives with PersonY's family	xReact	loved
PersonX expects to win	xReact	excited
PersonX comes home late	xReact	tired
PersonX has a drink	xReact	refreshed
PersonX sees dolphins	xReact	joyful
PersonX causes PersonY anxiety	xReact	guilty
PersonX goes broke	xReact	embarrassed
PersonX has a drink	xReact	refreshed
PersonX has a heart condition	xReact	scared about their health
PersonX shaves PersonY's hair	xReact	helpful
PersonX loses all PersonY's money	xReact	bad they lost someone's belongings
PersonX finishes dinner	xReact	relaxed and content
PersonX decides to go see a movie	xReact	immersed in a different reality
PersonX comes home early	xReact	happy to be not at work
PersonX finds peace	xReact	in a position of grace
PersonX says anything else	xReact	glad that people listened to him"""
    
lines = [v.split('\t') for v in  examples_string.split('\n')]

examples = [{'head':v[0], 'relation':v[1], 'tail':v[2]} for v in lines]

assert(examples[0]['relation'] == relation_to_test)

def ex2text(ex, include_gen = True):
    
    text = 'Because {},{}'.format(ex['head'], 
                                 ' she feels')
    if include_gen:
        text += ' {}\n#\n'.format(ex['tail'])
    return text


def ex2text(ex, include_gen = True, number=None):

    
    text = ''
    
    if number is not None:
        text += '{}.\n\n'.format(number)
    
    text += 'Event: {}\n\nOutcome: PersonX feels'.format(ex['head'])

    if include_gen:
        text += ' {}\n\n'.format(ex['tail'])
    return text



xReact_dict = {'relation':relation_to_test, 
              'examples':examples[:10],
             'function':ex2text,
              'prefix':'How do the characters react to the event?\n\n\n'}






def ex2text(ex, include_gen = True):
    
    text = 'Because {},{}'.format(ex['head'], 
                                 ' she feels')
    if include_gen:
        text += ' {}\n#\n'.format(ex['tail'])
    return text


def ex2text(ex, include_gen = True, number=None):

    
    text = ''
    
    if number is not None:
        text += 'Situation {}: '.format(number)
    
    text += '{}.\n\nPersonX feels:'.format(ex['head'])

    if include_gen:
        text += ' {}.\n\n'.format(ex['tail'])
    return text



xReact_dict = {'relation':relation_to_test, 
              'examples':examples[:10],
             'function':ex2text,
              'prefix':'Next, how do people feel in each situation? Examples:\n\n\n',
              'length':8}






'''

ADD xNeed

'''

'''

ADD xNeed

'''

##### DEFINE THIS ######
relation_to_test = 'xNeed'
########################



examples_string = """PersonX gets a job offer	xNeed	to apply
PersonX eats the food	xNeed	to be hungry
PersonX gets a date	xNeed	to ask someone out
PersonX has a baby shower	xNeed	to invite people
PersonX makes many new friends	xNeed	to be social
PersonX changes PersonY's mind	xNeed	to make PersonY think about the issue
PersonX watches Netflix	xNeed	to have a Netflix account
PersonX rides PersonY's skateboard	xNeed	to be at the skate park
PersonX takes a quick nap	xNeed	to lie down
PersonX handles the situation	xNeed	to have an issue
PersonX gets a new phone	xNeed	to pay for the new phone
PersonX takes the trash out	xNeed	to gather the trash"""
    
lines = [v.split('\t') for v in  examples_string.split('\n')]

examples = [{'head':v[0], 'relation':v[1], 'tail':v[2]} for v in lines]

assert(examples[0]['relation'] == relation_to_test)


def ex2text(ex, include_gen = True, number = None):

    text = ''
    if number is not None:
        text += 'Event {}.: '.format(number)
    text += '{}\n\nPrerequisites: For this to happen, PersonX needed'.format(ex['head'],ex['head'])

    if include_gen:
        text += ' {}.\n\n'.format(ex['tail'])
    return text



xNeed_dict = {'relation':relation_to_test, 
              'examples':examples[:10],
             'function':ex2text,
             'prefix':'What needs to be true for this event to take place?\n\n\n'}








'''

ADD HinderedBy

TODO: select examples

'''

##### DEFINE THIS ######
relation_to_test = 'HinderedBy'
########################

examples_string = """PersonX makes a doctor's appointment	HinderedBy	PersonX can't find the phone to call the doctor
PersonX rubs PersonY's forehead	HinderedBy	PersonX is afraid to touch PersonY
PersonX eats peanut butter	HinderedBy	PersonX is allergic to peanuts
PersonX looks perfect	HinderedBy	PersonX can\'t find any makeup
PersonX goes on a run	HinderedBy	PersonX has a knee injury
PersonX takes PersonY to the emergency room	HinderedBy	PersonY has no health insurance to pay for medical care
PersonX spends time with PersonY's family	HinderedBy	PersonY’s family doesn't like spending time with PersonX
PersonX moves from place to place	HinderedBy	PersonX can\'t afford to move
PersonX protests the government	HinderedBy	PersonX is arrested
PersonX has a huge fight	HinderedBy	PersonX does not like confrontation
PersonX understands PersonY's feelings	HinderedBy	PersonY will not talk to PersonX
PersonX asks PersonY a question	HinderedBy	PersonY cannot hear PersonX
PersonX loves PersonY very much	HinderedBy	PersonY is mean to PersonX
PersonX makes a request	HinderedBy	PersonX is afraid of being rejected
PersonX meets PersonY's spouse	HinderedBy	PersonY's spouse is out of the country
PersonX fills a waterbottle	HinderedBy	There is no sink nearby to fill it
PersonX schedules an appointment	HinderedBy	The receptionist won't answer PersonX's calls
PersonX becomes rich	HinderedBy	PersonX has no revenue
PersonX gets the job done	HinderedBy	PersonX is too tired to work
PersonX sits alone in a room	HinderedBy	People come into the room and bother PersonX
PersonX evaluates PersonY’s performance	HinderedBy	PersonX misses PersonY's performance
PersonX mows lawns	HinderedBy	PersonX's lawn mower breaks"""
    
lines = [v.split('\t') for v in  examples_string.split('\n')]

examples = [{'head':v[0], 'relation':v[1], 'tail':v[2]} for v in lines]

assert(examples[0]['relation'] == relation_to_test)



## converts an example into text for few-shot
def ex2text(ex, include_gen = True, number = None):
    text = ''
    if number is not None:
        text += '{}. Event: '.format(number)
        
    #text += '"{}"\n\nCondition: "{}" will not occur under the condition that'.format(ex['head'],ex['head'],number)
    text += '{}.\n\nCondition: "{}" is untrue because'.format(ex['head'],ex['head'],number)

    if include_gen:
        text += ' {}.\n\n\n'.format(ex['tail'])
    return text


## dictionary defining few-shot template for HinderedBy relation
## prefix will be put at the beginning, before few-shot examples
HinderedBy_dict = {'relation':relation_to_test, 
              'examples':examples[:5],
             'function':ex2text,
                  'prefix': 'What condition stops this event from occurring?\n\n\n'}


relation_dicts = (HinderedBy_dict,
                 xNeed_dict, 
                 xWant_dict,
                 xIntent_dict,
                 xReact_dict,
                 xAttr_dict,
                 xEffect_dict)

In [13]:
'''

Figure out which heads we need to generate for

'''

import os
import json


# variable definitions
batch_name = 'curie_inference_generation_0' ## define name of this batch
engine = 'curie' # the GPT-3 engine to use
top_p = 0.5# top_p for nucleus sampling
n_gen = 10 # number of generations for each event/relation pair
PersonX = 'Alex' # name to use for PersonX (to make sentences more natural)
PersonY = 'Chris' # name to use for PersonY
length = 12 # maximum token length of generations (number of words per generaation will be less)

stop_token = '\n\n' # stop token during generation
#d = HinderedBy_dict # which relation to useo
n_examples_in = 5 # how many few-shot examples to use in the prompt

# names to use in few-shot examples
names = ['Jean','Robin','Charlie', 'Ryan','Taylor','Jordan','Riley','Jamie','Leslie','Rowan',
               'Adrian','Ali','Wyatt', 'Sydney','Stevie','Shiloh', 'Sam','Pat','Noel','Nicky','Max',
               'Madison', 'Lindsay','Leslie','Lee','Jesse','Hunter','Glen','Devin','Avery']



heads_file = '/kaggle/input/dataset-todo/todo_events (2).txt'
out_file = 'inference_gens.jsonl'
meta_file = 'params.json'


### save params if this is not done already (for records)
if not os.path.isfile( meta_file):
    relation_param_dict = [{key:d[key] for key in d.keys() if key !='function'} for d in  relation_dicts] 
    param_dict = {'batch_name':batch_name,
                    'engine':engine,
                    'top_p':top_p,
                    'n_gen':n_gen,
                    'PersonX':PersonX,
                    'PersonY':PersonY,
                    'n_examples_in':n_examples_in,
                    'length':length,
                  'stop_token':stop_token,
                  'names':names,
                 'relation_param_dict':relation_param_dict}
    with open(meta_file, 'w') as f:
        f.write(json.dumps(param_dict))
        
        
### create output file if it does not exist
if not os.path.isfile(out_file):
    with open(out_file,'w') as f:
        pass
    




In [14]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import random
import time
import json
import torch
import gc

t_start = time.time()
print_step = 100
RELATION_BATCH_SIZE = 3

with open(heads_file) as f: 
    heads_all = [head.strip() for head in f.readlines()]

heads_done = []
try:
    with open(out_file, 'r') as f:
        for line in f:
            try:
                heads_done.append(json.loads(line)['head'])
            except: pass
except FileNotFoundError:
    pass

heads_todo = [head for head in heads_all if head not in heads_done]
print('{} heads todo'.format(len(heads_todo)))

def chunker(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))

for i in range(len(heads_todo)):
    gc.collect()
    torch.cuda.empty_cache()
    
    time.sleep(0.02)
    
    head_to_test = heads_todo[i]
    ex_result = {'head': head_to_test, 'tails_by_relation': {}}

    random.shuffle(names)
    names_list = list(zip(names[:15], names[15:]))

    relation_tasks = []
    for d in relation_dicts:
        examples_in = d['examples'][:n_examples_in]
        prompt = few_shot_prompt(examples_in + [{'head':head_to_test}], d['function'], number=True,
                                 Person_list=names_list[:len(examples_in)]+ [(PersonX,PersonY)])
        prompt = d['prefix'] + prompt
        prompt = name_PX_PY(prompt, PersonX, PersonY)
        
        relation_tasks.append({'d': d, 'prompt': prompt})

    for batch_tasks in chunker(relation_tasks, RELATION_BATCH_SIZE):
        batch_prompts = [t['prompt'] for t in batch_tasks]
        
        try:
            result = complete_gpt3(batch_prompts, length, engine, top_p=top_p, num_log_probs=1, n=n_gen, stop=stop_token, echo=False)
        except RuntimeError as e:
            if "out of memory" in str(e):
                torch.cuda.empty_cache()
                result = complete_gpt3(batch_prompts, length, engine, top_p=top_p, num_log_probs=1, n=n_gen, stop=stop_token, echo=False)
            else:
                raise e

        for idx, task in enumerate(batch_tasks):
            d = task['d']
            relation_outputs = []
            
            start_idx = idx * n_gen
            choices = result['choices'][start_idx : start_idx + n_gen]
            
            for choice in choices:
                out = choice['text']
                try:
                    end_ind = choice['logprobs']['text_offset'].index(max(choice['logprobs']['text_offset']))
                    nll = sum(choice['logprobs']['token_logprobs'][:end_ind + 1])
                except:
                    nll = 0.0
                
                relation_outputs.append({'text':out, 'result':choice, 'nll':nll})

            ex_result['tails_by_relation'][d['relation']] = {'relation':d['relation'], 'prompt':task['prompt'], 'tails':relation_outputs}
        
        del result
        torch.cuda.empty_cache()

    with open(out_file,'a') as f:
        f.write(json.dumps(ex_result) + '\n')
        
    if (i % print_step) == 0:
        print('='*50)
        print('time: {:.2f} min, avg_rate: {:.2f}'.format((time.time()-t_start)/60.,(time.time()-t_start)/60./(i+1)))
        print('{}) {}'.format(i, ex_result['head']))
        print('='*50)

953 heads todo
time: 0.52 min, avg_rate: 0.52
0) PersonX is a good friend
time: 50.95 min, avg_rate: 0.50
100) PersonX makes an offhand remark about PersonY's appearance
time: 100.73 min, avg_rate: 0.50
200) PersonX has been fired from a job
time: 150.45 min, avg_rate: 0.50
300) PersonX decides to go backpacking
time: 199.81 min, avg_rate: 0.50
400) PersonX doesn't get the joke
time: 249.07 min, avg_rate: 0.50
500) PersonX uses a lot of water
time: 298.14 min, avg_rate: 0.50
600) PersonX wants to sleep all day
time: 347.04 min, avg_rate: 0.50
700) PersonX plays the guitar
time: 396.06 min, avg_rate: 0.49
800) PersonX goes to church
time: 445.04 min, avg_rate: 0.49
900) PersonX gets a letter


In [15]:
'''

ONLY RUN once you have generated all inferences

this block takes the full set of generated inferences,
and puts them in a standardized format, including removing inferences
that repeat for a given event/relation input.

It also automatically divides generations into a train/val/test split
based on the event (i.e. all inferences for a given event are sorted into the 
same split)

generations are saved as a jsonl file where each entry has the following keys:

head: the event for generation
relation: relation for generation
tail: the generated inference
split: the dataset split

'''

import json
    
    
full_dataset_file = 'unique_dataset.jsonl'
    

def name_PX_PY(s, PersonX, PersonY):
    return s.replace('PersonX', PersonX).replace('PersonY', PersonY)
def scrub_PX_PY(s, PersonX, PersonY):
    return s.replace(PersonX, 'PersonX').replace(PersonY, 'PersonY')
    
def process_tail(tail):
    if tail[-1] == '.':
        tail = tail[:-1]
    if tail[0] == ' ':
        tail = tail[1:]
    tail = scrub_PX_PY(tail, PersonX, PersonY)
    
    return tail




import json
import random
i = 0

data_out = []

with open(out_file) as f, open(full_dataset_file,'w') as f_out:
    for line in f:
        
        r = random.random()
        if r< 0.8:
            split = 'train'
        elif r < 0.9:
            split = 'val'
        else:
            split = 'test'


        d = json.loads(line)

        small_tails_by_relation = {}

        for relation in d['tails_by_relation'].keys():
            
            # remove short inferences (degenerate)
            d['tails_by_relation'][relation]['tails'] = [v for v in d['tails_by_relation'][relation]['tails'] if len(v['text']) > 2]
            
            small_tails_by_relation[relation] = {'relation':relation,
                                                 'tails':[{'text':process_tail(v['text'])} for v in d['tails_by_relation'][relation]['tails']]}


        
        
        for relation in small_tails_by_relation.keys():
            inferences = [v['text'] for v in small_tails_by_relation[relation]['tails']]
            
            # do not include repeats
            for inference in set(list(inferences)):
                data_out.append({'split':split,'head':d['head'], 'relation':relation, 'inference':inference })
        
        new_d = {'split':split,'head':d['head'],'tails_by_relation':small_tails_by_relation}
        
    
    for d in data_out:
        f_out.write(json.dumps(d) + '\n')

        
print('total of {} unique generated examples written to {}'.format(len(data_out), full_dataset_file))

total of 39144 unique generated examples written to unique_dataset.jsonl


In [16]:
import json  # This import was missing or lost

# Save Raw Data
with open("raw_atomic10x.jsonl", "w") as f:
    for entry in raw_data:
        f.write(json.dumps(entry) + "\n")

NameError: name 'raw_data' is not defined