In [None]:
# Environment and Core Libraries
import pandas as pd
import numpy as np
import random
import math
import yaml
import os #multithreading

import fasttext

# PennyLane and PyTorch
import pennylane as qml
import torch
from torch.nn import Module, Parameter
import torch.nn as nn
from torch.optim import Adam
from sklearn.model_selection import train_test_split

# Lambeq
from lambeq.backend.quantum import Diagram as LambeqDiagram
from discopy.quantum import gates
import spacy
import discopy
from lambeq import BobcatParser, Rewriter, SpacyTokeniser, AtomicType, StronglyEntanglingAnsatz
from discopy.rigid import Ty

#data handling and plotting
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, auc, f1_score
import matplotlib.pyplot as plt

# Patch for discopy
monoidal_module = getattr(discopy, "monoidal", None)
if monoidal_module:
    diagram_class = getattr(monoidal_module, "Diagram", None)
    if diagram_class and not hasattr(diagram_class, "is_mixed"):
        diagram_class.is_mixed = property(lambda self: False)

# Load spacy model
spacy.load('en_core_web_sm')

In [None]:
#Data Loading Function
def load_data(csv_file, sample_fraction=1.0):
    sentences1, sentences2, is_duplicate = [], [], []
    try:
        df = pd.read_csv(csv_file, encoding='utf-8')
        if sample_fraction < 1.0:
            df = df.sample(frac=sample_fraction, random_state=42).reset_index(drop=True)
        
        sentences1 = df['question1'].astype(str).tolist()
        sentences2 = df['question2'].astype(str).tolist()
        is_duplicate = df['is_duplicate'].tolist()
        
        print(f"Loaded {len(sentences1)} sentence pairs.")
        return sentences1, sentences2, is_duplicate
    except Exception as e:
        print(f"An error occurred: {e}")
        return [], [], []

def create_balanced_training_set(training_data: list) -> list:
    """Creates a balanced training set by undersampling the majority class."""
    positives = [pair for pair in training_data if pair['label'] == 1]
    negatives = [pair for pair in training_data if pair['label'] == 0]
    
    # Undersample the larger class to match the size of the smaller class
    if len(positives) > len(negatives):
        positives = random.sample(positives, len(negatives))
    else:
        negatives = random.sample(negatives, len(positives))
    
    balanced_train_set = positives + negatives
    random.shuffle(balanced_train_set)
    
    print(f"Created a balanced training set with {len(positives)} positive and {len(negatives)} negative pairs.")
    return balanced_train_set
def load_fasttext_model(model_path: str):
    """Loads the full FastText .bin model."""
    print(f"Loading FastText model from {model_path}...")
    model = fasttext.load_model(model_path)
    print("FastText model loaded successfully.")
    return model

In [None]:
#Plotting functions
def plot_training_history(history: dict):
    """Plots the training loss and the average fidelity over epochs."""
    plt.figure(figsize=(10, 6))
    plt.plot(history['train_loss'], label='Training Loss (Local Cost)')
    if 'avg_fidelity' in history and history['avg_fidelity']:
        plt.plot(history['avg_fidelity'], label='Avg. Fidelity (from SWAP Test)', linestyle='--')
    plt.title('Training Progress')
    plt.xlabel('Epoch')
    plt.ylabel('Value')
    plt.legend()
    plt.grid(True)
    plt.show()
    
def plot_parameter_history(param_history):
    if not param_history:
        print("Parameter history is empty. Cannot plot.")
        return

    epochs = range(len(param_history))
    means = [d['mean'] for d in param_history]
    stds = [d['std'] for d in param_history]
    mins = [d['min'] for d in param_history]
    maxs = [d['max'] for d in param_history]

    plt.figure(figsize=(10, 6))
    plt.plot(epochs, means, label='Mean Parameter Value')
    plt.fill_between(epochs, np.array(means) - np.array(stds), np.array(means) + np.array(stds), alpha=0.2, label='1 Std. Deviation')
    plt.plot(epochs, mins, linestyle='--', color='gray', label='Min/Max Range')
    plt.plot(epochs, maxs, linestyle='--', color='gray')
    
    plt.title('Evolution of Model Parameters During Training')
    plt.xlabel('Epoch')
    plt.ylabel('Parameter Value')
    plt.legend()
    plt.grid(True)
    plt.show()
