In [None]:
# qmsa_pytorch.py
import os
import math
import random
import numpy as np
from tqdm import tqdm
from sklearn.cluster import MiniBatchKMeans
from sklearn.model_selection import train_test_split
import cv2
import torch
import torch.nn as nn
import torch.optim as optim

# -------------------------
# Utilities: embeddings
# -------------------------
def load_glove_embeddings(glove_path, dims=100):
    """
    Load GloVe txt file into a dict[word] -> numpy vector.
    """
    emb = {}
    with open(glove_path, 'r', encoding='utf8') as f:
        for line in f:
            parts = line.strip().split()
            word = parts[0]
            vals = np.array(parts[1:], dtype=np.float32)
            if vals.shape[0] != dims:
                continue
            emb[word] = vals
    return emb

def normalize_vec(v):
    v = v.astype(np.float64)
    norm = np.linalg.norm(v)
    if norm == 0:
        return v
    return (v / norm).astype(np.float32)

# -------------------------
# Build projectors
# -------------------------
def word_projectors_from_text(text_tokens, emb_map, dim):
    """
    text_tokens: list of tokens
    emb_map: dict token->vector
    returns: list of projectors (numpy arrays dim x dim)
    """
    P = []
    for t in text_tokens:
        if t not in emb_map:
            continue
        w = normalize_vec(emb_map[t])
        # rank-1 projector: outer product
        Pi = np.outer(w, w)  # shape (dim, dim)
        P.append(Pi.astype(np.float32))
    return P

def build_visual_vocab(image_paths, k=128, sample_limit=None):
    """
    Extract SIFT descriptors for images and run kmeans to get visual words.
    Returns k cluster centers (k x 128)
    """
    sift = cv2.SIFT_create()
    descs = []
    for i, p in enumerate(tqdm(image_paths, desc="SIFT features")):
        if sample_limit and i >= sample_limit:
            break
        img = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        kp, d = sift.detectAndCompute(img, None)
        if d is not None:
            descs.append(d)
    if len(descs) == 0:
        raise RuntimeError("No SIFT descriptors found.")
    descs = np.vstack(descs)
    km = MiniBatchKMeans(n_clusters=k, batch_size=4096, verbose=0, random_state=42)
    km.fit(descs)
    centers = km.cluster_centers_
    # normalize centers
    centers = np.array([normalize_vec(c) for c in centers], dtype=np.float32)
    return centers

def image_projectors_from_image(img_path, visual_vocab):
    """
    compute SIFT, assign each descriptor to nearest visual word, make projectors for each visual word occurrence
    returns list of projectors (dim x dim) where dim = visual_vocab.shape[1]
    """
    sift = cv2.SIFT_create()
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        return []
    kp, d = sift.detectAndCompute(img, None)
    if d is None:
        return []
    # normalize descriptors then match to vocabulary by dot product
    d_norm = np.array([normalize_vec(x) for x in d], dtype=np.float32)
    # assign using nearest euclidean
    from sklearn.neighbors import NearestNeighbors
    nn = NearestNeighbors(n_neighbors=1).fit(visual_vocab)
    nn_idx = nn.kneighbors(d_norm, return_distance=False).squeeze()
    P = []
    for idx in nn_idx:
        s = visual_vocab[idx]
        Pi = np.outer(s, s)
        P.append(Pi.astype(np.float32))
    return P

# -------------------------
# Density matrix estimation (QMR)
# -------------------------
# Implementation notes:
# We implement the ascent direction formulas and a backtracking line-search over t in (1, 0.5, 0.25,...)
# Stopping when F(rho_new) - F(rho) < eps (paper uses eps=1e-5)
# See paper formulas for definitions of F, gradient and Dk. Citations: paper algorithm description. :contentReference[oaicite:4]{index=4}

def compute_F_and_grad(rho, projectors, freqs=None):
    """
    rho: torch.tensor (d x d), symmetric PSD trace 1
    projectors: list of numpy arrays (d x d)
    freqs: optional list of occurrence counts for each projector (defaults 1)
    Returns: F (float), grad (torch tensor d x d)
    """
    device = rho.device
    d = rho.shape[0]
    if freqs is None:
        freqs = [1.0] * len(projectors)
    F = 0.0
    grad = torch.zeros_like(rho)
    # small epsilon to avoid log(0)
    eps = 1e-12
    for Pi_np, f in zip(projectors, freqs):
        Pi = torch.from_numpy(Pi_np).to(device)
        tr = torch.trace(Pi @ rho).clamp(min=eps)
        F += math.log(tr.item()) * f
        grad += (f / tr) * Pi
    return F, grad

