In [None]:
#load rpy2 magic
%load_ext rpy2.ipython

# to switch off warning messages
import warnings
warnings.filterwarnings("ignore")

# make default cell width 85% of available screen
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:85% !important; }</style>"))

# load R libraries & functions
%R options(warn=-1)
%R library(RColorBrewer)
%R library(ComplexHeatmap)
%R library(circlize)
%R library(dendextend)
%R library(Rtsne)
%R library(ggplot2)
%R library(gridExtra)
%R source("/gfs/devel/tkhoyratty/my_scripts/R/pca.R")
%R library(wesanderson)
%R Palette <- wes_palette("Cavalcanti1")

# load python modules
import glob
import re
import sys
import os
import rpy2.robjects as robjects
import CGAT.Database as DB
import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
%matplotlib inline

db = "./csvdb"

# ATAC Pipeline Report

In [None]:
# get peak/readcounts to use from pipeline.yml
import yaml

with open("pipeline.yml") as o:
    opts = yaml.load(o)
#     print(opts)
    
filt = opts["macs2"]["peaks"]
# print(filt)

In [None]:
# import counts & upper quantile normalise
def get_counts(filt=filt, db=db):
    # get counts, filter on fragment size

    if filt == "size_filt":
        peaks = "<150bp"
    if filt == "all":
        peaks = "all_fragments"
        
    statement = '''select sample_id, peak_id, RPM_width_norm *1000 as RPM 
                from all_norm_counts where size_filt == "%(peaks)s" ''' % locals()
        
    df = DB.fetch_DataFrame(statement, db)

    df = df.pivot("sample_id", "peak_id", "RPM").transpose()
    df.index.name = None

    # normalise to upper quantiles forC between sample comparison
    df = df.div(df.quantile(0.75, axis=0), axis=1) 
    
    return df

counts = get_counts()

In [None]:
# get sample_info table
sample_info = DB.fetch_DataFrame('''select distinct * from sample_info''', db)
sample_info.index = sample_info["sample_id"]
sample_info.index.name = None
sample_info.head(len(sample_info))

sample_info.head(len(sample_info))

***
<br>

## Sample Quality 
- Contamination with mitochondrial reads is a common problem for ATAC datasets
- As are PCR duplicates from library prep
- high values indicate low quality libraries

In [None]:
def mapping_stats(paired=True, db=db, sample_info=sample_info):
    '''Collect all mapping stats & retrun df for plotting'''
    
    if paired==True:
        reads = DB.fetch_DataFrame('''select READS_ALIGNED_IN_PAIRS/2 as MAPPED_PAIRS, PCT_READS_ALIGNED_IN_PAIRS, 
                                      TOTAL_READS, PCT_ADAPTER, sample_id from picardAlignmentSummary 
                                      where CATEGORY = "PAIR" ''', db)
    if paired==False:
        print("Update function for non-paired data")
    
    # Format mapping qc df
    reads["Filter"] = reads["sample_id"].apply(lambda x: x.split(".")[-1])
    reads["Filter"] = reads["sample_id"].apply(lambda x: x.split("_")[-1] if "size_filt_prep" not in x else "prep<150bp")
    reads["sample_id"] = reads["sample_id"].apply(lambda x: '_'.join(x.split("_")[0:-1]))
    reads["sample_id"] = reads["sample_id"].apply(lambda x: x.split(".")[0])

    if len(sample_info)==0:
        print("Provide sample_info df with sample annotations")
        
    reads = pd.merge(reads, sample_info, on="sample_id", how="inner")
    
    # get no. reads mapping to chrM
    chrm = DB.fetch_DataFrame('''select * from allContig''', db)
    
    # reformat df
    chrm = chrm.pivot("sample_id", "contig", "mapped_reads")
    chrm["total_mapped_reads"] = chrm.sum(axis=1)
    chrm["sample_id"] = chrm.index.values
    chrm.index.name = None
    chrm = chrm[["chrM", "total_mapped_reads", "sample_id"]]
    chrm["pct_chrM"] = chrm["chrM"] / chrm["total_mapped_reads"] *100 # % reads mapping to chrM

    # annotate df
    chrm["sample_id"] = chrm["sample_id"].apply(lambda x: str(x).split(".")[0])
    chrm = pd.merge(chrm, sample_info, on="sample_id", how="inner")
    chrm["Filter"] = "genome" # chrM only in genomic reads as filtered out after, others not tested

    df = pd.merge(reads, chrm[["pct_chrM", "sample_id", "Filter"]], how="outer", on=["sample_id", "Filter"])
    
    return df