def plot_parameter_evolution_polar(param_history):
    if not param_history:
        print("Parameter history is empty. Cannot plot.")
        return
    
    plt.figure(figsize=(8, 8))
    ax = plt.subplot(111, projection='polar')

    epochs = np.array(range(len(param_history)))
    
    mean_angles = np.array([d['mean'] for d in param_history]) % (4 * np.pi)

    ax.plot(mean_angles, epochs, 'o-', label='Mean Parameter Path')

    if len(epochs) > 0:
        ax.plot(mean_angles[0], epochs[0], 'gX', markersize=12, label='Start')
        ax.plot(mean_angles[-1], epochs[-1], 'rX', markersize=12, label='End')

    ax.set_theta_zero_location('N')# pyright: ignore
    ax.set_theta_direction(-1)# pyright: ignore
    ax.set_rlabel_position(0)# pyright: ignore
    ax.set_rlim(0, len(epochs) * 1.05)# pyright: ignore
    ax.set_xlabel("Epoch")
    ax.set_title('Cyclical Evolution of Mean Parameter', pad=20)
    ax.legend()
    plt.show()
def plot_parameter_deltas(param_history):
    if len(param_history) < 2:
        print("Need at least 2 epochs to plot parameter deltas.")
        return

    mean_angles = np.array([d['mean'] for d in param_history])
    
    # Calculate the shortest angle difference between each epoch
    deltas = []
    for i in range(1, len(mean_angles)):
        prev_angle = mean_angles[i-1]
        curr_angle = mean_angles[i]
        delta = np.arctan2(np.sin(curr_angle - prev_angle), np.cos(curr_angle - prev_angle))
        deltas.append(delta)

    plt.figure(figsize=(10, 6))
    # We plot against epochs 1 to N, since the first delta occurs at epoch 1
    plt.plot(range(1, len(mean_angles)), deltas, 'o-', label='Change in Mean Parameter (Delta)')
    
    plt.axhline(0, color='red', linestyle='--', label='No Change')
    plt.title('Epoch-to-Epoch Change in Mean Parameter Value')
    plt.xlabel('Epoch')
    plt.ylabel('Shortest Angle Difference (Radians)')
    plt.legend()
    plt.grid(True)
    plt.xticks(range(1, len(mean_angles)))
    plt.show()


def plot_confusion_matrix(y_true, y_pred, threshold=0.5):
    """
    Computes and plots a confusion matrix.
    
    Args:
        y_true (np.array): The ground-truth labels (0s and 1s).
        y_pred (np.array): The model's raw probability predictions (overlaps from 0 to 1).
        threshold (float): The cutoff for classifying a prediction as 1.
    """
    # Convert probability predictions to binary 0/1 predictions
    binary_preds = (y_pred >= threshold).astype(int)
    
    cm = confusion_matrix(y_true, binary_preds)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Predicted Non-Duplicate', 'Predicted Duplicate'],
                yticklabels=['Actual Non-Duplicate', 'Actual Duplicate'])
    plt.title('Confusion Matrix')
    plt.ylabel('Actual Label')
    plt.xlabel('Predicted Label')
    plt.show()
def plot_roc_curve(y_true, y_pred):
    """
    Computes and plots the ROC curve and AUC score.
    """
    fpr, tpr, thresholds = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Random Guess')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()

In [None]:
# QNLP MODEL AND TRAINING PIPELINE
def execute_discopy_diagram(current_width, diagram, params, wires, embedding_method='simple_pad', rotation_param=None, entangling_param=None):
    """Executes a DisCoPy/lambeq diagram's instructions, including padding."""
    wire_map = {i: w for i, w in enumerate(wires)}
    param_idx = 0
    for gate, offset in zip(diagram.boxes, diagram.offsets):
        if hasattr(qml, gate.name):
            op = getattr(qml, gate.name)
            gate_params = []
            num_params = len(gate.free_symbols)
            if num_params > 0:
                gate_params = params[param_idx : param_idx + num_params]
                param_idx += num_params
            target_wires = [wire_map[i + offset] for i in range(len(gate.dom))]
            op(*gate_params, wires=target_wires)
    ancilla_wires = wires[current_width:]
    if embedding_method == 'parameterized':
        if rotation_param is not None:
            for w in ancilla_wires: qml.RY(rotation_param, wires=w)
        if entangling_param is not None and len(ancilla_wires) > 1:
            for i in range(len(ancilla_wires)):
                qml.CPHASE(entangling_param, wires=[ancilla_wires[i], ancilla_wires[(i + 1) % len(ancilla_wires)]])

