In [None]:
import os
import re
import sys
import json
import pickle
import contextlib
import numpy as np
from tqdm import tqdm
from typing import List, Tuple, Dict

from qiskit_aer import AerSimulator
# from pytket.extensions.qiskit import AerBackend # toks import veiks su qiskit v2.x
from pytket.extensions.qiskit.backends.aer import AerBackend


from lambeq import QuantumTrainer, TketModel, BinaryCrossEntropyLoss, SPSAOptimizer, Dataset
from lambeq.backend.grammar import Diagram, Id

In [11]:
################################################
#### Consts for performance and reusability ####
################################################


# Configure the raw AerSimulator for GPU + cuQuantum acceleration
simulator = AerSimulator(
    method="statevector", device="GPU",
    precision="single",         # 32-bit float for ~2× speedup on large statevectors. By default, computers use "double" (64-bit) precision. For most QNLP tasks, 32-bit (single) is more than enough. It allows your GPU to store a larger quantum state and compute roughly 2x faster because it moves half the data.
    cuStateVec_enable=True,     # turn on NVIDIA cuStateVec kernels
    batched_shots_gpu=True,     # batch thousands of shots very efficiently on GPU
    batched_shots_gpu_max_qubits=40,
    num_threads_per_device=2    # limit CPU threads per GPU to reduce overhead
)

# Build the Pytket AerBackend with matching GPU options
backend = AerBackend(noise_model=None, simulation_method="statevector")
backend._qiskit_backend = simulator

comp_pass = backend.default_compilation_pass(2)

print("simulator object available devices:", simulator.available_devices())
print("AerSimulator available devices:", AerSimulator().available_devices())

# AerSimulator().available_methods()

simulator object available devices: ['GPU']
AerSimulator available devices: ('CPU', 'GPU')


In [12]:
#############################################
# 1. Data Loading Functions
#############################################

def flatten(encoded: List[Dict]) -> Tuple[List, List]:
    circuits, labels = [], []
    for dct in encoded:
        circuits.extend(dct['circuits'])
        labels  .extend(dct['labels'])
    return circuits, labels


def load_with_pikle(file_path):
    with open(file_path, 'rb') as f:
      loaded_data = pickle.load(f)

    dste = loaded_data['dste']
    length = 0
    for d in dste:
      length += len(d['circuits'])

    print(f"Successfully Loaded {length} circuits!")
    return dste

dstF    = load_with_pikle('saves/train/ds/train_550.pkl')
dstestF = load_with_pikle('saves/test/ds/test_80.pkl')
dsvF    = load_with_pikle('saves/val/ds/val_.pkl')

Successfully Loaded 10378 circuits!
Successfully Loaded 1463 circuits!
Successfully Loaded 1427 circuits!


In [13]:
dst    = dstF[:2]
dsv    = dsvF[:1]
dstest = dstestF[:1]

In [24]:
for k in dsv[0]:
    print(k)

circuits
labels
original_text_sentences


In [None]:
def initializeDefaultStageConfig(use_multi_stage_config, epochs, shots):
    default_stage_config = []
    if use_multi_stage_config:
        for s, a in [(2048, 0.1), (4096, 0.05), (8192, 0.02)]:
            default_stage_config.append({
                'epochs': epochs,
                'shots': s,
                'optimizer_hparams': {
                    'a': a,
                    'c': 0.06 if s<=4096 else 0.02,
                    'A': 0.2 * epochs
                }
            })
            print("Using default multi-stage configuration.")
    else:
        default_stage_config = [{
            'epochs': epochs,
            'shots': shots,
            'optimizer_hparams': { 'a': 0.1, 'c': 0.06, 'A': 0.2 * epochs }
        }]

        print("Using single-stage configuration")

    return default_stage_config


