In [2]:
from lambeq import BobcatParser, AtomicType, RemoveCupsRewriter, UnifyCodomainRewriter, Rewriter, IQPAnsatz
from lambeq.backend.grammar import Spider
from discopro.anaphora import connect_anaphora_on_top
from discopro.grammar import tensor
import pandas as pd
import numpy as np
import os, sys, datetime, pickle
from tqdm import tqdm

remove_cups = RemoveCupsRewriter()

parser = BobcatParser()
rewriter = Rewriter(['curry'])
# rewriter = Rewriter(['auxiliary',
#                      'connector',
#                      'coordination',
#                      'determiner',
#                      'object_rel_pronoun',
#                      'subject_rel_pronoun',
#                      'postadverb',
#                      'preadverb',
#                      'prepositional_phrase'])

N = AtomicType.NOUN
S = AtomicType.SENTENCE
P = AtomicType.PREPOSITIONAL_PHRASE

ansatz = IQPAnsatz({N: 1, S: 1, P:1}, n_layers=1, n_single_qubit_params=3)

In [3]:
def sent2dig(sentence, pro1, ref, pro2=''):
    diagram = parser.sentence2diagram(sentence)
    #diagram = diagram >> Spider(S, 2, 1)
    pro_box_idx = next(i for i, box in enumerate(diagram.boxes) if (box.name.casefold() == pro1.casefold() or box.name.casefold() == pro2.casefold()))
    ref_box_idx = next(i for i, box in enumerate(diagram.boxes) if box.name.casefold() == ref.casefold())
    diagram = connect_anaphora_on_top(diagram, pro_box_idx, ref_box_idx)
    diagram = rewriter(remove_cups(diagram)).normal_form()
    return diagram

In [4]:
def gen_labels(path: str, verbose=False, frac=1, save=False):
    df = pd.read_pickle(path)
    df = df.sample(frac=frac)
    
    if not os.path.exists(os.getcwd()+'/err_logs'):
        os.mkdir(os.getcwd()+'/err_logs')
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
    f = open("err_logs/log_"+path.split('/')[-1].split('.')[-2]+'_'+timestamp+".txt", 'w')
    
    circuits, labels, diagrams, sentences = [],[],[], []
    for i, row in tqdm(df.iterrows(), total=len(df), position=0, leave=True):
        sentence, pro1, pro2, right_ref, wrong_ref = row[['Sentence', 'Pronoun 1', 'Pronoun 2', 'Right Referent', 'Wrong Referent']]

        try:
            diagram_right = sent2dig(sentence.strip(), pro1.strip(), right_ref.strip(), pro2.strip())
            diagram_wrong = sent2dig(sentence.strip(), pro1.strip(), wrong_ref.strip(), pro2.strip())
            diagrams.append(diagram_right)
            circuits.append(ansatz(diagram_right))
            labels.append([0,1])
            sentences.append(sentence)
            diagrams.append(diagram_wrong)
            circuits.append(ansatz(diagram_wrong))
            labels.append([1,0])
            sentences.append(sentence)
        except Exception as err:
            tqdm.write(f"Error: {err}".strip(), file=f)
            if verbose:
                tqdm.write(f"Error: {err}".strip(), file=sys.stderr)
    f.close()
    if save:
        f = open('wino/data/train_data_'+datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")+'.pkl', 'wb')
        pickle.dump(list(zip(circuits, labels, diagrams, sentences)), f)
        f.close()
    else:
        return circuits, labels, diagrams

In [67]:
gen_labels(path="wino/data/train_sentences.pkl", save=True)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1584/1584 [12:35<00:00,  2.10it/s]


In [10]:
data_test = pd.read_pickle("wino/data/test_data_2024-09-30_151644.pkl")

In [12]:
from random import shuffle 
shuffle(data_test)

In [13]:
from math import floor
test_data = data_test[:floor(len(data_test)*0.4)]
val_data = data_test[floor(len(data_test)*0.4):floor(len(data_test)*0.8)]
unseen_data = data_test[floor(len(data_test)*0.8):]

In [23]:
f = open('wino/data/data_final/unseen_data.pkl', 'wb')
pickle.dump(unseen_data, f)
f.close()

In [3]:
train = pd.read_pickle("wino/data/data_final/train_data.pkl")
test = pd.read_pickle("wino/data/data_final/test_data.pkl")
val = pd.read_pickle("wino/data/data_final/val_data.pkl")
unseen = pd.read_pickle("wino/data/data_final/unseen_data.pkl")

In [5]:
from lambeq import NumpyModel
model = NumpyModel.from_checkpoint('runs/Oct01_16-32-34_smew-l.cs.ucl.ac.uk/best_model.lt')
model.initialise_weights()

In [40]:
df = pd.read_pickle('wino/data/train_sentences.pkl')

In [50]:
match = df[df['Sentence'] == unseen[101][3]] 

In [75]:
df = pd.read_pickle('wino/data/train_sentences.pkl')
results = []
for tup in tqdm(unseen):
    try:
        pred = model.get_diagram_output([tup[0]])[0]
        match = df[df['Sentence'] == tup[3]]
        ans = [match['Pronoun 1'].values[0]]
        if match['Pronoun 2'].values[0]:
            ans.append(match['Pronoun 2'].values[0])
        if tup[1] == [0,1]: 
            ans.append(match['Right Referent'].values[0])
        else:
            ans.append(match['Wrong Referent'].values[0])
        results.append((pred, tup[1], tup[3], ans))
    except:
        pass

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 510/510 [00:11<00:00, 43.15it/s]


In [88]:
pred, ans, sent, con = zip(*results)
pred = np.array(pred)
ans = np.array(ans)

In [90]:
np.mean(abs(pred - ans)**2)

0.33051786460141463

In [91]:
from sklearn.metrics import r2_score
coeff = r2_score(ans, pred)

In [92]:
coeff

-0.3220714584056572

In [93]:
from sklearn.metrics import mean_squared_error
mse = mean_squared_error(ans, pred)

In [94]:
mse

0.3305178646014143