def execute_discopy_diagram_local(diagram, params, wires):
    """Executes a DisCoPy/lambeq diagram without padding for local measurements."""
    wire_map = {i: w for i, w in enumerate(wires)}
    param_idx = 0
    for gate, offset in zip(diagram.boxes, diagram.offsets):
        if hasattr(qml, gate.name):
            op = getattr(qml, gate.name)
            gate_params = []
            num_params = len(gate.free_symbols)
            if num_params > 0:
                gate_params = params[param_idx : param_idx + num_params]
                param_idx += num_params
            target_wires = [wire_map[i + offset] for i in range(len(gate.dom))]
            op(*gate_params, wires=target_wires)
def get_diagram_width(diagram):
    """Calculates the true maximum width of a diagram at any point."""
    if not diagram.boxes:
        return len(diagram.cod)
    # The width is the maximum wire index a box acts on.
    return max(
        [offset + len(box.dom) for box, offset in zip(diagram.boxes, diagram.offsets)]
        + [len(diagram.cod)]
    )
    
def preprocess_data_for_model(data_pairs, Tokeniser, ansatz, parser, rewriter, qubit_limit=20):
    """
    Preprocesses data, retrying with an implicit "You" if a Ty(p) error occurs.
    """
    print(f"Starting preprocessing with qubit limit {qubit_limit}...")
    filtered_pairs, all_symbols, n_max = [], set(), 0
    N = AtomicType.NOUN

    for i, (s1, s2, is_duplicate) in enumerate(data_pairs):
        processed_successfully = False
        try:
            # First attempt
            d1 = ansatz(rewriter(parser.sentence2diagram(Tokeniser.tokenise_sentence(s1), tokenised=True)))
            d2 = ansatz(rewriter(parser.sentence2diagram(Tokeniser.tokenise_sentence(s2), tokenised=True)))
            processed_successfully = True # Mark success if no exception

        except Exception as e:
            # Check if the error is the specific Ty(p) issue
            # Note: The exact error message or type might vary, adjust if needed
            error_str = str(e)
            is_type_p_error = "Ty(p)" in error_str or "'Ty' object has no attribute 'label'" in error_str # Check for both potential messages

            if is_type_p_error:
                print(f"Warning: Pair #{i+1} failed with Ty(p) related error. Retrying with implicit 'You'.")
                try:
                    # Modify sentence(s) and retry processing
                    # Decide if only one or both sentences need modification based on error details if possible
                    # For simplicity, let's retry both if either fails with Ty(p)
                    s1_modified = "You " + s1
                    s2_modified = "You " + s2

                    d1 = ansatz(rewriter(parser.sentence2diagram(Tokeniser.tokenise_sentence(s1_modified), tokenised=True)))
                    d2 = ansatz(rewriter(parser.sentence2diagram(Tokeniser.tokenise_sentence(s2_modified), tokenised=True)))
                    processed_successfully = True # Mark success on retry
                    print(f"  --> Retry successful for pair #{i+1}.")

                except Exception as e_retry:
                    # If retry also fails, log the second error
                    print(f"Warning: Retry failed for pair #{i+1}. Reason: {e_retry}")
                    continue # Skip this pair
            else:
                # If it's a different error (e.g., Bobcat parse failure), log it and skip
                print(f"Warning: Failed to process pair #{i+1}. Reason: {e}")
                continue # Skip this pair

        # If either the first attempt or the retry was successful:
        if processed_successfully:
            try:
                # Proceed with width calculation and filtering using the successful diagrams
                width1 = len(d1.cod) # Use simple len(cod) with ansatz
                width2 = len(d2.cod)

                if width1 <= qubit_limit and width2 <= qubit_limit:
                    pair_data = {'s1': s1, 's2': s2, 'label': is_duplicate, 'd1': d1, 'd2': d2,
                                 'structural_disparity': abs(width1 - width2),
                                 'width1': width1, 'width2': width2}
                    filtered_pairs.append(pair_data)
                    all_symbols.update(d1.free_symbols)
                    all_symbols.update(d2.free_symbols)
                    n_max = max(n_max, width1, width2)
                # else: # Optional: Log if filtered out due to qubit limit after successful parse
                #     print(f"Note: Pair #{i+1} successfully parsed but exceeded qubit limit ({width1}, {width2})")

            except Exception as e_post_process:
                # Catch potential errors in width calculation or symbol extraction after successful parse/retry
                print(f"Warning: Error after processing pair #{i+1}. Reason: {e_post_process}")
                continue


    print(f"Preprocessing complete. Found {len(filtered_pairs)} valid pairs.")
    print(f"Total unique parameters (symbols) found: {len(all_symbols)}")
    print(f"N_Max for the filtered dataset is: {n_max}")
    return filtered_pairs, sorted(list(all_symbols), key=lambda s: s.name), n_max