mapping_qc = mapping_stats()

In [None]:
# get no. reads mapping to chrM
chrm = DB.fetch_DataFrame('''select * from allContig''', db)

# reformat df
chrm = chrm.pivot("sample_id", "contig", "mapped_reads")
chrm["total_mapped_reads"] = chrm.sum(axis=1)
chrm["sample_id"] = chrm.index.values
chrm.index.name = None
chrm = chrm[["chrM", "total_mapped_reads", "sample_id"]]
chrm["pct_chrM"] = chrm["chrM"] / chrm["total_mapped_reads"] *100 # % reads mapping to chrM

# annotate df
chrm["sample_id"] = chrm["sample_id"].apply(lambda x: str(x).split(".")[0])
chrm = pd.merge(chrm, sample_info, on="sample_id", how="inner")
chrm["Filter"] = "genome" # chrM only in genomic reads as filtered out after, others not tested

# chrm.head(30)

In [None]:
%%R -i chrm -h 400 -w 450

# Palette <- c("darkorange3", "slateblue2", "darkgray", "seagreen4")

ggplot(chrm, aes(y=pct_chrM, x=sample_id, fill=condition)) +
    geom_bar(stat="identity", colour="black") +
    scale_fill_manual(values=Palette) +
    coord_flip() +
    theme_Publication() +
    labs(x="", title="chrM Contamination")

In [None]:
dups = mapping_qc.pivot("sample_id", "Filter", "MAPPED_PAIRS")
dups["duplicates"] = dups["filt"] - dups["prep"]
dups = pd.DataFrame(dups.to_records())
dups = pd.merge(dups, sample_info, how="inner", on="sample_id")
# dups.head()

In [None]:
%%R -i dups -h 400 -w 550

ggplot(dups, aes(y=duplicates, x=sample_id, fill=condition)) +
    geom_bar(stat="identity", colour="black") +
    scale_fill_manual(values=Palette) +
    coord_flip() +
    theme_Publication() +
    labs(x="", title="PCR Duplicates", y="No. Read Pairs")

***
<br>

## Insert sizes
- nucleosomal phasing should be apparent in library insert size distributions

In [None]:
def fragment_stats(db=db, sample_info=sample_info):
    '''Collect insert size stats & format df for plotting'''
    
    def clean(df):
        df["Filter"] = df["sample_id"].apply(lambda x: x.split(".")[-1])
        df["Filter"] = df["sample_id"].apply(lambda x: x.split("_")[-1] if "size_filt_prep" not in x else "prep<150bp")
        df["sample_id"] = df["sample_id"].apply(lambda x: x.rstrip("_prep"))
        df["sample_id"] = df["sample_id"].apply(lambda x: x.split(".")[0])
        df = pd.merge(df, sample_info, on="sample_id", how="inner")
        return df
    
    insert_sizes = DB.fetch_DataFrame('''select * from picardInsertSizeHistogram where sample_id like "%prep"''', db)
    size_stats = DB.fetch_DataFrame('''select * from picardInsertSizeMetrics where sample_id like "%prep"''', db)

    insert_sizes = clean(insert_sizes)
    size_stats = clean(size_stats)

    return [size_stats, insert_sizes]

(size_stats, insert_sizes) = fragment_stats()

In [None]:
%%bash

mkdir QC_plots

In [None]:
def clean(df):
    df["Filter"] = df["sample_id"].apply(lambda x: x.split(".")[-1])
    df["Filter"] = df["sample_id"].apply(lambda x: x.split("_")[-1] if "size_filt_prep" not in x else "prep<150bp")
    df["Filter"] = df["Filter"].apply(lambda x: x.replace("prep<150bp", "<150bp"))
    df["Filter"] = df["Filter"].apply(lambda x: x.replace("prep", "all"))
    df["sample_id"] = df["sample_id"].apply(lambda x: x.rstrip("_prep"))
    df["sample_id"] = df["sample_id"].apply(lambda x: x.split(".")[0])
    df = pd.merge(df, sample_info, on="sample_id", how="inner")
    return df
    
def insertSizePlots(outfiles):
    '''Collect insert size metrics & generate plots'''
    
    sample_info = DB.fetch_DataFrame('''select * from sample_info''', db)

    insert_sizes = DB.fetch_DataFrame('''select * from picardInsertSizeHistogram where sample_id like "%prep"''', db)
    insert_sizes = clean(insert_sizes)
