In [3]:
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import datetime
import os
import sys

In [5]:
from discopro.grammar import tensor
from discopro.anaphora import connect_anaphora_on_top
from lambeq import BobcatParser, NumpyModel, AtomicType, Rewriter, Dataset, QuantumTrainer, SPSAOptimizer , AtomicType, IQPAnsatz, RemoveCupsRewriter, UnifyCodomainRewriter, BinaryCrossEntropyLoss
from lambeq.backend.grammar import Spider, Ty
from lambeq.backend.quantum import Box, qubit, SelfConjugate, Ry, Diagram
from contextuality.model import Model, Scenario, CyclicScenario

In [6]:
remove_cups = RemoveCupsRewriter()

parser = BobcatParser()
rewriter = Rewriter(['auxiliary',
                     'connector',
                     'coordination',
                     'determiner',
                     'object_rel_pronoun',
                     'subject_rel_pronoun',
                     'postadverb',
                     'preadverb',
                     'prepositional_phrase'])

N = AtomicType.NOUN
S = AtomicType.SENTENCE
P = AtomicType.PREPOSITIONAL_PHRASE 

ansatz = IQPAnsatz({N: 1, S: 1, P:1}, n_layers=1, n_single_qubit_params=3) 

In [12]:
def sent2dig(sentence1: str, sentence2: str, pro: str, ref: str, mode='none'):
    diagram1 = parser.sentence2diagram(sentence1)
    diagram2 = parser.sentence2diagram(sentence2)
    diagram = tensor(diagram1,diagram2)
    
    if mode == 'spider':
        diagram = diagram >> Spider(S, 2, 1)
    elif mode == 'box':
        merger = UnifyCodomainRewriter(Ty('s'))
        diagram = merger(diagram)
        
    pro_box_idx = next(i for i, box in enumerate(diagram.boxes) if box.name.casefold() == pro.casefold())
    ref_box_idx = next(i for i, box in enumerate(diagram.boxes) if box.name.casefold() == ref.casefold())
    final_diagram = connect_anaphora_on_top(diagram, pro_box_idx, ref_box_idx)
    rewritten_diagram = rewriter(remove_cups(final_diagram)).normal_form()
    return rewritten_diagram

In [19]:
def gen_labels(path: str, verbose=False):
    df = pd.read_csv(path, index_col=0)
    df = df[:10]
    
    if not os.path.exists(os.getcwd()+'/err_logs'):
        os.mkdir(os.getcwd()+'/err_logs')
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
    f = open("err_logs/log_"+path.split('/')[-1].split('.')[-2]+'_'+timestamp+".txt",'w')
    
    circuits, labels, diagrams = [],[],[]
    #selected_cols = [random.choice(['referent', 'wrong_referent']) for i in range(len(df))]
    for i, row in tqdm(df.iterrows(), total=len(df), position=0, leave=True):
        #ref = row[selected_cols[i]]
        # label = [0,1] if selected_cols[i] == 'referent' else [1,0]
        label = [[0.25, 0.25],[0.25, 0.25]]
        sent1, sent2, pro, ref = row[['sentence1', 'sentence2', 'pronoun', 'referent']]

        try:
            diagram = sent2dig(sent1.strip(), sent2.strip(), pro.strip(), ref.strip())
            diagrams.append(diagram)
            circ = ansatz(diagram)
            circuits.append(circ)
            labels.append(label)
        except Exception as err:
            tqdm.write(f"Error: {err}".strip(), file=f)
            if verbose:
                tqdm.write(f"Error: {err}".strip(), file=sys.stderr)
    f.close()
    
    return circuits, labels, diagrams

In [20]:
train_circuits, train_labels, train_diagrams = gen_labels('dataset/original_data/train.csv')
val_circuits, val_labels, val_diagrams = gen_labels('dataset/original_data/val.csv')
test_circuits, test_labels, test_diagrams = gen_labels('dataset/original_data/test.csv')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.51it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.64it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:03<00:00,  2.55it/s]


In [76]:
model = NumpyModel.from_diagrams(train_circuits + val_circuits + test_circuits, use_jit=False)
loss = BinaryCrossEntropyLoss(use_jax=True)
acc = lambda y_hat, y: np.sqrt(np.mean((y_hat-y)**2)/2)

In [28]:
SEED = random.randint(0, 1000)
BATCH_SIZE = 20
EPOCHS = 100

In [29]:
train_dataset = Dataset(train_circuits, train_labels, batch_size=BATCH_SIZE)
val_dataset = Dataset(val_circuits, val_labels, shuffle=True)
test_dataset = Dataset(test_circuits, test_labels)

In [30]:
trainer = QuantumTrainer(model,
                         loss_function=loss,
                         optimizer=SPSAOptimizer,
                         epochs=EPOCHS,
                         optim_hyperparams={'a': 0.1, 'c': 0.06, 'A': 0.01 * EPOCHS},
                         evaluate_functions={"err": acc},
                         evaluate_on_train=True,
                         verbose='text', 
                         seed=SEED)

In [31]:
print("Learning parameters: "+datetime.datetime.now().strftime("%Y-%m-%d_%H_%M_%S"))
trainer.fit(train_dataset, val_dataset, eval_interval=1, log_interval=1)
test_acc = acc(model(test_dataset.data), test_dataset.targets)
print('Test accuracy:', test_acc)

