# SCRIPT TO PERFORM QUALITY CONTROL ON UK BIOBANK SAMPLES

## This script should only be run once

#### Initialization
##### Load packages

In [1]:
from pathlib import Path
from datetime import datetime

import dxdata
import dxpy
import networkx as nx
import hail as hl
import pyspark
from pyspark.sql.functions import col
from pyspark.sql import SparkSession
from fields import fields_for_id

from packaging import version

##### Spark, Hail and dataset configuration 

In [None]:
sc = pyspark.SparkContext()
spark = SparkSession(sc)

In [3]:
# Constants
DATABASE = "matrix_tables"
REFERENCE_GENOME = "GRCh38"

LOG_FILE = (
    Path("../hail_logs", f"hail_{datetime.now().strftime('%H%M')}.log")
    .resolve()
    .__str__()
)

In [None]:
hl.init(sc=sc, default_reference=REFERENCE_GENOME, log=LOG_FILE)

In [5]:
dispensed_dataset_id = dxpy.find_one_data_object(
    typename="Dataset", name="app*.dataset", folder="/", name_mode="glob"
)["id"]

dataset = dxdata.load_dataset(id=dispensed_dataset_id)  # type: ignore
participant = dataset["participant"]

### Filtering
#### Hard filtering

In [6]:
fields = ["22027", # Outliers for heterozygosity or missing rate
          "22019", # Sex chromosome aneuploidy
          "22021", # Genetic kinship to other participants
          "21000"] # Ethnic background

field_names = [fields_for_id(i, participant) for i in fields]
field_names = ["eid"] + [field.name for fields in field_names for field in fields]

In [None]:
df = participant.retrieve_fields(
    names=field_names,
    engine=dxdata.connect(),
    coding_values="replace",  # type: ignore
)

df_filtered = df.filter(
    (~df.p22027.isNull())
    | (~df.p22019.isNull())
    | (df.p22021 == "Participant excluded from kinship inference process")
    | (df.p22021 == "Ten or more third-degree relatives identified")
    | (df.p21000_i0 == "White and Black Caribbean")
    | (df.p21000_i0 == "White and Black African")
    | (df.p21000_i0 == "White and Asian")
    | (df.p21000_i0 == "Any other mixed background")
)
filtered_samples_to_remove = hl.Table.from_spark(df_filtered.select("eid")).key_by("eid")

print(f"Samples to be filtered: {filtered_samples_to_remove.count()}")

#### Ancestry filtering

In [None]:
# Use ancestry filter from Privet et al.
# In order to obtain this file, run files 01_QC_samples_pcs.ipynb and 01_QC_samples_ancestry.ipynb
ANCESTRY_FILE = "file:///opt/notebooks/ancestry.csv"

anc = hl.import_table(ANCESTRY_FILE, delimiter=",", quote='"')
anc = anc.key_by(eid=anc["PC_UKBB.eid"])
ancestry_to_remove = anc.filter(anc.group != "United Kingdom")

print(f"Ancestry to remove: {ancestry_to_remove.count()}")

#### Withdrawn

In [None]:
df_withdrawn = df.filter(df.eid.startswith("w"))

withdrawn_to_remove = hl.Table.from_spark(df_withdrawn.select("eid")).key_by("eid")
print(f"Withdrawn samples to remove: {withdrawn_to_remove.count()}")

#### Related individuals

In [None]:
# Remove related individuals
RAW_REL_FILE = Path("/mnt/project/Bulk/Genotype Results/Genotype calls/ukb_rel.dat")
MAX_KINSHIP = 0.125  # 2nd degree relatives 

rel = hl.import_table(
    f"file://{RAW_REL_FILE}",
    delimiter=" ",
    impute=True,
    types={"ID1": "str", "ID2": "str"},
)

rel = rel.filter(
    hl.is_defined(filtered_samples_to_remove[rel.ID1])
    | hl.is_defined(filtered_samples_to_remove[rel.ID2])
    | hl.is_defined(ancestry_to_remove[rel.ID1])
    | hl.is_defined(ancestry_to_remove[rel.ID2])
    | hl.is_defined(withdrawn_to_remove[rel.ID1])
    | hl.is_defined(withdrawn_to_remove[rel.ID2]),
    keep=False,
)

rel = rel.filter(rel.Kinship > MAX_KINSHIP, keep=True)

In [None]:
# Hail maximal independent set is not working so we use networkx 

# Collect the ID pairs into a list for processing
rel_data = rel.select('ID1', 'ID2').collect()

# Create a graph using networkx from the relationships
G = nx.Graph()
for row in rel_data:
    G.add_edge(row.ID1, row.ID2)

# Compute the maximal independent set using networkx
independent_set = set(nx.maximal_independent_set(G))

# Extract all unique IDs from the original relationship data
all_ids = set([row.ID1 for row in rel_data] + [row.ID2 for row in rel_data])

# Calculate the related samples to remove (those not in the independent set)
related_samples_to_remove_ids = all_ids - independent_set

# Convert the related samples to remove into a Hail Table and key the table by "eid"
related_samples_to_remove = hl.Table.parallelize(
    [hl.struct(eid=sample) for sample in related_samples_to_remove_ids]
).key_by('eid')

print(f"Related samples not already in filter and high kinship coefficient: {related_samples_to_remove.count()}")

#### Combine all samples to remove

In [None]:
final_to_remove = (
    filtered_samples_to_remove.join(ancestry_to_remove, how="outer")
    .join(withdrawn_to_remove, how="outer")
    .join(related_samples_to_remove, how="outer")
).distinct()

print(f"Final number of samples to remove: {final_to_remove.count()}")

### Save and export

In [None]:
SAMPLES_TO_REMOVE_FILE = "/tmp/samples_to_remove.tsv"

final_to_remove.eid.export(SAMPLES_TO_REMOVE_FILE)

In [None]:
!hadoop fs -getmerge /tmp/samples_to_remove.tsv ../tmp/samples_to_remove.tsv
!dx upload ../tmp/samples_to_remove.tsv --path WGS_Javier/Data/Input_regenie/