def project_to_psd_trace_one(A):
    """
    Ensure symmetric PSD with trace 1 by eigen-decomposition:
    A -> U diag(max(eigvals, 0)) U^T / trace
    (This is used if numerical stray negativity appears; in iterations we try to keep rho PSD.)
    """
    A_np = A.cpu().numpy()
    A_np = (A_np + A_np.T) / 2.0
    eigvals, eigvecs = np.linalg.eigh(A_np)
    eigvals[eigvals < 0] = 0.0
    s = eigvals.sum()
    if s <= 0:
        # fallback to identity
        d = A_np.shape[0]
        return torch.from_numpy(np.eye(d, dtype=np.float32)/d)
    diag = np.diag(eigvals / s)
    rho_np = eigvecs @ diag @ eigvecs.T
    return torch.from_numpy(rho_np.astype(np.float32))

def estimate_density_matrix(projectors, max_iter=200, eps=1e-5, device='cpu', freqs=None):
    """
    projectors: list of numpy arrays (d x d)
    returns rho: torch tensor (d x d) PSD, trace 1
    Implements the globally convergent iteration idea from paper (Dbar, Dtilde, Dk).
    Reference: paper Algorithm 1 and formulas 4-8. :contentReference[oaicite:5]{index=5}
    """
    d = projectors[0].shape[0]
    # initialize rho0 as random diagonal positive with trace 1
    diag = np.random.rand(d).astype(np.float32) + 0.1
    diag = diag / diag.sum()
    rho = torch.from_numpy(np.diag(diag)).to(device)
    # main loop
    prev_F = None
    for k in range(max_iter):
        F_val, grad = compute_F_and_grad(rho, projectors, freqs)
        # convert grad to torch
        # Compute Dbar and Dtilde as in paper
        rho_grad = grad.to(device)
        # Dbar = grad*rho + rho*grad)/2 - rho
        Dbar = (rho_grad @ rho + rho @ rho_grad) * 0.5 - rho
        # construct rho_grad @ rho @ rho_grad
        temp = rho_grad @ rho @ rho_grad
        denom = torch.trace(temp).clamp(min=1e-12)
        Dtilde = rho_grad @ rho @ rho_grad / denom - rho
        # line search candidates for t: start at 1 then halves
        t_candidates = [1.0, 0.5, 0.25, 0.125, 0.0625, 0.03, 0.01, 0.005]
        found = False
        best_rho = None
        best_F = None
        for t in t_candidates:
            # q(t) = 1 + 2t + t^2 * tr(grad rho grad)
            tr_term = torch.trace(rho_grad @ rho @ rho_grad).item()
            q = 1.0 + 2.0 * t + (t*t) * tr_term
            # Dk = 2/q * Dbar + t*tr(grad rho grad)/q * Dtilde
            Dk = (2.0 / q) * Dbar + (t * tr_term / q) * Dtilde
            rho_new = rho + t * Dk
            # ensure symmetry
            rho_new = (rho_new + rho_new.T) / 2.0
            # project to PSD & normalize trace if needed
            # small fix: if negative eigenvalues exist, project to PSD and renormalize
            eigvals = None
            try:
                eigvals = torch.linalg.eigvalsh(rho_new)
            except Exception:
                eigvals = torch.from_numpy(np.linalg.eigvalsh(rho_new.cpu().numpy()))
            if (eigvals < -1e-8).any():
                # projection
                rho_new = project_to_psd_trace_one(rho_new).to(device)
            else:
                # normalize trace to 1
                tr = torch.trace(rho_new).item()
                if tr <= 0:
                    rho_new = project_to_psd_trace_one(rho_new).to(device)
                else:
                    rho_new = rho_new / tr
            # compute F_new
            F_new, _ = compute_F_and_grad(rho_new, projectors, freqs)
            # require sufficient increase: paper uses Armijo/backtracking to find t that ensures improvement
            if prev_F is None or F_new - F_val > 1e-12:
                found = True
                best_rho = rho_new
                best_F = F_new
                break
        if not found:
            # no improvement -> stop
            break
        # update
        if prev_F is not None and abs(best_F - prev_F) < eps:
            rho = best_rho
            break
        rho = best_rho
        prev_F = best_F
    # final ensure PSD & trace1
    rho = (rho + rho.T) / 2.0
    rho = project_to_psd_trace_one(rho)
    return rho

# -------------------------
# Simple PyTorch classifier (MLP)
# -------------------------
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, hidden=512, nclass=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden, hidden//2),
            nn.ReLU(),
            nn.Linear(hidden//2, nclass)
        )
    def forward(self, x):
        return self.net(x)

def train_mlp(X_train, y_train, X_val, y_val, input_dim, epochs=30, lr=1e-3, device='cpu'):
    model = MLPClassifier(input_dim).to(device)
    opt = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    X_train_t = torch.from_numpy(X_train).float().to(device)
    y_train_t = torch.from_numpy(y_train).long().to(device)
    X_val_t = torch.from_numpy(X_val).float().to(device)
    y_val_t = torch.from_numpy(y_val).long().to(device)
    best = None
    for e in range(epochs):
        model.train()
        opt.zero_grad()
        logits = model(X_train_t)
        loss = loss_fn(logits, y_train_t)
        loss.backward()
        opt.step()
        model.eval()
        with torch.no_grad():
            val_logits = model(X_val_t)
            val_preds = val_logits.argmax(dim=1)
            acc = (val_preds == y_val_t).float().mean().item()
        #print(f"Epoch {e} loss {loss.item():.4f} val_acc {acc:.4f}")
        best = model
    return best