class QNLPModel(nn.Module):
    """Generates the exact number of parameters required by the diagram."""
    def __init__(self, symbols, embedding_dim=100):
        super().__init__()
        # We need a way to map the averaged sentence vector to a variable number of params.
        # Let's use a simple MLP instead of just Linear. Kharti's method
        # Outputting a large fixed size, we will slice later if needed, but based on symbols count.
        # Max possible symbols could be large, let's estimate or find max_symbols needed across dataset.
        # For now, let's assume a reasonable max (e.g., 50 parameters typical?)
        # A better approach might be needed if this max is too small or too large.
        max_params_heuristic = 50 # Adjust as needed
        self.encoder = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim // 2),
            nn.ReLU(),
            nn.Linear(embedding_dim // 2, max_params_heuristic)
        )
        self.symbol_to_word = {s.name.replace('.', '_'): s.name.split('_')[-1] for s in symbols}

    def forward(self, diagram, fasttext_model):
        symbols = diagram.free_symbols
        num_required_params = len(symbols)
        
        if not symbols or num_required_params == 0:
            return torch.tensor([]) # Return empty tensor if no params needed

        words = [self.symbol_to_word.get(s.name.replace('.', '_')) for s in symbols]
        vectors = [fasttext_model.get_word_vector(w) for w in words if w]
        
        if not vectors:
            # If no words found, return zeros matching required params
            return torch.zeros(num_required_params)

        sentence_vector = torch.tensor(vectors, dtype=torch.float32).mean(dim=0)
        
        # Generate a large fixed set of potential parameters
        all_params = self.encoder(sentence_vector)
        
        # Slice exactly the number needed for this diagram
        if num_required_params > all_params.shape[0]:
             # Handle cases needing more params than our heuristic max
             # This might indicate the heuristic needs increasing or a different strategy
             print(f"Warning: Diagram requires {num_required_params} params, but model only outputs {all_params.shape[0]}. Truncating.")
             return all_params # Or handle error appropriately
        
        quantum_params = all_params[:num_required_params]
        return quantum_params

def train_model(model, fasttext_model, training_data, validation_data, n_max, device_name,
                base_learning_rate, lambda_penalty, epochs, n_layers, embedding_method):
    """Trains with a local cost function and logs SWAP test fidelity."""
    optimizer = Adam(model.parameters(), lr=base_learning_rate)
    dev = qml.device(device_name, wires=n_max)
    swap_dev = qml.device(device_name, wires=1 + 2 * n_max)
    history = {'train_loss': [], 'val_loss': [], 'avg_fidelity': [], 'param_history': []}
    print("--- Starting training with LOCAL cost function ---")

    @qml.qnode(dev, interface="torch")
    def local_expval_qnode(params, diagram, num_qubits):
        execute_discopy_diagram_local(diagram, params, wires=range(num_qubits))
        return [qml.expval(qml.PauliZ(i)) for i in range(num_qubits)]

    for epoch in range(epochs):
        model.train()
        total_train_loss, total_fidelity, num_trained_pairs = 0, 0, 0
        for pair in training_data:
            optimizer.zero_grad()
            params1 = model(pair['d1'], fasttext_model)
            params2 = model(pair['d2'], fasttext_model)
            
            if params1.nelement() == 0 or params2.nelement() == 0:
                print(f"Skipping pair with no parameters needed.\n")
                continue  # Skip pairs where no parameters are needed (highly probability of being an error case)
            
            exp_vals1 = torch.stack(local_expval_qnode(params1, pair['d1'], pair['width1']))
            exp_vals2 = torch.stack(local_expval_qnode(params2, pair['d2'], pair['width2']))
            
            if pair['width1'] < n_max: exp_vals1 = torch.cat([exp_vals1, torch.zeros(n_max - pair['width1'])])
            if pair['width2'] < n_max: exp_vals2 = torch.cat([exp_vals2, torch.zeros(n_max - pair['width2'])])

            mse = torch.mean((exp_vals1 - exp_vals2)**2)
            target_distance = 1 - pair['label']
            local_loss = (mse - target_distance)**2
            structural_penalty = torch.tensor(lambda_penalty * pair['structural_disparity'])
            loss = local_loss + structural_penalty
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

            with torch.no_grad():
                @qml.qnode(swap_dev, interface="torch")
                def swap_test_qnode(p1, p2):
                    qml.Hadamard(wires=0)
                    # Note: These dummy thetas/phis are for the parameterized padding method
                    dummy_theta = torch.tensor(0.1); dummy_phi = torch.tensor(0.1)
                    execute_discopy_diagram(pair['width1'], pair['d1'], p1, range(1, 1+n_max), embedding_method, dummy_theta, dummy_phi)
                    execute_discopy_diagram(pair['width2'], pair['d2'], p2, range(1+n_max, 1+2*n_max), embedding_method, dummy_theta, dummy_phi)
                    for j in range(n_max): qml.CSWAP(wires=[0, 1+j, 1+n_max+j])
                    qml.Hadamard(wires=0)
                    return qml.expval(qml.PauliZ(0))
                fidelity = swap_test_qnode(params1, params2).item()
                total_fidelity += fidelity
            num_trained_pairs += 1

        avg_train_loss = total_train_loss / num_trained_pairs if num_trained_pairs > 0 else 0
        avg_fidelity = total_fidelity / num_trained_pairs if num_trained_pairs > 0 else 0
        history['train_loss'].append(avg_train_loss); history['avg_fidelity'].append(avg_fidelity)
        all_params = torch.cat([p.data.flatten() for p in model.parameters()]).detach().numpy()
        if all_params.size > 0:
            history['param_history'].append({'mean': np.mean(all_params), 'std': np.std(all_params), 'min': np.min(all_params), 'max': np.max(all_params)})
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Avg Fidelity: {avg_fidelity:.4f}")

    return model, history

