In [9]:
import os
import numpy as np
import math
import re
import random
import shutil
import gzip
import pandas as pd
from scipy.special import softmax
from sklearn.model_selection import train_test_split
from sklearn.utils import resample
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from tensorflow.keras import layers
from tensorflow.keras import regularizers
from tensorflow.keras.preprocessing.sequence import pad_sequences
import tensorflow_addons as tfa
from sklearn import metrics
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import constraints
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras.applications import efficientnet as efn
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
from tensorflow.keras.constraints import Constraint
from scipy.spatial.distance import squareform
%matplotlib inline
from toolz import interleave
from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.linear_model import LassoCV, ElasticNetCV
from sklearn.model_selection import KFold,StratifiedKFold

print("Tensorflow version " + tf.__version__)

Tensorflow version 2.10.1


In [2]:
# Detect hardware, return appropriate distribution strategy
try:
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', TPU.master())
except ValueError:
    print('Running on GPU')
    TPU = None

if TPU:
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.TPUStrategy(TPU)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

N_REPLICAS = strategy.num_replicas_in_sync
# Number of computing cores, is 8 for a TPU V3-8
print(f'N_REPLICAS: {N_REPLICAS}')

Running on GPU
N_REPLICAS: 1


![ploidy support](./assets/ploidy.jpg)

