In [1]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from scipy.sparse import csc_array
from scipy.optimize import milp, LinearConstraint, Bounds
from math import inf
import re

def split_into_sentences(text):
    cleaned_text = re.sub(r'\s+', ' ', text.strip())
    sentences = re.split(r'(?<=[.!?])\s+', cleaned_text)
    return sentences

def milp_select_sentences(similarity_matrix, max_key_points, redundancy_threshold, lambda_param=0.5):
    scale_factor = 10.0
    scaled_lambda = lambda_param * scale_factor

    m, n = similarity_matrix.shape
    if m != n:
        raise ValueError("similarity_matrix must be square.")
    if max_key_points < 1:
        return []
    if max_key_points > m:
        return list(range(m))
    if not (0.0 <= lambda_param <= 1.0):
        raise ValueError("lambda_param must be in [0,1].")


    S = np.copy(similarity_matrix)
    np.fill_diagonal(S, 0.0)
    S[S < redundancy_threshold] = 0.0
    r = np.ones(m, dtype=float)
    K = max_key_points

    u = np.sum(np.maximum(S, 0.0), axis=1)
    l = np.sum(np.minimum(S, 0.0), axis=1)


    integrality = np.concatenate([np.ones(m), np.zeros(m)])
    lo = np.zeros(2 * m)
    hi = np.ones(2 * m)
    lo[m:] = -inf
    hi[m:] = inf
    bounds = Bounds(lo, hi)


    c = np.concatenate([-r, scaled_lambda * np.ones(m)])

    b_lo = np.zeros(4 * m + 1)
    b_hi = np.full(4 * m + 1, np.inf)
    rows, cols, vals = [], [], []
    row_idx = 0

    for i in range(m):
        rows.append(row_idx)
        cols.append(i)
        vals.append(-l[i])
        rows.append(row_idx)
        cols.append(m + i)
        vals.append(1.0)
        row_idx += 1

        rows.append(row_idx)
        cols.append(i)
        vals.append(u[i])
        rows.append(row_idx)
        cols.append(m + i)
        vals.append(-1.0)
        row_idx += 1


        b_lo[row_idx] = -u[i]
        for j in range(m):
            if j == i:
                rows.append(row_idx)
                cols.append(i)
                vals.append(-u[i])
            else:
                rows.append(row_idx)
                cols.append(j)
                vals.append(-S[i, j])
        rows.append(row_idx)
        cols.append(m + i)
        vals.append(1.0)
        row_idx += 1


        b_lo[row_idx] = l[i]
        for j in range(m):
            if j == i:
                rows.append(row_idx)
                cols.append(i)
                vals.append(l[i])
            else:
                rows.append(row_idx)
                cols.append(j)
                vals.append(S[i, j])
        rows.append(row_idx)
        cols.append(m + i)
        vals.append(-1.0)
        row_idx += 1


    b_lo[row_idx] = K
    b_hi[row_idx] = K
    for i in range(m):
        rows.append(row_idx)
        cols.append(i)
        vals.append(1.0)

    A = csc_array((vals, (rows, cols)), shape=(4 * m + 1, 2 * m))
    constraints = LinearConstraint(A, b_lo, b_hi)


    res = milp(c=c, integrality=integrality, bounds=bounds, constraints=constraints)
    if not res.success:
        print("MILP did not converge or was infeasible. Returning empty selection.")
        return []

    x_sol = res.x[:m]
    selected_indices = [i for i in range(m) if x_sol[i] > 0.5]
    return selected_indices

class MILPExtractor:
    def __init__(self, similarity_threshold=0.75, lambda_param=0.5, model_name="jinaai/jina-colbert-v2"):
        self.similarity_threshold = similarity_threshold
        self.lambda_param = lambda_param
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

    def get_key_points(self, passages, max_key_points=5, highlight=False):
        all_sentences = []
        for text in passages:
            all_sentences.extend(split_into_sentences(text))
        if not all_sentences:
            return {"highlighted_passages": None, "key_points": []}

        N = len(all_sentences)
        inputs = self.tokenizer(all_sentences, padding=True, truncation=True, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        sim_matrix = torch.matmul(embeddings, embeddings.T).numpy()
        sim_matrix = (sim_matrix + 1) / 2

        selected_indices = milp_select_sentences(
            similarity_matrix=sim_matrix,
            max_key_points=max_key_points,
            redundancy_threshold=self.similarity_threshold,
            lambda_param=self.lambda_param
        )
        key_points = [all_sentences[idx] for idx in selected_indices]

        if highlight:
            keypoint_set = set(key_points)
            highlighted_passages = []
            for text in passages:
                sentences = split_into_sentences(text)
                highlighted_sentences = [
                    f"{{{{{s}}}}}" if s in keypoint_set else s for s in sentences
                ]

                highlighted_passages.append(". ".join(highlighted_sentences) + ".")
            return {"highlighted_passages": highlighted_passages, "key_points": key_points}
        else:
            return {"highlighted_passages": None, "key_points": key_points}


if __name__ == "__main__":
    passages = [
        """Proper safety procedures in water treatment and sewage management are essential to protect workers,
        the public, and the environment. Staff must receive comprehensive training on hazard identification,
        chemical handling, and machinery operation. Personal protective equipment, including gloves, goggles,
        and respirators, should be worn whenever chemicals or biological contaminants are present. Regular
        maintenance of pumps, valves, and filtration systems ensures safe, efficient performance. Clear signage
        and barricades help reduce the risk of slips and falls around wet surfaces. Emergency response plans,
        updated periodically, outline proper steps for containment and cleanup in the event of spills or
        accidental releases. Supervisors should enforce strict protocols for hazardous waste disposal, including
        separating corrosive or flammable materials. Employees must wash hands thoroughly after coming into
        contact with untreated wastewater. Proper ventilation in enclosed areas, along with continuous gas
        monitoring, helps detect harmful fumes before they reach dangerous levels. Decontamination procedures,
        such as disinfecting tools and workspaces, limit the spread of pathogens. Training programs should
        incorporate up-to-date information on regulatory guidelines, reinforcing best practices. Documentation
        of all inspections, repairs, and incidents improves oversight and fosters accountability. By adhering
        to these measures, facility personnel safeguard public health, maintain compliance with environmental
        regulations, and preserve essential water resources."""
    ]

    extractor = MILPExtractor(similarity_threshold=0.75, lambda_param=0.5)
    results = extractor.get_key_points(passages, max_key_points=5, highlight=False)
    print("Key Points:")
    for i, kp in enumerate(results["key_points"], start=1):
        print(f"{i}. {kp}")


Key Points:
1. Personal protective equipment, including gloves, goggles, and respirators, should be worn whenever chemicals or biological contaminants are present.
2. Regular maintenance of pumps, valves, and filtration systems ensures safe, efficient performance.
3. Employees must wash hands thoroughly after coming into contact with untreated wastewater.
4. Decontamination procedures, such as disinfecting tools and workspaces, limit the spread of pathogens.
5. Training programs should incorporate up-to-date information on regulatory guidelines, reinforcing best practices.
