In [None]:
!pip install --upgrade pip
!pip install pytket==2.0.1 pytket-qiskit lambeq==0.4.3 discopy>=1.1.0
# !pip install qiskit-aer-gpu-cu11
!pip install qiskit-aer

In [None]:
import os
import re
import sys
import json
import pickle
import contextlib
import numpy as np
from tqdm import tqdm
from typing import List, Tuple, Dict

# from qiskit.providers.aer import AerSimulator
from qiskit_aer import AerSimulator
from pytket.extensions.qiskit import AerBackend


from lambeq.backend.grammar import Diagram, Id
from lambeq import ( BobcatParser, AtomicType, IQPAnsatz,
                RemoveCupsRewriter, SimpleRewriteRule, Rewriter, UnifyCodomainRewriter,
                QuantumTrainer, TketModel, BinaryCrossEntropyLoss, SPSAOptimizer,
                Dataset)

from lambeq import ( BobcatParser, AtomicType, IQPAnsatz,
                RemoveCupsRewriter, SimpleRewriteRule, Rewriter, UnifyCodomainRewriter )

In [None]:
_CLEAN_RE = re.compile(r"[^\w\s']")
parser    = BobcatParser(verbose='text', device=0)

ansatz    = IQPAnsatz(
    {AtomicType.SENTENCE: 1,
     AtomicType.NOUN:     1,
     AtomicType.PREPOSITIONAL_PHRASE: 0},
    n_layers=2, n_single_qubit_params=3
)


def create_rewriter():
    # Rule to delete conjunction boxes (“and”, “but”) # just the wire, no box
    conj_rule = SimpleRewriteRule(cod=AtomicType.CONJUNCTION, template=Id(AtomicType.CONJUNCTION))

    # Rule to delete punctuation boxes (commas, quotes, dashes)
    punc_rule = SimpleRewriteRule(cod=AtomicType.PUNCTUATION, template=Id(AtomicType.PUNCTUATION))

    # remove_pp2 = SimpleRewriteRule(cod=AtomicType.PREPOSITIONAL_PHRASE, template=Id(AtomicType.SENTENCE))
    remove_pp = SimpleRewriteRule(cod=AtomicType.PREPOSITIONAL_PHRASE, template=Id(AtomicType.NOUN))

    rewriter = Rewriter(
        [
            'coordination', 'determiner',
            'postadverb', 'preadverb',
            'connector', 'auxiliary',
            'prepositional_phrase',
            'subject_rel_pronoun',
            'object_rel_pronoun',
        ]
    )
    rewriter.add_rules(remove_pp, punc_rule, conj_rule)
    return rewriter

rewriter = create_rewriter()
remove_cups = RemoveCupsRewriter()
unify = UnifyCodomainRewriter(output_type=AtomicType.SENTENCE)

In [None]:
def brbrpatapim(filename, left, right):
    dataset    = load_cnn_extractive(filename)
    dst        = dataset[left:right]
    dste, errs = preprocess_and_encode(dst)
    print("\nDone!")

    return dste, errs

train, errs_train = brbrpatapim("datasets/cnn_dailymail/train.jsonl", 35, 40)

In [None]:
#################################################
#### Mainly to test if GPU works with:       ####
#### sim.run(tk_circ, n_shots=1) and         ####
#### backend.run_circuit(tk_circ, n_shots=1) ####
#################################################

# def brbrpatapim(filename, left, right):
#     dataset    = load_cnn_extractive(filename)
#     dst        = dataset[left:right]
#     dste, errs = preprocess_and_encode(dst)
#     print("Done!")
#     return dste, errs

# a, b = brbrpatapim("datasets/cnn_dailymail/test.jsonl", 10, 11)
# for i, circ in enumerate(a[:]):
#   d = circ['circuits']
#   l = len()
#   print(f"{i} circuit: {l}")
# # with open("saves/test_debug_10_25.pkl", 'wb') as f:
# #         pickle.dump({'dste': a}, f)

