# Imports

In [1]:
import sys
sys.path.append(r"C:\Users\andre\Documents\School\Hoger\Masterproef\Code\Automatic-Circuit-Discovery")
sys.path.append(r"C:\Users\andre\Documents\School\Hoger\Masterproef\Code\TransformerLens")

In [2]:
# from acdc.text_entailment.utils import get_all_text_entailment_things
from acdc.text_entailment.utils import determine_relevant_rules_and_facts
from acdc.text_entailment.utils import parse_input_sequence
from acdc.text_entailment.utils import extract_rules_and_facts
from acdc.text_entailment.utils import generate_corrupt_examples
from acdc.text_entailment.utils import add_necessary_facts
from acdc.text_entailment.utils import reconstruct_theory
from acdc.text_entailment.utils import generate_dummy_fact
from acdc.text_entailment.utils import Fact
from acdc.text_entailment.utils import added_facts_satisfy_a_final_rule

# from transformer_lens import HookedEncoder
# from ACDCPPExperiment import ACDCPPExperiment

import numpy as np
import torch
import tqdm.notebook as tqdm
import json
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


# General set-up

In [4]:
DEPTH = 0
QDep = False
NO_RCONC = False
RETRAINED = False

# num_examples
N = 2

tokenizer_name = "bert-base-uncased"
model_name = "andres-vs/bert-base-uncased-finetuned_Att-Noneg"
dataset_name = f"andres-vs/ruletaker-Att-Noneg-"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if QDep:
    dataset_name += f"QDep{DEPTH}-NoRconc"
    if RETRAINED:
        model_name = model_name + f"-QDep{DEPTH}-NoRconc_retrained"
    else:
        model_name = model_name + f"-QDep{DEPTH}-NoRconc"
else:
    dataset_name += f"depth{DEPTH}"
    if RETRAINED:
        model_name = model_name + f"-depth{DEPTH}_retrained-1"
    else:
        model_name = model_name + f"-depth{DEPTH}"

In [5]:
print(model_name)
print(dataset_name)

andres-vs/bert-base-uncased-finetuned_Att-Noneg-depth0
andres-vs/ruletaker-Att-Noneg-depth0


In [6]:
dataset = load_dataset(dataset_name)

# EAP experiment set-up

In [8]:
things = get_all_text_entailment_things(model_name, dataset['test'], N, device, metric_name="abs_logit_diff_diff")

If using BERT for interpretability research, keep in mind that BERT has some significant architectural differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning that the last LayerNorm in a block cannot be folded.


: 

# text_entailment.utils test base

In [6]:
example = dataset['test'][107]
print(example['input'])

Charlie is quiet. Quiet, furry things are white. Gary is quiet. All quiet things are rough. Bob is furry. Rough things are white. If something is quiet and blue then it is rough. If something is white and rough then it is furry. Furry, kind things are rough.[SEP]Bob is white.


In [7]:
raw_rules, raw_facts, raw_query = extract_rules_and_facts(example['input'])

In [12]:
print("raw rules:", raw_rules)

raw rules: [('If someone is big and round then they are white', 0), ('All nice, cold people are rough', 1), ('If Charlie is rough and Charlie is blue then Charlie is nice', 2), ('If Dave is nice and Dave is cold then Dave is round', 4), ('If Dave is big then Dave is white', 6), ('Nice people are blue', 10), ('All cold people are blue', 21), ('If Dave is rough and Dave is nice then Dave is big', 22)]


In [7]:
rules, facts, query = parse_input_sequence(example['input'])
print(f"Rules: {rules}")
print(f"Facts: {facts}")
print(f"Query: {query}")

