# Wow! We made it to our final notebook together - it went by so fast, huh?

# So our data is all cleaned up individually, but now its time to bring them together! 
---
# This tutorial will walk us through the combination of multiple plink BED datasets 
  - If a singular study has multiple arrays/sets of data - run this tutorial on those intra-study datasets before merging everything together!
# This tutorial will process neccesary post-merger quality control measures
---
# The basic steps of our merger will include:

  - **Performing the raw merger**
  - **Eliminate non-overlapping variants**
  - **Inspect for duplicate samples**
  - **Conduct fine-grained batch effect analyses**
---
# You've got this! Lets get our neccesary packages loaded up and we will get to work!

In [4]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from qqman import qqman
import subprocess
import IPython
import seaborn as sns
from itertools import combinations
from collections import Counter
from tabulate import tabulate
import tempfile
import glob
import itertools


In [None]:
#If you do not already have the software Primus downloded (we will be using it later) you can download it from https://primus.gs.washington.edu/primusweb/res/form.html 
!wget -nc -P primus https://primus.gs.washington.edu/docroot/versions/PRIMUS_v1.9.0.tgz
!tar -xvf primus/PRIMUS_v1.9.0.tgz



In [None]:
#Set up plink to work in jupyter notebook (Compute Canada)
!module load StdEnv/2020 && module load plink/1.9b_6.21-x86_64 && which plink
!module load StdEnv/2020 && module load plink/2.00a3.6 && which plink2
!module load StdEnv/2020 && module load gcc/9.3.0 && module load flashpca/2.0 && which flashpca
#!wget -P "Directory" -nc https://hgdownload.soe.ucsc.edu/goldenPath/hg38/database/gap.txt.gz
#!gunzip "Directory"/gap.txt.gz 
#!cut -f 2-4,8 "Directory"/gap.txt > "Directory"/genome_gap_hg38.bed

In [None]:
#If you do not already have the software Flashpca you can download it from  https://github.com/gabraham/flashpca.git
!git clone https://github.com/gabraham/flashpca.git
!cd gabraham
!make

In [None]:
#Copy the output from above into this next command -- or just the absolute path to your downloaded plink
plink_path = 'Absolute Path to Plink'
plink2_path = 'Absolute Path to Plink 2'
primus_path = 'Absolute path to run_PRIMUS.pl'
flash_path = 'Absolute path to flashpca' 
genome_gap = "<Directory>/genome_gap_hg38.bed"

In [4]:

# Shared QC table to track all studies
shared_qc_table = []