#     print(insert_sizes.head())
    insert_sizes.to_csv(outfiles[3], sep="\t", header=True, index=False)

    # plots
    sns.set(style="whitegrid", palette="muted")

    # all
    plt.figure(figsize=(10,6))
    sns_fragment_hist_all = sns.lineplot(data=insert_sizes[insert_sizes["Filter"]=="all"], ci=None,
                                           x="insert_size", y="All_Reads.fr_count", hue="condition")
    plt.title("All insert sizes")
    plt.savefig(outfiles[1])
    plt.show()
    plt.close()
    
    # log2 scales
    plt.figure(figsize=(10,6))
    insert_sizes["log2_All_Reads_fr_count"] = np.log2(insert_sizes["All_Reads.fr_count"])
    sns_fragment_hist_size_filt = sns.lineplot(data=insert_sizes[insert_sizes["Filter"]=="all"], ci=None,
                                           x="insert_size", y="log2_All_Reads_fr_count", hue="condition")
    plt.title("log2(All insert sizes")
    plt.savefig(outfiles[2])
    plt.show()
    plt.close()
    
    # box
    sns_fragment_box = sns.catplot(data=insert_sizes, x="sample_id", y="insert_size", hue="Filter", 
                                      kind="box", height=5, aspect=2)#.set_title("Insert sizes per sample")
    plt.title("Insert sizes per sample")
    sns_fragment_box.savefig(outfiles[0])
    sns_fragment_box.set_xticklabels(rotation=30, ha="right")
    
insertSizePlots(["QC_plots/fragment_box.png",
              "QC_plots/fragment_hist_all.png",
              "QC_plots/fragment_hist_all_log2.png",
              "QC_plots/insert_sizes.txt"])

In [None]:
sample_info = DB.fetch_DataFrame('''select * from sample_info''', db)

insert_sizes = DB.fetch_DataFrame('''select * from picardInsertSizeHistogram where sample_id like "%prep"''', db)
insert_sizes = clean(insert_sizes)
insert_sizes = insert_sizes.sort_values("sample_id")

# plots
sns.set(style="whitegrid", palette="muted")

for s in set(insert_sizes["condition"]):

    # all
    plt.figure(figsize=(10,6))
    
    df = insert_sizes[insert_sizes["condition"]==s]
    sns_fragment_hist_all = sns.lineplot(data=df[df["Filter"]=="all"], ci=None,
                                           x="insert_size", y="All_Reads.fr_count", hue="sample_id")
    plt.title(s)
    plt.show()
    plt.close()

***
<br>

## Mapping QC
* All mapped reads are contained in "genome"
* "filt" contains reads with MAPQ scores >10 & has reads mapping to chrM subtracted
* Reads in "prep" have PCR duplicates subtracted
* "prep<150bp" files are further filtered for reads with insert sizes < 150 b.p.

In [None]:
%%R

# R functions

get_legend <- function(a.gplot){ 
  tmp <- ggplot_gtable(ggplot_build(a.gplot)) 
  leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box") 
  legend <- tmp$grobs[[leg]] 
  return(legend)
} 

In [None]:
%%R -i mapping_qc -w 800 -h 750

test <- ggplot(mapping_qc, aes(y=MAPPED_PAIRS, x=sample_id, colour=condition, shape=Filter)) + 
        geom_point(size=6) + 
        theme_Publication()
        
a <- ggplot(mapping_qc, aes(y=MAPPED_PAIRS, x=sample_id, colour=condition, shape=Filter)) + 
        geom_point(size=6) + 
        theme_Publication() +           
        theme(axis.text.x=element_text(angle=45, hjust=1)) + 
        scale_y_continuous(limits=c(0, 65000000)) +
        scale_colour_manual(values=Palette) +
        labs(x="", y="Mapped Pairs") +
        geom_hline(yintercept=25000000, lty="dashed", col="black")

b <- ggplot(subset(mapping_qc, Filter=="genome"), 
            aes(y=PCT_READS_ALIGNED_IN_PAIRS*100, x=sample_id, colour=condition)) + 
        geom_point(size=6, shape=17) + 
        theme_Publication() +           
        theme(axis.text.x=element_text(angle=45, hjust=1)) + 
        labs(y="% Reads in Pairs", x="") +
        scale_y_continuous(limits=c(0, 100)) +
        scale_colour_manual(values=Palette)