Rules: [<acdc.text_entailment.utils.Rule object at 0x000002012D989930>, <acdc.text_entailment.utils.Rule object at 0x000002014D969D80>, <acdc.text_entailment.utils.Rule object at 0x000002014D9688B0>, <acdc.text_entailment.utils.Rule object at 0x000002014D969FF0>, <acdc.text_entailment.utils.Rule object at 0x000002014D96A0E0>, <acdc.text_entailment.utils.Rule object at 0x000002014D969EA0>]
Facts: [<acdc.text_entailment.utils.Fact object at 0x000002014D96A170>, <acdc.text_entailment.utils.Fact object at 0x000002014D96A200>, <acdc.text_entailment.utils.Fact object at 0x000002014D969E40>]
Query: Bob is white.


In [8]:
relevant_rules, relevant_facts, query = determine_relevant_rules_and_facts(rules, facts, query, example['depth'])
print(f"Relevant rules: {relevant_rules}")
print(f"Relevant facts: {relevant_facts}")
print(f"Query: {query}")

Relevant rules: [<acdc.text_entailment.utils.Rule object at 0x000002012D989930>, <acdc.text_entailment.utils.Rule object at 0x000002014D9688B0>, <acdc.text_entailment.utils.Rule object at 0x000002014D96A0E0>, <acdc.text_entailment.utils.Rule object at 0x000002014D969D80>, <acdc.text_entailment.utils.Rule object at 0x000002014D969FF0>, <acdc.text_entailment.utils.Rule object at 0x000002014D969EA0>]
Relevant facts: []
Query: Bob is white.


In [9]:
for rule in relevant_rules:
    print(rule)
for fact in relevant_facts:
    print(fact)

Quiet, furry things are white.
Rough things are white.
If something is white and rough then it is furry.
All quiet things are rough.
If something is quiet and blue then it is rough.
Furry, kind things are rough.


In [10]:
final_rules = []
for rule in relevant_rules:
    is_general_entity = rule.consequence_entity in ["[ANY_PERSON]", "[ANY_THING]"]
    consequence_property_match = any(prop == query.property for prop in rule.consequence_property)
    
    if ((rule.consequence_entity == query.entity or is_general_entity) and 
        consequence_property_match and
        rule.consequence_negative == query.negative):
        final_rules.append(rule)

In [11]:
print("Final rules:")
for rule in final_rules:
    print(rule)

Final rules:
Quiet, furry things are white.
Rough things are white.


In [13]:
facts_to_add = []
one_int_rule_found = False
for final_rule in final_rules:
    print("Final rule:", final_rule)
    # Identify the premises we need to make true
    needed_premises = []
    for prop in final_rule.premise_property:
        needed_premises.append(Fact(
            # questionable? should needed premises not always have query.entity?
            entity=query.entity if final_rule.premise_entity in ["[ANY_PERSON]", "[ANY_THING]"] else final_rule.premise_entity,
            property=prop,
            negative=final_rule.premise_negative
        ))
    # print("Needed premises:")
    for premise in needed_premises:
        print("Needed premise:", premise)
        # print(premise)

        premise_added = False
                
        # Check if the premise already exists in facts
        if any(fact.entity == query.entity and 
                fact.property == premise.property and
                fact.negative == premise.negative
                for fact in facts + facts_to_add):
            premise_added = True
            print("premise in theory already:", premise_added)
            continue
        print("premise in theory already:", premise_added)

        # Try to find an intermediate rule to derive this premise
        for int_rule in rules:
            rule_is_general_entity = int_rule.consequence_entity in ["[ANY_PERSON]", "[ANY_THING]"]
            premise_is_general_entity = premise.entity in ["[ANY_PERSON]", "[ANY_THING]"]
            consequence_property_match = any(prop == premise.property for prop in int_rule.consequence_property)
            
            if ((int_rule.consequence_entity == premise.entity or rule_is_general_entity or premise_is_general_entity) and
                consequence_property_match and
                int_rule.consequence_negative == premise.negative):
                print("Intermediate rule found:")
                print(int_rule)
                # Found a rule that can derive the premise
                one_int_rule_found = True
                int_needed_premises = []
                if int_rule.premises:
                    int_needed_premises = int_rule.premises
                else:
                    for prop in int_rule.premise_property:
                        int_needed_premises.append(Fact(
                            entity=premise.entity if int_rule.premise_entity in ["[ANY_PERSON]", "[ANY_THING]"] else int_rule.premise_entity,
                            property=prop,
                            negative=int_rule.premise_negative
                        ))
                print("Needed premises for intermediate rule:")
                all_premises_exist = True
                for int_premise in int_needed_premises:
                    print(int_premise)
                    if not any(fact.entity == query.entity and 
                                fact.property == int_premise.property and
                                fact.negative == int_premise.negative
                                for fact in facts + facts_to_add):
                        # verify that the added fact(s) do not make any final rule immediately true without intermediate rule, if it does -> 
                        facts_to_add.append(Fact(
                            entity=query.entity,
                            property=int_premise.property,
                            negative=int_premise.negative
                        ))
                        all_premises_exist = False
                print("facts to add:", [str(fact) for fact in facts_to_add])

                print("Final immediately satisfied:", added_facts_satisfy_a_final_rule(facts + facts_to_add, final_rules, query))
                if all_premises_exist or facts_to_add and not added_facts_satisfy_a_final_rule(facts + facts_to_add, final_rules, query):
                    premise_added = True
                    break
                else:
                    facts_to_add = []
        
        # If no rule can derive the premise, add it directly
        if not premise_added:
            facts_to_add.append(premise)

            

    # If we've found a way to satisfy this rule using at least one intermediate rule, we're done
    if facts_to_add and one_int_rule_found:
        break
    else:
        facts_to_add = []