def count_variants(study_name, bim_file, fam_file, step_name):
    """
    Args:
        study_name : Name of the study (e.g., "Study1").
        bim_file : Path to the BIM file.
        fam_file : Path to the FAM file.
        step_name : Name of the step (e.g., "Start", "After Class1", "After Class2").
    
    Returns:
        dict: A dictionary containing counts for autosomal, X, Y, MT variants,
              total individuals, males, females, and ambiguous individuals.
    """
    autosomal = 0
    x_chr = 0
    y_chr = 0
    mt_chr = 0

    
    with open(bim_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            chrom = parts[0]
            if chrom.startswith("chr"):
                chrom_clean = chrom.replace("chr", "")
            else:
                chrom_clean = chrom
            if chrom_clean in ['X', '23', '25']:
                x_chr += 1
            elif chrom_clean in ['Y', '24']:
                y_chr += 1
            elif chrom_clean in ['MT', 'M', '26']:
                mt_chr += 1
            elif chrom_clean.isdigit():
                if 1 <= int(chrom_clean) <= 22:
                    autosomal += 1

    
    individuals = 0
    males = 0
    females = 0
    ambiguous = 0
    
    with open(fam_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            sex_code = int(parts[4])
            if sex_code == 1:
                males += 1
            elif sex_code == 2:
                females += 1
            elif sex_code == 0:
                ambiguous += 1
            individuals += 1
    
    shared_qc_table.append([
        study_name,
        step_name,
        autosomal,
        x_chr,
        y_chr,
        mt_chr,
        individuals,
        males,
        females,
        ambiguous,
    ])
    
    return {
        "autosomal": autosomal,
        "x_chr": x_chr,
        "y_chr": y_chr,
        "mt_chr": mt_chr,
        "individuals": individuals,
        "males": males,
        "females": females,
        "ambiguous": ambiguous,
    }
headers = [
    "Study Name", "Step Name", "Autosomal", "X Chr", "Y Chr", "MT Chr",
    "Individuals", "Males", "Females", "Ambiguous"
]

def save_qc_table(filename="Merger_results.txt"):
    """
    Saves the shared QC table to a text file.

    Args:
        filename : Name of the output file.
    """
    with open(filename, "w") as f:
        # Write the header
        f.write("Study\tStep\tAutosomal\tX_Chr\tY_Chr\tMT_Chr\tIndividuals\tMales\tFemales\tAmbiguous\n")
        
        # Write each row of data
        for row in shared_qc_table:
            f.write("\t".join(map(str, row)) + "\n")

# Before we dive in, we need to look at the overlap that our data has with each other.
# Why is this a concern you may ask?
---
# In the world of bioinformatics - number of variants is the name of the game!
- The more variants we have, the finer and more detailed our analyses can be.
# But why exactly do we need to be concerned with overlapping data?
---
# Different biotechnology companies use different chips to genotype their genetic data
- Different chips will focus on different positions and thus we are presented with our problem of overlap.
---
# While we love and cherish our data, we need to be pragmatic about the overlap of variants while also balancing the demographic needs of our merged dataset!


- Move onto the next cell so we can see what we are dealing with in our data!

In [5]:

def combo(list_of_bim_files):
    """
    Analyzes the intersection of SNP IDs from multiple .bim files and displays the results in a table.
    """
    
    def read_bim_file(bim_file_path):
        """
        Reads a .bim file and returns a set of SNP IDs from the second column.
        """
        snp_ids = set()
        try:
            with open(bim_file_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) > 1:
                        snp_ids.add(parts[1])
        except FileNotFoundError:
            print(f"Error: File {bim_file_path} not found!")
        return snp_ids
    
    def get_filename(file_path):
        """
        Extracts the file name from the full file path.
        """
        return os.path.basename(file_path)
    
    # Read SNP IDs from all .bim files
    snp_dict = {bim_file: read_bim_file(bim_file) for bim_file in list_of_bim_files}
    
    # Remove any files that couldn't be read
    snp_dict = {k: v for k, v in snp_dict.items() if v}
    
    if not snp_dict:
        print("No valid .bim files were successfully read. Exiting.")
        return pd.DataFrame()  # Return an empty DataFrame if no files were read
    
    file_names = list(snp_dict.keys())
    intersections = []
    
    # Consider all combinations of the files
    for r in range(2, len(file_names) + 1):  # Starting from 2 files for intersection
        for combo in itertools.combinations(file_names, r):
            # Calculate the intersection of the SNP IDs of the files in the combination
            combined_snps = snp_dict[combo[0]]
            for file in combo[1:]:
                combined_snps = combined_snps.intersection(snp_dict[file])
            
            # Store the intersection along with the combination and the number of variants
            num_intersection = len(combined_snps)
            intersections.append((num_intersection, combo, set(file_names) - set(combo)))
    
    # Sort by the number of intersecting variants (descending order), then by the number of files (more files is better)
    intersections.sort(key=lambda x: (-x[0], -len(x[1])))
    
    # Prepare the results in a DataFrame
    results = []
    for idx, (num_intersection, best_combination, excluded_files) in enumerate(intersections, 1):
        results.append({
            "Rank": idx,
            "Intersecting SNPs": num_intersection,
            "Best Combination": ", ".join(get_filename(file) for file in best_combination),
            "Excluded Files": ", ".join(get_filename(file) for file in excluded_files)
        })
    
    # Convert to a DataFrame
    results_df = pd.DataFrame(results)
    
    # Display the table in the notebook
    display(results_df)
    
    return results_df



# Plug in your bim files from the last notebook down below to see the permutations of overlap amongst your datasets
- If one or more of your datasets has one or more arrays, you can run this script on those bim files as well as a handy sanity check!

In [None]:

list_of_bim_files = ['path/to/file1.bim', 'path/to/file2.bim', 'path/to/file3.bim']
combo(list_of_bim_files)

# How are we looking? If we see in our permutations that a singular study or array is causing a sharp decline in your variant counts, that may be cause for concern.
---
# Up next we have the raw merger of our datasets
- Include the studies you think thread the needle in terms of variant counts and your research priorities! 

# In this section we are going to:
- Affix a designated suffix to our sample IDS 
  - This will help us keep track of who comes from where once everything is merged together
- Merged the datasets together without minding for overlap
  - We should expect a single set of plink files at the end

In [6]:
class RawMerger:
    def __init__(self, output_dir, studies, suffixes):
        """
        Initialize the GWAS merger class.
        
        Args:
            output_dir (str): Directory to store output files.
            studies (list you make before running): List of study prefixes (e.g., ["/path/to/study1", "/path/to/study2"]).
            suffixes (list you make before running): List of suffixes corresponding to each study (e.g., ["suffix1", "suffix2"]).
        """
        if len(studies) != len(suffixes):
            raise ValueError("The number of studies and suffixes must match.")

        self.output_dir = os.path.abspath(output_dir)
        os.makedirs(self.output_dir, exist_ok=True)
        self.merge_list = os.path.join(self.output_dir, "merge_list.txt")
        self.variant_sets = {}  # Store variant sets for each study
        self.studies = studies  # List of study prefixes
        self.suffixes = suffixes  # List of suffixes

    def update_ids(self):
        """
        Update IDs in the PLINK files using the provided suffixes.
        """
        for study, suffix in zip(self.studies, self.suffixes):
            base_name = os.path.basename(study)
            update_ids_file = os.path.join(self.output_dir, f"{base_name}_update_ids.txt")
            output_prefix = os.path.join(self.output_dir, f"{base_name}_updated")

            # Create update IDs file
            fam_file = f"{study}.fam"
            with open(fam_file, 'r') as f_in, open(update_ids_file, 'w') as f_out:
                for line in f_in:
                    parts = line.strip().split()
                    f_out.write(f"{parts[0]} {parts[1]} {parts[0]} {parts[1]}-{suffix}\n") #You can attach the suffix with anything you want just change the divdier here {parts[1]}-{suffix}\n

            # Run PLINK to update IDs
            subprocess.run([plink_path, "--bfile", study,"--update-ids", update_ids_file,"--keep-allele-order","--make-bed","--out", output_prefix], check=True)

            # Add to merge list
            with open(self.merge_list, 'a') as f:
                f.write(f"{output_prefix}\n")

            # Extract variant IDs from the .bim file
            bim_file = f"{study}.bim"
            with open(bim_file, 'r') as f:
                variant_ids = {line.strip().split()[1] for line in f}
            self.variant_sets[suffix] = variant_ids


    def merge_studies(self):
        """
        Merge all PLINK studies into a single dataset.
        """
        with open(self.merge_list, 'r') as f:
            files = f.readlines()
        if not files:
            raise ValueError("No files to merge!")

        first_file = files[0].strip()
        merge_temp = os.path.join(self.output_dir, "merge_temp.txt")
        with open(merge_temp, 'w') as f:
            f.writelines(files[1:])

        base_name = os.path.join(self.output_dir, "merged_data")
        subprocess.run([plink_path, "--bfile", first_file, "--merge-list", merge_temp, "--keep-allele-order", "--make-bed", "--out", base_name], check=True)

        # Track QC after merging
        count_variants("Merged", f"{base_name}.bim", f"{base_name}.fam", "After Merge")
        self.current_base = base_name
        
    def get_output_files(self):
        """Return the final BIM and FAM files after missingness filtering."""
        return self.current_base

In [None]:
output_dir = "/path/to/output"
studies = ["/path/to/study1", "/path/to/study2"]
suffixes = ["suffix1", "suffix2"]
raw_merger = RawMerger(output_dir, studies, suffixes)

In [None]:
raw_merger.update_ids()
raw_merger.merge_studies()
print(tabulate(shared_qc_table, headers=headers, tablefmt="pretty"))

# Take a look at the genotyping rates between when the datasets were on their own versus that final merger, do you see how big a role the intersection of variants plays! 

# This next section of code is going to now filter out those non-overlapping variants!
---
# What if say, you have ten datasets and only one dataset does not have a particular variant, seems silly to throw that variant out, right? 
- Our code will include some flexibility in the percentage of overlap needed for a variant to be kept in the final merged reference panel
    - Just remember, the lower your threshold for accepting variants, the higher the missingness will be and the lower the overall genotyping rate. 

In [8]:

class IntersectionAnalysis:
    def __init__(self, variant_sets, base_name, output_dir):
        """
        Initialize the intersection analysis class.
        
        Args:
            variant_sets (dict): Dictionary mapping suffixes to sets of variant IDs.
            base_name (str): Prefix of the merged PLINK dataset.
            output_dir (str): Directory to store output files.
        """
        self.variant_sets = variant_sets
        self.base_name = base_name
        self.output_dir = os.path.abspath(output_dir)
        os.makedirs(self.output_dir, exist_ok=True)
        self.intersection_output = None

    def calculate_pairwise_intersections(self):
        """
        Calculate pairwise intersections between datasets.
        """
        pairwise_results = []
        labels = list(self.variant_sets.keys())
        data_sets = list(self.variant_sets.values())

        # Calculate self-intersections
        for i, set_a in enumerate(data_sets):
            intersection_count = len(set_a & set_a)
            pairwise_results.append((labels[i], labels[i], intersection_count, 100.0, 100.0))

        # Calculate pairwise intersections
        for (i, set_a), (j, set_b) in combinations(enumerate(data_sets), 2):
            intersection_count = len(set_a & set_b)
            percent_a = (intersection_count / len(set_a)) * 100 if len(set_a) > 0 else 0
            percent_b = (intersection_count / len(set_b)) * 100 if len(set_b) > 0 else 0
            pairwise_results.append((labels[i], labels[j], intersection_count, percent_a, percent_b))
            pairwise_results.append((labels[j], labels[i], intersection_count, percent_b, percent_a))

        # Convert results to a DataFrame
        df = pd.DataFrame(pairwise_results, columns=["Set1", "Set2", "IntersectionSize", "PercentSet1", "PercentSet2"])

        # Display pairwise results
        print("Pairwise Intersection Results:")
        print(df)
        self.pairwise_results = df
        return df

    def generate_heatmap(self):
        """
        Generate and display a heatmap for pairwise intersection results.
        
        Args:
            pairwise_results (pd.DataFrame): DataFrame containing pairwise intersection results.
        """
       
        heatmap_data = self.pairwise_results.pivot(index="Set1", columns="Set2", values="PercentSet1")
        heatmap_data = heatmap_data.fillna(0)

        
        heatmap_data = heatmap_data.where(np.triu(np.ones(heatmap_data.shape)).astype(bool))

        # Create the heatmap
        plt.figure(figsize=(10, 8))
        sns.heatmap(heatmap_data, annot=True, fmt=".1f", cmap="Blues", cbar_kws={'label': 'Percent in Common'})
        plt.title("Heatmap with Percent in Common")
        plt.xticks(rotation=45)
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show() 
    def get_best_intersection(self, threshold=""):
        """
        Calculate the "best" intersection (variants present in most studies).
    
        Args:
            threshold (float): Threshold for the percentage of studies a variant must be present in.
    
        Returns:
            set: Set of variants that meet the threshold.
        """
        # Count how many studies each variant appears in
        variant_counts = Counter()
        for variant_set in self.variant_sets.values():
            # Ensure counts are integers
            variant_counts.update({variant: 1 for variant in variant_set})


        # Calculate the threshold count
        threshold_count = len(self.variant_sets) * threshold

        # Select variants that meet the threshold
        best_intersection = {variant for variant, count in variant_counts.items() if count >= threshold_count}  # No need to convert count to int

        # Calculate the average percentage of how often each variant is shared across studies
        total_studies = len(self.variant_sets)
        average_percentage = np.mean([(count / total_studies) * 100 for count in variant_counts.values() if count >= threshold_count])  # No need to convert count to int

        print(f"Best intersection includes {len(best_intersection)} variants (present in at least {threshold * 100}% of studies).")
        print(f"On average, these variants are shared across {average_percentage:.2f}% of studies.")
    
        return best_intersection
    
    def filter_merged_dataset(self, threshold):
        """
        Filter the merged dataset using the optimal intersection.
        """
        best_intersection = self.get_best_intersection(threshold)

        # Write the best intersection to a text file
        intersection_file = os.path.join(self.output_dir, "best_intersection.txt")
        with open(intersection_file, 'w') as file:
            file.write('\n'.join(best_intersection))

        print(f"Best intersection saved to: {intersection_file}")

        filtered_prefix = os.path.join(self.output_dir, "merged_data_filtered")
        filtered_prefix = filtered_prefix.strip()
        self.base_name = self.base_name.strip()
        intersection_file = intersection_file.strip()

        subprocess.run([plink_path, "--bfile", self.base_name, "--extract", intersection_file, "--keep-allele-order", "--make-bed", "--out", filtered_prefix],capture_output=True, text=True)
        count_variants("Merged", f"{filtered_prefix}.bim", f"{filtered_prefix}.fam", "After Intersection")
        self.current_base = filtered_prefix
    def get_output_files(self):
        return self.current_base

In [None]:
final_base = raw_merger.get_output_files()

print(f"Final base file: {final_base}")

In [None]:
intersection = IntersectionAnalysis(raw_merger.variant_sets, 'path/to/base/name', output_dir)

# Here we will take another look at the final overlap of varaints and we can see the percentages of each overlap in a heatmap format.

In [None]:
intersection.calculate_pairwise_intersections()
intersection.generate_heatmap()

# Now with that information in hand, we can plug in the threshold we think is appropriate and plug it in below!

In [None]:
intersection.filter_merged_dataset(DECIMAL OF PERCENTAGE THRESHOLD)
print(tabulate(shared_qc_table, headers=headers, tablefmt="pretty"))

# This next section should look familar to our "twin" screening back in the quality control notebook, we have a similar concern as we did of duplicate samples in our final dataset, so let us go ahead and nip them in the bud!

In [16]:

class TwinProcessor:
    def __init__(self, study_name, base_name, out_dir):
        """
        Args:
            study_name : Name of the study.
            base_name : Base name of the input files (including full path, e.g., "path/to/data_base_name").
            out_dir : Directory where all output files will be saved.
        """
        self.study_name = study_name
        self.base_name = base_name
        self.out_dir = out_dir

        # Create the output directory if it doesn't exist
        os.makedirs(self.out_dir, exist_ok=True)
        self.intermediate_files = []
        # Initialize current BIM and FAM files
        self.current_base = os.path.basename(base_name)
        self.current_bim = f"{self.current_base}.bim"
        self.current_fam = f"{self.current_base}.fam"

    def identify_twins(self):
        # Create a temporary file for processing
        temp_base = os.path.join(self.out_dir, "temp")
        subprocess.run([plink_path, "--bfile", self.base_name, "--keep-allele-order", "--make-bed", "--out", temp_base], check=True)

        # Calculate relatedness using PLINK2
        subprocess.run([plink2_path, "--bfile", temp_base, "--make-king-table", "--king-table-filter", "0.5", "--out", "relatedness"], check=True)

        # Create twins.txt file with a header
        twins_file_path = os.path.join(self.out_dir, "twins.txt")
        with open(twins_file_path, "w") as twins_file:
            twins_file.write("FID1\tIID1\tFID2\tIID2\n")  # Write header

            # Extract twin pairs with KING coefficient > 0.5
            if os.path.exists("relatedness.kin0"):
                with open("relatedness.kin0", "r") as kin_file:
                    next(kin_file)  # Skip the header line
                    for line in kin_file:
                        fields = line.strip().split()
                        if float(fields[7]) > 0.354:  # KING coefficient in column 8
                            twins_file.write(f"{fields[0]}\t{fields[1]}\t{fields[2]}\t{fields[3]}\n")

        self.intermediate_files.append(f"{temp_base}.*")
        self.intermediate_files.append("relatedness.kin0")
        self.intermediate_files.append(twins_file_path)
        self.twins_file = twins_file_path

    def screen_twins(self):
        # Ensure twins.txt exists
        if not os.path.exists(self.twins_file):
            print("Twins file does not exist. Creating an empty file.")
            with open(self.twins_file, "w") as twins_file:
                twins_file.write("FID1\tIID1\tFID2\tIID2\n")  # Write header

        # Run PLINK to calculate missingness
        subprocess.run([plink_path, "--bfile", self.base_name, "--missing", "--out", "missingness"], check=True)

        # Read missingness data
        individual_variants = {}
        if os.path.exists("missingness.imiss"):
            with open("missingness.imiss", "r") as imiss_file:
                next(imiss_file)  # Skip the header line
                for line in imiss_file:
                    fields = line.strip().split()
                    fid, iid = fields[0], fields[1]
                    non_missing_count = int(fields[4])
                    individual_variants[(fid, iid)] = non_missing_count
            self.intermediate_files.append("missingness.imiss")

        # Initialize lists for twins to remove and tied twins
        twins_to_remove = []
        twins_tied = []

        # Process twins file
        with open(self.twins_file, "r") as twins_file:
            header = next(twins_file)  # Read header
            for line in twins_file:
                fields = line.strip().split()
                twin1 = (fields[0], fields[1])
                twin2 = (fields[2], fields[3])
                if twin1 in individual_variants and twin2 in individual_variants:
                    if individual_variants[twin1] < individual_variants[twin2]:
                        twins_to_remove.append(twin2)
                    elif individual_variants[twin1] > individual_variants[twin2]:
                        twins_to_remove.append(twin1)
                    else:
                        twins_tied.append((twin1, twin2))

        # Create twins_to_remove.txt with header
        remove_file_path = os.path.join(self.out_dir, "twins_to_remove.txt")
        with open(remove_file_path, "w") as remove_file:
            remove_file.write("FID\tIID\n")  # Write header
            for fid, iid in twins_to_remove:
                remove_file.write(f"{fid}\t{iid}\n")

        # Create twins_tied.txt with header
        tied_file_path = os.path.join(self.out_dir, "twins_tied.txt")
        with open(tied_file_path, "w") as tied_file:
            tied_file.write(header)  # Write header from twins.txt
            for twin1, twin2 in twins_tied:
                tied_file.write(f"{twin1[0]}\t{twin1[1]}\t{twin2[0]}\t{twin2[1]}\n")

        self.tied = tied_file_path
        self.remove = remove_file_path
        



    def remove_twins(self):
        """Remove flagged twins from the dataset."""
        twin_base = os.path.join(self.out_dir, f"{os.path.basename(self.base_name)}_twins_filtered")
        subprocess.run([plink_path, "--bfile", self.base_name, "--remove", self.remove ,"--keep-allele-order", "--make-bed", "--out", twin_base], check=True)

        # Update current BIM and FAM files
        self.current_base = twin_base
        self.current_bim = f"{twin_base}.bim"
        self.current_fam = f"{twin_base}.fam"
        count_variants(self.study_name, self.current_bim, self.current_fam, "Removal of Homozygotic Twins")
        total_manual = subprocess.run(["wc", "-l", self.tied], capture_output=True, text=True).stdout.strip()
        print(f"Total number of twins neededing manual curation: {total_manual}")
        for file in self.intermediate_files:
            if os.path.exists(file):
                os.remove(file)
    def get_output_files(self):
        return self.current_base    

In [None]:
final_base = intersection.get_output_files()

print(f"Final base file: {final_base}")

In [None]:
twins = TwinProcessor("Merged Dataset", "Path/to/last/basename", "Path/to/output/directory")

In [None]:
twins.identify_twins()
twins.screen_twins()

# For this final section, if you want to "manually curate" your samples you can manually add them to the twins_to_remove.txt, 
- Yake a look at your flagged tied twins a bit more closely 
  - see if the flagged relationships actually make sense in the context of your data and if not it make speak to the quality of the data or simply a bad statstical guess!

In [None]:
twins.remove_twins()
print(tabulate(shared_qc_table, headers=headers, tablefmt="pretty"))

# So now that we have merged our datasets together, trimmed for only overlapping variants, and removed any duplicate samples, what's next?
---
# We need to screen for batch effects!
- What exactly is a batch effect?
# Batch effects are  technical artifacts that introduce non-biological variations in genetic data due to differences in experimental conditions, equipment, or processing across multiple datasets.
- Unless you want your groundbreaking genetic discovery to be just a sneaky batch effect in disguise! 
  - Screening for batch effects ensures that the biological signals you’re studying aren’t overshadowed by technical noise, keeping your data clean and your conclusions reliable!
---
# What we'll do is screen our datasets for variants statistically with a particular dataset 
  - Indicating a "batch effected" variant
# We'll do this through **G**enome-**W**ide **A**ssociation **S**tudy (GWAS)
  - We will be running this GWAS through a jack-kniffing statistical method:
    - This means for each study we will pull it out as our "case" and then pool the significant values.
  - We will aslo run our GWAS with the first twenty **P**rincipal **C**omponents (PCs) as covariants:
    - We are using PCs to measure the genetic distance between samples.
    - This is to ensure that we are not throwing out variants that are associated with a particular ancestry as opposed to a genuine batch effect!

In [26]:
class GWAS:
    def __init__(self, base_name, output_dir, suffixes):
        
        """
        Initialize the PCA and GWAS analysis class.
        
        Args:
            base_name (str): Prefix of the merged PLINK dataset.
            output_dir (str): Directory to store output files.
            studies (list): List of study names (e.g., ["study1", "study2"]).
            suffixes (list): List of suffixes corresponding to each study (e.g., ["s1", "s2"]).
        """
        self.base_name = base_name
        self.output_dir = os.path.abspath(output_dir)
        self.suffixes = suffixes
        os.makedirs(self.output_dir, exist_ok=True)
        self.base_name = base_name
        self.output_dir = os.path.abspath(output_dir)
        os.makedirs(self.output_dir, exist_ok=True)
        

    def run_pca(self):
        """
        Run PCA on the merged dataset.
        """
        pca_dir = os.path.join(self.output_dir, "PCA")
        os.makedirs(pca_dir, exist_ok=True)

        # Step 1: Quality Control (QC) for Your Data
        print("Performing Quality Control on Your Data...")
        subprocess.run([plink_path, "--bfile", self.base_name,"--maf", "0.05","--geno", "0.05","--hwe", "1e-25","--exclude", "range", genome_gap,"--keep-allele-order","--make-bed","--out", os.path.join(pca_dir, "qc_filtered_your_data")], check=True)

        # Step 2: LD Pruning
        print("Performing LD Pruning...")
        subprocess.run([plink_path, "--bfile", os.path.join(pca_dir, "qc_filtered_your_data"),"--indep-pairwise", "50", "5", "0.2", "--out", os.path.join(pca_dir, "ld_pruned_ref")], check=True)
        subprocess.run([plink_path, "--bfile", os.path.join(pca_dir, "qc_filtered_your_data"),"--extract", os.path.join(pca_dir, "ld_pruned_ref.prune.in"),"--keep-allele-order","--make-bed","--out", os.path.join(pca_dir, "ld_pruned_dataset")], check=True)

        # Step 3: Relatedness Screening
        print("Performing Relatedness Screening...")
        subprocess.run([plink_path, "--bfile", os.path.join(pca_dir, "ld_pruned_dataset"),"--genome","--out", os.path.join(pca_dir, "IBD_Projection")], check=True)
        subprocess.run(["perl", primus_path,"-p", os.path.join(pca_dir, "IBD_Projection.genome"),"--degree_rel_cutoff 3", "--no_PR"], check=True)
        with open(os.path.join(pca_dir, "unrelated_samples.txt"), 'w') as f:
            subprocess.run(["awk", "NR==FNR {iids[$1]; next} $2 in iids", os.path.join(pca_dir, "IBD_Projection.genome_PRIMUS/IBD_Projection.genome_unrelated_samples.txt"), os.path.join(pca_dir, "ld_pruned_dataset.fam")], stdout=f, check=True)
        with open(os.path.join(pca_dir, "individuals_related.txt"), 'w') as f:
            subprocess.run(["grep", "-v", "-F", "-f", os.path.join(pca_dir, "unrelated_samples.txt"),os.path.join(pca_dir, "ld_pruned_dataset.fam")], stdout=f, check=True)

        # Step 4: Split the Data
        print("Splitting the Data...")
        subprocess.run([plink_path, "--bfile", os.path.join(pca_dir, "ld_pruned_dataset"),"--keep", os.path.join(pca_dir, "unrelated_samples.txt"),"--keep-allele-order","--make-bed","--out", os.path.join(pca_dir, "pruned_unrelated")], check=True)
        subprocess.run([plink_path, "--bfile", os.path.join(pca_dir, "ld_pruned_dataset"),"--keep", os.path.join(pca_dir, "individuals_related.txt"),"--keep-allele-order","--make-bed","--out", os.path.join(pca_dir, "pruned_related")], check=True)

        # Step 5: Run PCA on Unrelated Individuals
        print("Running PCA on Unrelated Individuals...")
        subprocess.run([flash_path, "--bfile", os.path.join(pca_dir, "pruned_unrelated"),"--numthreads", "12","--outvec", os.path.join(pca_dir, "pca_ref.eigenvec"),"--outpve", os.path.join(pca_dir, "pca_ref.pve"),"--outval", os.path.join(pca_dir, "pca_ref.eigenval"),"--outload", os.path.join(pca_dir, "pca_ref.SNPloadings"),"--outmeansd", os.path.join(pca_dir, "pca_ref.meanSD"),"--outpc", os.path.join(pca_dir, "pca_ref.PC"),"--memory", "25000","-d", "20"], check=True)

        # Step 6: Project Related Data onto Reference PCA
        print("Projecting Related Data onto Reference PCA...")
        subprocess.run([
            flash_path, "--bfile", os.path.join(pca_dir, "pruned_related"),"--numthreads", "12","--project","--inmeansd", os.path.join(pca_dir, "pca_ref.meanSD"),"--inload", os.path.join(pca_dir, "pca_ref.SNPloadings"),"--outproj", os.path.join(pca_dir, "pca_proj.PC"),"-v"], check=True)

        # Step 7: Merge PCA Results
        print("Merging PCA Results...")
        ref_pca = pd.read_csv(os.path.join(pca_dir, "pca_ref.PC"), sep='\s+', header=None)
        proj_pca = pd.read_csv(os.path.join(pca_dir, "pca_proj.PC"), sep='\s+', header=None)

        # Rename columns
        ref_pca.columns = ["FID", "IID"] + [f"PC{i}" for i in range(1, 21)]
        proj_pca.columns = ["FID", "IID"] + [f"PC{i}" for i in range(1, 21)]

        # Merge PCA results
        merged_pca = pd.concat([ref_pca, proj_pca], ignore_index=True)
        merged_pca.to_csv(os.path.join(pca_dir, "pca_merged_results.txt"), sep="\t", index=False)

        print("PCA results merged and saved to:", os.path.join(pca_dir, "pca_merged_results.txt"))

    def run_gwas(self):
        """
        Run GWAS analysis for each phenotype derived from the suffixes.
        """
        united_pheno_file = os.path.join(self.output_dir, "phenotype_united.txt")
    
        with open(united_pheno_file, 'w') as f:
            for suffix in self.suffixes:  
                fam_file = f"{self.base_name}.fam"
                with open(fam_file, 'r') as f_in:
                    for line in f_in:
                        parts = line.strip().split()
                        phenotype = parts[1].split('-')[-1]  
                        f.write(f"{parts[0]} {parts[1]} {phenotype}\n")
        unique_phenos = pd.read_csv(united_pheno_file, sep='\s+', header=None)[2].unique()

        for pheno in unique_phenos:
            assoc_output = os.path.join(self.output_dir, f"GWAS_{pheno}")
            subprocess.run([plink_path, "--bfile", self.base_name,"--make-pheno", united_pheno_file, pheno,"--logistic","--covar", os.path.join(self.output_dir, "PCA/pca_merged_results.txt"),"--adjust","--allow-no-sex","--out", assoc_output], check=True)

In [None]:
final_base = twins.get_output_files()

print(f"Final base file: {final_base}")

In [None]:
Flag = GWAS('Path/to/last/basename', 'output/directory', suffixes)

In [None]:
# This next stage is going to take awhile, so get comfy!
Flag.run_pca()
Flag.run_gwas()

# Thanks for sticking in there! 
# Now that we have the raw results of our GWASs,  we will process them.
# We will graph these values two different ways:
- **Manhattan-Plot** 
  - A Manhattan plot visually maps genetic associations across the genome, with each point representing a variant and taller peaks highlighting stronger signals of potential biological importance (or in our case, a batch effect).
- **QQ Plot**
  - A QQ plot compares the distribution of observed p-values to the expected null distribution, helping to identify deviations that may signal batch effects or other technical artifacts in your data

In [31]:
class BatchProcessor:
    def __init__(self, base_name, data_dir, output_dir):
        self.base_name = base_name
        self.data_dir = data_dir
        self.output_dir = output_dir
        self.gwas_data = None

    def read_gwas_results(self):
        # List all .adjusted files in the directory
        assoc_files = [f for f in os.listdir(self.data_dir) if f.endswith('.adjusted')]
        gwas_data_list = []

        for adjusted_file in assoc_files:
            # Construct the corresponding logistic file name
            logistic_file = adjusted_file.replace('.assoc.logistic.adjusted', '.assoc.logistic')
            logistic_path = os.path.join(self.data_dir, logistic_file)

            # Check if the logistic file exists
            if not os.path.exists(logistic_path):
                raise FileNotFoundError(f"Matching logistic file not found for: {adjusted_file}")

            # Read the adjusted file (UNADJ values)
            df_adjusted = pd.read_csv(os.path.join(self.data_dir, adjusted_file), sep='\s+')  # Use whitespace as separator
            if not all(col in df_adjusted.columns for col in ['CHR', 'SNP', 'UNADJ']):
                raise ValueError(f"File {adjusted_file} does not have required columns: CHR, SNP, UNADJ")
            df_adjusted = df_adjusted[['CHR', 'SNP', 'UNADJ']]
            df_adjusted.rename(columns={'UNADJ': 'P'}, inplace=True)  # Rename UNADJ to P

            # Read the logistic file (BP values)
            df_logistic = pd.read_csv(logistic_path, sep='\s+')  # Use whitespace as separator
            if len(df_logistic.columns) < 3:
                raise ValueError(f"File {logistic_file} does not have at least 3 columns (expecting BP in column 3)")
            df_logistic = df_logistic.iloc[:, :3]  # Select the first 3 columns
            df_logistic.columns = ['CHR', 'SNP', 'BP']  # Ensure proper column names

            # Merge BP into the adjusted file
            df = pd.merge(df_adjusted, df_logistic, on=['CHR', 'SNP'], how='left')

            # Check for missing BP values
            if df['BP'].isna().any():
                print(f"Warning: Some SNPs in {adjusted_file} are missing BP information from {logistic_file}")

            gwas_data_list.append(df)

        # Combine all GWAS results into a single dataframe
        self.gwas_data = pd.concat(gwas_data_list, ignore_index=True)

    def generate_manhattan_plot(self):
        if self.gwas_data is None:
            raise ValueError("GWAS data not loaded. Please call read_gwas_results() first.")

        # Create Manhattan plot using qqman
        figure, axes = plt.subplots(figsize=(16, 6))
        qqman.manhattan(
            self.gwas_data,
            ax=axes,
            suggestiveline = -np.log10(1e-5),  # Threshold that should give someone pause about a variant while still technically fine
            genomewideline = -np.log10(5e-8),  # Genome-wide significance threshold
            cmap=plt.get_cmap("jet"),
            cmap_var=10,
            xrotation=45,   # Rotate x-axis labels by 45 degrees
        )
        plt.show()
    def generate_QQ_plot(self):
        if self.gwas_data is None:
            raise ValueError("GWAS data not loaded. Please call read_gwas_results() first.")
        figure, axes = plt.subplots(figsize=(8, 8))  # Adjust the figure size as needed
        # Generate the QQ plot
        qqman.qqplot(
            self.gwas_data['P'],  
                ax=axes,              
                marker='o',           
                title='QQ Plot',      
        )
        plt.show()


    def process_significant_snps(self):
        concatenated_snps_file = os.path.join(self.output_dir, 'all_significant_snps.txt')
        
        # List to store counts for each file
        counts = []

        # DataFrame to store all significant SNPs
        all_significant_snps = pd.DataFrame()

        for adjusted_file in os.listdir(self.data_dir):
            if adjusted_file.endswith('.adjusted'):
                # Read the adjusted file
                df = pd.read_csv(os.path.join(self.data_dir, adjusted_file), sep='\s+')
                
                # Filter significant SNPs
                significant_snps = df[df['BONF'] < 5e-8]['SNP'].drop_duplicates()
                
                # Append to the all_significant_snps DataFrame
                all_significant_snps = pd.concat([all_significant_snps, significant_snps], ignore_index=True)
                
                # Count the number of significant SNPs
                count = len(significant_snps)
                counts.append((adjusted_file, count))

        # Drop duplicates from the concatenated DataFrame
        all_significant_snps.drop_duplicates(inplace=True)
        
        # Save the concatenated SNPs to a single file
        all_significant_snps.to_csv(concatenated_snps_file, index=False, header=False)
        
        # Save the concatenated SNPs as a class attribute
        self.significant_snps = concatenated_snps_file

        # Print the counts as a table in the notebook
        counts_df = pd.DataFrame(counts, columns=['File', 'Significant SNPs'])
        display(counts_df)

        print(f"Processing complete. SNPs that need to be removed due to batch effects can be found: {concatenated_snps_file}")
    
    def remove_batchs(self, base):
        Final_base= os.path.join(self.output_dir, base)
        subprocess.run([plink_path, "--bfile", self.base_name, "--exclude", self.significant_snps, "--keep-allele-order", "--make-bed", "--out", Final_base], check=True)
        self.base_name = Final_base

    def cleanup_intermediate_files(self):
        # Get the base name without the directory path
        final_base_name = os.path.basename(self.base_name)

        # List all files to delete based on patterns
        patterns_to_delete = [
            os.path.join(self.output_dir, '*.bed'),
            os.path.join(self.output_dir, '*.bim'),
            os.path.join(self.output_dir, '*.fam'),
            os.path.join(self.output_dir, '*log'),
            os.path.join(self.output_dir, '*update*'),
            os.path.join(self.output_dir, 'GWAS*'),
            os.path.join(self.output_dir, 'id_mapping.txt'),
            os.path.join(self.output_dir, 'merged*'),
            os.path.join(self.output_dir, '*united*'),
        ]

        for pattern in patterns_to_delete:
            for file_path in glob.glob(pattern):
                file_name = os.path.basename(file_path)
                # Skip files that match the final base_name (e.g., final_base.bed, final_base.bim, final_base.fam)
                if not file_name.startswith(final_base_name):
                    os.remove(file_path)
        print("You have made it through the merger! Happy Hunting!")


In [None]:
final_base = twins.get_output_files()
print(f"Final base file: {final_base}")

In [None]:
Batch = BatchProcessor('path/to/last/basename','path/to/GWAS/output', 'path/to/output/directory')

# Let us start with the manhattan plot since this will give us the best image of where we are at in terms of variants significantly associated with a particular study.
---
- For the manhattan plot there will be two lines:
  - One showing the line for genome-wide significane
  - A lower "suspicion" line, alerting us to variants approaching significance
- For both the manhattan plot and QQ plot we will be looking at the unadjusted P-values (those not considering ancestry). We will handle that in the last step of this class but keep that in mind!

In [None]:
Batch.read_gwas_results()
Batch.generate_manhattan_plot()

# Now let us move on to the QQ plot, if there was no significane in our variants we should expect to see a fairly straight line. If we see a leftward skew it can be telling us that there are some significant variants to be found!

In [None]:
Batch.generate_QQ_plot()

# Up to this point we have been looking at the unadjusted P-values, but we didn't do all that processing for nothing! 
# Let us now grab the variants that are *truly* batch effected (i.e., not an ancestry-based variant)

In [None]:
Batch.process_significant_snps()

# As you can see, the value of truly batch effected variants is smaller than what the manhattan plot would have suggested! 
- We have made it to the end! Lastly we will screen out these significant variants and clean up the intermediate files and then you are off to do amazing thing!

In [None]:
Batch.remove_batchs("/home/belleza/scratch/Tutorial/test/FINAL_FILE_NAME")
Batch.cleanup_intermediate_files()

In [None]:
Batch.remove_batchs("path_to_output_directory/FINAL_FILE_NAME")
Batch.cleanup_intermediate_files()

# You did it! You have successfully lifted over, quality controlled, and merged multiple open-access datasets to make your very own reference panel!
# Use it in good health!
---
- If you just can't get enough there are some optional tasks that can be done below!
  - They include:
    - Removing the suffixes we added if you do not want them in the final dataset
    - Changing your underscores to a "." (If you convert your data to a VCF it will conjoin your FID and IID as FID_IID, so if there underscores in the FID or IID, this can cause some headaches)
    - Convert your data to **V**ariant **C**all **F**ormat (VCF), a single file format commonly used bioinformatics research!


In [None]:
#Optional Final Processing steps
class Optional:
    def __init__(self, base_name, output_dir):
        self.base_name = base_name
        self.output_dir = output_dir 
        self.current_base = base_name    
        
    #Remove the suffix we added (Will not work if you have not deleted duplicate samples [tied_twins])
    def Suffix_remover(self, base):
        # Construct the path to the .fam file
        fam_file = f"{self.base_name}.fam"

        # Define the final output base name
        Final_base = os.path.join(self.output_dir, base)

        # Define the ID mapping file path
        id_mapping_file = os.path.join(self.output_dir, "id_mapping.txt")

        # Write the ID mapping to the file
        with open(fam_file, 'r') as infile, open(id_mapping_file, 'w') as outfile:
            for line in infile:
                # Split the line into FID and IID (first two columns)
                fid, iid = line.strip().split()[:2]

                # Remove anything after the final '.' in the IID
                new_iid = iid.rsplit('.', 1)[0]  # Split on the last '.' and take the first part

                # Write the old FID, old IID, new FID (unchanged), and new IID to the file
                outfile.write(f"{fid} {iid} {fid} {new_iid}\n")

        # Run PLINK to update IDs using the generated mapping file
        subprocess.run([plink_path,'--bfile', self.current_base,'--update-ids', id_mapping_file,'--make-bed','--keep-allele-order','--out', Final_base],check=True)

        # Update the current base to the new base
        self.current_base = Final_base

        # Delete the ID mapping file after PLINK has finished
        os.remove(id_mapping_file)


            
    #Change all _ to "."
    def Underscore_Remover(self, base):
        fam_file = f"{self.base_name}.fam"
        id_mapping_file = os.path.join(self.output_dir, "id_mapping.txt")
        Final_base= os.path.join(self.output_dir, base)
        with open(fam_file, 'r') as infile, open(id_mapping_file, 'w') as outfile:
            for line in infile:
                # Split the line into FID and IID (first two columns)
                fid, iid = line.strip().split()[:2]

                # Reconstruct FID by replacing underscores with periods
                new_fid = fid.replace('_', '.')

                # Reconstruct IID by replacing underscores with periods
                new_iid = iid.replace('_', '.')

                # Write the old FID, old IID, new FID, and new IID to the output file
                outfile.write(f"{fid} {iid} {new_fid} {new_iid}\n")
        subprocess.run([plink_path, '--bfile', self.current_base, '--update-ids', id_mapping_file, '--make-bed', '--keep-allele-order', '--out', Final_base], check=True)
        self.current_base = Final_base
    #Convert to VCF

    def VCFMaker(self, base):
        Final_base= os.path.join(self.output_dir, base)
        subprocess.run([plink_path, '--bfile', self.current_base, '--recode', 'vcf', '--out', Final_base], check=True)
        self.current_base = Final_base
        

In [None]:
WrapUp = Optional("Path/to/final/basename", "Path/to/output/directory"

In [None]:
#Suffix Remover
WrapUp.Suffix_remover("path/to/ouput/basename") #If you want the name to be the same, input Path/to/final/basename that you plugged in when setting up WrapUP

In [None]:
#Underscore Removal
WrapUp.Underscore_Remover("path/to/ouput/basename") #If you want the name to be the same, input Path/to/final/basename that you plugged in when setting up WrapUP

In [None]:
#Convert to VCF
WrapUp.VCFMaker("path/to/ouput/basename") #If you want the name to be the same, input Path/to/final/basename that you plugged in when setting up WrapUP