In [35]:
class DataReader:
    """
    If the reference is unphased, cannot handle phased target data, so the valid (ref, target) combinations are:
    (phased, phased), (phased, unphased), (unphased, unphased) 
    If the reference is haps, the target cannot be unphased (can we merge every two haps to form unphased diploids?)
    Important note: for each case, the model should be trained separately
    """
    def __init__(self, ):
        self.allele_count = 2
        self.genotype_vals = None
        self.ref_is_phased = None
        self.reference_panel = None
        self.VARIANT_COUNT = 0
        self.is_phased = False
        self.MISSING_VALUE = None
        self.ref_is_hap = False
        self.target_is_hap = False
        self.ref_n_header_lines = []
        self.ref_n_data_header = ""
        self.target_n_header_lines = []
        self.target_n_data_header = ""
        self.map_values_1_vec = np.vectorize(self.map_hap_2_ind_parent_1)
        self.map_values_2_vec = np.vectorize(self.map_hap_2_ind_parent_2)
        self.delimiter_dictionary = {"vcf":"\t", "csv":",", "tsv":"\t", "infer":"\t"}
        self.training_file_extension = "vcf"
        self.test_file_extension = "vcf"
        self.target_is_phased = True
        
    def read_star_sv_files(self, file_path, is_reference=False, separator="\t", first_column_is_index=True, comments="##") -> pd.DataFrame:
        """
        In this form the data should not have more than a column for ids. The first column can be either sample ids or variant ids. In case of latter, make sure to pass :param variants_as_columns=True. Example of sample input file:
        ## Comment line 0
        ## Comment line 1
        Sample_id 17392_chrI_17400_T_G ....
        HG1023               1
        HG1024               0
        """
        print("Reading the file...")
        data_header = None
        root, ext = os.path.splitext(file_path)
        with gzip.open(file_path, 'rt') if ext in {'.gz', '.zip'} else open(file_path, 'rt') as f_in:
            # skip info
            while True:
                line = f_in.readline()
                if line.startswith(comments):
                    if is_reference:
                        self.ref_n_header_lines.append(line)
                    else:
                        self.target_n_header_lines.append(line)
                else:
                    data_header = line
                    break
        if data_header is None:
            raise IOError("The file only contains comments!")
        df = pd.read_csv(file_path,
                           sep=separator,
                           comment=comments[0],
                           index_col=0 if first_column_is_index else None,
                           dtype='category',
                           names=data_header.strip().split(separator))
        df = df.astype('category')
        return df


    def find_file_extension(self, file_path, file_format, delimiter):
        # Default assumption
        separator = "\t"
        found_file_format = "vcf"
        
        if file_format not in {"vcf", "csv", "tsv", "infer"}:
            raise ValueError("File extension must be one of {'vcf', 'csv', 'tsv', 'infer'}.")
        if file_format == 'infer':
            file_name_tokenized = file_path.split(".")
            for possible_extension in file_name_tokenized[::-1]:
                if possible_extension in {"vcf", "csv", "tsv"}:
                    found_file_format = possible_extension
                    separator = self.delimiter_dictionary[possible_extension] if delimiter is not None else delimiter
                    break
        else:
            found_file_format = file_format
            
        return found_file_format, separator

    
    def assign_training_set(self, file_path,
                            target_is_gonna_be_phased_or_haps,
                            variants_as_columns=False,
                            delimiter=None,
                            file_format="infer",
                            first_column_is_index=True,
                            comments="##") -> None:
        """
        :param file_path: reference panel or the training file path. Currently, VCF, CSV, and TSV are supported
        :param target_is_gonna_be_phased: Indicates whether the targets for the imputation will be phased or unphased.
        :param variants_as_columns: Whether the columns are variants and rows are samples or vice versa.
        :param delimiter: the seperator used for the file
        :param file_format: one of {"vcf", "csv", "tsv", "infer"}. If "infer" then the class will try to find the extension using the file name.
        :param first_column_is_index: used for csv and tsv files to indicate if the first column should be used as identifier for samples/variants.
        :param comments: The token to be used to filter out the lines indicating comments.
        :return: None
        """
        self.target_is_phased = target_is_gonna_be_phased_or_haps
        sample_value_index = 2
        self.training_file_extension, separator = self.find_file_extension(file_path, file_format, delimiter)

        self.reference_panel = self.read_star_sv_files(file_path, is_reference=True, separator=separator, first_column_is_index=first_column_is_index, comments=comments) if self.training_file_extension != 'vcf' else self.read_star_sv_files(file_path, is_reference=True, separator='\t', first_column_is_index=False, comments="##")
        
        if self.training_file_extension != "vcf":
            if variants_as_columns:
                self.reference_panel = self.reference_panel.transpose()
            self.reference_panel.reset_index(inplace=True)
            self.reference_panel.rename(columns={self.reference_panel.columns[0]: "ID"}, inplace=True)
        else: # VCF
            sample_value_index += 8
        
        self.ref_is_hap = not("|" in self.reference_panel.iloc[0, sample_value_index] or "/"  in self.reference_panel.iloc[0, sample_value_index])
        self.ref_is_phased = "|" in self.reference_panel.iloc[0, sample_value_index] or self.ref_is_hap
        ## For now I won't support merging haploids into unphased data
        if self.ref_is_hap and not target_is_gonna_be_phased_or_haps:
            raise ValueError("The reference contains haploids while the target will be unphased diploids. The model cannot predict the target at this rate.")

        if not self.ref_is_phased and target_is_gonna_be_phased_or_haps:
            raise ValueError("The reference contains unphased diploids while the target will be phased or haploid data. The model cannot predict the target at this rate.")

        self.VARIANT_COUNT = self.reference_panel.shape[0]
        print(f"{self.VARIANT_COUNT} {'haplotype' if self.ref_is_hap else 'diplotype'} variants found!")

        self.is_phased = target_is_gonna_be_phased_or_haps and self.ref_is_phased
        
        allele_sep = "|" if self.ref_is_phased else "/"
        def get_num_allels(g):
            v1, v2 = g.split(allele_sep)
            return max(int(v1), int(v2)) + 1

        genotype_vals = np.unique(self.reference_panel.iloc[:, sample_value_index-1:].values)
        if self.ref_is_phased and not target_is_gonna_be_phased_or_haps: # In this case ref is not haps
            phased_to_unphased_dict = {}
            for i in range(genotype_vals.shape[0]):
                key = genotype_vals[i]
                v1, v2 = [int(s) for s in genotype_vals[i].split(allele_sep)]
                genotype_vals[i] = f"{min(v1, v2)}{allele_sep}{max(v1, v2)}"
                phased_to_unphased_dict[key] = genotype_vals[i]
            self.reference_panel.iloc[:, sample_value_index-1:].replace(phased_to_unphased_dict, inplace=True)

        self.genotype_vals = np.unique(genotype_vals)

        self.allele_count = max(map(get_num_allels, self.genotype_vals)) if not self.ref_is_hap else len(self.genotype_vals)
        self.MISSING_VALUE = self.allele_count if self.is_phased else len(self.genotype_vals)

        if self.is_phased:
            self.hap_map = {str(i): i for i in range(self.allele_count)}
            self.hap_map.update({".": self.allele_count})
            self.r_hap_map = {i:k for k, i in self.hap_map.items()}
            self.map_preds_2_allele = np.vectorize(lambda x: self.r_hap_map[x])

        self.SEQ_DEPTH = self.allele_count + 1


    def assign_test_set(self, file_path,
                        variants_as_columns=False,
                        delimiter=None,
                        file_format="infer",
                        first_column_is_index=True,
                        comments="##") -> None:
        """
        :param file_path: reference panel or the training file path. Currently, VCF, CSV, and TSV are supported
        :param variants_as_columns: Whether the columns are variants and rows are samples or vice versa.
        :param delimiter: the seperator used for the file
        :param file_format: one of {"vcf", "csv", "tsv", "infer"}. If "infer" then the class will try to find the extension using the file name.
        :param first_column_is_index: used for csv and tsv files to indicate if the first column should be used as identifier for samples/variants.
        :param comments: The token to be used to filter out the lines indicating comments.
        :return: None
        """
        if self.reference_panel is None:
            raise RuntimeError("First you need to use 'DataReader.assign_training_set(...) to assign a training set.' ")

        sample_value_index = 2
        target_file_extension, separator = self.find_file_extension(file_path, file_format, delimiter)

        test_df = self.read_star_sv_files(file_path, is_reference=False, separator=separator, first_column_is_index=first_column_is_index, comments=comments) if self.training_file_extension != 'vcf' else self.read_star_sv_files(file_path, is_reference=False, separator='\t', first_column_is_index=False, comments="##")

        if self.training_file_extension != "vcf":
            if variants_as_columns:
                test_df = test_df.transpose()
            test_df.reset_index(inplace=True)
            test_df.rename(columns={test_df.columns[0]: "ID"}, inplace=True)
        else: # VCF
            sample_value_index += 8

        is_hap = not("|" in test_df.iloc[0, sample_value_index] or "/"  in test_df.iloc[0, sample_value_index])
        is_phased = "|" in test_df.iloc[0, sample_value_index]
        if (is_hap or is_phased) and not self.ref_is_phased:
            raise RuntimeError("")

        ALLELE_SEP = "|" if self.ref_is_phased else "/"

        def key_gen(v1, v2):
            return f"{v1}{ALLELE_SEP}{v2}"

        self.genotype_keys = np.array([key_gen(i,j) for i in range(self.allele_count) for j in range(self.allele_count)]) if self.is_phased else self.genotype_vals
        self.genotype_keys = np.hstack([self.genotype_keys, [".|."] if self.is_phased else ["./."]])
        self.replacement_dict = {g:i for i,g in enumerate(self.genotype_keys)}
        self.reverse_replacement_dict = {i:g for g,i in self.replacement_dict.items()}

    def map_hap_2_ind_parent_1(self, x):
        return self.hap_map[x.split('|')[0]]

    def map_hap_2_ind_parent_2(self, x):
        return self.hap_map[x.split('|')[1]]

    def __get_forward_data(self, data: pd.DataFrame):
        if self.is_phased:
            # break it into haplotypes
            _x = np.empty((data.shape[1] * 2, data.shape[0]), dtype=np.int32)

            _x[0::2] = self.map_values_1_vec(data.values.T)
            _x[1::2] = self.map_values_2_vec(data.values.T)
            return _x
        else:
            return data.replace(self.replacement_dict).values.T.astype(np.int32)

    def get_ref_set(self, starting_var_index=None, ending_var_index=None):
        if starting_var_index>=0 and ending_var_index>=starting_var_index:
            return self.__get_forward_data(self.reference_panel.iloc[starting_var_index:ending_var_index, 9:])
        else:
            print("No variant indices provided or indices not valid, using the whole sequence...")
            return self.__get_forward_data(self.reference_panel.iloc[:, 9:])

    def get_target_set(self, starting_var_index=None, ending_var_index=None):
        if starting_var_index>=0 and ending_var_index>=starting_var_index:
            return self.__get_forward_data(self.target_set.iloc[starting_var_index:ending_var_index, 9:])
        else:
            print("No variant indices provided or indices not valid, using the whole sequence...")
            return self.__get_forward_data(self.target_set.iloc[:, 9:])

    def convert_haps_to_genotypes(self, allele_probs):
      '''output format: GT:DS:GP'''
      FORMAT = "GT:DS:GP"
      n_haploids, n_variants, n_alleles = allele_probs.shape
      allele_probs_normalized = softmax(allele_probs, axis=-1)

      if n_haploids % 2 != 0:
          raise ValueError("Number of haploids should be even.")

      n_samples = n_haploids // 2
      genotypes = np.zeros((n_samples, n_variants), dtype=object)

      for i in tqdm(range(n_samples)):
        haploid_1 = allele_probs_normalized[2 * i]
        haploid_2 = allele_probs_normalized[2 * i + 1]

        for j in range(n_variants):
          phased_probs = np.multiply.outer(haploid_1[j], haploid_2[j]).flatten()
          unphased_probs = np.array([phased_probs[0], sum(phased_probs[1:3]), phased_probs[-1]])
          unphased_probs_str = ",".join([f"{v:.6f}" for v in unphased_probs])
          alt_dosage = np.dot(unphased_probs, [0, 1, 2])
          variant_genotypes = [str(v) for v in np.argmax(allele_probs_normalized[i*2:(i+1)*2, j], axis=-1)]
          genotypes[i, j] = '|'.join(variant_genotypes) + f":{alt_dosage:.3f}:{unphased_probs_str}"

      new_vcf = self.target_set.copy()
      new_vcf.iloc[:n_variants, 9:] = genotypes.T
      new_vcf["FORMAT"] = FORMAT
      new_vcf["QUAL"] = "."
      new_vcf["FILTER"] = "."
      new_vcf["INFO"] = "IMPUTED"
      return new_vcf

    def convert_unphased_probs_to_genotypes(self, allele_probs):
      '''output format: GT:DS:GP'''
      FORMAT = "GT:DS:GP"
      n_samples, n_variants, n_alleles = allele_probs.shape
      allele_probs_normalized = softmax(allele_probs, axis=-1)
      genotypes = np.zeros((n_samples, n_variants), dtype=object)

      for i in tqdm(range(n_samples)):
          for j in range(n_variants):
              unphased_probs = allele_probs_normalized[i, j]
              unphased_probs_str = ",".join([f"{v:.6f}" for v in unphased_probs])
              alt_dosage = np.dot(unphased_probs, [0, 1, 2])
              variant_genotypes = np.vectorize(self.reverse_replacement_dict.get)(np.argmax(unphased_probs, axis=-1)).flatten()
              genotypes[i, j] = '/'.join(variant_genotypes) + f":{unphased_probs_str}:{alt_dosage:.3f}"

      new_vcf = self.target_set.copy()
      new_vcf.iloc[:, 9:] = genotypes.T
      new_vcf["FORMAT"] = FORMAT
      new_vcf["QUAL"] = "."
      new_vcf["FILTER"] = "."
      new_vcf["INFO"] = "IMPUTED"
      return new_vcf

    def __get_headers_for_output(self):
      headers = ["##fileformat=VCFv4.2",
           '''##source=STI v1.0.0''',
           '''##INFO=<ID=IMPUTED,Number=0,Type=Flag,Description="Marker was imputed">''',
           '''##FORMAT=<ID=GT,Number=1,Type=String,Description="Genotype">''',
           '''##FORMAT=<ID=DS,Number=A,Type=Float,Description="Estimated Alternate Allele Dosage : [P(0/1)+2*P(1/1)]">''',
           '''##FORMAT=<ID=GP,Number=G,Type=Float,Description="Estimated Posterior Probabilities for Genotypes 0/0, 0/1 and 1/1">''']
      return headers

    def preds_to_genotypes(self, preds):
        """
        WARNING: This only supports bi-allelic data right now!
        :param preds: numpy array of (n_samples, n_variants, n_alleles)
        :return: numpy array of the same shape, with genotype calls, e.g., "0/1"
        """
        if self.is_phased:
          return self.convert_haps_to_genotypes(preds)
        else:
          return self.convert_unphased_probs_to_genotypes(preds)

    def write_ligated_results_to_vcf(self, df, file_name):
      with gzip.open(file_name, 'wt') if file_name.endswith(".gz") else open(file_name, 'wt') as f_out:
          # write info
          f_out.write("\n".join(self.__get_headers_for_output())+"\n")
      df.to_csv(file_name, sep="\t", mode='a', index=False)