print("Facts to add:")
for fact in facts_to_add:
    print(fact)

Final rule: Quiet, furry things are white.
Needed premise: Bob is quiet.
premise in theory already: False
Needed premise: Bob is furry.
premise in theory already: True
Final rule: Rough things are white.
Needed premise: Bob is rough.
premise in theory already: False
Intermediate rule found:
All quiet things are rough.
Needed premises for intermediate rule:
[ANY_THING] is quiet.
facts to add: ['Bob is quiet.']
Checking rule: Quiet, furry things are white.
Final immediately satisfied: True
Checking rule: Quiet, furry things are white.
Intermediate rule found:
If something is quiet and blue then it is rough.
Needed premises for intermediate rule:
[ANY_THING] is quiet.
[ANY_THING] is blue.
facts to add: ['Bob is quiet.', 'Bob is blue.']
Checking rule: Quiet, furry things are white.
Final immediately satisfied: True
Checking rule: Quiet, furry things are white.
Intermediate rule found:
Furry, kind things are rough.
Needed premises for intermediate rule:
[ANY_THING] is furry.
[ANY_THING] is 

In [13]:
facts_to_add = []
# First, find rules that directly infer the query
final_rules = []
for rule in rules:
    is_general_entity = rule.consequence_entity in ["[ANY_PERSON]", "[ANY_THING]"]
    consequence_property_match = any(prop == query.property for prop in rule.consequence_property)
    
    if ((rule.consequence_entity == query.entity or is_general_entity) and 
        consequence_property_match and
        rule.consequence_negative == query.negative):
        final_rules.append(rule)

