# Imports

In [None]:
import os
from Bio import SeqIO
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import pandas as pd
from scipy.stats import wilcoxon, spearmanr
from matplotlib.ticker import FuncFormatter, ScalarFormatter, LogLocator, LogFormatter
import numpy as np
from sklearn.metrics import f1_score
import statsmodels.formula.api as smf
from statsmodels.stats.multitest import multipletests

# Constants

In [None]:
# Constants
SAMPLES = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
ABUNDANCES = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
METHODS = [
    "all",
	"centroid",
    "ggrasp",
    "vlq",
	"meshclust_0.95",
	"meshclust_0.99",
    "gclust_0.95",
	"gclust_0.99",
	"gclust_0.999",
    "vsearch_0.95",
    "vsearch_0.99",
    "vsearch_0.999",
    "single-linkage_1",
    "single-linkage_5",
    "single-linkage_10",
    "single-linkage_25",
    "single-linkage_50",
    "single-linkage_90",
    "single-linkage_99",
	"complete-linkage_1",
    "complete-linkage_5",
    "complete-linkage_10",
    "complete-linkage_25",
    "complete-linkage_50",
    "complete-linkage_90",
    "complete-linkage_99",

]
METHOD_LABELS = {
    "all": "All",
	"centroid": "Centroid",
    "ggrasp": "GGRaSP",
    "vlq": "VLQ",
	"meshclust_0.95": "MC-0.95",
	"meshclust_0.99": "MC-0.99",
    "gclust_0.95": "GC-0.95",
	"gclust_0.99": "GC-0.99",
	"gclust_0.999": "GC-0.999",
    "vsearch_0.95": "VS-0.95",
    "vsearch_0.99": "VS-0.99",
    "vsearch_0.999": "VS-0.999",
    "single-linkage_1": r"SL-$P_{1}$",
    "single-linkage_5": r"SL-$P_{5}$",
    "single-linkage_10": r"SL-$P_{10}$",
    "single-linkage_25": r"SL-$P_{25}$",
    "single-linkage_50": r"SL-$P_{50}$",
    "single-linkage_90": r"SL-$P_{90}$",
    "single-linkage_99": r"SL-$P_{99}$",
	"complete-linkage_1": r"CL-$P_{1}$",
    "complete-linkage_5": r"CL-$P_{5}$",
    "complete-linkage_10": r"CL-$P_{10}$",
    "complete-linkage_25": r"CL-$P_{25}$",
    "complete-linkage_50": r"CL-$P_{50}$",
    "complete-linkage_90": r"CL-$P_{90}$",
    "complete-linkage_99": r"CL-$P_{99}$",
}
EXPERIMENTS = ["global", "country", "state"]
ALPHA = 0.05
OUTPUT_DIR = "figures/viruses"

# Helper functions

In [None]:
def create_mapping(experiment):
    """
    Determine all lineages that are present in reference and simulation genomes
    """
    idx2lineage = []
    lineage2idx = {}
    # Add lineages from reference
    with open(f"root/reference_sets/{experiment}/all.tsv", "r") as f_in:
        next(f_in) #skip header
        for line in f_in:
            line = line.strip().split("\t")
            lineage = line[0]
            if lineage not in lineage2idx:
                lineage2idx[lineage] = len(idx2lineage)
                idx2lineage.append(lineage)
    print(len(lineage2idx), "unique lineages found in reference")
    # Add lineages from simulations
    with open(f"root/samples/1/metadata.tsv", "r") as f_in:
        next(f_in)
        for line in f_in:
            line = line.strip().split("\t")
            lineage = line[11]
            if lineage not in lineage2idx:
                lineage2idx[lineage] = len(idx2lineage)
                idx2lineage.append(lineage)
    print(len(lineage2idx), "unique lineages found in total")

    return idx2lineage, lineage2idx

def generate_groundtruth(idx2lineage, lineage2idx):
    # Map sequence ids to lineages
    id2lineage = {}
    with open(f"root/samples/1/metadata.tsv", "r") as f_in:
        next(f_in)
        for line in f_in:
            line = line.strip().split("\t")
            id = line[0]
            lineage = line[11]
            id2lineage[id] = lineage

    abundances = np.zeros((len(SAMPLES)*len(ABUNDANCES), len(idx2lineage)), dtype=np.float64)
    for sample in SAMPLES:
        # Map sequence ids to genome length
        id2length = {}
        for record in SeqIO.parse(f"root/samples/{sample}/sequences.tsv", "fasta"):
            id2length[record.id.strip().split("|")[0]] = len(record.seq)
        for record in SeqIO.parse("root/B.1.1.7_sequence.fasta", "fasta"):
            id2length[record.id.strip().split("|")[0]] = len(record.seq)

        for i, abundance in enumerate(ABUNDANCES):
            nucleotides = {id: 0 for id in id2length.keys()}
            linecount = 0
            with open(f"root/samples/{sample}/wwsim_B.1.1.7_sequence_ab{abundance}_1.fastq", "r") as f_in:
                for line in f_in:
                    if linecount % 4 == 0:
                        id = line[1:].strip().split("|")[0]
                    elif linecount % 4 == 1:
                        nucleotides[id] += len(line.strip())             
                    linecount += 1
            linecount = 0
            with open(f"root/samples/{sample}/wwsim_B.1.1.7_sequence_ab{abundance}_2.fastq", "r") as f_in:
                for line in f_in:
                    if linecount % 4 == 0:
                        id = line[1:].strip().split("|")[0]
                    elif linecount % 4 == 1:
                        nucleotides[id] += len(line.strip())             
                    linecount += 1

            for id in nucleotides:
                try:
                    lineage = id2lineage[id]
                except: # not included in sequences file!
                    lineage = "B.1.1.7"
                abundances[sample-1 + i*len(SAMPLES), lineage2idx[lineage]] += nucleotides[id] / id2length[id]
            abundances[sample-1 + i*len(SAMPLES), :] = abundances[sample-1 + i*len(SAMPLES), :] / np.sum(abundances[sample-1 + i*len(SAMPLES), :])

    return abundances



# Generate groundtruth