******Classical QUBO solvers******

In [None]:
# !pip install numpy

import numpy as np

def build_Q1(w: np.ndarray) -> np.ndarray:
    #takes feature matrix as parameter,returns first discrepency term
    n = w.shape[0]  #n denotes the total number of patient
    w_sum = np.sum(w, axis=0)  # column sums
    Q1 = np.zeros((n, n))

    for i in range(n):
        # Diagonal part
        for s in range(3):
            Q1[i, i] += 4 * (w[i, s]**2 - w[i, s] * w_sum[s])

        # Off-diagonal part
        for j in range(i + 1, n):
            for s in range(3):
                Q1[i, j] += 8 * w[i, s] * w[j, s]

    return Q1

def build_Q2(w: np.ndarray, rho: float = 0.5) -> np.ndarray:

    n = w.shape[0]                   # number of patients
    Q2 = np.zeros((n, n))            # initialize QUBO matrix

    u = w**2                         # u[i, s] = w[i, s]^2
    u_sum = np.sum(u, axis=0)

    for i in range(n):
        Q2[i, i] = rho * np.sum(4 * (u[i]**2 - u[i] * u_sum))


    for i in range(n):
        for j in range(i + 1, n):
            Q2[i, j] = rho * np.sum(8 * u[i] * u[j])

    return Q2

def build_Q3(w: np.ndarray, rho: float = 0.5) -> np.ndarray:

    n_samples, n_features = w.shape
    Q3 = np.zeros((n_samples, n_samples))

    # building all v_i^(ss') for s' > s and their sums over i
    pair_list = []
    for s in range(n_features):
        for sp in range(s + 1, n_features):
            pair_list.append((s, sp))

    n_pairs = len(pair_list)
    V = np.zeros((n_samples, n_pairs))
    for idx, (s, sp) in enumerate(pair_list):
        V[:, idx] = w[:, s] * w[:, sp]

    v_total = np.sum(V, axis=0)  # v_Sigma^(ss') for each pair

    # diagonal entries
    for i in range(n_samples):
        v_i = V[i]                       # all v_i^(ss')
        Q3[i, i] = 2.0 * rho * np.sum(
            4.0 * (v_i**2 - v_i * v_total)
        )

    # off-diagonal entries
    for i in range(n_samples):
        for j in range(i + 1, n_samples):
            v_i = V[i]
            v_j = V[j]
            Q3[i, j] = 2.0 * rho * np.sum(
                8.0 * v_i * v_j
            )

    return Q3

def add_group_size_penalty(Q: np.ndarray, lam_size: float = 3.0) -> np.ndarray:

    Q_new = Q.copy()
    n = Q.shape[0]

    # Diagonal contributions: lam_size * (1 - n)
    for i in range(n):
        Q_new[i, i] += lam_size * (1.0 - n)

    # Off-diagonal contributions: 2 * lam_size for i < j
    for i in range(n):
        for j in range(i + 1, n):
            Q_new[i, j] += 2.0 * lam_size

    return Q_new



def QUBO_formulation(W: np.ndarray,
                     rho: float = 0.5 ,
                     lam_size: float = 3.0,
                     normalize: bool = True) -> np.ndarray:
    """
    Combine Q1, Q2, Q3, then normalize and add the group-size penalty.

    Parameters
    ----------
    W : np.ndarray of shape (n, d)
        Feature matrix.
    rho : float
        Covariance weight used in Q2 and Q3.
    lam_size : float
        Penalty coefficient for the group-size constraint.
    normalize : bool
        If True, divide the discrepancy QUBO by its max absolute value
        before adding the constraint penalty.

    Returns
    -------
    Q_final : np.ndarray of shape (n, n)
        Final QUBO matrix.
    """

    Q1 = build_Q1(W)
    Q2 = build_Q2(W, rho)
    Q3 = build_Q3(W, rho)

    Q = Q1 + Q2 + Q3


    if normalize:
        max_abs = np.max(np.abs(Q))
        if max_abs > 0.0:
            Q = Q / max_abs

    Q_final = add_group_size_penalty(Q, lam_size=lam_size)

    return Q_final