# For each final rule, find intermediate rules and necessary facts
one_int_rule_found = False
for final_rule in final_rules:
    # Identify the premises we need to make true
    needed_premises = []
    for prop in final_rule.premise_property:
        needed_premises.append(Fact(
            # questionable? should needed premises not always have query.entity?
            entity=query.entity if final_rule.premise_entity in ["[ANY_PERSON]", "[ANY_THING]"] else final_rule.premise_entity,
            property=prop,
            negative=final_rule.premise_negative
        ))
    
    for premise in needed_premises:
        premise_added = False
                
        # Check if the premise already exists in facts
        if any(fact.entity == query.entity and 
                fact.property == premise.property and
                fact.negative == premise.negative
                for fact in facts + facts_to_add):
            premise_added = True
            continue
        
        # Try to find an intermediate rule to derive this premise
        for int_rule in rules:
            rule_is_general_entity = int_rule.consequence_entity in ["[ANY_PERSON]", "[ANY_THING]"]
            premise_is_general_entity = premise.entity in ["[ANY_PERSON]", "[ANY_THING]"]
            consequence_property_match = any(prop == premise.property for prop in int_rule.consequence_property)
            
            if ((int_rule.consequence_entity == premise.entity or rule_is_general_entity or premise_is_general_entity) and
                consequence_property_match and
                int_rule.consequence_negative == premise.negative):
                
                # Found a rule that can derive the premise
                one_int_rule_found = True
                int_needed_premises = []
                if int_rule.premises:
                    int_needed_premises = int_rule.premises
                else:
                    for prop in int_rule.premise_property:
                        int_needed_premises.append(Fact(
                            entity=premise.entity if int_rule.premise_entity in ["[ANY_PERSON]", "[ANY_THING]"] else int_rule.premise_entity,
                            property=prop,
                            negative=int_rule.premise_negative
                        ))
                
                all_premises_exist = True
                for int_premise in int_needed_premises:
                    if not any(fact.entity == query.entity and 
                                fact.property == int_premise.property and
                                fact.negative == int_premise.negative
                                for fact in facts + facts_to_add):
                        facts_to_add.append(Fact(
                            entity=query.entity,
                            property=int_premise.property,
                            negative=int_premise.negative
                        ))
                        all_premises_exist = False

                if all_premises_exist or facts_to_add:
                    premise_added = True
                    break
        # If no rule can derive the premise, add it directly
        if not premise_added:
            facts_to_add.append(premise)

    # If we've found a way to satisfy this rule using at least one intermediate rule, we're done
    if facts_to_add and one_int_rule_found:
        break
    else:
        facts_to_add = []

print("Facts to add:")
for fact in facts_to_add:
    print(fact)

Facts to add:
Bob is quiet.


In [10]:
theory_index = relevant_facts[0].theory_index if relevant_facts else None
print(theory_index)

3


In [11]:
dummy_fact = generate_dummy_fact(rules, facts, query, theory_index)

In [12]:
print("Dummy fact:", dummy_fact[0])

Dummy fact: Fiona is white.


In [13]:
non_relevant_facts = [fact for fact in facts if fact not in relevant_facts] + dummy_fact
print("Non-relevant facts:", non_relevant_facts)

Non-relevant facts: [<acdc.text_entailment.utils.Fact object at 0x0000023A5E1D9B70>, <acdc.text_entailment.utils.Fact object at 0x0000023A5E1D9C90>, <acdc.text_entailment.utils.Fact object at 0x0000023A5E1D9D50>, <acdc.text_entailment.utils.Fact object at 0x0000023A5E1D9DB0>, <acdc.text_entailment.utils.Fact object at 0x0000023A5E1D9E10>, <acdc.text_entailment.utils.Fact object at 0x0000023A5E1D9E70>, <acdc.text_entailment.utils.Fact object at 0x0000023A5E1D9ED0>, <acdc.text_entailment.utils.Fact object at 0x0000023A5E1D9F30>, <acdc.text_entailment.utils.Fact object at 0x0000023A5E1D9F90>, <acdc.text_entailment.utils.Fact object at 0x0000023A5E182860>]


In [14]:
print(reconstruct_theory(rules, non_relevant_facts))

Red things are big. Anne is big. Erin is green. Fiona is white. All red things are round. If something is green and cold then it is furry. Anne is green. Anne is cold. Erin is round. If something is rough then it is furry. If something is cold then it is furry. Anne is rough. Anne is furry. All furry things are round. Erin is furry. Erin is cold. If something is round then it is rough. All rough things are cold.


In [11]:
rule_number = 1
print(relevant_rules[rule_number].premise_entity)
print(relevant_rules[rule_number].premise_property)
print(relevant_rules[rule_number].premise_negative)
print(relevant_rules[rule_number].consequence_entity)
print(relevant_rules[rule_number].consequence_property)
print(relevant_rules[rule_number].consequence_negative)
print(relevant_rules[rule_number].format_type)
print(relevant_rules[rule_number].premises)
print(relevant_rules[rule_number].theory_index)

