In [14]:
import discopy
import lambeq
from discopy.grammar.pregroup import Ty, Word, Cup, Diagram
from discopy.quantum import Circuit, sqrt, Ket, H, Rx, CX, SWAP,Bra, Controlled, Rz
from pytket.extensions.qiskit import tk_to_qiskit, AerBackend
import numpy as np
from discopy import CircuitFunctor, qubit
from collections import defaultdict
from nltk.tokenize import word_tokenize
from random import shuffle

import random

In [2]:
parser = lambeq.BobcatParser()

In [3]:
corpus = [('Alice loves Bob', 2), ('Alice loves Bob', 0)]

In [4]:
params = defaultdict(lambda:0.1)
def initialize_params():
    for data_point in corpus:
        sen = data_point[0]
        for word in word_tokenize(sen):
            params[word] = random.random()

initialize_params()
params

defaultdict(<function __main__.<lambda>()>,
            {'Alice': 0.8546146963243221,
             'loves': 0.5293797319917231,
             'Bob': 0.40790552080013165})

In [5]:
params

defaultdict(<function __main__.<lambda>()>,
            {'Alice': 0.8546146963243221,
             'loves': 0.5293797319917231,
             'Bob': 0.40790552080013165})

In [16]:
def ansatz_cod_len_1(phase):
    return Ket(0)>>Rz(phase)>>Rx(phase)>>Rz(phase)
def ansatz_cod_len_1_masked():
    return Ket(0)
def ansatz_cod_len_2(phase):
    return Ket(0)>>Rx(phase)
def ansatz_cod_len_2_masked():

    return Ket(0)
def ansatz_cod_len_3(phase):
    return Ket(0,0,0) >> H @ H @ H >> Controlled(Rz(phase)) @ Controlled(Rz(phase))
def ansatz_cod_len_3_masked():
    return Ket(0,0) >> H @ Circuit.id(1) >> CX
def ansatz_cod_len_4(phase):
    return Ket(0, 0, 0) >> Rx(phase) @ Circuit.id(1) @ Circuit.id(1) >> H @ Circuit.id(1) @ Circuit.id(1) >> CX @ Circuit.id(1) >> Circuit.id(1) @ CX


In [11]:
counter = 10000
m = 5000

In [17]:
s, n = Ty('s'), Ty('n')
def cnot_ar(box):
    global counter

    cod = len(box.cod)
    box_name = str(box)
    counter += 1
    if cod == 1 and counter != m:
        return ansatz_cod_len_1(params[box_name])
    elif cod == 1 and counter == m:
        return ansatz_cod_len_1_masked()

    elif cod == 2 and counter != m:
        return ansatz_cod_len_2(params[box_name])
    elif cod == 2 and counter == m:
        return ansatz_cod_len_2_masked()

    elif cod == 3 and counter != m:
        return ansatz_cod_len_3(params[box_name])
    elif cod == 3 and counter == m:
        return ansatz_cod_len_3_masked()

    elif cod == 4 and counter != m:
        return ansatz_cod_len_4(params[box_name])

def generate_functor():
    Func = CircuitFunctor(
    ob={s: qubit ** 1, n: qubit ** 1},
    ar=cnot_ar)
    return Func


In [18]:
circ = generate_functor()(parser.sentence2diagram('Alice loves Bob'))
circ_eval = Circuit.eval(
            circ,
            backend=AerBackend(),
            n_shots=1024,
            seed=1,
            compilation=AerBackend().default_compilation_pass(2))

AxiomError: Ket(0, 0, 0) >> H @ Id(2) >> Id(1) @ H @ Id(1) >> Id(2) @ H does not compose with CRz(0.529) @ Id(2) >> Id(2) @ CRz(0.529): cod=qubit @ qubit @ qubit, dom=qubit @ qubit @ qubit @ qubit.

In [137]:
data = []
def build_data():
    """
    We build a data list in form of: [[sen, masked sen, eval of sen with random params], ...]
    :return:
    """
    for tuple in corpus:
        data_point = []

        #Append the tuple
        data_point.append(tuple)

        #Get the evaluation of the origin sentence
        circ = generate_functor()(parser.sentence2diagram(tuple[0]))
        circ_eval = Circuit.eval(
            circ,
            backend=AerBackend(),
            n_shots=1024,
            seed=1,
            compilation=AerBackend().default_compilation_pass(2))
        data_point.append(np.abs(circ_eval.array))
        data.append(data_point)
build_data()
data

[[('Alice loves Bob', 2), 0.2460937500000001],
 [('Alice loves Bob', 0), 0.11718750000000006]]

In [138]:
test_dict = defaultdict()
def build_test_dict():
    for data_point in corpus:
        sen = data_point[0]
        if sen not in test_dict.keys():
            dia = generate_functor()(parser.sentence2diagram(sen))
            circ_eval = Circuit.eval(
                dia,
                backend=AerBackend(),
                n_shots=1024,
                seed=1,
                compilation=AerBackend().default_compilation_pass(2))
            test_dict[sen] = np.abs(circ_eval.array)