In [None]:
class ClinicalTrial:
    rho: float  # Relative importance between first and second moments
    w: np.ndarray  # Normalized patient covariates

    def assert_valid(self, group1: np.ndarray, group2: np.ndarray) -> None:
        """
        Checks if the patient constraints are met.

        Arguments (where n is the number of patients):
            - group1: np.ndarray(size = n) => Binary array of patients belonging to group 1
            - group2: np.ndarray(size = n) => Binary array of patients belonging to group 2
        Throws an AssertionError if the constraints are not met.
        """
        group_size = int(self.w.shape[0] / 2)
        # constraint 1: number of people in each group
        assert (
            np.sum(group1) == group_size
        ), f"Each group should have {group_size} patients"
        # contraint 2: every patient is in one group
        assert (
            group1 + group2 == 1
        ).all(), "Every patient needs to be assigned to one group"

    def discrepancy(self, group1: np.ndarray, group2: np.ndarray) -> float:
        """
        Calculates discrepancy between patient groups.

        Arguments (where n is the number of patients):
            - group1: np.ndarray(size = n) => Binary array of patients belonging to group 1
            - group2: np.ndarray(size = n) => Binary array of patients belonging to group 2
        Returns:
            - float => Value of discrepancy measure for group1 and group2
        """
        # Check that all the constraints are being met
        self.assert_valid(group1, group2)

        # Order of the groups is arbitrary
        if group1[0] == 0:
            group1, group2 = group2, group1

        n, r = self.w.shape

        # Calculate mean values for each covariate
        Mu = []
        for i in range(r):
            Mu.append(self.w[:, i].dot(group1 - group2) / n)

        # Calculate second moments (variance and covariance)
        Var_ii = []  # variance
        Var_ij = []  # covariance

        for i in range(r):
            for j in range(i, r):
                if i == j:
                    Var_ii.append((self.w[:, i] ** 2).dot(group1 - group2) / n)
                else:
                    Var_ij.append(
                        (self.w[:, i] * self.w[:, j]).dot(group1 - group2) / n
                    )

        # Calculate final discrepancy
        discrepancy = (
            np.sum(np.abs(Mu))
            + self.rho * np.sum(np.abs(Var_ii))
            + 2 * self.rho * np.sum(np.abs(Var_ij))
        )
        return discrepancy

In [None]:
from io import StringIO
import numpy as np