def evaluate_model(model, fasttext_model, test_data, n_max, device_name, embedding_method):
    """Evaluates the trained model on the test set using the SWAP test."""
    print("\n--- Starting Evaluation on Test Set ---")
    model.eval()
    swap_dev = qml.device(device_name, wires=1 + 2 * n_max)
    predictions, true_labels = [], []
    with torch.no_grad():
        for pair in test_data:
            @qml.qnode(swap_dev, interface="torch")
            def swap_test_qnode(p1, p2):
                qml.Hadamard(wires=0)
                dummy_theta = torch.tensor(0.1); dummy_phi = torch.tensor(0.1)
                execute_discopy_diagram(pair['width1'], pair['d1'], p1, range(1, 1 + n_max), embedding_method, dummy_theta, dummy_phi)
                execute_discopy_diagram(pair['width2'], pair['d2'], p2, range(1 + n_max, 1 + 2 * n_max), embedding_method, dummy_theta, dummy_phi)
                for j in range(n_max): qml.CSWAP(wires=[0, 1 + j, 1 + n_max + j])
                qml.Hadamard(wires=0)
                return qml.expval(qml.PauliZ(0))
            
            params1 = model(pair['d1'], fasttext_model)
            params2 = model(pair['d2'], fasttext_model)
            measured_overlap = swap_test_qnode(params1, params2)
            predictions.append(measured_overlap.item())
            true_labels.append(pair['label'])
    
    if predictions:
        # Convert predictions to binary (0 or 1) based on a threshold (e.g., 0.5)
        threshold = 0.5
        binary_preds = (np.array(predictions) >= threshold).astype(int)
        true_labels_arr = np.array(true_labels)

        # ✅ Calculate F1 Score
        f1 = f1_score(true_labels_arr, binary_preds)
        print(f"Test Set F1 Score: {f1:.4f}")

        # You can also calculate accuracy here if desired
        accuracy = np.mean(binary_preds == true_labels_arr)
        print(f"Test Set Accuracy: {accuracy:.4f}")

        print("\n--- Evaluation Plots ---")
        plot_confusion_matrix(true_labels_arr, predictions, threshold=threshold) # Pass true_labels_arr
        plot_roc_curve(true_labels_arr, predictions) # Pass true_labels_arr
    else:
        print("No valid pairs in the test set to evaluate.")