c <- ggplot(mapping_qc, aes(y=PCT_ADAPTER, x=sample_id, colour=condition, shape=Filter)) + 
        geom_point(size=6) + 
        theme_Publication() +           
        theme(axis.text.x=element_text(angle=45, hjust=1)) + 
        scale_colour_manual(values=Palette) +
        labs(x="", y="% Adaptor")

d <- ggplot(mapping_qc, aes(y=pct_chrM, x=sample_id, colour=condition)) + 
        geom_point(size=6, shape=17) +
        theme_Publication() +
        labs(x="", y= "% chrM Reads") +
        scale_y_continuous(limits=c(0,100)) +          
        theme(axis.text.x=element_text(angle=45, hjust=1)) + 
        scale_colour_manual(values=Palette)

legend <- get_legend(a)

grid.arrange(a + theme(legend.position="none"), b + theme(legend.position="none"), 
             c + theme(legend.position="none"), d + theme(legend.position="none"), 
             ncol=2, nrow=2, bottom=legend)

***
<br>

## TSS Enrichment
### Normalised ATAC signal over all TSSs

All_fragments | Size_filt
- | -
![alternate text](deeptools.dir/TSS.all.heatmap.png) | ![alternate text](deeptools.dir/TSS.size_filt.heatmap.png)

### Called peaks relative to nearest TSS

In [None]:
def get_tss_dist(db=db):

    statement = '''select distinct a.sample_id, a.peak_id, a.size_filt, c.strand || "" || b.TSSdist as TSSdist
                from all_norm_counts a, merged_peaks_GREAT_annotated b,
                ensemblGeneset c where a.peak_id = b.peak_id 
                and b.gene_id = c.gene_id '''
    
    df = DB.fetch_DataFrame(statement, db)
    
    df["TSSdist"] = df["TSSdist"].apply(lambda x: float(x.replace("+", "")))
#     df["TSSdist"] = df["TSSdist"].astype(int)
    
    return df
    
df = get_tss_dist()
# df.head()

In [None]:
%%R -i df -w 600 -h 400

p2 <- ggplot(df, aes(TSSdist/1000, fill=size_filt)) + 
            geom_histogram(position="dodge", aes((y=c(..count..[..group..==1]/sum(..count..),
                                 ..count..[..group..==2]/sum(..count..))*100), x=TSSdist/1000)) + 
            labs(y="Percent Total Peaks", x="TSS distance (kb)", title="") + 
        theme_Publication() +
        scale_fill_manual(values=Palette)

p2

***
<br>

## Peakcalling QC
### All peaks
* All detected peaks with & without filtering reads by size & fraction of reads in peaks

In [None]:
peak_stats = DB.fetch_DataFrame('''select a.no_peaks, a.size_filt, b.FRIP, b.sample_id from no_peaks a, 
                                frip_table b where a.sample_id=b.sample_id and a.size_filt=b.size_filt ''', db)

peak_stats = pd.merge(peak_stats, sample_info, how="inner", on="sample_id")
peak_stats.head(len(peak_stats))

In [None]:
%%R -i peak_stats  -w 800 -h 350

no_peaks <- ggplot(peak_stats, 
                   aes(y=no_peaks, x=sample_id, fill=condition, 
                       shape=factor(replicate), alpha=size_filt)) + 
                geom_text(aes(label=paste0("r", replicate)), 
                          position=position_dodge(width=1), vjust=-0.5, fontface="bold", size=4.5, show.legend=F) +
                geom_bar(stat="identity", position="dodge", colour="black") + theme_Publication() +
                theme(axis.text.x=element_text(angle=45, hjust=1)) + 
                scale_alpha_discrete(range=c(0.4, 1)) +
                labs(y="No. Peaks", y="", x="") +
                scale_y_continuous(limits=c(0, max(peak_stats$no_peaks))) +
                scale_fill_manual(values=Palette) 
                
frip_plot <- ggplot(peak_stats, 
                    aes(y=FRIP, x=sample_id, fill=condition,
                        alpha=size_filt)) + 
                geom_text(aes(label=paste0("r", replicate)), 
                          position=position_dodge(width=1), vjust=-0.5, fontface="bold", size=4.5, show.legend=F) +
                geom_bar(stat="identity", position="dodge", colour="black") + theme_Publication() +          
                theme(axis.text.x=element_text(angle=45, hjust=1)) + 
                scale_alpha_discrete(range=c(0.4, 1)) +
                labs(y="FRIP", x="") +
                scale_y_continuous(limits=c(0,1)) + 
                scale_fill_manual(values=Palette) +
                guides(colour=guide_legend(override.aes=list(size=6)), 
                       alpha=guide_legend(override.aes=list(size=6))) +
                geom_hline(yintercept=0.2, lty="dashed", colour="black") 