# -------------------------
# QIMF fusion
# -------------------------
def qimf_fusion(pt, pi, alpha, beta, cos_theta):
    """
    pt, pi: probabilities for positive class (scalar or numpy arrays)
    alpha^2 + beta^2 = 1 expected
    returns fused probability for positive
    Formula P_u = a^2 P_t + b^2 P_i + 2 a b sqrt(P_t P_i) cosÎ¸
    """
    a2 = alpha*alpha
    b2 = beta*beta
    term = 2 * alpha * beta * np.sqrt(pt * pi) * cos_theta
    pu = a2 * pt + b2 * pi + term
    # clip to [0,1]
    return np.clip(pu, 0.0, 1.0)

# -------------------------
# Example pipeline (driver)
# -------------------------
def example_pipeline(glove_path, image_paths, texts_tokens_list, labels, visual_k=128):
    """
    glove_path: path to glove embeddings
    image_paths: list of image file paths aligned with texts_tokens_list and labels
    texts_tokens_list: list of token lists
    labels: binary 0/1
    """
    emb = load_glove_embeddings(glove_path)
    text_dim = len(next(iter(emb.values())))
    # build visual vocab
    visual_vocab = build_visual_vocab(image_paths, k=visual_k, sample_limit=500)  # sample limit for speed
    vis_dim = visual_vocab.shape[1]

    # produce projectors for each sample and estimate density matrices
    rhos_text = []
    rhos_image = []
    for tokens, imgp in tqdm(zip(texts_tokens_list, image_paths), total=len(labels), desc="Estimating rhos"):
        P_text = word_projectors_from_text(tokens, emb, text_dim)
        if len(P_text) == 0:
            # fallback to uniform diag
            P_text = [np.eye(text_dim, dtype=np.float32)]
        rho_t = estimate_density_matrix(P_text, device='cpu')
        rhos_text.append(rho_t.cpu().numpy().astype(np.float32))

        P_img = image_projectors_from_image(imgp, visual_vocab)
        if len(P_img) == 0:
            P_img = [np.eye(vis_dim, dtype=np.float32)]
        rho_i = estimate_density_matrix(P_img, device='cpu')
        rhos_image.append(rho_i.cpu().numpy().astype(np.float32))

    # Flatten and prepare datasets
    X_text = np.array([r.flatten() for r in rhos_text], dtype=np.float32)
    X_img = np.array([r.flatten() for r in rhos_image], dtype=np.float32)
    y = np.array(labels, dtype=np.int64)

    # split
    Xt_tr, Xt_val, yt_tr, yt_val = train_test_split(X_text, y, test_size=0.2, random_state=42)
    Xi_tr, Xi_val, yi_tr, yi_val = train_test_split(X_img, y, test_size=0.2, random_state=42)

    # train MLPs
    model_text = train_mlp(Xt_tr, yt_tr, Xt_val, yt_val, input_dim=Xt_tr.shape[1])
    model_img = train_mlp(Xi_tr, yi_tr, Xi_val, yi_val, input_dim=Xi_tr.shape[1])

    # get probabilities on validation set
    device = 'cpu'
    with torch.no_grad():
        pt_logits = model_text(torch.from_numpy(Xt_val).float().to(device))
        pi_logits = model_img(torch.from_numpy(Xi_val).float().to(device))
        pt_probs = torch.softmax(pt_logits, dim=1)[:,1].cpu().numpy()
        pi_probs = torch.softmax(pi_logits, dim=1)[:,1].cpu().numpy()

    # grid search alpha/cos to maximize accuracy on validation
    best_acc = -1
    best_params = None
    alphas = np.linspace(0.1, 0.9, 9)
    cos_list = np.linspace(-1.0, 1.0, 21)
    for a in alphas:
        b = math.sqrt(max(0.0, 1.0 - a*a))
        for cos_theta in cos_list:
            pu = qimf_fusion(pt_probs, pi_probs, a, b, cos_theta)
            preds = (pu >= 0.5).astype(int)
            # compute accuracy against yt_val (careful Xt_val/yt_val aligned earlier)
            acc = (preds == yt_val).mean()
            if acc > best_acc:
                best_acc = acc
                best_params = (a, b, cos_theta, acc)
    print("Best fusion params alpha, beta, cos, acc:", best_params)
    return model_text, model_img, best_params

# -------------------------
# If run as script, show usage example (user must adapt to dataset)
# -------------------------
if __name__ == "__main__":
    print("This script is a library; adapt example_pipeline() to your dataset paths.")