In [None]:
#Click this to run model
def main(config_path: str):
    """Main function to run the entire workflow."""
    with open(config_path, 'r') as f: config = yaml.safe_load(f)
    print("Configuration loaded:\n", yaml.dump(config, indent=2))

    # --- Initialize Lambeq Objects ---
    tokeniser = SpacyTokeniser()
    parser = BobcatParser()
    rewriter = Rewriter(config['qnlp']['rewrite_rules'])
    N = AtomicType.NOUN; S = AtomicType.SENTENCE
    ansatz = StronglyEntanglingAnsatz({N: 1, S: 1}, n_layers=config['qnlp']['n_layers'])

    # --- Load Data and Models ---
    fasttext_model = load_fasttext_model(config['data']['fasttext_path'])
    if not fasttext_model: return
    
    sentences1, sentences2, value = load_data(config['data']['path'], config['data']['sample_fraction'])
    data_pairs = list(zip(sentences1, sentences2, value))
    
    filtered_data, symbols, n_max = preprocess_data_for_model(
        data_pairs, tokeniser, ansatz, parser, rewriter, config['data']['qubit_limit']
    )
    
    if filtered_data and n_max > 0:
        # --- Create Datasets ---
        train_val_data_raw, test_data = train_test_split(filtered_data, test_size=0.2, random_state=42)
        train_data_raw, val_data = train_test_split(train_val_data_raw, test_size=0.25, random_state=42)
        training_data = create_balanced_training_set(train_data_raw)
        print(f"\nData split into {len(training_data)} training pairs, {len(val_data)} validation pairs, and {len(test_data)} test pairs.")
        
        # --- Initialize Model and Train ---
        embedding_dim = fasttext_model.get_dimension()
        model = QNLPModel(symbols, embedding_dim=embedding_dim)
        
        trained_model, history = train_model(
            model=model, fasttext_model=fasttext_model, training_data=training_data,
            validation_data=val_data, n_max=n_max, device_name=config['simulation']['device'],
            base_learning_rate=config['training']['base_learning_rate'],
            lambda_penalty=config['training']['lambda_penalty'], epochs=config['training']['epochs'],
            n_layers=config['qnlp']['n_layers'], embedding_method=config['qnlp']['embedding_method']
        )

        # --- Evaluate and Plot ---
        evaluate_model(
            model=trained_model, fasttext_model=fasttext_model, test_data=test_data,
            n_max=n_max, device_name=config['simulation']['device'],
            embedding_method=config['qnlp']['embedding_method']
        )
        plot_training_history(history)
        if 'param_history' in history and history['param_history']:
            plot_parameter_evolution_polar(history['param_history'])
            plot_parameter_deltas(history['param_history'])
    else:
        print("\nNo data to train on.")

# Call main
config_file_path = 'config.yaml'
main(config_file_path)

In [None]:
def preprocess_for_lambeq_model(data_pairs, Tokeniser, ansatz, parser, rewriter, qubit_limit=20):
    """Prepares lambeq Circuit objects suitable for PennylaneModel."""
    print(f"Starting preprocessing for PennylaneModel with qubit limit {qubit_limit}...")
    filtered_pairs, all_symbols, n_max = [], set(), 0
    N = AtomicType.NOUN
    for i, (s1, s2, is_duplicate) in enumerate(data_pairs):
        try:
            # Full pipeline: parse -> rewrite -> ansatz = lambeq Circuit
            c1 = ansatz(rewriter(parser.sentence2diagram(Tokeniser.tokenise_sentence(s1), tokenised=True)))
            c2 = ansatz(rewriter(parser.sentence2diagram(Tokeniser.tokenise_sentence(s2), tokenised=True)))

            width1 = len(c1.cod) # Width determined by ansatz output
            width2 = len(c2.cod)

            if width1 <= qubit_limit and width2 <= qubit_limit:
                pair_data = {'s1': s1, 's2': s2, 'label': is_duplicate,
                             'circuit1': c1, 'circuit2': c2, # Store circuits
                             'width1': width1, 'width2': width2}
                filtered_pairs.append(pair_data)
                all_symbols.update(c1.free_symbols)
                all_symbols.update(c2.free_symbols)
                n_max = max(n_max, width1, width2)
        except Exception as e:
            print(f"Warning: Failed to process pair #{i+1}. Reason: {e}")
            continue
    print(f"Preprocessing complete. Found {len(filtered_pairs)} valid pairs.")
    # Return circuits and labels directly, plus symbols for weight init
    circuits1 = [p['circuit1'] for p in filtered_pairs]
    circuits2 = [p['circuit2'] for p in filtered_pairs]
    labels = [p['label'] for p in filtered_pairs]
    return circuits1, circuits2, labels, sorted(list(all_symbols), key=lambda s: s.name), n_max
from lambeq.training import SPSAOptimizer
from lambeq.training.quantum_model import QuantumModel
from sympy import Symbol