key <- get_legend(frip_plot)

grid.arrange(no_peaks + theme(legend.position="none"), frip_plot + theme(legend.position="none"),
             ncol=2, nrow=1, bottom=key)

### High confidence peaks
* Peaks which are consistent between biological replicates, from size filtered & non- size filtered peak sets

In [None]:
merged_peaks = DB.fetch_DataFrame('''select * from no_peaks where merged like "%merged" ''', db)

In [None]:
%%R -i merged_peaks -w 600 -h 400

ggplot(merged_peaks, aes(y=no_peaks, x=sample_id, alpha=size_filt, fill=sample_id)) +
    geom_bar(stat="identity", position="dodge", colour="black") +
    scale_alpha_discrete(range=c(0.4, 1)) +
    scale_fill_manual(values=Palette, guide=FALSE) +
    theme_Publication() +
    labs(x="")

***
<br>


## Data Exploration 
* based on counts over merged peakset
* merged peakset consists of all detected peaks 
* counts are normalised for sequencing depth, peak width, and upper quantile normalised for between sample comparison

### Pearson correlation between nomalised counts in consensus peakset
* clustering with Ward method on Manhattan distances

In [None]:
%R colnames(counts) <- sample_info$sample_id

In [None]:
%%R -i counts

cm <- data.matrix(log2(counts +1))
m <- cor(cm, method="pearson", use="all")

distr <- dist(m, method="manhattan")
clustr <- hclust(distr, method="ward.D2")
dendr <- as.dendrogram(clustr)
dendr <- dendr %>% sort(type="labels")

distc <- dist(t(m), method="manhattan")
clustc <- hclust(distc, method="ward.D2")
dendc <- as.dendrogram(clustc)
dendc <- rev(dendc) %>% sort(type="labels")

p2 <- Heatmap(m,
       col = colorRamp2(c(min(m), max(m)), c("white", "#02401B")),
       cluster_rows=dendr,
       cluster_columns=dendc,
       column_dend_reorder = FALSE,
       column_dend_height = unit(2, "cm"),
       row_dend_width = unit(2, "cm"),
       row_names_gp=gpar(fontsize=16),
       column_names_gp=gpar(fontsize=16),
       name="Pearson Correlation:",
       heatmap_legend_param=list(legend_direction="horizontal", 
#                                   at=c(0.9, 1), 
                                  color_bar = "continuous",
                                  legend_width = unit(5, "cm"), 
                                  title_position = "lefttop",
                                  title_gp=gpar(fontsize=18),
                                  labels_gp=gpar(fontsize=14)),
       )

draw(p2, heatmap_legend_side = "bottom")

### Dimensionality Reduction

In [None]:
%%R  -w 1200 -h 300 -i sample_info

ggplot_prcomp <- function(prcomp_object,
                          plots=list("A"=c("PC1","PC2"), "B"=c("PC3","PC4"), "C"=c("PC5","PC6")),
                          sample_information="none",
                          fill="c()",
                          shape="c()",
                          label="none",
                          color="c()",
                          alpha="c()",
                          size=3,
                          nudge_scale_factor=40){
    require(gridExtra)
    pca = prcomp_object

    # sample_information should have the same rownames as pca$x
    pvs <- summary(pca)$importance["Proportion of Variance",]

    names = paste(names(pvs)," (",round(pvs,2),")",sep="")

    #scree plot
    fs <- data.frame(x=c(1:length(names(pvs))), y=as.vector(pvs))

    pcdf <- as.data.frame(pca$x)
    
    pcdf <- merge(pcdf, sample_information, by=0, all=T)

    gps = list()

    scree <- ggplot(fs, aes(x,y)) + 
                geom_point(size=4) + 
                xlab("principal component") + 
                ylab("proportion of variance") + 
                ggtitle("scree plot") +
                theme_Publication()

    c_lab <- function(props, C){
        return(paste(C, " (", props[[C]]*100,"%)",sep=""))
    }

    for(plot in names(plots)){

        comps <- plots[[plot]]

        PCX <- comps[1]
        PCY <- comps[2]

        nudge_x <- diff(range(pcdf[[PCX]]))/nudge_scale_factor
        nudge_y <- diff(range(pcdf[[PCY]]))/nudge_scale_factor

        gp <- ggplot(pcdf, aes_string(PCX, PCY, fill=fill, colour=color)) +
#                 scale_shape_manual(values=c(21,22,23,24,25)) +
                scale_alpha_discrete(range=c(0.5, 1)) +
                theme_Publication()

        gp <- gp + geom_point(size=size, height=15, width=15, shape=21)
        
        if(label!="none"){
            gp <- gp + geom_text_repel(aes_string(label=label),  color="black", check_overlap=T, force=5)
        }
        
        gp <- gp + xlab(c_lab(pvs,PCX)) + ylab(c_lab(pvs,PCY))

        gps[[plot]] <- gp

    }

    gps[["scree"]] <- scree

    return(gps)
}

