# Demonstration of SUM2HLA Pipeline on Google Colab **TPU**

This notebook demonstrates the execution of the **SUM2HLA** pipeline using Google Colab's TPU runtime. We specifically utilize the **TPU v5e-1** hardware, which is accessible through the **Google Colab Free Tier**. This setup highlights how JAX-based hardware acceleration can be applied to SUM2HLA to achieve efficient performance compared to standard environments.

### Environment & Dependencies
The default Google Colab environment comes pre-installed with most of the necessary software and Python packages required by SUM2HLA, including:
* **System Tools:** `git`, `git-lfs`
* **Python Libraries:** `pandas`, `numpy`, `scipy`, `jax`, `jaxlib`, `threadpoolctl`
* **External Tools:** This notebook automatically handles the installation of **PLINK** (from https://www.cog-genomics.org/plink/) to complete the required environment.

### Setup Process
Since the core dependencies are readily available, this notebook performs only the minimal setup required to reproduce the analysis:
1.  **Repository Setup:** Clone the official SUM2HLA GitHub repository.
2.  **Data Retrieval:** Download the example dataset (approx. 300MB) using `git-lfs`.

In [1]:
import os
import sys

# ==========================================
# 1. Install PLINK (Official Source)
# ==========================================
# Check if PLINK is installed in the system path
if not os.path.exists('/usr/local/bin/plink'):
    print("Installing PLINK from official source...")

    # Download the latest version (Ensure the link points to a valid Linux x86_64 build)
    !wget -q https://s3.amazonaws.com/plink1-assets/plink_linux_x86_64_20241022.zip -O plink.zip

    # Unzip to /usr/local/bin for global access
    !unzip -q -o plink.zip -d /usr/local/bin/
    !chmod +x /usr/local/bin/plink
    !rm plink.zip

    print("PLINK installation complete.")
else:
    print("PLINK is already installed.")

# ==========================================
# 2. Clone SUM2HLA Repository & Pull LFS Data
# ==========================================
REPO_URL = "https://github.com/WansonChoi/SUM2HLA.git"
REPO_DIR = "/content/SUM2HLA"

if not os.path.exists(REPO_DIR):
    print("Cloning SUM2HLA repository...")
    !git clone {REPO_URL}

    # Change directory to the repository
    %cd {REPO_DIR}

    # Download large files via Git LFS (approx. 300MB)
    print("Pulling Git LFS files (approx. 300MB)...")
    !git lfs pull
    print("Repository setup complete.")
else:
    print("Repository already exists. Checking for updates...")
    %cd {REPO_DIR}
    !git pull

# ==========================================
# 3. Verify Environment and Dependencies
# ==========================================
import jax
import pandas
import numpy

print(f"\n[Environment Check]")
print(f"JAX version: {jax.__version__}")
print(f"Pandas version: {pandas.__version__}")
print("PLINK version:")
!plink --version

Installing PLINK from official source...
PLINK installation complete.
Cloning SUM2HLA repository...
Cloning into 'SUM2HLA'...
remote: Enumerating objects: 1296, done.[K
remote: Counting objects: 100% (1296/1296), done.[K
remote: Compressing objects: 100% (928/928), done.[K
remote: Total 1296 (delta 432), reused 1221 (delta 363), pack-reused 0 (from 0)[K
Receiving objects: 100% (1296/1296), 39.78 MiB | 34.15 MiB/s, done.
Resolving deltas: 100% (432/432), done.
/content/SUM2HLA
Pulling Git LFS files (approx. 300MB)...
Repository setup complete.





[Environment Check]
JAX version: 0.7.2
Pandas version: 2.2.2
PLINK version:
PLINK v1.9.0-b.7.7 64-bit (22 Oct 2024)


In [2]:
# Run the example command
!python SUM2HLA.py \
    --sumstats example/WTCCC.RA.GWASsummary.N4798.assoc.logistic \
    --ref example/REF_1kG.EUR.hg19.SNP+HLA \
    --out OUT.WTCCC_RA.REF_1kG.EUR

Namespace(sumstats='example/WTCCC.RA.GWASsummary.N4798.assoc.logistic', ref='example/REF_1kG.EUR.hg19.SNP+HLA', out='OUT.WTCCC_RA.REF_1kG.EUR', batch_size=30, skip_SWCA=False, gpu_id=0, plink_path='/usr/local/bin/plink')
2025-12-24 08:50:10,541 [INFO] JAX with tpu (Total cores: 1)
2025-12-24 08:50:10,541 [INFO] Device details: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]
2025-12-24 08:50:10,542 [INFO] SUM2HLA start. (2025-12-24 08:50:10.542014)




Batch size: 30
=====[0]: 0-th batch / First 5 items: [(0,), (1,), (2,), (3,), (4,)] (2025-12-24 08:50:37.744730)
=====[100]: 100-th batch / First 5 items: [(3000,), (3001,), (3002,), (3003,), (3004,)] (2025-12-24 08:51:45.580869)

Total time for the LL when `N_causal`=1: 0:01:25.631564



Postprocessing the calculated LLs.


   rank  rank_p            SNP  ...  LL+Lprior_diff  LL+Lprior_diff_acc     logPP
0     1     0.0  HLA_DRB1_0401  ...             0.0                 0.0 -0.000906

[1 rows x 9 columns]
=====[ ROUND