def initialize_weights_from_fasttext(symbols: list[Symbol], fasttext_model) -> torch.Tensor:
    """Creates initial weights based on FastText embeddings."""
    symbol_to_word = {s.name.replace('.', '_'): s.name.split('_')[-1] for s in symbols}
    initial_weights = []
    embedding_dim = fasttext_model.get_dimension()
    # Simple linear layer to map embedding to a scalar parameter
    # NOTE: This assumes 1 parameter per symbol, adjust if ansatz needs more
    encoder = nn.Linear(embedding_dim, 1) # Use a fresh encoder instance

    with torch.no_grad(): # We only need the initial values
        for sym in symbols:
            word = symbol_to_word.get(sym.name.replace('.', '_'), None)
            if word:
                vector = fasttext_model.get_word_vector(word)
                vector_tensor = torch.tensor(vector, dtype=torch.float32)
                # Generate initial parameter value
                param_val = encoder(vector_tensor).squeeze().item()
                initial_weights.append(param_val)
            else:
                initial_weights.append(0.1) # Default for symbols without clear words

    return torch.tensor(initial_weights, dtype=torch.float32)
from lambeq.training.trainer import Trainer
from lambeq.training.dataset import Dataset
from torch.optim import Adam
import pennylane as qml

def train_with_lambeq_pennylane_model(circuits1, circuits2, labels, symbols, fasttext_model, n_max, device_name, config):
    """Sets up and trains using lambeq.training.PennylaneModel."""

    # 1. Initialize weights
    initial_weights = initialize_weights_from_fasttext(symbols, fasttext_model)

    # 2. Define the QuantumModel - uses SWAP test for comparing two circuits
    # We need to use QuantumModel.from_diagrams for comparison tasks
    # This requires diagrams, not circuits - let's adjust preprocessing slightly if needed
    # Assuming circuits1/circuits2 are compatible lists of lambeq circuits
    # Note: PennylaneModel expects a single list of diagrams usually.
    # For pair comparison, we might need a custom model or structure adaptation.
    # Let's try defining a custom PyTorch model that uses PennylaneModel internally,
    # or adapt the input format if possible.

    # --- Alternative Approach: Custom PyTorch Module using QNodes ---
    # This might be closer to what PennylaneModel expects, wrapping QNodes.

    class PairComparisonModel(torch.nn.Module):
        def __init__(self, n_max, device_name, initial_weights):
            super().__init__()
            self.n_max = n_max
            self.device = qml.device(device_name, wires=1 + 2 * n_max)
            # Store weights as parameters
            self.weights = torch.nn.Parameter(initial_weights)

        def get_qnode(self, circuit1, circuit2, width1, width2):
            # Define QNode dynamically inside, similar to previous approaches
            @qml.qnode(self.device, interface='torch')
            def swap_test_qnode(weights):
                # Map global weights to the specific symbols needed by circuits
                symbols1 = circuit1.free_symbols
                params1 = torch.tensor([weights[symbols.index(s)] for s in symbols1 if s in symbols]) # Map weights
                
                symbols2 = circuit2.free_symbols
                params2 = torch.tensor([weights[symbols.index(s)] for s in symbols2 if s in symbols]) # Map weights

                qml.Hadamard(wires=0)
                # Need execute_discopy_diagram here to run the lambeq circuit objects
                # (Assuming execute_discopy_diagram is available)
                execute_discopy_diagram(width1, circuit1, params1, range(1, 1+self.n_max), 'parameterized', torch.tensor(0.1), torch.tensor(0.1)) # Use dummy padding params for now
                execute_discopy_diagram(width2, circuit2, params2, range(1+self.n_max, 1+2*self.n_max), 'parameterized', torch.tensor(0.1), torch.tensor(0.1))
                for j in range(self.n_max): qml.CSWAP(wires=[0, 1+j, 1+self.n_max+j])
                qml.Hadamard(wires=0)
                return qml.expval(qml.PauliZ(0))
            return swap_test_qnode

        def forward(self, data):
            # Data should be a batch of (circuit1, circuit2, width1, width2) tuples/lists
            results = []
            for c1, c2, w1, w2 in data:
                 qnode = self.get_qnode(c1, c2, w1, w2)
                 # We need the global list of symbols available here
                 # This architecture gets complex quickly.
                 # PennylaneModel might abstract this better if used correctly.
                 # Let's reconsider PennylaneModel's intended use.

    # --- Reverting to a simpler PennylaneModel setup (might not fit pair comparison directly) ---
    # PennylaneModel is typically for classifying single diagrams.
    # For pairs, we might need two models or a custom Pytorch layer.

    # --- Let's focus on the structure you'd use IF PennylaneModel supported pairs easily ---
    # (This is conceptual - the library might need extending or a different approach)

    print("Note: Direct pair comparison with lambeq.PennylaneModel might require custom adaptation.")
    print("Setting up a basic structure assuming such adaptation is possible or using a workaround.")

    # Placeholder: Define a loss (needs adaptation for pair comparison)
    bce_loss = torch.nn.BCELoss() # Binary Cross Entropy for 0/1 labels

    # Placeholder: Define optimizer
    optimizer = Adam # Using Adam class

    # Create PennylaneModel instance (conceptual)
    # model = PennylaneModel(model_spec={'backend': device_name, 'n_wires': 1 + 2*n_max}) # Simplified

    # Prepare dataset (conceptual)
    # Need to structure data appropriately for the model's forward pass
    # train_data = list(zip(circuits1, circuits2)) # Example structure
    # train_dataset = Dataset(train_data, labels)

    # Setup Trainer (conceptual)
    # trainer = Trainer(model=model,
    #                   loss_function=bce_loss,
    #                   optimizer=optimizer,
    #                   learning_rate=config['training']['base_learning_rate'],
    #                   epochs=config['training']['epochs'],
    #                   evaluate_functions={'acc': lambda y_hat, y: (y_hat.round() == y).float().mean()},
    #                   evaluate_on_train=True,
    #                   verbose='text',
    #                   seed=0)

    # Train model (conceptual)
    # trainer.fit(train_dataset, val_dataset=None) # Assuming train_dataset is correctly formatted

    print("Conceptual setup for lambeq.PennylaneModel complete.")
    print("Actual implementation for pair comparison may require a custom PyTorch module wrapping QNodes.")

    # Return None for now as the direct PennylaneModel path is unclear for pairs
    return None, None
