In [None]:
# %load ../snippets/basic_settings.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import seaborn as sns
import sys
import plotly.express as px
import yaml


sns.set_context("notebook", font_scale=1.1)
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)
plt.rcParams["figure.figsize"] = (16, 12)
plt.rcParams['savefig.dpi'] = 200
plt.rcParams['figure.autolayout'] = False
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['font.size'] = 16
plt.rcParams['lines.linewidth'] = 2.0
plt.rcParams['lines.markersize'] = 8
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['text.usetex'] = False  # True activates latex output in fonts!
plt.rcParams['font.family'] = "serif"
plt.rcParams['font.serif'] = "cm"
pd.set_option('display.float_format', lambda x: '{:,.2f}'.format(x))

In [None]:
from typing import Optional, Tuple, List, Union

In [None]:
config_file = "../nguyenb_config.yaml"
with open(config_file) as file:
    # The FullLoader parameter handles the conversion from YAML
    # scalar values to Python the dictionary format
    configs = yaml.load(file, Loader=yaml.FullLoader)

In [None]:
# Run on server:
root = Path(configs['root']['server'])
scratchDir = configs['scratchDir']['server']

In [None]:
mapDir = root/configs['mapDir']
countDir = root/configs['libraryCountsDir']
resultDir = root/configs['resultDir']
sampleData = pd.read_csv(root/configs['sampleData'])

# Cleaning the data 
- Using data from library_11_1 as a test case

In [None]:
counts = pd.read_csv(countDir/"library_11_1_mbarq_merged_counts.csv")
countsFilt = counts[counts.sum(axis=1, numeric_only=True) > 10]


# Normalized for library depth and log transform
annotation_cols = list(countsFilt.columns[0:2])
sampleIDs = list(countsFilt.columns[2:])
cpms = countsFilt.copy().set_index(list(annotation_cols))
cpms = np.log2(cpms/cpms.sum()*1000000 +0.5).reset_index()

In [None]:
cpms

In [None]:
countsFilt.sample(5)

## Extract counts for control barcodes

In [None]:
def get_control_counts(control_file: Union[str, Path], counts_df: pd.DataFrame) -> pd.DataFrame:
    
    """
    Control file: No header
    
    [barcode],[conc],[phenotype]
    
    1. Check that the first column contains barcodes
    2. Check that second column is numeric -> should contain concentrations
    3. If there is a thrid column it should be string, and at least one should be ['wt', 'WT', 'wildtype']
    
    counts_df: 
    barcode,geneName,sample1,sample2
    
    Merge left, convert NA to 0 
    return merged data frame
    
    """
    cntrl_df = pd.read_csv(control_file, header=None)
    num_cols = cntrl_df.shape[1]
    
    # Add column validation code here
    col_names = ['barcode', 'concentration', 'genotype']
    cntrl_df.columns = col_names[0:num_cols]
    
    # Add column validation for counts_df
    fdf = cntrl_df.merge(counts_df, how='left', on='barcode').fillna(0)
    
    if num_cols == 3:
        if any(cntrl_df.genotype.isin(['wt', 'WT', 'wildtype'])):
            wt_df = fdf[fdf.genotype.isin(['wt', 'WT', 'wildtype'])]
        else:
            wt_df = pd.DataFrame()
    
    return wt_df, fdf
    

In [None]:
def calculate_correlation(control_df: pd.DataFrame, sampleIDs: List, cutoff: float=0.8):
    """
    Given a data frame with a 'concentration' column and sampleID (normalised) counts + list of sampleIDs, 
    calculate correlation between concentration 
    return a list of 'good samples', i.e. passing the cutoff
    
    Assert concentration column is present
    Assert sampleIDs are in control_df columns
    """
    concentrations = np.log2(control_df.concentration)
    samples = control_df[sampleIDs]
    corr_df = pd.DataFrame(samples.corrwith(concentrations), columns=['R'])
    corr_df["R2"] = corr_df.R**2
    good_samples = corr_df[corr_df.R2 > cutoff].index
    return corr_df, good_samples

In [None]:
def draw_correlation_plots(control_df, sampleIDs):
    """
    given a complete control_df, draw correlation plot for each sampleID for each genotype
    """
    pass

In [None]:
not_found_in_dnaid1315 = ['AACAACACGGTAAGCAA', 'AGAATGACCCGGAGGCT', 'AGTCATCGATGCTATAT', 'CCGACGACTGATTGTCC',
           'CTACGACAGGGACTTAA', 'GTGTATAGCAGGAACCC', 'GTGTATAGCAGGAACCC', 'TAAGTCCGGGCTAAGTC',
           'TATAACACCCCCGATTC', 'TCTCACGCAGCGTTTCG']