2024-07-03_13_14_30


Epoch 1:    train/loss: 1.1192   valid/loss: 1.1773   train/err: 0.1678   valid/err: 0.1566
Epoch 2:    train/loss: 0.9870   valid/loss: 1.4101   train/err: 0.1674   valid/err: 0.1707
Epoch 3:    train/loss: 0.9530   valid/loss: 1.2091   train/err: 0.1739   valid/err: 0.1586
Epoch 4:    train/loss: 0.9837   valid/loss: 1.4382   train/err: 0.1555   valid/err: 0.1697
Epoch 5:    train/loss: 0.9521   valid/loss: 1.1242   train/err: 0.1369   valid/err: 0.1350
Epoch 6:    train/loss: 0.9174   valid/loss: 1.1377   train/err: 0.1486   valid/err: 0.1616
Epoch 7:    train/loss: 0.8519   valid/loss: 1.1717   train/err: 0.1406   valid/err: 0.1672
Epoch 8:    train/loss: 0.9551   valid/loss: 1.1002   train/err: 0.1505   valid/err: 0.1495
Epoch 9:    train/loss: 1.0021   valid/loss: 1.4796   train/err: 0.1675   valid/err: 0.1816
Epoch 10:   train/loss: 1.0625   valid/loss: 1.4797   train/err: 0.1699   valid/err: 0.1835
Epoch 11:   train/loss: 1.1152   valid/loss: 1.2572   train/err: 0.1934   valid/

Test accuracy: 0.14306237778219807


Epoch 100:  train/loss: 0.7508   valid/loss: 1.0864   train/err: 0.1533   valid/err: 0.1547

Training completed!


In [113]:
def train(trainer: QuantumTrainer, EPOCH_ARR: [int], BATCH_ARR: [int], SEED_N: int, train_dataset: Dataset, val_dataset: Dataset, test_dataset: Dataset):
    SEEDS = random.sample(range(1000), SEED_N)
    trainer.verbose = 'supress'
    model = trainer.model
    
    print("%0s %23s %7s %7s  %12s" % ("Time","Epochs","Batch","Seed","Accuracy"))
    for EPOCHS in EPOCH_ARR:
        for BATCH_SIZE in BATCH_ARR:
            for SEED in SEEDS:
                trainer.epochs = EPOCHS
                trainer.optim_hyperparams = {'a': 0.1, 'c': 0.06, 'A': 0.01 * EPOCHS}
                train_dataset.batch_size = BATCH_SIZE
                time = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
                print("%0s %8s %7s %7s" % (time, EPOCHS, BATCH_SIZE, SEED), end='')
                trainer.fit(train_dataset, val_dataset, eval_interval=1, log_interval=1)
                test_acc = acc(model(test_dataset.data), test_dataset.targets)
                print("%14s" % (round(test_acc, 6)))

In [146]:
class data_scenario:
    def __init__(self, scenario, data, model):
        self.scenario = scenario
        self.data = data
        self.model = model
        try:
            self.model.initialise_weights()
        except Exception as err:
            print("Model is not valid or none provided: ", err)
        self.emp_models = []

    def gen_emp(self, s1, s2, p, n):
        try:
            diag = ansatz(sent2dig(s1, s2, p, n))
        except Exception as err:
            tqdm.write(f"Error: {err}".strip(), file=sys.stderr)
            return None
        diagXX = diag.apply_gate(Ry(0),0).apply_gate(Ry(np.pi/8),1)
        diagXZ = diag.apply_gate(Ry(0),0).apply_gate(Ry(3*np.pi/8),1)
        diagZX = diag.apply_gate(Ry(np.pi/4),0).apply_gate(Ry(np.pi/8),1)
        diagZZ = diag.apply_gate(Ry(np.pi/4),0).apply_gate(Ry(3*np.pi/8),1)
        dist = [self.model.get_diagram_output([diagXX])[0].flatten()]
        dist.append(self.model.get_diagram_output([diagXZ])[0].flatten())
        dist.append(self.model.get_diagram_output([diagZX])[0].flatten())
        dist.append(self.model.get_diagram_output([diagZZ])[0].flatten())
        #dist = self.model.get_diagram_output([diagXX, diagXZ, diagZX, diagZZ])
        emp_model = Model(self.scenario, dist)
        return emp_model

    def get_models(self):
        for _, row in tqdm(self.data.iterrows(), total=len(self.data)):
            emp_model = self.gen_emp(*row[['sentence1','sentence2','pronoun','referent']])
            if emp_model:
                self.emp_models.append(emp_model)
        return self.emp_models

    def plot_cnxt(self):
        sf, di = [], []
        for mod in self.emp_models:
            sf.append(mod.signalling_fraction())
            di.append(mod.CbD_direct_influence())
        plt.scatter(sf, di)
        plt.axhline(y=2, color='g', linestyle='-')
        plt.axvline(x=1/6, color='g', linestyle='-')
        plt.show()

In [147]:
tmp = data_scenario(scenario=CyclicScenario(['a','b','A','B'],2), 
                    data=pd.read_csv('dataset/original_data/test.csv'),
                    model=NumpyModel.from_checkpoint('runs/Jul02_12-46-52_smew.cs.ucl.ac.uk/best_model.lt'))

In [None]:
tmp.get_models()

In [None]:
tmp.plot_cnxt()