def train_quantum_summarizer(
    encoded_train: List[Dict],
    encoded_val:   List[Dict],
    encoded_test:  List[Dict],
    batch_size: int = 5,
    epochs:     int = 5,
    shots:      int = 4096, # 8192
    seed:       int = 42,
    checkpoint_dir:    str = 'saves/model_checkpoints',
    use_multi_stage_config: bool = True,
    stage_config:     List[Dict] = None
) -> Tuple[TketModel, QuantumTrainer, Dict[str, float]]:

    os.makedirs(checkpoint_dir, exist_ok=True)

    train_circuits, train_labels = flatten(encoded_train)
    val_circuits,   val_labels   = flatten(encoded_val)
    test_circuits,  test_labels  = flatten(encoded_test)

    if not stage_config:
        stage_config = initializeDefaultStageConfig(use_multi_stage_config, epochs, shots)

    final_model   = None
    final_trainer = None
    best_val_acc  = -np.inf
    best_checkpoint     = None

    # common metric & datasets
    acc_fn     = lambda y_hat, y: np.mean(np.argmax(y_hat,1) == np.argmax(y,1))
    eval_funcs = {'accuracy': acc_fn}
    val_ds     = Dataset(val_circuits, val_labels, shuffle=False)
    test_ds    = Dataset(test_circuits, test_labels, shuffle=False)


    for idx, cfg in enumerate(stage_config, start=1):
        t_epochs  = cfg['epochs']
        t_shots   = cfg['shots']
        opt_hp  = cfg['optimizer_hparams']
        batch   = cfg.get('batch_size', batch_size)
        t_seed    = cfg.get('seed', seed + idx)

        print(f"\n=== Stage {idx}: epochs={t_epochs}, shots={t_shots}, batch={batch}, seed={t_seed} ===")

        backend_config = {
            'backend':     backend,
            'compilation': comp_pass,
            'shots':       t_shots,
        }

        checkpoint_path = os.path.join(checkpoint_dir, f'model_stage{idx}.lt')

        if final_model is None:
            # First stage: build from scratch
            diagrams = train_circuits + val_circuits
            model = TketModel.from_diagrams(diagrams, backend_config=backend_config)
            model.initialise_weights()
        else:
            # Subsequent stages: resume from last checkpoint
            model = TketModel.from_checkpoint(best_checkpoint, backend_config=backend_config)


        trainer = QuantumTrainer(
            model,
            loss_function      = BinaryCrossEntropyLoss(),
            optimizer          = SPSAOptimizer,
            optim_hyperparams  = opt_hp,
            evaluate_functions = eval_funcs,
            evaluate_on_train  = True,
            epochs             = t_epochs,
            seed               = t_seed,
            verbose            = 'text'
        )


        # Fit with early stopping
        train_ds = Dataset(train_circuits, train_labels, batch_size=batch, shuffle=True)
        hist = trainer.fit(
            train_ds,
            val_ds,
            early_stopping_criterion = 'accuracy',
            early_stopping_interval  = 3,
            minimize_criterion       = False
        )

        # Save checkpoint
        model.save(checkpoint_path)
        print(f"» Saved checkpoint: {checkpoint_path}")

        # 5) Track best val accuracy
        #    hist.val_metrics is a list of dicts of per-epoch val metrics
        #    (you may need to inspect trainer.history if API differs)
        final_val_acc = hist.val_metrics[-1]['accuracy']
        if final_val_acc > best_val_acc:
            best_val_acc = final_val_acc
            best_checkpoint    = checkpoint_path

        # Prepare for next stage
        final_model   = model
        final_trainer = trainer

    test_ds = Dataset(test_circuits, test_labels, shuffle=False)
    test_metrics = final_trainer.evaluate(test_ds)
    print(f"\nFinal test metrics: {test_metrics}")

    return final_model, final_trainer, test_metrics

In [None]:
# Custom 2-stage strategy:
stages = [
  {'epochs': 5,  'shots': 1024,
   'optimizer_hparams': {'a':0.15,'c':0.06,'A':0.2*5},
   'batch_size':  8},
  {'epochs': 15, 'shots': 8192,
   'optimizer_hparams': {'a':0.02,'c':0.02,'A':0.2*15}}
]

model, trainer, metrics = train_quantum_summarizer(
  encoded_train   = dst,
  encoded_val     = dsv,
  encoded_test    = dstest,

  batch_size = 5,
  seed       = 123,
  # epochs = 10,
  # shots = 8192,
  stage_config   = stages
)




=== Stage 1: epochs=5, shots=1024, batch=8, seed=124 ===


ValueError: Symbols not initialised. Instantiate through `from_diagrams()`.

In [None]:
#######################################################################################################
#######################################################################################################
#######################################################################################################
#######################################################################################################
#######################################################################################################
#######################################################################################################
#######################################################################################################
#######################################################################################################
#######################################################################################################
#######################################################################################################