df <- as.data.frame(log2(counts+1))

pca <- prcomp(t(df), scale=FALSE)
rownames(sample_info) <- sample_info$category

pca_plots <- ggplot_prcomp(pca, 
             plots=list("A"=c("PC1","PC2"), "B"=c("PC3", "PC4"), "C"=c("PC5", "PC6")),
             sample_information=sample_info, 
             fill="condition",
             size=8,
             nudge_scale_factor=30) 

a <- pca_plots$A + scale_fill_manual(values=Palette)
b <- pca_plots$B + scale_fill_manual(values=Palette)
c <- pca_plots$C + scale_fill_manual(values=Palette)
s <- pca_plots$scree

a <- a + theme(legend.direction="horizontal", legend.box="horizontal")
legend <- get_legend(a)
blank <- grid.rect(gp=gpar(col="white"))                      

grid.arrange(a + theme(legend.position="none"), b + theme(legend.position="none"), c + theme(legend.position="none"), s, 
             bottom=legend, ncol=4, nrow=1)

In [None]:
%%R -h 400 -w 500
                           
tsne_out <- Rtsne(t(df), pca=T, perplexity=4)
tsne_df <- as.data.frame(tsne_out$Y)
rownames(tsne_df) <- colnames(log2counts)
colnames(tsne_df) <- c("tSNE1", "tSNE2")
tsne_df$category <- rownames(tsne_df)

tsne_df <- merge(tsne_df, sample_info, by.x="category", by.y="sample_id")

p <- ggplot(tsne_df, aes(y=tSNE1, x=tSNE2, shape=category.y, fill=condition, colour=stimuli)) + 
        geom_point(size=7, stroke=2) + 
        theme_Publication()  + 
        scale_colour_manual(values=c("black", "red3")) +
        scale_fill_manual(values=Palette) + 
        scale_shape_manual(values=c(21,22)) +
        labs(title="tSNE of Normalised Read Counts \nOver All Peaks") +
        theme(legend.position="right", legend.direction="vertical") +
        guides(fill=guide_legend(override.aes=list(shape=21)))

grid.arrange(p, ncol=1, nrow=1)

### Replicate correlation

In [None]:
import warnings
warnings.simplefilter('ignore')

# use sample information to get no. replicates & conditions
rep_pairs = sample_info.pivot("sample_group", "replicate", "sample_id").transpose()
rep_pairs.columns.name = None
rep_pairs.index.name = None

# report replicates to dict
reps = {}
for col in rep_pairs.columns:
    reps[col]=[rep_pairs[col].iloc[0], rep_pairs[col].iloc[1]]

# get palette
colours = ["light red", "windows blue", "dusty purple", "greyish", "amber", "faded green"]
pal = sns.xkcd_palette(colours)

if len(rep_pairs.transpose()) > len(pal):
    import random
    
    extra_colours = sns.xkcd_rgb.keys() # all 954 colours
    pal2 = sns.xkcd_palette(extra_colours)
    pal = pal + random.sample(pal2, len(pal2))
    
sns.set(style="whitegrid", palette="muted")# set seaborn theme

from scipy import stats
def pearsonr(x, y):
    return stats.pearsonr(x, y)[0]

# use dict to subset df of normalised counts & plot rep correlations
n = 0
for key in reps:
    n = n + 1
    c = n - 1
    df = counts[reps[key]]
    df.columns = ["Rep1", "Rep2"]
    p = sns.jointplot(data=np.log2(df+1), y="Rep1", x="Rep2", kind="reg", size=7, color=pal[c], stat_func=pearsonr)
    plt.subplots_adjust(top=0.9)
    p.fig.suptitle(key) # add title
    plt.show()