erin
['furry']
False
erin
['quiet']
False
if_then
[<acdc.text_entailment.utils.Fact object at 0x00000210A369AD70>]
3


In [None]:
def added_facts_satisfy_a_final_rule2(facts, final_rules, query):
    """
    Checks if the facts after facts_to_add was added directly satisfy any of the final rules,
    which would make the query immediately true without requiring an intermediate reasoning step/rule.
    
    Args:
        facts (list): List of Fact objects in the theory, including the newly added facts
        final_rules (list): List of Rule objects that directly conclude the query
        query (Fact): The query fact to be satisfied

    Returns:
        bool: True if any final rule is fully satisfied by facts, False otherwise
    """
    for rule in final_rules:
        print("Checking rule:", rule)
        all_premises_satisfied = True
        for premise in rule.premises:
            premise_satisfied = False
            for fact in facts:
                if (fact.entity == premise.entity or (premise.entity in ["[ANY_PERSON]", "[ANY_THING]"]) and fact.entity == query.entity) and \
                    fact.property == premise.property and \
                    fact.negative == premise.negative:
                    premise_satisfied = True
                    break
            if not premise_satisfied:
                all_premises_satisfied = False
                break
        
        if all_premises_satisfied:
            return True
    
    return False

In [30]:
def add_necessary_facts2(rules, facts, query, num_reasoning_steps=1):
    """
    Adds the necessary facts to make the query true based on the number of reasoning steps.
    
    Args:
        rules (list): List of Rule objects
        facts (list): List of Fact objects
        query (Query): The query to make true
        num_reasoning_steps (int): Number of reasoning steps (1 or 2)
        
    Returns:
        list: List of facts needed to add to make the query true
    """
    facts_to_add = []
    
    if num_reasoning_steps == 2:
        # First, find rules that directly infer the query
        final_rules = []
        for rule in rules:
            is_general_entity = rule.consequence_entity in ["[ANY_PERSON]", "[ANY_THING]"]
            consequence_property_match = any(prop == query.property for prop in rule.consequence_property)
            
            if ((rule.consequence_entity == query.entity or is_general_entity) and 
                consequence_property_match and
                rule.consequence_negative == query.negative):
                final_rules.append(rule)

        # For each final rule, find intermediate rules and necessary facts
        one_int_rule_found = False
        for final_rule in final_rules:
            print("Final rule:", final_rule)
            # Identify the premises we need to make true
            needed_premises = []
            for prop in final_rule.premise_property:
                needed_premises.append(Fact(
                    # questionable? should needed premises not always have query.entity?
                    entity=query.entity if final_rule.premise_entity in ["[ANY_PERSON]", "[ANY_THING]"] else final_rule.premise_entity,
                    property=prop,
                    negative=final_rule.premise_negative
                ))
            # print("Needed premises:") 
            for premise in needed_premises:
                print("Needed premise:", premise)
                premise_added = False
                        
                # Check if the premise already exists in facts
                if any(fact.entity == query.entity and 
                        fact.property == premise.property and
                        fact.negative == premise.negative
                        for fact in facts + facts_to_add):
                    premise_added = True
                    print("premise in theory already:", premise_added)
                    continue
                
                print("premise in theory already:", premise_added)
                
                # Try to find an intermediate rule to derive this premise
                for int_rule in rules:
                    rule_is_general_entity = int_rule.consequence_entity in ["[ANY_PERSON]", "[ANY_THING]"]
                    premise_is_general_entity = premise.entity in ["[ANY_PERSON]", "[ANY_THING]"]
                    consequence_property_match = any(prop == premise.property for prop in int_rule.consequence_property)
                    
                    if ((int_rule.consequence_entity == premise.entity or rule_is_general_entity or premise_is_general_entity) and
                        consequence_property_match and
                        int_rule.consequence_negative == premise.negative):
                        
                        # Found a rule that can derive the premise
                        one_int_rule_found = True
                        int_needed_premises = []
                        if int_rule.premises:
                            int_needed_premises = int_rule.premises
                        else:
                            for prop in int_rule.premise_property:
                                int_needed_premises.append(Fact(
                                    entity=premise.entity if int_rule.premise_entity in ["[ANY_PERSON]", "[ANY_THING]"] else int_rule.premise_entity,
                                    property=prop,
                                    negative=int_rule.premise_negative
                                ))
                        
                        all_premises_exist = True
                        for int_premise in int_needed_premises:
                            if not any(fact.entity == query.entity and 
                                        fact.property == int_premise.property and
                                        fact.negative == int_premise.negative
                                        for fact in facts + facts_to_add):
                                facts_to_add.append(Fact(
                                    entity=query.entity,
                                    property=int_premise.property,
                                    negative=int_premise.negative
                                ))
                                all_premises_exist = False

                        if all_premises_exist or facts_to_add:
                            premise_added = True
                            break
                # If no rule can derive the premise, add it directly
                if not premise_added:
                    facts_to_add.append(premise)

            # If we've found a way to satisfy this rule using at least one intermediate rule, we're done
            if facts_to_add and one_int_rule_found:
                break
            else:
                facts_to_add = []

    
    return facts_to_add