w_data = """
-0.7650636535189481,0.283894442848472,-0.655816465582355
-0.5856817886844943,1.2319371664692689,0.910198903115477
-0.854754585936175,-0.3306700946565902,-0.054686355633081954
0.22153660307055117,1.2078202901666337,-0.6382331351268004
0.400918467905005,-0.04542566114612644,-0.4096498392045905
-0.1372271265983581,0.47572063137057313,3.984424441638505
-0.1372271265983581,-0.27827825993017496,0.1980740196655155
0.8493731299911412,0.21597910153645256,-0.42613421150667297
-0.7650636535189481,1.1249358532184994,-0.08325926762335818
-0.1372271265983581,0.14113362335585988,1.382750909108507
1.925664318997869,0.538646274137218,-0.372285261986537
0.3112275354877781,0.29415104541396137,-0.20579310173550433
0.7596821975739143,0.5652579997125404,-0.24150924172334964
0.13184567065332425,-0.42297951774598436,-0.09809520269523238
0.6699912651566874,-2.1402674337783933,0.40028231990439345
-0.5856817886844943,-0.49948822877502536,-0.6052643905226355
-0.1372271265983581,0.5788410679749358,-0.6909831264934643
0.22153660307055117,-0.9424625774142155,3.0009667900962667
-0.1372271265983581,-0.18458280946707206,1.9861888311800715
1.1184459272428235,0.07848518606396536,-0.5876810600670809
-0.5856817886844943,-1.5392859645357366,-0.23546497187925275
-0.854754585936175,-1.0267330417360687,-0.41954046258584
0.3112275354877781,1.6823960629265113,-0.12941551006918903
-0.4959908562672674,1.1784365098438845,-0.0563347928632902
-0.4062999238500389,-0.034060236681667924,-0.4140456718184792
-0.4959908562672674,-1.010100713251496,-0.2310691392653641
-0.9444455183534036,-1.2224401069045752,-0.3305248521545948
0.8493731299911412,0.5339337810665824,2.334888253276789
-0.5856817886844943,-1.79653264509717,-0.11128270053689834
-0.4062999238500389,1.039279361522928,-0.439871188425075
0.04215473823609734,0.7701128455475578,-0.3568998478379267
-1.0341364507706305,1.2319371664692689,-0.2871160050924444
-1.0341364507706305,-1.3943075012451869,-0.6431784468174251
-0.9444455183534036,0.4901353160571981,-0.5448216920816665
0.13184567065332425,1.948790524154465,0.25302192733912365
2.19473711624955,1.742826856420469,-0.3865717179816751
-0.047536194181131176,-0.05928593488327222,-0.6459258422011055
0.13184567065332425,-1.2745547361562415,-0.628891990822287
-0.6753727211017212,-0.0942138247008759,-0.39426442505598025
-0.1372271265983581,0.7665091743759014,-0.790438839382695
5.692683480521414,1.1975636876011444,-0.6415300095872168
-0.1372271265983581,1.334780397598892,2.2075190032893652
-0.1372271265983581,-1.042810959271162,3.1909766548316036
-1.3929001804395398,-0.2976826431621845,-0.5349310687004171
0.49060940032223355,-0.5690668029355032,-0.6343867815896478
-0.854754585936175,0.5527837533491035,-0.4970170124056275
0.3112275354877781,-0.17515782332581153,-0.298105586627166
0.400918467905005,0.5200735073294374,-0.25854309310216816
0.22153660307055117,0.12422408939654765,-0.4662461841084069
-0.854754585936175,0.5333793701171032,-0.09864468177196847
2.2844280486667783,-0.7370533206297188,-0.2129363297330734
-1.1238273831878574,0.1691313763049011,-0.5266888825493758
2.19473711624955,1.1448946473999888,-0.6602122981962437
-0.047536194181131176,0.5769006296517378,-0.40855088105111836
-0.5856817886844943,2.2354209850386475,-0.5948242880646499
-0.1372271265983581,1.00906396477596,-0.34371234999626077
2.015355251415096,-0.4933897083306817,-0.5635039806906933
0.7596821975739143,-0.41161409328152654,1.9673966467556976
-0.4959908562672674,-2.2081827750904024,0.03377977572142715
-0.5856817886844943,-1.0563940275335575,-0.25854309310216816
1.1184459272428235,0.822504680273963,-0.5497670037722913
0.3112275354877781,-0.7772481144674374,0.16070944244746196
-0.6753727211017212,0.161369623012099,1.1431780316515756
-0.5856817886844943,0.5949189855100293,-0.642628967740689
-0.5856817886844943,1.2100379339645713,-0.6805430240354787
0.5803003327394605,0.17717033507244279,-0.40250661120702147
-1.2135183156050844,-1.703946016533026,-0.2920613167830691
-0.7650636535189481,-0.9915279464437146,0.83546974867937
-1.1238273831878574,-0.9455118376363927,0.07938653909052192
-0.1372271265983581,-0.6713556231156471,-0.5558112736163882
-0.31660899143281196,1.0620102104518552,-0.5596576271535407
0.04215473823609734,-1.776019439966191,0.16455579598461453
-0.5856817886844943,0.805317940839901,-0.026113443642805722
0.22153660307055117,-0.11472702983185462,0.03597769202837148
0.8493731299911412,-1.3275009818321366,-0.43217848135076986
-0.854754585936175,-2.067362393920999,-0.22667330665147542
-0.5856817886844943,-0.6142512953185962,-0.3579988059913989
-0.047536194181131176,-0.964361809918913,-0.33601964292195563
-0.854754585936175,-0.08340281118590664,-0.4310795231972977
-0.5856817886844943,-1.629932154776673,-0.5047097194799326
-0.854754585936175,1.709007788501833,-0.47009253764555947
0.22153660307055117,1.058960950229678,-0.6250456372851344
-1.0341364507706305,-0.8049686619417282,-0.46789462133861515
-0.4959908562672674,1.955166250073549,-0.6321888652827035
-0.4062999238500389,0.2248496767282231,-0.2931602749365413
0.22153660307055117,-0.15769387841701005,4.715231613697493
-0.9444455183534036,0.4934617817541149,-0.6332878234361756
-0.22691805901558504,-0.5552065291983574,-0.27942329801813925
-0.854754585936175,-1.2448937503587518,-0.6272435535920787
1.477209656911733,-0.44848242142232825,0.2717042159481504
-0.1372271265983581,-0.7722584159220621,1.8509070824876483
1.1184459272428235,0.30579367535315943,-0.39096755059556376
-0.6753727211017212,0.34238479801922156,0.16730319136829494
-0.7650636535189481,-0.6835526640043345,0.18268860551690522
0.400918467905005,-0.6635938698228451,0.052462064330453924
1.925664318997869,1.5379720105854517,-0.6102097022132602
-0.4062999238500389,-0.17377179595209344,-0.16458217098029823
1.387518724494506,0.6583990392261534,-0.5497670037722913
-0.31660899143281196,-0.8227098123252801,-0.10249103530912104
0.7596821975739143,0.6425983271658103,-0.5525143991559717
"""

w = np.loadtxt(StringIO(w_data), delimiter=",")
print(w[0])   # (100, 3)


