In [None]:
import pandas as pd
import re
import tempfile
import os
import gzip
import shutil
import subprocess
from pyspark.sql.functions import udf, col, lit
from pyspark.sql import functions as F
import seaborn as sns

"""
Import Parquet As a DataFrame
"""

##Read in parquet file from public S3 bucket
parquet_s3 = "s3://steichenetalpublicdata/analyzed_sequences/parquet"
df_spark = spark.read.parquet(parquet_s3)

# allow pyspark to use apache arrow
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

# Make a query class

The query class can hold our spark query until it's time to execute

In [None]:
def combine_dfs(list_of_dfs):
    """helper function to combine many queries back into one"""
    combined_counts = (
        pd.concat(list_of_dfs)
        .groupby(["ez_donor"])
        .sum()
        .drop(["total_count", "NormalizedCustomerValue"], axis=1)
    )
    combined_sums = pd.concat(
        [i.set_index(["ez_donor"])["total_count"] for i in list_of_dfs]
    )
    combined_sums = (
        combined_sums.reset_index().groupby(["ez_donor"]).head(1).set_index("ez_donor")
    )
    combined = combined_counts.join(combined_sums)
    combined["NormalizedCustomerValue"] = combined["count"] / combined["total_count"]
    return combined


class Query:

    """An example query class to hold query parameters"""

    def __init__(self, q_name, length="", v_fam="", d_gene="", j_gene="", regex=""):
        self.query_name = q_name
        self.v_fam = v_fam
        self.j_gene = j_gene
        self.d_gene = d_gene

        if not length:
            raise Exception("Length must be supplied")
        self.length = length
        self.regular_expression = regex
        self.applied = False

    def apply(self, df):

        """Apply function will take in spark dataframe and apply query parameters to it if they exist

        Returns a filtered dataframe
        """
        self.queried_dataframe = ""

        ##Lets get length
        self.queried_dataframe = df.filter(F.length(df.cdr3_aa) == self.length)

        ##If the rest of these were specified, add them to the filter
        if self.v_fam:
            self.queried_dataframe = self.queried_dataframe.filter(
                self.queried_dataframe.v_fam == self.v_fam
            )

        if self.d_gene:
            self.queried_dataframe = self.queried_dataframe.filter(
                self.queried_dataframe.d_gene == self.d_gene
            )

        if self.j_gene:
            print("have j gene", self.j_gene)
            self.queried_dataframe = self.queried_dataframe.filter(
                self.queried_dataframe.j_gene == self.j_gene
            )

        if self.regular_expression:
            self.queried_dataframe = self.queried_dataframe.filter(
                self.queried_dataframe.cdr3_aa.rlike(self.regular_expression)
            )

        print("Found {} sequences".format(self.queried_dataframe.count()))
        self.applied = True
        return self.queried_dataframe

    def get_normalized(self, dataframe, column="ez_donor"):
        q = self.apply(dataframe)
        search_1 = q.groupby("ez_donor").count()
        search_2 = (
            dataframe.groupby("ez_donor")
            .count()
            .withColumnRenamed("count", "total_count")
        )
        new_df = search_1.join(search_2, column).withColumn(
            "NormalizedCustomerValue", (F.col("count") / F.col("total_count"))
        )
        return pd.DataFrame(new_df.collect(), columns=new_df.columns)

# PCT64


In [None]:
# make query class
pct64_query = Query(
    "PCT64",
    v_fam="IGHV3",
    d_gene="IGHD3-3",
    j_gene="IGHJ6",
    length=25,
    regex=r"^......[YRKG][DSG]FWS..............$",
)

# get normalized counts that turn it into a pandas datafram
normal_query_df_pct64 = pct64_query.get_normalized(df_spark)
normal_query_df_pct64["class"] = "pct64"

# CH01 - CH04

In [None]:
# make query class
ch04_query = Query(
    "ch04",
    v_fam="IGHV3",
    j_gene="IGHJ2",
    length=26,
    regex=r"^..............Y[YQK]GSG.......$",
)

# get normalized counts that turn it into a pandas datafram
normal_query_df_ch04 = ch04_query.get_normalized(df_spark)
normal_query_df_ch04["class"] = "ch04"

# PG9

In [None]:
# make query class
pg9_query = Query(
    "pg9",
    v_fam="IGHV3",
    j_gene="IGHJ6",
    length=30,
    regex=r"^...............YDF............$",
)

# get normalized counts that turn it into a pandas datafram
normal_query_df_pg9 = pg9_query.get_normalized(df_spark)
normal_query_df_pg9["class"] = "pg9"

# PGT145

In [None]:
# make query class
pgt145_33 = Query(
    "pgt33",
    v_fam="IGHV1",
    j_gene="IGHJ6",
    length=33,
    regex=r"^.............Y[GND][DEY].................$",
)

# get normalized counts that turn it into a pandas datafram
normal_query_df_pgt145_33 = pgt145_33.get_normalized(df_spark)
normal_query_df_pgt145_33["class"] = "pgt145_33"

In [None]:
# make query class
pgt145_34 = Query(
    "pgt34",
    v_fam="IGHV1",
    j_gene="IGHJ6",
    length=34,
    regex=r"^.............Y[GND][DEY]..................$",
)

# get normalized counts that turn it into a pandas datafram
normal_query_df_pgt145_34 = pgt145_34.get_normalized(df_spark)
normal_query_df_pgt145_34["class"] = "pgt145_34"

In [None]:
pgt145_df = combine_dfs([normal_query_df_pgt145_34, normal_query_df_pgt145_33])
pgt145_df["class"] = "pgt145"

# CAP256

In [None]:
# make query class
cap256_37 = Query(
    "cap256_37",
    v_fam="IGHV3",
    j_gene="IGHJ3",
    length=37,
    regex=r"^................YD[FIL]..................$",
)

# get normalized counts that turn it into a pandas datafram
normal_query_df_cap256_37 = cap256_37.get_normalized(df_spark)
normal_query_df_cap256_37["class"] = "cap256_37"

In [None]:
# make query class
cap256_38 = Query(
    "cap256_38",
    v_fam="IGHV3",
    j_gene="IGHJ3",
    length=38,
    regex=r"^.................YD[FIL]..................$",
)

# get normalized counts that turn it into a pandas datafram
normal_query_df_cap256_38 = cap256_38.get_normalized(df_spark)
normal_query_df_cap256_38["class"] = "cap256_38"

In [None]:
# make query class
cap256_39 = Query(
    "cap256_39",
    v_fam="IGHV3",
    j_gene="IGHJ3",
    length=39,
    regex=r"^................YD[FIL]....................$",
)

# get normalized counts that turn it into a pandas datafram
normal_query_df_cap256_39 = cap256_39.get_normalized(df_spark)
normal_query_df_cap256_39["class"] = "cap256_39"

In [None]:
cap256_df = combine_dfs([normal_query_df_cap256_37, normal_query_df_cap256_38])
cap256_df["class"] = "cap256"

# Combine all dfs

In [None]:
final_df = pd.concat(
    [
        normal_query_df_pct64,
        normal_query_df_ch04,
        normal_query_df_pg9,
        cap256_df.reset_index(),
        pgt145_df.reset_index(),
    ]
)

In [None]:
final_df = final_df.rename({"NormalizedCustomerValue": "normal"}, axis=1)

In [None]:
sns.pointplot(data=final_df, x="class", y="normal", hue="class")