from lambeq import IQPAnsatz # Ensure IQPAnsatz is imported

def main_pennylane_model_branch(config_path: str):
    """Main function for the PennylaneModel branch."""
    with open(config_path, 'r') as f: config = yaml.safe_load(f)
    print("Configuration loaded:\n", yaml.dump(config, indent=2))

    # --- Initialize Lambeq Objects ---
    tokeniser = SpacyTokeniser()
    parser = BobcatParser() # Or OncillaParser / discocirc when ready
    rewriter = Rewriter(config['qnlp']['rewrite_rules'])
    N = AtomicType.NOUN; S = AtomicType.SENTENCE
    # Use IQPAnsatz for preprocessing as it's most compatible
    ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=config['qnlp']['n_layers'])

    # --- Load Data and Models ---
    fasttext_model = load_fasttext_model(config['data']['fasttext_path'])
    if not fasttext_model: return

    sentences1, sentences2, labels_raw = load_data(config['data']['path'], config['data']['sample_fraction'])
    data_pairs = list(zip(sentences1, sentences2, labels_raw))

    # Preprocess to get circuits and labels
    circuits1, circuits2, labels, symbols, n_max = preprocess_for_lambeq_model(
        data_pairs, tokeniser, ansatz, parser, rewriter, config['data']['qubit_limit']
    )

    if circuits1 and n_max > 0:
        # --- Split Data (if necessary for Trainer) ---
        # Trainer might handle splitting, or do it manually:
        # train_c1, test_c1, train_c2, test_c2, train_labels, test_labels = train_test_split(...)

        print(f"\nPreprocessing successful. Proceeding with {len(circuits1)} pairs.")

        # --- Train using PennylaneModel structure ---
        # Note: This function currently returns None due to pair comparison complexity
        trained_model, history = train_with_lambeq_pennylane_model(
            circuits1, circuits2, labels, symbols, fasttext_model, n_max,
            config['simulation']['device'], config
        )

        if trained_model:
             # --- Evaluate (conceptual) ---
             # test_acc = trainer.test(test_dataset) # Assuming test_dataset setup
             # print(f"Test accuracy: {test_acc['acc']:.2f}")
             print("Evaluation step needs implementation based on chosen model structure.")
        else:
            print("Training function needs further implementation for pair comparison.")

    else:
        print("\nNo data to train on after preprocessing.")

# Example call for this branch
# config_file_path = 'config.yaml'
# main_pennylane_model_branch(config_file_path)