# Prototype Development Notebook for HLA Greedy Coverage Solver

---

## 1. Setup and Imports

In [None]:
import sys
import os

# Go up one directory from notebook_dev to project root (hla_solver_prototype)
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Now imports will find hla_solver package
from hla_solver import HLASolver
from hla_solver.preprocessing import preprocess

In [None]:
import yaml
import pandas as pd
import numpy as np

---
## 2. Load Config

In [None]:
# Load YAML configuration for serial run in Jupyter-safe mode
config_path = "../configs/cover_serial_config.yaml"
# config_path = "../configs/cover_thread_pool_config.yaml"
with open(config_path, "r") as f:
    config = yaml.safe_load(f)
for key, value in config.items():
    print(f"{key}: {value}")

---
## 3. Data Preprocessing

In [None]:
csv_dataset = config["data"]["dataset_path"]
number_doubles = config["data"].get("number_doubles", 0)

df, hla_groups, allele_to_id, allele_matrix, allele_id_to_positions, allele_columns = preprocess(
    csv_dataset,
    number_doubles,
    verbose=True
)

---
## 4. Instantiate Solver

In [None]:
solver = HLASolver(config_path)

---
## 5. Run Solver

In [1]:
import sys
import os

# Optional: import time if your run is long and you want timestamps
from datetime import datetime

class Tee:
    def __init__(self, *streams):
        self.streams = streams

    def write(self, message):
        for s in self.streams:
            s.write(message)
            s.flush()  # Ensure real-time logging

    def flush(self):
        for s in self.streams:
            s.flush()

# Set up project root path
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from hla_solver import HLASolver

# Path to config file
config_path = os.path.join(project_root, "configs", "cover_serial_config.yaml")
solver = HLASolver(config_path)

# Define log file path
log_path = os.path.join(project_root, "console_output_serial_test.log")

# Open log file and tee stdout to both console and log
with open(log_path, "w", encoding="utf-8") as logfile:
    tee = Tee(sys.stdout, logfile)
    sys.stdout = tee  # Override stdout
    try:
        best_filters = solver.run()
    finally:
        sys.stdout = sys.__stdout__  # Restore original stdout

# Final summary printed to notebook only
print("\nBest filters:", best_filters)



--- START [Preprocessing Dataset] ---
Loading dataset from ../data/simulated_HLA_samples_5000.csv...
[Timing] Load dataset: 0.0160 seconds
Dataset loaded: 5000 samples, 13 columns
[Timing] Detect allele columns: 0.0000 seconds
Identified allele columns: ['HLA-A 1', 'HLA-A 2', 'HLA-B 1', 'HLA-B 2', 'HLA-C 1', 'HLA-C 2', 'HLA-DQB1 1', 'HLA-DQB1 2', 'HLA-DRB1 1', 'HLA-DRB1 2']
[Timing] Truncate alleles: 0.0580 seconds
[Timing] Index alleles: 0.0060 seconds
[Timing] Build allele matrix: 0.3740 seconds
[Timing] Group HLA columns: 0.0000 seconds
HLA groups identified: [('HLA-A', 'HLA-A 1', 'HLA-A 2'), ('HLA-B', 'HLA-B 1', 'HLA-B 2'), ('HLA-C', 'HLA-C 1', 'HLA-C 2'), ('HLA-DQB', 'HLA-DQB1 1', 'HLA-DQB1 2'), ('HLA-DRB', 'HLA-DRB1 1', 'HLA-DRB1 2')]
[Timing] Validate/coerce alleles: 0.0010 seconds
[Timing] Data duplication (augmentation): 0.0000 seconds
[Timing] Build sample allele sets: 0.3930 seconds
--- END [Preprocessing Dataset] | Time: 0.87s ---
[Preprocessing] Completed. Total samples: 

---
## 6. Run Solver with Restart

In [None]:
import sys
import os

# Optional: import time if your run is long and you want timestamps
from datetime import datetime

class Tee:
    def __init__(self, *streams):
        self.streams = streams

    def write(self, message):
        for s in self.streams:
            s.write(message)
            s.flush()  # Ensure real-time logging

    def flush(self):
        for s in self.streams:
            s.flush()

# Set up project root path
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from hla_solver import HLASolver

# Path to config file
config_path = os.path.join(project_root, "configs", "cover_serial_config.yaml")
solver = HLASolver(config_path)

# Define log file path
log_path = os.path.join(project_root, "console_output_serial_restart_test.log")