In [9]:
necessary_facts = add_necessary_facts(relevant_rules, facts, query, num_reasoning_steps=2)
print(f"Necessary facts: {necessary_facts}")

Necessary facts: [<acdc.text_entailment.utils.Fact object at 0x0000025D92AFE260>]


In [10]:
for fact in necessary_facts:
    print(fact)

Bob is quiet.


In [14]:
print(necessary_facts[0].entity)
print(necessary_facts[0].property)
print(necessary_facts[0].negative)
print(necessary_facts[0].is_query)
print(necessary_facts[0].theory_index)

Charlie
young
False
False
None


In [13]:
new_facts = facts + necessary_facts
for fact in new_facts:
    print(fact)

Anne is big.
Erin is green.
Anne is red.
Anne is green.
Anne is cold.
Erin is round.
Anne is rough.
Anne is furry.
Erin is furry.
Erin is cold.
Erin is red.


In [14]:
new_theory = reconstruct_theory(rules, new_facts)
print(f"New theory: {new_theory}")

New theory: Red things are big. Anne is big. Erin is green. Anne is red. All red things are round. If something is green and something is cold then it is furry. Anne is green. Anne is cold. Erin is round. If something is rough then something is furry. If something is cold then something is furry. Anne is rough. Anne is furry. All furry things are round. Erin is furry. Erin is cold. If something is round then something is rough. All rough things are cold. Erin is red.


In [14]:
rule_number = 1
print(rules[rule_number].premise_entity)
print(rules[rule_number].premise_property)
print(rules[rule_number].premise_negative)
print(rules[rule_number].consequence_entity)
print(rules[rule_number].consequence_property)
print(rules[rule_number].consequence_negative)
print(rules[rule_number].format_type)
print(rules[rule_number].premises)
print(rules[rule_number].theory_index)

fiona
['blue', 'nice']
False
fiona
['red']
False
if_then
[<acdc.text_entailment.utils.Fact object at 0x0000026F3FBD4400>, <acdc.text_entailment.utils.Fact object at 0x0000026F3FBD7010>]
2


In [15]:
print(rules[rule_number].premises[0].entity)

erin


In [11]:
for rule in rules:
    print(rule)

Red things are big.
All red things are round.
If something is green and cold then it is furry.
If something is rough then it is furry.
If something is cold then it is furry.
All furry things are round.
If something is round then it is rough.
All rough things are cold.


In [12]:
print(facts[5].entity)
print(facts[5].property)
print(facts[5].negative)
print(facts[5].is_query)
print(facts[5].theory_index)

Erin
round
False
False
8


In [27]:
print(relevant_facts[0].entity)
print(relevant_facts[0].property)
print(relevant_facts[0].negative)
print(relevant_facts[0].is_query)
print(relevant_facts[0].theory_index)