def load_cnn_extractive(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            # if (i == 300):
            #     break
            try:
                record = json.loads(line)
            except json.JSONDecodeError:
                continue

            text = record.get("text")
            summary = record.get("summary")
            ts = record.get("text_sentences")
            ss = record.get("summary_sentences")
            labels = record.get("labels")
            if text is not None or summary is not None:
                data.append({'text': text, 'summary': summary, 'text_sentences': ts, 'summary_sentences': ss, 'labels': labels})
            else:
                print(f"{i}. NULL INSTANCE\ntext: {text}\n{summary}\n")
    return data


def load_qe():
    with open("saves/qe.pkl", 'rb') as f:
      loaded_data = pickle.load(f)

    dste = loaded_data['dste']
    length = 0
    for d in dste:
      length += len(d['circuits'])

    print(type(dste[0]['circuits'][0]))
    print(f"Successfully Loaded {length} circuits!")
    return dste

es = load_qe()

result = simulator.run(es[0]['circuits'][0]).result()
print(result.metadata.get("device", "unknown"))


In [None]:
#############################################
# Step 2. Preprocessing and lambeq Pipeline
#############################################

def remove_by_idx(ls, remove):
    if not remove:
        return ls
    remove.sort(reverse=True)
    for n in remove:
        ls.pop(n)
    return ls


def sentence_simplify(sentences):
    return [_CLEAN_RE.sub("", s) for s in sentences if s is not None]


def sent2diagrams(sentences):
    diagrams, none_idx = [], []
    for i, sent in enumerate(sentences):
        d = parser.sentence2diagram(sent, tokenised=False, suppress_exceptions=True)
        if d is None:
            none_idx.append(i)
            continue

        diagrams.append(d)
    print(f"Whilst parsing sentences2diagrams, lost {len(sentences) - len(diagrams)} due to Null, out of {len(sentences)}.")

    return diagrams, none_idx


def normalize(sentence_diagrams):
    diagrams_normalized, none_idx, errs = [], [], []
    drop_rewrite = 0
    drop_cups = 0
    for i, d in enumerate(sentence_diagrams):
        try:
            d = rewriter(d)
            if (d is None):
                none_idx.append(i)
                drop_rewrite += 1
                continue

            d = remove_cups(d)
            d = d.normal_form()
            d = d.pregroup_normal_form()
            d = unify(d)

        except Exception as e:
            none_idx.append(i)
            errs.append(f"{e} ( normalize() )")
            drop_cups += 1
            continue

        diagrams_normalized.append(d)

    print(f"Dropped {drop_rewrite} diagrams in rewrite, {drop_cups} in cup removal, out of {len(sentence_diagrams)}.")

    return diagrams_normalized, none_idx, errs

def quantum_encode(diagrams: "List"):
    encoded_diagrams, remove, errs = [], [], []
    for i, diagram in enumerate(diagrams):
        try:
            circ    = ansatz(diagram)
            # circ = circ.to_tk()
            encoded_diagrams.append(circ)
        except Exception as e:
            errs.append(f"{e} ( quantum_encode() )")
            remove.append(i)

    return encoded_diagrams, remove, errs

def will_train(
    circuits,
    qubit_limit: int = 40,
    mem_limit_bytes: int = 7 * 2**30,   # 7 GiB
):

    valid, invalid_idxs, errs = [], [], []
    for idx, circ in enumerate(circuits):
        try:
            tk_circ = circ.to_tk()
            needed = 16 * (2 ** tk_circ.n_qubits)
            if needed > mem_limit_bytes:
                raise RuntimeError(f"Needs {needed} bytes > limit {mem_limit_bytes} bytes ({needed/2**20:.0f} MiB > {(mem_limit_bytes/2**20):.0f} MiB). ")

            comp_pass.apply(tk_circ)

            syms = tk_circ.free_symbols()
            if syms:
                bind_map = {s: 0.0 for s in syms}
                tk_circ.symbol_substitution(bind_map)

            simulator.run(tk_circ, n_shots=1)
            # backend.run_circuit(tk_circ, n_shots=1)

        except Exception as e:
            invalid_idxs.append(idx)
            errs.append(f"{e} ( will_train() )")
            continue

        valid.append( circ )

    return valid, invalid_idxs, errs



def preprocess_and_encode(dataset):
    encoded_data, errors = [], []
    errs1, errs2, errs3, errs4 = [], [], [], []

    for i, ds_dict in enumerate(tqdm(dataset, desc="Filtering and Encoding dataset")):
        with open(os.devnull, 'w') as devnull, \
         contextlib.redirect_stdout(devnull), \
         contextlib.redirect_stderr(devnull):

            text_sentences = ds_dict['text_sentences']
            labels         = ds_dict['labels']

            sentences_simplified = sentence_simplify(text_sentences)
            diagrams, remove     = sent2diagrams(sentences_simplified)
            text_sentences       = remove_by_idx(text_sentences, remove)
            labels               = remove_by_idx(labels, remove)

            normalized_diagrams, remove, errs2 = normalize(diagrams)
            text_sentences                     = remove_by_idx(text_sentences, remove)
            labels                             = remove_by_idx(labels, remove)

            circuits, remove, errs3 = quantum_encode(normalized_diagrams)
            text_sentences          = remove_by_idx(text_sentences, remove)
            labels                  = remove_by_idx(labels, remove)

            circuits, remove, errs4 = will_train(circuits)
            text_sentences          = remove_by_idx(text_sentences, remove)
            labels                  = remove_by_idx(labels, remove)

            encoded_data.append({'circuits': circuits, 'labels': labels, 'original_text_sentences': text_sentences})


            errors += errs1 + errs2 + errs3 + errs4

    return encoded_data, errors