[[-0.76506365  0.28389444 -0.65581647]
 [-0.58568179  1.23193717  0.9101989 ]
 [-0.85475459 -0.33067009 -0.05468636]
 [ 0.2215366   1.20782029 -0.63823314]
 [ 0.40091847 -0.04542566 -0.40964984]
 [-0.13722713  0.47572063  3.98442444]
 [-0.13722713 -0.27827826  0.19807402]
 [ 0.84937313  0.2159791  -0.42613421]
 [-0.76506365  1.12493585 -0.08325927]
 [-0.13722713  0.14113362  1.38275091]
 [ 1.92566432  0.53864627 -0.37228526]
 [ 0.31122754  0.29415105 -0.2057931 ]
 [ 0.7596822   0.565258   -0.24150924]
 [ 0.13184567 -0.42297952 -0.0980952 ]
 [ 0.66999127 -2.14026743  0.40028232]
 [-0.58568179 -0.49948823 -0.60526439]
 [-0.13722713  0.57884107 -0.69098313]
 [ 0.2215366  -0.94246258  3.00096679]
 [-0.13722713 -0.18458281  1.98618883]
 [ 1.11844593  0.07848519 -0.58768106]
 [-0.58568179 -1.53928596 -0.23546497]
 [-0.85475459 -1.02673304 -0.41954046]
 [ 0.31122754  1.68239606 -0.12941551]
 [-0.49599086  1.17843651 -0.05633479]
 [-0.40629992 -0.03406024 -0.41404567]
 [-0.49599086 -1.01010071

In [None]:
print(w.shape)

(100, 3)


In [None]:
n_patients , n_covariates = w.shape
rho = 0.5

trial = ClinicalTrial()
trial.w = w
trial.rho = rho





In [None]:
"""
Complete QUBO Pipeline for Patient Group Matching
==================================================

Dependencies:
    pip install numpy torch

Hardware:
    - GPU recommended (CUDA-enabled for faster solving)
    - CPU fallback available

Usage:
    W = np.random.randn(100, 3)  # Your feature matrix
    group1, group2 = solve_patient_matching(W, rho=0.5, lam_size=3.0)
"""

import numpy as np
import torch
import time
from typing import Tuple, Dict, List

# ============================================================================
# PART 1: QUBO FORMULATION
# ============================================================================

def build_Q1(w: np.ndarray) -> np.ndarray:
    """
    Build first discrepancy term for mean matching.

    Parameters:
    -----------
    w : np.ndarray of shape (n, d)
        Feature matrix

    Returns:
    --------
    Q1 : np.ndarray of shape (n, n)
        First QUBO term
    """
    n = w.shape[0]
    w_sum = np.sum(w, axis=0)
    Q1 = np.zeros((n, n))

    for i in range(n):
        # Diagonal part
        for s in range(w.shape[1]):
            Q1[i, i] += 4 * (w[i, s]**2 - w[i, s] * w_sum[s])

        # Off-diagonal part
        for j in range(i + 1, n):
            for s in range(w.shape[1]):
                Q1[i, j] += 8 * w[i, s] * w[j, s]

    return Q1


def build_Q2(w: np.ndarray, rho: float = 0.5) -> np.ndarray:
    """
    Build second discrepancy term for variance matching.

    Parameters:
    -----------
    w : np.ndarray of shape (n, d)
        Feature matrix
    rho : float
        Weight for variance term

    Returns:
    --------
    Q2 : np.ndarray of shape (n, n)
        Second QUBO term
    """
    n = w.shape[0]
    Q2 = np.zeros((n, n))

    u = w**2
    u_sum = np.sum(u, axis=0)

    for i in range(n):
        Q2[i, i] = rho * np.sum(4 * (u[i]**2 - u[i] * u_sum))

    for i in range(n):
        for j in range(i + 1, n):
            Q2[i, j] = rho * np.sum(8 * u[i] * u[j])

    return Q2


def build_Q3(w: np.ndarray, rho: float = 0.5) -> np.ndarray:
    """
    Build third discrepancy term for covariance matching.

    Parameters:
    -----------
    w : np.ndarray of shape (n, d)
        Feature matrix
    rho : float
        Weight for covariance term

    Returns:
    --------
    Q3 : np.ndarray of shape (n, n)
        Third QUBO term
    """
    n_samples, n_features = w.shape
    Q3 = np.zeros((n_samples, n_samples))

    # Build all v_i^(ss') for s' > s
    pair_list = []
    for s in range(n_features):
        for sp in range(s + 1, n_features):
            pair_list.append((s, sp))

    n_pairs = len(pair_list)
    V = np.zeros((n_samples, n_pairs))
    for idx, (s, sp) in enumerate(pair_list):
        V[:, idx] = w[:, s] * w[:, sp]

    v_total = np.sum(V, axis=0)

    # Diagonal entries
    for i in range(n_samples):
        v_i = V[i]
        Q3[i, i] = 2.0 * rho * np.sum(4.0 * (v_i**2 - v_i * v_total))

    # Off-diagonal entries
    for i in range(n_samples):
        for j in range(i + 1, n_samples):
            v_i = V[i]
            v_j = V[j]
            Q3[i, j] = 2.0 * rho * np.sum(8.0 * v_i * v_j)

    return Q3


def add_group_size_penalty(Q: np.ndarray, lam_size: float = 3.0) -> np.ndarray:
    """
    Add penalty for group size constraint (encouraging equal split).

    Parameters:
    -----------
    Q : np.ndarray of shape (n, n)
        QUBO matrix
    lam_size : float
        Penalty coefficient

    Returns:
    --------
    Q_new : np.ndarray of shape (n, n)
        QUBO with penalty
    """
    Q_new = Q.copy()
    n = Q.shape[0]

    for i in range(n):
        Q_new[i, i] += lam_size * (1.0 - n)

    for i in range(n):
        for j in range(i + 1, n):
            Q_new[i, j] += 2.0 * lam_size

    return Q_new


def QUBO_formulation(W: np.ndarray,
                     rho: float = 0.5,
                     lam_size: float = 3.0,
                     normalize: bool = True) -> np.ndarray:
    """
    Complete QUBO formulation combining all terms.

    Parameters:
    -----------
    W : np.ndarray of shape (n, d)
        Feature matrix
    rho : float
        Covariance weight
    lam_size : float
        Group size penalty coefficient
    normalize : bool
        Whether to normalize before adding penalty

    Returns:
    --------
    Q_final : np.ndarray of shape (n, n)
        Final QUBO matrix
    """
    Q1 = build_Q1(W)
    Q2 = build_Q2(W, rho)
    Q3 = build_Q3(W, rho)

    Q = Q1 + Q2 + Q3

    if normalize:
        max_abs = np.max(np.abs(Q))
        if max_abs > 0.0:
            Q = Q / max_abs

    Q_final = add_group_size_penalty(Q, lam_size=lam_size)

    return Q_final


# ============================================================================
# PART 2: SOLVERS
# ============================================================================

class QUBOSimulatedAnnealing:
    """GPU-accelerated Simulated Annealing for QUBO problems."""

    def __init__(self, Q: np.ndarray, device='cuda'):
        self.n = Q.shape[0]
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')

        Q_sym = Q + Q.T - np.diag(np.diag(Q))
        self.Q = torch.from_numpy(Q_sym).float().to(self.device)

    def energy(self, x: torch.Tensor) -> torch.Tensor:
        """Compute QUBO energy: x^T Q x"""
        if x.dim() == 1:
            return x @ self.Q @ x
        else:
            return torch.sum(x * (x @ self.Q), dim=-1)

    def solve(self,
              n_iterations: int = 10000,
              T_initial: float = 100.0,
              T_final: float = 0.01,
              n_parallel: int = 64,
              verbose: bool = False) -> Tuple[np.ndarray, float]:

        x = torch.randint(0, 2, (n_parallel, self.n),
                         dtype=torch.float32, device=self.device)

        current_energy = self.energy(x)
        best_x = x.clone()
        best_energy = current_energy.clone()

        alpha = (T_final / T_initial) ** (1.0 / n_iterations)
        T = T_initial

        for iteration in range(n_iterations):
            flip_idx = torch.randint(0, self.n, (n_parallel,), device=self.device)
            x_new = x.clone()
            x_new[torch.arange(n_parallel), flip_idx] = 1 - x_new[torch.arange(n_parallel), flip_idx]

            new_energy = self.energy(x_new)
            delta_E = new_energy - current_energy

            accept_prob = torch.exp(-delta_E / T)
            accept = (delta_E < 0) | (torch.rand(n_parallel, device=self.device) < accept_prob)

            x = torch.where(accept.unsqueeze(1), x_new, x)
            current_energy = torch.where(accept, new_energy, current_energy)

            improved = current_energy < best_energy
            best_x = torch.where(improved.unsqueeze(1), x, best_x)
            best_energy = torch.where(improved, current_energy, best_energy)

            T *= alpha

            if verbose and iteration % 2000 == 0:
                print(f"  Iteration {iteration}, Best Energy={best_energy.min().item():.4f}")

        best_idx = torch.argmin(best_energy)
        return best_x[best_idx].cpu().numpy().astype(int), best_energy[best_idx].item()


class QUBOTabuSearch:
    """Tabu Search for QUBO problems."""

    def __init__(self, Q: np.ndarray):
        self.Q = Q + Q.T - np.diag(np.diag(Q))
        self.n = Q.shape[0]

    def energy(self, x: np.ndarray) -> float:
        return x @ self.Q @ x

    def solve(self,
              max_iterations: int = 5000,
              tabu_tenure: int = 10,
              n_restarts: int = 5,
              verbose: bool = False) -> Tuple[np.ndarray, float]:

        global_best_x = None
        global_best_energy = float('inf')

        for restart in range(n_restarts):
            if verbose:
                print(f"  Restart {restart+1}/{n_restarts}")

            x = np.random.randint(0, 2, self.n)
            current_energy = self.energy(x)

            best_x = x.copy()
            best_energy = current_energy

            tabu_list = {}

            for iteration in range(max_iterations):
                best_neighbor_energy = float('inf')
                best_flip_idx = -1

                for i in range(self.n):
                    if i in tabu_list and tabu_list[i] > iteration:
                        continue

                    x[i] = 1 - x[i]
                    neighbor_energy = self.energy(x)
                    x[i] = 1 - x[i]

                    if neighbor_energy < best_neighbor_energy:
                        best_neighbor_energy = neighbor_energy
                        best_flip_idx = i

                if best_flip_idx >= 0:
                    x[best_flip_idx] = 1 - x[best_flip_idx]
                    current_energy = best_neighbor_energy
                    tabu_list[best_flip_idx] = iteration + tabu_tenure

                    if current_energy < best_energy:
                        best_x = x.copy()
                        best_energy = current_energy

            if best_energy < global_best_energy:
                global_best_x = best_x
                global_best_energy = best_energy

        return global_best_x, global_best_energy


# ============================================================================
# PART 3: GROUP BALANCING AND VALIDATION
# ============================================================================

def balance_groups(solution: np.ndarray, target_size: int = 50) -> np.ndarray:
    """
    Adjust solution to have exactly target_size ones.

    Parameters:
    -----------
    solution : np.ndarray
        Binary solution vector
    target_size : int
        Desired number of ones

    Returns:
    --------
    balanced : np.ndarray
        Balanced binary solution
    """
    balanced = solution.copy()
    current_size = int(balanced.sum())

    if current_size == target_size:
        return balanced

    if current_size > target_size:
        # Too many ones - flip some to zero
        ones_idx = np.where(balanced == 1)[0]
        flip_idx = np.random.choice(ones_idx, current_size - target_size, replace=False)
        balanced[flip_idx] = 0
    else:
        # Too few ones - flip some to one
        zeros_idx = np.where(balanced == 0)[0]
        flip_idx = np.random.choice(zeros_idx, target_size - current_size, replace=False)
        balanced[flip_idx] = 1

    return balanced


def validate_solution(group1: np.ndarray, group2: np.ndarray) -> Dict:
    """
    Validate that the solution is correct.

    Returns:
    --------
    validation : dict
        Validation results
    """
    n = len(group1)

    validation = {
        'valid': True,
        'size_correct': group1.sum() == n // 2,
        'complementary': np.array_equal(group2, 1 - group1),
        'binary': np.all((group1 == 0) | (group1 == 1)),
        'group1_size': int(group1.sum()),
        'group2_size': int(group2.sum())
    }

    validation['valid'] = all([
        validation['size_correct'],
        validation['complementary'],
        validation['binary']
    ])

    return validation


# ============================================================================
# PART 4: MAIN PIPELINE
# ============================================================================

def solve_patient_matching(W: np.ndarray,
                          rho: float = 0.5,
                          lam_size: float = 3.0,
                          n_runs: int = 5,
                          use_gpu: bool = True,
                          verbose: bool = False) -> Tuple[np.ndarray, np.ndarray]:
    """
    Complete pipeline to solve patient matching problem.

    Parameters:
    -----------
    W : np.ndarray of shape (n, d)
        Feature matrix (n patients, d features)
    rho : float
        Covariance weight in QUBO
    lam_size : float
        Group size penalty coefficient
    n_runs : int
        Number of solver runs
    use_gpu : bool
        Whether to use GPU if available
    verbose : bool
        Print progress

    Returns:
    --------
    group1 : np.ndarray of shape (n,)
        Binary array with exactly 50 ones
    group2 : np.ndarray of shape (n,)
        Binary array = 1 - group1
    """

    n = W.shape[0]
    target_size = n // 2

    if verbose:
        print("="*70)
        print("PATIENT MATCHING PIPELINE")
        print("="*70)
        print(f"Number of patients: {n}")
        print(f"Number of features: {W.shape[1]}")
        print(f"Target group size: {target_size}")
        print(f"Rho: {rho}")
        print(f"Lambda (size penalty): {lam_size}")

    # Step 1: Build QUBO
    if verbose:
        print("\n[1/4] Building QUBO matrix...")
    start_time = time.time()
    Q = QUBO_formulation(W, rho=rho, lam_size=lam_size, normalize=True)
    if verbose:
        print(f"      QUBO shape: {Q.shape}")
        print(f"      Time: {time.time() - start_time:.2f}s")

    # Step 2: Solve with multiple methods
    if verbose:
        print("\n[2/4] Solving with multiple methods...")

    all_solutions = []
    all_energies = []

    # Method 1: Simulated Annealing (GPU)
    if use_gpu and torch.cuda.is_available():
        if verbose:
            print("\n  Method 1: Simulated Annealing (GPU)")
        solver_sa = QUBOSimulatedAnnealing(Q, device='cuda')

        for run in range(n_runs):
            if verbose:
                print(f"    Run {run+1}/{n_runs}...")
            solution, energy = solver_sa.solve(
                n_iterations=15000,
                n_parallel=64,
                verbose=False
            )
            all_solutions.append(solution)
            all_energies.append(energy)
            if verbose:
                print(f"      Energy: {energy:.4f}, Size: {solution.sum()}")

    # Method 2: Tabu Search (CPU)
    if verbose:
        print("\n  Method 2: Tabu Search (CPU)")
    solver_tabu = QUBOTabuSearch(Q)

    for run in range(min(3, n_runs)):  # Fewer runs for slower method
        if verbose:
            print(f"    Run {run+1}/{min(3, n_runs)}...")
        solution, energy = solver_tabu.solve(
            max_iterations=5000,
            n_restarts=3,
            verbose=False
        )
        all_solutions.append(solution)
        all_energies.append(energy)
        if verbose:
            print(f"      Energy: {energy:.4f}, Size: {solution.sum()}")

    # Step 3: Select best and balance
    if verbose:
        print("\n[3/4] Selecting best solution and balancing groups...")

    best_idx = np.argmin(all_energies)
    best_solution = all_solutions[best_idx]
    best_energy = all_energies[best_idx]

    if verbose:
        print(f"      Best energy: {best_energy:.4f}")
        print(f"      Original group size: {best_solution.sum()}")

    # Balance to exactly 50-50 split
    group1 = balance_groups(best_solution, target_size=target_size)
    group2 = 1 - group1

    if verbose:
        print(f"      Balanced group1 size: {group1.sum()}")
        print(f"      Balanced group2 size: {group2.sum()}")

    # Step 4: Validate
    if verbose:
        print("\n[4/4] Validating solution...")

    validation = validate_solution(group1, group2)

    if verbose:
        print(f"      Valid: {validation['valid']}")
        print(f"      Size correct: {validation['size_correct']}")
        print(f"      Complementary: {validation['complementary']}")
        print(f"      Binary: {validation['binary']}")

    if not validation['valid']:
        raise ValueError("Solution validation failed!")

    if verbose:
        print("\n" + "="*70)
        print("✓ SUCCESS!")
        print("="*70)
        print(f"Group 1: {group1.sum()} patients")
        print(f"Group 2: {group2.sum()} patients")
        print(f"Total time: {time.time() - start_time:.2f}s")

    return group1, group2



if __name__ == "__main__":



    W = w

    print("Example Usage:")
    print("-" * 70)

    # Solve the matching problem
    group1,group2 = solve_patient_matching(
        W=W,
        rho=0.5,
        lam_size=3.0,
        n_runs=5,
        use_gpu=True,
        verbose=True
    )





Example Usage:
----------------------------------------------------------------------
PATIENT MATCHING PIPELINE
Number of patients: 100
Number of features: 3
Target group size: 50
Rho: 0.5
Lambda (size penalty): 3.0

[1/4] Building QUBO matrix...
      QUBO shape: (100, 100)
      Time: 0.07s

[2/4] Solving with multiple methods...

  Method 1: Simulated Annealing (GPU)
    Run 1/5...
      Energy: -3827.0403, Size: 25
    Run 2/5...
      Energy: -3827.0935, Size: 25
    Run 3/5...
      Energy: -3827.1147, Size: 25
    Run 4/5...
      Energy: -3827.0845, Size: 25
    Run 5/5...
      Energy: -3827.1138, Size: 25

  Method 2: Tabu Search (CPU)
    Run 1/3...
      Energy: -3826.9016, Size: 25
    Run 2/3...
      Energy: -3826.7739, Size: 25
    Run 3/3...
      Energy: -3827.1526, Size: 25

[3/4] Selecting best solution and balancing groups...
      Best energy: -3827.1526
      Original group size: 25
      Balanced group1 size: 50
      Balanced group2 size: 50

[4/4] Validating s

In [None]:
disc = trial.discrepancy(group1, group2)
print(f"Discrepancy: {disc:.4f}")

Discrepancy: 0.7573


In [None]:
disc = trial.discrepancy(group1, group2)
print(f"Discrepancy: {disc:.4f}")

Discrepancy: 0.5201


D Wave Implementation

In [None]:
!pip install dwave-ocean-sdk dwave-inspector

In [None]:
import os
from dwave.system import DWaveSampler, EmbeddingComposite, LeapHybridSampler
from dimod import BinaryQuadraticModel
import dwave.inspector

with open("tokens/dwave_token.txt", 'r') as file:
    os.environ['DWAVE_API_TOKEN'] = file.read().strip()

def create_sample_qubo(size=100):
    np.random.seed(42)
    Q = np.random.randn(size, size) * 10
    Q = (Q + Q.T) / 2
    return Q

def qubo_to_dict(Q):
    return {(i, j): Q[i, j] for i in range(Q.shape[0]) for j in range(i, Q.shape[1]) if Q[i, j] != 0}

def solve_quantum(Q, num_reads=100, use_hybrid=False):
    model = BinaryQuadraticModel.from_qubo(Q, offset=0.0)

    with open("tokens/dwave_token.txt", 'r') as file:
        os.environ['DWAVE_API_TOKEN'] = file.read().strip()

    if use_hybrid:
        sampler = LeapHybridSampler()
        print("Using Hybrid Solver...")
    else:
        sampler = EmbeddingComposite(DWaveSampler())
        print("Using Quantum Annealer with Embedding...")

    response = sampler.sample(model, num_reads=num_reads) if not use_hybrid else sampler.sample(model)

    group1 = np.array(list(response.samples()[0].values()))
    group2 = np.array([0 if group1[i]==1 else 1 for i in range(len(Q))])

    return group1, group2, response

def solve_qubo(qubo_dict, num_reads=100, chain_strength=None):
    bqm = BinaryQuadraticModel.from_qubo(qubo_dict)
    sampler = EmbeddingComposite(DWaveSampler())
    print(f"Variables: {len(bqm.variables)}, Interactions: {len(bqm.quadratic)}")
    return sampler.sample(bqm, num_reads=num_reads, chain_strength=chain_strength, label='QUBO_100x100')

def solve_hybrid(qubo_dict):
    bqm = BinaryQuadraticModel.from_qubo(qubo_dict)
    return LeapHybridSampler().sample(bqm, label='QUBO_100x100_Hybrid')

def analyze(sampleset, top_n=5):
    print(f"\nBest Energy: {sampleset.first.energy:.4f}")
    print(f"Total Samples: {len(sampleset)}\n")
    for i, d in enumerate(sampleset.data(['sample', 'energy', 'num_occurrences'])[:top_n]):
        print(f"{i+1}. Energy: {d.energy:.4f}, Occurs: {d.num_occurrences}")
    return sampleset.first.sample, sampleset.first.energy

if __name__ == "__main__":
    Q = QUBO_formulation(W, rho=rho, lam_size=lam_size, normalize=True)
    qubo_dict = qubo_to_dict(Q)


    print("\n=== Quantum Annealer (EmbeddingComposite) ===")
    group1_qa, group2_qa, response_qa = solve_quantum(qubo_dict, num_reads=100, use_hybrid=False)
    print(f"Group1: {group1_qa[:10]}...")
    print(f"Group2: {group2_qa[:10]}...")
    print(f"Energy: {response_qa.first.energy:.4f}")

    print("\n=== Hybrid Solver (LeapHybridSampler) ===")
    group1_hy, group2_hy, response_hy = solve_quantum(qubo_dict, use_hybrid=True)
    print(f"Group1: {group1_hy[:10]}...")
    print(f"Group2: {group2_hy[:10]}...")
    print(f"Energy: {response_hy.first.energy:.4f}")

    dwave.inspector.show(response_qa)