build_test_dict()
test_dict

defaultdict(None, {'Alice loves Bob': 0.11718750000000006})

In [139]:
'''
from pytket.extensions.qiskit import tk_to_qiskit, AerBackend
tk_circ = circuit.to_tk()
tk_to_qiskit(tk_circ).draw()
'''

'\nfrom pytket.extensions.qiskit import tk_to_qiskit, AerBackend\ntk_circ = circuit.to_tk()\ntk_to_qiskit(tk_circ).draw()\n'

In [140]:
epochs = 1500
lr = 0.05

In [141]:
def loss(y_true, y_pred):
    return (y_true - y_pred)**2

In [142]:
def update(loss, updating_params):
    for param in updating_params:
        params[param] = params[param.strip()] + lr * loss

In [143]:
def train():
    for _ in range(epochs):
        for data_point in data:
            sen = data_point[0][0]

            global m
            m=data_point[0][1]

            y_true = data_point[1]

            global counter
            counter = -1
            circ = generate_functor()(parser.sentence2diagram(sen))
            circ_eval = Circuit.eval(
                circ,
                backend=AerBackend(),
                n_shots=1024,
                seed=1,
                compilation=AerBackend().default_compilation_pass(2))
            qugit = np.abs(circ_eval.array)
            ls = loss(y_true, qugit)
            updating_params = []
            for token in word_tokenize(sen):
                if word_tokenize(sen).index(token) != m:
                    updating_params.append(str(token))
            update(ls, updating_params)
            print(ls)


In [144]:
train()

0.0
0.016616821289062514
0.0
0.016616821289062514
0.0
0.016616821289062514
0.0
0.016616821289062514
0.0
0.015625000000000014
0.0
0.015625000000000014
0.0
0.015625000000000014
0.0
0.014663696289062514
1.52587890625e-05
0.013732910156250014
1.52587890625e-05
0.013732910156250014
1.52587890625e-05
0.012832641601562512
1.52587890625e-05
0.011123657226562512
1.52587890625e-05
0.011123657226562512
1.52587890625e-05
0.01031494140625001
6.103515625e-05
0.01031494140625001
6.103515625e-05
0.01031494140625001
6.103515625e-05
0.01031494140625001
6.103515625e-05
0.01031494140625001
6.103515625e-05
0.01031494140625001
6.103515625e-05
0.009536743164062505
0.0001373291015625
0.009536743164062505
0.0001373291015625
0.009536743164062505
0.0001373291015625
0.009536743164062505
0.000244140625
0.009536743164062505
0.000244140625
0.009536743164062505
0.000244140625
0.009536743164062505
0.00054931640625
0.009536743164062505
0.00054931640625
0.009536743164062505
0.00054931640625
0.009536743164062505
0.000549

KeyboardInterrupt: 

In [145]:
params

{'Alice': 0.11031005859375043,
 'loves': 0.14435241699218765,
 'Bob': 0.13404235839843742}

In [161]:
data_test=[('Alice loves Bob', 2)]

In [165]:
def test(data_test):
    for data_point in data_test:
       sen = data_point[0]
       test_sen = generate_functor()(parser.sentence2diagram(sen))
       test_sen_eval = Circuit.eval(test_sen,
                backend=AerBackend(),
                n_shots=1024,
                seed=1,
                compilation=AerBackend().default_compilation_pass(2))
       print(np.abs(test_sen_eval.array) - test_dict[sen])
test(data_test)

0.2851562500000001


In [162]:
generate_functor()(parser.sentence2diagram('Romeo kills kills')).draw()
circ = generate_functor()(parser.sentence2diagram('Romeo kills kills'))
np.abs(Circuit.eval(circ,
                backend=AerBackend(),
                n_shots=1024,
                seed=1,
                compilation=AerBackend().default_compilation_pass(2)).array)

0.2851562500000001


In [264]:
params['Romeo'] = 0.15

In [26]:
generate_functor()(parser.sentence2diagram(' loves loves Bob ')).draw()
circ = generate_functor()(parser.sentence2diagram(' loves loves Bob '))
np.abs(Circuit.eval(circ,
                backend=AerBackend(),
                n_shots=1024,
                seed=1,
                compilation=AerBackend().default_compilation_pass(2)).array)

AxiomError: Ket(0, 0, 0) >> Rx(0.944) @ Id(2) >> H @ Id(2) >> CX @ Id(1) >> Id(1) @ CX does not compose with Id(2) @ Ket(0, 0) >> Id(2) @ H @ Id(1) >> Id(3) @ Rx(0.944) >> Id(2) @ CX: cod=qubit @ qubit @ qubit, dom=qubit @ qubit.

KeyError: 'seal'