Erin
kind
False
False
2


In [20]:
print(query.entity)
print(query.property)
print(query.negative)
print(query.is_query)

Charlie
big
True
True


In [36]:
inverted_query = query.invert()
print(inverted_query)

Erin is kind.


In [30]:
print(inverted_query.entity)
print(inverted_query.property)
print(inverted_query.negative)
print(inverted_query.is_query)

Erin
kind
False
True


In [37]:
relevant_facts[0] == inverted_query

True

In [38]:
non_relevant_facts = [fact for fact in facts if fact != inverted_query]

In [39]:
for fact in non_relevant_facts:
    print(fact)

Erin is young.
Charlie is blue.
Erin is quiet.
Charlie is young.
Charlie is smart.
Charlie is kind.
Erin is cold.
Erin is smart.
Erin is blue.
Charlie is quiet.
Erin is round.
Charlie is round.
Charlie is cold.


In [32]:
for fact in facts:
    print(fact)

Charlie is quiet.
Gary is quiet.
Bob is furry.


# text_entailment.utils test corruption

In [7]:
selected_examples = dataset['test'].select(range(140))

In [8]:
corrupted_examples = generate_corrupt_examples(selected_examples)

In [9]:
for i, (example, corrupted_example) in enumerate(zip(selected_examples, corrupted_examples)):
    print(f"Example {i+1}:")
    print(example['depth'])
    print(example['proof_strategy'])
    print(example['input'])
    print(example['label'])
    print(corrupted_example['input'])
    print(corrupted_example['label'])
    print("")

Example 1:
0
proof
If someone is big and round then they are white. All nice, cold people are rough. If Charlie is rough and Charlie is blue then Charlie is nice. Charlie is cold. If Dave is nice and Dave is cold then Dave is round. Bob is nice. If Dave is big then Dave is white. Charlie is blue. Dave is round. Bob is round. Nice people are blue. Dave is cold. Dave is white. Charlie is nice. Charlie is white. Charlie is round. Charlie is big. Bob is rough. Bob is blue. Bob is big. Bob is cold. All cold people are blue. If Dave is rough and Dave is nice then Dave is big.[SEP]Dave is cold.
True
If someone is big and round then they are white. All nice, cold people are rough. If charlie is rough and charlie is blue then charlie is nice. Charlie is cold. If dave is nice and dave is cold then dave is round. Bob is nice. If dave is big then dave is white. Charlie is blue. Dave is round. Bob is round. Nice people are blue. Anne is green. Dave is white. Charlie is nice. Charlie is white. Charl

In [10]:
corrupted_dataset = generate_corrupt_examples(dataset['test'])

In [13]:
from transformers import AutoTokenizer, BertForSequenceClassification
import evaluate
from huggingface_hub import login
login(token="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", add_to_git_credential=True)

In [15]:
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained(model_name)

# Define the metric
metric = evaluate.load("accuracy")

# Check if CUDA is available and move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Function to compute predictions
def compute_predictions(batch):
    # Move inputs to the same device as the model
    # The line below now includes padding='max_length' and truncation=True
    # The max_length parameter is added to control the sequence length
    inputs = tokenizer(batch['input'], padding='max_length', truncation=True, max_length=512, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    # Return a dictionary containing the logits and move them back to CPU
    return {"logits": outputs.logits.cpu().numpy()}

# Compute predictions and evaluate
predictions = corrupted_dataset.map(compute_predictions, batched=True, batch_size=16)

predicted_labels = np.argmax(predictions["logits"], axis=-1)
true_labels = corrupted_dataset["test"]["label"]
true_labels = np.argmax(true_labels, axis=1)  # Convert to class indices

accuracy_result = metric.compute(predictions=predicted_labels, references=true_labels)

print(f"Accuracy on the test set: {accuracy_result['accuracy']}")

Map:   1%|▏         | 64/4618 [01:17<1:31:48,  1.21s/ examples]


KeyboardInterrupt: 