# ### inicializations not needed anymore ### #

_CLEAN_RE = re.compile(r"[^\w\s']")
parser    = BobcatParser(verbose='text', device=0)

ansatz    = IQPAnsatz(
    {AtomicType.SENTENCE: 1,
     AtomicType.NOUN:     1,
     AtomicType.PREPOSITIONAL_PHRASE: 0},
    n_layers=2, n_single_qubit_params=3
)


def create_rewriter():
    # Rule to delete conjunction boxes (“and”, “but”) # just the wire, no box
    conj_rule = SimpleRewriteRule(cod=AtomicType.CONJUNCTION, template=Id(AtomicType.CONJUNCTION))

    # Rule to delete punctuation boxes (commas, quotes, dashes)
    punc_rule = SimpleRewriteRule(cod=AtomicType.PUNCTUATION, template=Id(AtomicType.PUNCTUATION))

    # remove_pp2 = SimpleRewriteRule(cod=AtomicType.PREPOSITIONAL_PHRASE, template=Id(AtomicType.SENTENCE))
    remove_pp = SimpleRewriteRule(cod=AtomicType.PREPOSITIONAL_PHRASE, template=Id(AtomicType.NOUN))

    rewriter = Rewriter(
        [
            'coordination', 'determiner',
            'postadverb', 'preadverb',
            'connector', 'auxiliary',
            'prepositional_phrase',
            'subject_rel_pronoun',
            'object_rel_pronoun',
        ]
    )
    rewriter.add_rules(remove_pp, punc_rule, conj_rule)
    return rewriter

rewriter = create_rewriter()
remove_cups = RemoveCupsRewriter()
unify = UnifyCodomainRewriter(output_type=AtomicType.SENTENCE)

In [None]:
def brbrpatapim(filename, left, right):
    dataset    = load_cnn_extractive(filename)
    dst        = dataset[left:right]
    dste, errs = preprocess_and_encode(dst)
    print("\nDone!")

    return dste, errs

train, errs_train = brbrpatapim("datasets/cnn_dailymail/train.jsonl", 35, 40)

In [None]:
#################################################
#### Mainly to test if GPU works with:       ####
#### sim.run(tk_circ, n_shots=1) and         ####
#### backend.run_circuit(tk_circ, n_shots=1) ####
#################################################

# def brbrpatapim(filename, left, right):
#     dataset    = load_cnn_extractive(filename)
#     dst        = dataset[left:right]
#     dste, errs = preprocess_and_encode(dst)
#     print("Done!")
#     return dste, errs

# a, b = brbrpatapim("datasets/cnn_dailymail/test.jsonl", 10, 11)
# for i, circ in enumerate(a[:]):
#   d = circ['circuits']
#   l = len()
#   print(f"{i} circuit: {l}")
# # with open("saves/test_debug_10_25.pkl", 'wb') as f:
# #         pickle.dump({'dste': a}, f)