In [36]:
df = DataReader().read_star_sv_files("./data/STI_benchmark_datasets/DELL.chr22.genotypes.full.vcf.gz", is_reference=False, separator="\t", first_column_is_index=False, comments="##")
df

Reading the file...


Unnamed: 0,#CHROM,POS,ID,REF,ALT,QUAL,FILTER,INFO,FORMAT,HG00096,...,NA21128,NA21129,NA21130,NA21133,NA21135,NA21137,NA21141,NA21142,NA21143,NA21144
0,22,16533236,SI_BD_17525,C,<CN0>,100,PASS,AC=125;AF=0.0249601;AFR_AF=0.09;AMR_AF=0.0086;...,GT,0|0,...,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0
1,22,16577743,YL_CN_CEU_5170,T,<CN0>,100,PASS,AC=29;AF=0.00579073;AFR_AF=0.0098;AMR_AF=0.001...,GT,0|0,...,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0
2,22,16589908,SI_BD_17528,T,<CN0>,100,PASS,AC=186;AF=0.0371406;AFR_AF=0.1021;AMR_AF=0.014...,GT,0|0,...,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0
3,22,16633635,YL_CN_STU_4360,G,<CN0>,100,PASS,AC=2;AF=0.00039936;AFR_AF=0;AMR_AF=0;AN=5008;C...,GT,0|0,...,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0
4,22,16940402,BI_GS_DEL1_B2_P2862_55,A,<CN0>,100,PASS,AC=2;AF=0.00039936;AFR_AF=0;AMR_AF=0;AN=5008;C...,GT,0|0,...,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
568,22,50808793,BI_GS_DEL1_B5_P2896_582,T,<CN0>,100,PASS,AC=3;AF=0.00059904;AFR_AF=0.0023;AMR_AF=0;AN=5...,GT,0|0,...,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0
569,22,50975825,BI_GS_DEL1_B2_P2896_693,A,<CN0>,100,PASS,AC=118;AF=0.0235623;AFR_AF=0.087;AMR_AF=0.0043...,GT,0|0,...,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0
570,22,51054942,UW_VH_22595,T,<CN0>,100,PASS,AC=1;AF=0.00019968;AFR_AF=0;AMR_AF=0;AN=5008;C...,GT,0|0,...,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0
571,22,51163690,BI_GS_DEL1_B2_P2897_127,C,<CN0>,100,PASS,AC=1;AF=0.00019968;AFR_AF=0;AMR_AF=0;AN=5008;C...,GT,0|0,...,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0,0|0


In [34]:
from sys import getsizeof
getsizeof(df)/(1024.0**3)

0.08057642728090286

In [37]:
from sys import getsizeof
getsizeof(df)/(1024.0**3)

0.002566937357187271

In [39]:
type(df.iloc[0, 10])

str