In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import io
import os

In [None]:

def read_vcf(path):
    '''
    input: path to vcf file
    returns: pandas df from vcf, the header is ignored
    #partly taken from https://gist.github.com/dceoy/99d976a2c01e7f0ba1c813778f9db744, 20.12.2021
    '''
    with open(path, 'r') as f: 
        lines = [l for l in f if not l.startswith('##')]
        
    return pd.read_csv(
        io.StringIO(''.join(lines)),
        dtype={'#CHROM': str, 'POS': int, 'ID': str, 'REF': str, 'ALT': str,
               'QUAL': float, 'FILTER': str, 'INFO': str},
        sep='\t'
    ).rename(columns={'#CHROM': 'CHROM'})

In [None]:
def read_INFO(df):
    '''
    input: pandas vcf dataframe
    returns: dataframe containing recoded column INFO from the vcf dataframe
    '''
    new = df["INFO"].str.split(";", expand = True)
    new = new.dropna()
    new.columns = pd.DataFrame(new.iloc[0,].str.split("=")).apply(lambda row : row[0][0], axis = 1)
    new_new = new.applymap(lambda cell: cell.split("=")[-1])
    return(new_new)

In [None]:
def number_of_heterozygots(df, column):
    '''
    input: pandas vcf dataframe, name of the sample
    returns: valuecounts of mono/heterozygots (0/0,0/1 ...)
    '''
    
    new = df[column].str.split(":", expand = True)[0]
    new = new.dropna()
    return(new.value_counts())

In [None]:
def numeric_INFO(INFO):
    
    INFO['MQ'] = pd.to_numeric(INFO['MQ'])
    INFO['AN'] = pd.to_numeric(INFO['AN'])
    INFO['BaseQRankSum'] = pd.to_numeric(INFO['BaseQRankSum'])
    INFO['ClippingRankSum'] = pd.to_numeric(INFO['ClippingRankSum'])
    INFO['ExcessHet'] = pd.to_numeric(INFO['ExcessHet'])
    INFO['FS'] = pd.to_numeric(INFO['FS'])
    INFO['MQRankSum'] = pd.to_numeric(INFO['MQRankSum'])
    INFO['QD'] = pd.to_numeric(INFO['QD'])
    INFO['ReadPosRankSum'] = pd.to_numeric(INFO['ReadPosRankSum'])
    INFO['SOR'] = pd.to_numeric(INFO['SOR'])
    

In [None]:
def plot_dist(df, column, threshold):
    '''
    Plot the distribution of the metric with the treshold
    Default thresholds GATK hard filtering (GATK 4.0 VariantFiltration):
        For indels:
            QD < 2.0
            QUAL < 30.0
            FS > 200.0
            ReadPosRankSum < -20.0
            
        For SNPs:
            QD < 2.0
            QUAL < 30.0
            SOR > 3.0
            FS > 60.0
            MQ < 40.0
            MQRankSum < -12.5
            ReadPosRankSum < -8.0

    '''
    
    sns.displot(df[column], x=column, kind="kde")
    plt.axvline(treshold, 0, 1, color = 'red')

In [None]:
def change_vcf_sample_names(old_vcf_path, new_vcf_path, new_names):
    """
    Change the names of samples in vcf: last columns are renamed to the ones given in the list
    input: 
        old_vcf - path to the vcf file, which samples need to be renamed
        new_vcf - path to the new vcf file with renamed samples
        new_names - path to the file with new names separated by "\t"
        
    return: old name samples
    
    ! UPD: similar thing can be done in shell by "bcftools reheader -s new_names.list -o new_vcf_path old_vcf_path"
    
    """
    old_names = []
    with open(new_names, 'r') as sample_names:
        samples = [l.strip() for l in sample_names]
        if (len(samples)== 0):
            print('FILE WITH NEW NAMES IS EMPTY!')
    with open(old_vcf_path, 'r') as f:
        with open(new_vcf_path, 'w') as new_vcf: 
            for l in f: 
                if l.startswith('#CHROM'):
                    headers = l.split('\t')
                    if (headers[-(len(samples)+1)]!="FORMAT"):
                        print('MAKE SURE THAT NUMBER OF NAMES IS EQUAL TO NUMBER OF SAMPLES IN THE VCF!')
                    old_names = headers[-len(samples):] 
                    del headers[-len(samples):]
                    headers+=samples
                    print(*headers, sep = "\t", file = new_vcf) 
                else:
                    print(l, file = new_vcf)
    return(old_names)             

In [None]:
def clean_vcf(path_in, path_out):
    '''
    Deleats all the raws that did not pass one of the filters
    input: path to input vcf
    output: path to output cleaned vcf
    '''  
    with open(path_in, 'r') as f:
        comments = []
        raws = []
        for l in f:
            if l.startswith('##'):
                comments.append(l)
            else:
                raws.append(l)

    df = pd.read_csv(io.StringIO(''.join(raws)),
                     dtype={'#CHROM': str, 'POS': int, 'ID': str, 'REF': str, 'ALT': str,
                            'QUAL': float, 'FILTER': str, 'INFO': str}, sep='\t').rename(columns={'#CHROM': 'CHROM'})
    df_cleaned = df[df['FILTER']=='PASS']
    with open(path_out, 'w') as output:
        print(*comments, sep = '', file = output)
    with open(path_out, 'a') as output:    
        df_cleaned.to_csv(output, sep='\t',index=False, header=True)
    return df_cleaned    