def load_cnn_extractive(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            # if (i == 300):
            #     break
            try:
                record = json.loads(line)
            except json.JSONDecodeError:
                continue

            text = record.get("text")
            summary = record.get("summary")
            ts = record.get("text_sentences")
            ss = record.get("summary_sentences")
            labels = record.get("labels")
            if text is not None or summary is not None:
                data.append({'text': text, 'summary': summary, 'text_sentences': ts, 'summary_sentences': ss, 'labels': labels})
            else:
                print(f"{i}. NULL INSTANCE\ntext: {text}\n{summary}\n")
    return data


def load_qe():
    with open("saves/qe.pkl", 'rb') as f:
      loaded_data = pickle.load(f)

    dste = loaded_data['dste']
    length = 0
    for d in dste:
      length += len(d['circuits'])

    print(type(dste[0]['circuits'][0]))
    print(f"Successfully Loaded {length} circuits!")
    return dste

es = load_qe()

result = simulator.run(es[0]['circuits'][0]).result()
print(result.metadata.get("device", "unknown"))


In [None]:
#############################################
# Step 2. Preprocessing and lambeq Pipeline
#############################################

def remove_by_idx(ls, remove):
    if not remove:
        return ls
    remove.sort(reverse=True)
    for n in remove:
        ls.pop(n)
    return ls


def sentence_simplify(sentences):
    return [_CLEAN_RE.sub("", s) for s in sentences if s is not None]


def sent2diagrams(sentences):
    diagrams, none_idx = [], []
    for i, sent in enumerate(sentences):
        d = parser.sentence2diagram(sent, tokenised=False, suppress_exceptions=True)
        if d is None:
            none_idx.append(i)
            continue

        diagrams.append(d)
    print(f"Whilst parsing sentences2diagrams, lost {len(sentences) - len(diagrams)} due to Null, out of {len(sentences)}.")

    return diagrams, none_idx


def normalize(sentence_diagrams):
    diagrams_normalized, none_idx, errs = [], [], []
    drop_rewrite = 0
    drop_cups = 0
    for i, d in enumerate(sentence_diagrams):
        try:
            d = rewriter(d)
            if (d is None):
                none_idx.append(i)
                drop_rewrite += 1
                continue

            d = remove_cups(d)
            d = d.normal_form()
            d = d.pregroup_normal_form()
            d = unify(d)

        except Exception as e:
            none_idx.append(i)
            errs.append(f"{e} ( normalize() )")
            drop_cups += 1
            continue

        diagrams_normalized.append(d)

    print(f"Dropped {drop_rewrite} diagrams in rewrite, {drop_cups} in cup removal, out of {len(sentence_diagrams)}.")

    return diagrams_normalized, none_idx, errs

def quantum_encode(diagrams: "List"):
    encoded_diagrams, remove, errs = [], [], []
    for i, diagram in enumerate(diagrams):
        try:
            circ    = ansatz(diagram)
            # circ = circ.to_tk()
            encoded_diagrams.append(circ)
        except Exception as e:
            errs.append(f"{e} ( quantum_encode() )")
            remove.append(i)

    return encoded_diagrams, remove, errs

def will_train(
    circuits,
    qubit_limit: int = 40,
    mem_limit_bytes: int = 7 * 2**30,   # 7 GiB
):

    valid, invalid_idxs, errs = [], [], []
    for idx, circ in enumerate(circuits):
        try:
            tk_circ = circ.to_tk()
            needed = 16 * (2 ** tk_circ.n_qubits)
            if needed > mem_limit_bytes:
                raise RuntimeError(f"Needs {needed} bytes > limit {mem_limit_bytes} bytes ({needed/2**20:.0f} MiB > {(mem_limit_bytes/2**20):.0f} MiB). ")

            comp_pass.apply(tk_circ)

            syms = tk_circ.free_symbols()
            if syms:
                bind_map = {s: 0.0 for s in syms}
                tk_circ.symbol_substitution(bind_map)

            simulator.run(tk_circ, n_shots=1)
            # backend.run_circuit(tk_circ, n_shots=1)

        except Exception as e:
            invalid_idxs.append(idx)
            errs.append(f"{e} ( will_train() )")
            continue

        valid.append( circ )

    return valid, invalid_idxs, errs



def preprocess_and_encode(dataset):
    encoded_data, errors = [], []
    errs1, errs2, errs3, errs4 = [], [], [], []

    for i, ds_dict in enumerate(tqdm(dataset, desc="Filtering and Encoding dataset")):
        with open(os.devnull, 'w') as devnull, \
         contextlib.redirect_stdout(devnull), \
         contextlib.redirect_stderr(devnull):

            text_sentences = ds_dict['text_sentences']
            labels         = ds_dict['labels']

            sentences_simplified = sentence_simplify(text_sentences)
            diagrams, remove     = sent2diagrams(sentences_simplified)
            text_sentences       = remove_by_idx(text_sentences, remove)
            labels               = remove_by_idx(labels, remove)

            normalized_diagrams, remove, errs2 = normalize(diagrams)
            text_sentences                     = remove_by_idx(text_sentences, remove)
            labels                             = remove_by_idx(labels, remove)

            circuits, remove, errs3 = quantum_encode(normalized_diagrams)
            text_sentences          = remove_by_idx(text_sentences, remove)
            labels                  = remove_by_idx(labels, remove)

            circuits, remove, errs4 = will_train(circuits)
            text_sentences          = remove_by_idx(text_sentences, remove)
            labels                  = remove_by_idx(labels, remove)

            encoded_data.append({'circuits': circuits, 'labels': labels, 'original_text_sentences': text_sentences})


            errors += errs1 + errs2 + errs3 + errs4

    return encoded_data, errors