# Open log file and tee stdout to both console and log
with open(log_path, "w", encoding="utf-8") as logfile:
    tee = Tee(sys.stdout, logfile)
    sys.stdout = tee  # Override stdout
    try:
        best_filters = solver.run(restart_file="restarts/greedy_restart_20251016_180855.csv")
    finally:
        sys.stdout = sys.__stdout__  # Restore original stdout

# Final summary printed to notebook only
print("\nBest filters:", best_filters)

---
## 6. Results

In [None]:
print("\nFinal best global filters:")
print( best_filters)

---
## 7. Optional: Validation and Metrics

In [None]:
from hla_solver.validation import print_validation_report

print( df )
# Example: assuming best_filters contains sample indices or a structure from which you extract selected indices
selected_indices = ['SIM_44', 'SIM_52', 'SIM_8', 'SIM_17', 'SIM_41']
print_validation_report(df, hla_groups, selected_indices)

---
## Validation
- Create a small dataset
- Find best single sample cover
- Print best single sample cover and a tabular list of all samples that it covers to identify any incorrectness

In [None]:
# we now have data/simulated_HLA_samples_100.csv
# this has only 100 samples generated as per the reformat and generate function

# add path to modules to system path:

# Set up project root path
import os, sys
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Imports from your package
# from hla_solver import HLASolver
from hla_solver.selectors import find_top_covering_samples
from hla_solver.preprocessing import preprocess
from hla_solver.selectors import sample_coverage_percent
from hla_solver.selectors import find_matching_samples_by_filter
from hla_solver.utils import TimingLogger

# define data set
csv_path = "../data/simulated_HLA_samples_20000.csv"

# create timing logger:
timing_logger = TimingLogger(verbose=True)

# preprocess data set
df, hla_groups, allele_to_id, id_to_allele, allele_matrix, allele_id_to_positions, allele_columns, sample_allele_sets = preprocess( csv_path )
# from selectors import find_top_covering_samples, multi_sample_coverage_and_remainder
# def find_top_covering_samples(
#     df: pd.DataFrame,
#     hla_groups: List[Tuple[str, str, str]],
#     top_n: int,
#     allele_sets_by_sample: Dict[str, Dict[str, frozenset]],
#     coverage_type: str = "intersection",
#     verbose: bool = False
# ) -> List[Dict[str, Any]]:

top_samples = find_top_covering_samples(
    df=df,
    hla_groups=hla_groups,
    top_n=3, 
    allele_sets_by_sample=sample_allele_sets,
    allele_to_id=allele_to_id,
    verbose=True, 
    enable_timing=True,
    timing=timing_logger.timing,
)

print( top_samples )

In [None]:
import time
import pandas as pd

# Display more columns without wrapping
pd.set_option('display.max_columns', None)      # show all columns
pd.set_option('display.width', 2000)            # increase overall width
pd.set_option('display.max_colwidth', None)     # show full content of each column

print("\nTop Covering Samples:")
for sample in top_samples:
    print(sample)

# Pick the top sample
top_sample_id = top_samples[0]["Sample ID"]
top_sample_alleles = top_samples[0]["Alleles"]

print(f"\nTop Sample ID: {top_sample_id}")
print("Top Sample Alleles (by group):")
for group, alleles in top_sample_alleles.items():
    print(f"  {group}: {alleles}")

# Time the sample coverage computation
start_time = time.time()

coverage_pct, allele_dict = sample_coverage_percent(
    df=df,
    hla_groups=hla_groups,
    sample_id=top_sample_id,
    allele_sets_by_sample=sample_allele_sets,
    allele_to_id=allele_to_id,
    coverage_type="intersection",
    verbose=True
)

elapsed = time.time() - start_time
print(f"\n sample_coverage_percent executed in {elapsed:.4f} seconds.")

# print("\nCovered Sample IDs:", covered_sample_ids)

In [None]:
coverage_pct, allele_dict = sample_coverage_percent(
                df=df,
                hla_groups=hla_groups,
                sample_id=top_sample_id,
                allele_sets_by_sample=sample_allele_sets,
                allele_to_id=allele_to_id,
                coverage_type="intersection",
                verbose=True
            )

print( allele_dict )

In [None]:
from pprint import pprint

# Print first 2 entries from the dictionary
for i, (sample_id, allele_dict) in enumerate(sample_allele_sets.items()):
    if i >= 2:
        break
    print(f"Sample ID: {sample_id}")
    pprint(allele_dict)
    print()


In [None]:
print( list(sample_allele_sets['SIM_1']['HLA-A'])[0] )