In [None]:
control_file = root/"controls_3col.csv"
wt_df, cdf = get_control_counts(control_file, cpms)

wt_df = wt_df[~wt_df.barcode.isin(not_found_in_dnaid1315)]
df, gs = calculate_correlation(wt_df, sampleIDs)

In [None]:
# cdf1 = cdf[[1]]
# cdf1.to_csv(root/"controls_1col.csv", index=False, header=None)
# cdf2 = cdf[[1,3]]
# cdf2.to_csv(root/"controls_2col.csv", index=False,header=None)
# cdf3 = cdf[[1,3,2]]
# cdf3.to_csv(root/"controls_3col.csv", index=False,header=None)

# Run MAGeCK batch correction

In [None]:
sampleData

In [None]:
def read_in_sample_data(sample_data_file):
    pass

In [None]:
def prepare_mageck_dataset(clean_df, meta_df, library):
    batch_file = outDir/f"{library}_batch.txt"
    count_file = outDir/f"{library}_count.txt"
    batch_df = meta_df[meta_df.library == library][['sampleID', 'batch', 'day']].sort_values(['day', 'batch'])
    batch_df.to_csv(batch_file, index=False, sep='\t')
    magDf = clean_df[clean_df.library == library]
    magDf2 = magDf[['barcode', 'ShortName', 'barcode_cnt', 'sampleID']]
    magDf2 = (magDf2.pivot(index=['barcode', 'ShortName'], columns='sampleID', values = 'barcode_cnt')
         .reset_index().rename({'ShortName': 'gene'}, axis=1)
          .fillna(0))
    magDf2.to_csv(count_file, index=False, sep='\t')
    return batch_file, count_file

In [None]:
def run_command(args):
    """Run command, transfer stdout/stderr"""
    result = subprocess.run(args)
    try:
        result.check_returncode()
    except subprocess.CalledProcessError as e:
        raise e
        

def batch_correct(outDir, library,  r_path="./batchCorrect.R"):
    count_path = outDir / f"{library}_count.txt"
    batch_path = outDir / f"{library}_batch.txt"
    cmd = f'Rscript {r_path} {count_path} {batch_path} {library} {outDir}'
    print(cmd)
    r = run_command(cmd.split())


def get_contrast_samples(library_df, treat_col = 'day', treatment='d1', control='d0', sampleID = 'sampleID'):
    controls = ",".join(library_df[library_df[treat_col] == control][sampleID].unique())
    treats = ",".join(library_df[library_df[treat_col] == treatment][sampleID].unique())
    return controls, treats


def run_mageck(count_file, treated, controls, out_prefix, control_barcode_file):
    cmd = (f"mageck test -k {count_file} -t {treated} "
          f"-c {controls}  -n {out_prefix} "  
          f"--control-sgrna {control_barcode_file}  --normcounts-to-file")
    print(cmd)
    r = run_command(cmd.split())

In [None]:
# For each library:
library='library_12_2'
meta = meat[meat.library == library]
c = all_contrasts[library]


def mageck_library(library, meta, outDir, contrasts, control_barcode_file, batch_corr=True, batch_col='batch'):
    
    """
    1. Check if batch correction is needed, run if yes -> different count file as input for mageck
    2. For each contrast check if threr are samples, run mageck
    3. Concatenate results for multiple days

    """
    print(meta[batch_col].nunique())
    if batch_corr is True and meta[batch_col].nunique() > 1:
        batch_correct(outDir, library,  r_path="./batchCorrect.R")
        count_file = outDir/f"{library}_count_batchcorrected.txt"
    else:
        count_file = outDir/f"{library}_count.txt"
        
    result_dfs = []
    for contrast, samples in contrasts.items():
        print(contrast)
        if len(samples[0]) == 0 or len(samples[1]) == 0:
            continue
        else:
            treated = samples[1] 
            controls = samples[0]
        out_prefix = outDir/f"{library}-{contrast}"
        run_mageck(count_file, treated, controls, out_prefix, control_barcode_file)
        res = pd.read_table(f'{out_prefix}.gene_summary.txt').assign(contrast=contrast)
        result_dfs.append(res)
    results = pd.concat(result_dfs).assign(library=library)
    return results



In [None]:
def batch_correct():
    """
    Given count df only with good samples
    sample data df (read in and validated somewhere else) with information about batches etc. 
    batch column name
    
    
    """
    
    