In [None]:
import scanpy as sc
import anndata as ann
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
import glob
from matplotlib import rcParams
from matplotlib import colors

import seaborn as sb

sc.settings.verbosity = 3


plt.rcParams['figure.figsize']=(8,8) #rescale figures
sc.settings.verbosity = 3
sc.set_figure_params(dpi=200, dpi_save=300, 
                     vector_friendly=False,
                    format='pdf')
sc.logging.print_versions()

In [None]:
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)

In [None]:
# Set size for plots
sb.set_context(context='paper')

In [None]:
#set analysis version
version = "V1"
#set output files_path
output_files_path = "/Sunshine_DeRisi_RSV_files/"

fig_path = "/Sunshine_DeRisi_RSV_files/figures/"

In [None]:
sc.settings.figdir = fig_path

In [None]:
name = "2024_RSV_annotated_unfiltered_human_virus"##+version
preprocessed_path = output_files_path+version+'_'+name+'.h5ad'

adata_human_virus = sc.read_h5ad(preprocessed_path)

In [None]:
adata_human_virus #note multiseq doublets have been removed

In [None]:
#Recalculate n_counts, log_counts, n_genes for each cell and add to .obs
adata_human_virus.obs['n_counts'] = adata_human_virus.X.sum(1)
adata_human_virus.obs['log_counts'] = np.log(adata_human_virus.obs['n_counts'])
adata_human_virus.obs['n_genes'] = (adata_human_virus.X > 0).sum(1)
#check to make sure new observations are there
adata_human_virus

In [None]:
#Filter out based on multiseq calls [nothing should be filtered for doublets]
adata_human_virus = adata_human_virus[adata_human_virus.obs['multiseq_id'] != "Negative"]
print('Number of cells after multiseq negative cell filter: {:d}'.format(adata_human_virus.n_obs))
adata_human_virus = adata_human_virus[adata_human_virus.obs['multiseq_id'] != "Doublet"]
print('Number of cells after multiseq doublets cell filter: {:d}'.format(adata_human_virus.n_obs))

In [None]:
rcParams['figure.figsize']=(30,5)
fig_ind=np.arange(141, 144)
fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.6)

p3_adata_human_virus = sb.histplot(adata_human_virus.obs['n_counts'], 
                 kde=False, #kde=false means not normalized
                 ax=fig.add_subplot(fig_ind[0]))
p4_adata_human_virus = sb.histplot(adata_human_virus.obs['n_counts'][adata_human_virus.obs['n_counts']<6000], 
                 kde=False, bins=60, 
                 ax=fig.add_subplot(fig_ind[1]))
p5_adata_human_virus = sb.histplot(adata_human_virus.obs['n_counts'][adata_human_virus.obs['n_counts']>10000], 
                 kde=False, bins=60, 
                 ax=fig.add_subplot(fig_ind[2]))
plt.show()

In [None]:
#Thresholding decision: genes
rcParams['figure.figsize']=(20,5)
fig_ind=np.arange(131, 133)
fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.6) #create a grid for subplots

p6_adata_human_virus = sb.histplot(adata_human_virus.obs['n_genes'], kde=False, bins=60, ax=fig.add_subplot(fig_ind[0]))


p7_adata_human_virus = sb.histplot(adata_human_virus.obs['n_genes'][adata_human_virus.obs['n_genes']<2500], 
                 kde=False, bins=60, ax=fig.add_subplot(fig_ind[1])) 

plt.show()

In [None]:
# Calculate summary statistics
    #first extract n_count and n_genes for each cell
#n_count for each cell
human_n_counts_only = adata_human_virus.obs ['human_n_counts']
human_n_counts_only.describe()
human_n_counts_only_mean = np.mean(human_n_counts_only)
print(human_n_counts_only_mean)
human_n_counts_only_median = np.median(human_n_counts_only)
print(human_n_counts_only_median)
human_n_counts_only_range = np.ptp(human_n_counts_only)
print(human_n_counts_only_range)
human_n_counts_only_standard_deviation = np.std(human_n_counts_only) #sq root of variance
print(human_n_counts_only_standard_deviation)

In [None]:
# Calculate summary statistics
    #first extract n_count and n_genes for each cell
#n_count for each cell
n_genes_only = adata_human_virus.obs ['n_genes']
n_genes_only.describe()
n_genes_only_mean = np.mean(n_genes_only)
print(n_genes_only_mean)
n_genes_only_median = np.median(n_genes_only)
print(n_genes_only_median)
n_genes_only_range = np.ptp(n_genes_only)
print(n_genes_only_range)
n_genes_only_standard_deviation = np.std(n_genes_only) #sq root of variance
print(n_genes_only_standard_deviation)

In [None]:
# Now calculate uppper and lower limit for n_counts and n_genes 

print (n_genes_only_mean) 
n_genes_only_1SD_upper = n_genes_only_mean + n_genes_only_standard_deviation
print(n_genes_only_1SD_upper)
n_genes_only_1SD_lower = n_genes_only_mean - n_genes_only_standard_deviation
print(n_genes_only_1SD_lower)
n_genes_only_2SD_upper = n_genes_only_mean + (2*n_genes_only_standard_deviation)
print (n_genes_only_2SD_upper)
n_genes_only_2SD_lower = n_genes_only_mean - (2*n_genes_only_standard_deviation)
print (n_genes_only_2SD_lower)
n_genes_only_3SD_upper = n_genes_only_mean + (3*n_genes_only_standard_deviation)
print (n_genes_only_3SD_upper)
n_genes_only_3SD_lower = n_genes_only_mean - (3*n_genes_only_standard_deviation)
print (n_genes_only_3SD_lower)

In [None]:
# Now calculate uppper and lower limit for human_n_counts 
human_n_counts_only_1SD_upper = human_n_counts_only_mean + human_n_counts_only_standard_deviation
print(human_n_counts_only_1SD_upper)
human_n_counts_only_1SD_lower = human_n_counts_only_mean - human_n_counts_only_standard_deviation
print(human_n_counts_only_1SD_lower)
human_n_counts_only_2SD_upper = human_n_counts_only_mean + (2*human_n_counts_only_standard_deviation)
print (human_n_counts_only_2SD_upper)
human_n_counts_only_2SD_lower = human_n_counts_only_mean - (2*human_n_counts_only_standard_deviation)
print (human_n_counts_only_2SD_lower)
human_n_counts_only_3SD_upper = human_n_counts_only_mean + (3*human_n_counts_only_standard_deviation)
print (human_n_counts_only_3SD_upper)
human_n_counts_only_3SD_lower = human_n_counts_only_mean - (3*human_n_counts_only_standard_deviation)
print (human_n_counts_only_3SD_lower)

In [None]:
#Thresholding decision: genes
rcParams['figure.figsize']=(20,5)
fig_ind=np.arange(131, 133)
fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.6) #create a grid for subplots

p6_adata_human_virus = sb.histplot(adata_human_virus.obs['n_genes'], kde=False, bins=60, ax=fig.add_subplot(fig_ind[0]))

plt.axvline(n_genes_only_3SD_upper, color='g')
plt.axvline(n_genes_only_3SD_lower, color='g')
plt.axvline(n_genes_only_2SD_upper, color='b')
plt.axvline(n_genes_only_2SD_lower, color='b')
plt.axvline(n_genes_only_1SD_upper, color='r')
plt.axvline(n_genes_only_1SD_lower, color='r')

p7_adata_human_virus = sb.histplot(adata_human_virus.obs['n_genes'][adata_human_virus.obs['n_genes']<2500], 
                 kde=False, bins=60, ax=fig.add_subplot(fig_ind[1])) 
plt.axvline(n_genes_only_3SD_lower, color='g')
plt.axvline(n_genes_only_2SD_lower, color='b')
plt.axvline(n_genes_only_1SD_lower, color='r')
plt.show()

In [None]:
#this is for HUMAN ONLY
rcParams['figure.figsize']=(20,5)
fig_ind=np.arange(131, 134)
fig = plt.figure()
fig.subplots_adjust(hspace=0.4, wspace=0.6)

p3_adata_human_virus = sb.histplot(adata_human_virus.obs['human_n_counts'], 
                 kde=False, #kde=false means not normalized
                 ax=fig.add_subplot(fig_ind[0]))
plt.axvline(human_n_counts_only_2SD_upper, color='b')
plt.axvline(human_n_counts_only_2SD_lower, color='b')
plt.axvline(human_n_counts_only_1SD_upper, color='r')
plt.axvline(human_n_counts_only_1SD_lower, color='r')

p4_adata_human_virus = sb.histplot(adata_human_virus.obs['human_n_counts'][adata_human_virus.obs['human_n_counts']<6000], 
                 kde=False, bins=60, 
                 ax=fig.add_subplot(fig_ind[1]))
plt.axvline(human_n_counts_only_2SD_lower, color='b')
plt.axvline(human_n_counts_only_1SD_lower, color='r')

p5_adata_human_virus = sb.histplot(adata_human_virus.obs['human_n_counts'][adata_human_virus.obs['human_n_counts']>10000], 
                 kde=False, bins=60, 
                 ax=fig.add_subplot(fig_ind[2]))
plt.axvline(human_n_counts_only_2SD_upper, color='b')
plt.axvline(human_n_counts_only_1SD_upper, color='r')

plt.show()

In [None]:
print(n_genes_only_2SD_lower)
print(human_n_counts_only_2SD_lower)

In [None]:
# Filter cells according to identified QC thresholds:
print('Total number of cells: {:d}'.format(adata_human_virus.n_obs))

sc.pp.filter_genes(adata_human_virus, min_cells=3)
print('Number of genes after min cell filter: {:d}'.format(adata_human_virus.n_obs))

sc.pp.filter_cells(adata_human_virus, min_counts = human_n_counts_only_2SD_lower)
print('Number of cells after min count filter: {:d}'.format(adata_human_virus.n_obs))

sc.pp.filter_cells(adata_human_virus, min_genes = n_genes_only_2SD_lower)
print('Number of cells after gene filter: {:d}'.format(adata_human_virus.n_obs))

In [None]:
# Re-calculate n_counts, log_counts, n_genes since genes were filtered
adata_human_virus.obs['n_counts'] = adata_human_virus.X.sum(1)
adata_human_virus.obs['log_counts'] = np.log(adata_human_virus.obs['n_counts']) #natural log
adata_human_virus.obs['n_genes'] = (adata_human_virus.X > 0).sum(1)

In [None]:
adata_human_virus.obs['n_counts'].describe()

# Assess basic characteristics of viral reads per treatment

In [None]:
adata_human_virus

In [None]:
human_virus_plot_df = pd.DataFrame(adata_human_virus.obs[['viral_transcript_frac','batch','treatment','viral_transcript_log_counts','viral_transcript_n_counts','new_multiseq_id']])

In [None]:
rcParams['figure.figsize']=(3,3)

batch_order = ['3','2','1','0']
treatment_order = ['Vehicle_Control','Heat_Killed_RSV','RSV_infected']
ax = sb.stripplot(x='batch',y='viral_transcript_frac', data=human_virus_plot_df,
              order=batch_order,
              hue='treatment',
             hue_order=treatment_order,
            dodge=True,
            size=1,
                 palette=['black','gray','firebrick'])

ax.set_xlabel('Time Point (hours)')
ax.set_ylabel('Viral Read Fraction per Cell')
ax.set_xticklabels(['0','4','8','12'])
ax.grid(False)
sb.despine(right=True)

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 12.0
plt.rcParams['legend.fontsize'] = 12.0
legend = plt.legend(bbox_to_anchor=(1.6,0.8), loc='right', borderaxespad=0)
plt.gca().get_legend().set_frame_on(False)
for text in legend.get_texts():
    text.set_fontsize(8)

## Figure S1B

In [None]:
rcParams['figure.figsize']=(5,5)
sc.pl.violin(adata_human_virus,keys='viral_transcript_frac', 
             groupby = 'new_multiseq_id', xlabel='Treatment',
             ylabel='Viral Fraction per Cell', size=1, rotation=90,
            order=['0hr_VC','4hr_VC','8hr_VC','12hr_VC',
                  '0hr_HK','4hr_HK','8hr_HK','12hr_HK',
                  '0hr_RSV','4hr_RSV','8hr_RSV','12hr_RSV'],
             palette='Blues',
             save = 'violin_viralfracpercell_bytreatment.pdf')

## Evaluate dynamic range of raw viral transcipts per cell

In [None]:
sorted_df = human_virus_plot_df.sort_values(['new_multiseq_id','viral_transcript_n_counts'])
sorted_df['index_num'] = range(len(sorted_df))
grouped = sorted_df.groupby('new_multiseq_id')
new_column_values = []

from collections import Counter
cnt = Counter()

for _, group in grouped:
    group_size=len(group)
    sorted_group = group.sort_values('viral_transcript_n_counts')
    numbers = list(range(1, group_size+1))
    new_column_values.extend(numbers)
    cnt[group_size] += 1
    #counter +=group_size

sorted_df['cell_num'] = new_column_values

In [None]:
treatment_order = ['Vehicle_Control','Heat_Killed_RSV','RSV_infected']
hue_colors = {'0': 'firebrick', 
              '1': 'darkorange', 
              '2': 'wheat',
             '3': 'gray'}


ax = sb.relplot(data=sorted_df, x="index_num", y="viral_transcript_n_counts",
           hue="batch",
           col="treatment",
           kind="line",
           col_order=treatment_order,
               palette=hue_colors,
               aspect=1,
               linewidth=5)

#ax.grid(False)
sb.despine(right=True)

plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 12.0
plt.rcParams['legend.fontsize'] = 12.0
legend = plt.legend(bbox_to_anchor=(1.6,0.8), loc='right', borderaxespad=0)


In [None]:
sorted_df_hk = sorted_df[sorted_df['treatment'] == "Heat_Killed_RSV"]
sorted_df_vc = sorted_df[sorted_df['treatment'] == "Vehicle_Control"]
sorted_df_inf = sorted_df[sorted_df['treatment'] == "RSV_infected"]

In [None]:
#this is for HUMAN ONLY
rcParams['figure.figsize']=(20,4)
fig = plt.figure()
#axs = plt.subplots(1, 4, sharex=True, sharey=True)
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i, timepoint in enumerate(['3','2','1','0']):
    ax = plt.subplot(1,4,i+1)
    _ = sb.lineplot(data = sorted_df_inf[sorted_df_inf['batch'] == timepoint],
                    x="cell_num", y="viral_transcript_n_counts",
                    ax = ax,
                   linewidth=5,
                   #color='red',
                   linestyle='-',
                    color='darkred',alpha=0.5)
    ax.set_ylim(-100, 26000)
    ax.set(xticklabels=[],xlabel=None)
    ax.tick_params(bottom=False)
    sb.despine(right=True, bottom=True, left = True)

#plt.savefig('/lineplot_infected_alltp_viral_transcripts_percell.pdf')

In [None]:
#this is for HUMAN ONLY
rcParams['figure.figsize']=(20,4)
fig = plt.figure()
#axs = plt.subplots(1, 4, sharex=True, sharey=True)
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i, timepoint in enumerate(['3','2','1','0']):
    ax = plt.subplot(1,4,i+1)
    _ = sb.lineplot(data = sorted_df_hk[sorted_df_hk['batch'] == timepoint],
                    x="cell_num", y="viral_transcript_n_counts",
                    ax = ax,
                   linewidth=5,
                   #color='red',
                    color='darkorange',alpha=1,
                    linestyle='--')
    ax.set_ylim(-100, 26000)
    ax.set(xticklabels=[],xlabel=None)
    ax.tick_params(bottom=False)
    sb.despine(right=True, bottom=True, left = True)

#plt.savefig('/lineplot_heatkilled_alltp_viral_transcripts_percell.pdf')

In [None]:
#this is for HUMAN ONLY
rcParams['figure.figsize']=(20,4)
fig = plt.figure()
#axs = plt.subplots(1, 4, sharex=True, sharey=True)
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i, timepoint in enumerate(['3','2','1','0']):
    ax = plt.subplot(1,4,i+1)
    _ = sb.lineplot(data = sorted_df_vc[sorted_df_vc['batch'] == timepoint],
                    x="cell_num", y="viral_transcript_n_counts",
                    ax = ax,
                   linewidth=5,
                    color='blue',alpha=1,
                    linestyle='--')
    ax.set_ylim(-100, 26000)
    ax.set(xticklabels=[],xlabel=None)
    ax.tick_params(bottom=False)
    sb.despine(right=True, bottom=True, left = True)
    
#plt.savefig('/lineplot_vehiclecontrol_alltp_viral_transcripts_percell.pdf')

# Investigate distribution of counts

In [None]:
# Annotate cells by percent viral transcripts 
col = 'viral_transcript_frac'
conditions = [adata_human_virus.obs[col] >= 0.3, 
              (adata_human_virus.obs[col] < 0.3) & (adata_human_virus.obs[col]>= 0.2),
              (adata_human_virus.obs[col] < 0.2) & (adata_human_virus.obs[col]>= 0.1),
              (adata_human_virus.obs[col] < 0.1) & (adata_human_virus.obs[col]>= 0.01),
              adata_human_virus.obs[col] < 0.01]
infection_fraction = [">30%virus", "20=<x<30%virus", "10=<x<20%virus",
                     "1=<x<10%virus","<1%virus"]
adata_human_virus.obs['infection_frac'] = np.select(conditions, infection_fraction, default=np.nan)
adata_human_virus.obs['infection_frac'].value_counts()

In [None]:
adata_human_virus_df = pd.DataFrame(adata_human_virus.obs)

## Figure S1 C & D

In [None]:
rcParams['figure.figsize']=(3,3)
cats = ['<1%virus','1=<x<10%virus','10=<x<20%virus','20=<x<30%virus','>30%virus']
colors = ['dimgray','darkblue','blue','cornflowerblue','turquoise']

for col, inf_frac_cat in zip (colors,cats):
    df = adata_human_virus_df[adata_human_virus_df.infection_frac== inf_frac_cat]
    _ = sb.distplot(df['n_counts'],  kde=True, label= inf_frac_cat, hist=False,
           color=col)
    _.grid(False)
    sb.despine()
plt.legend(prop={'size': 20},bbox_to_anchor=(1.05, 1))
plt.title('Total UMIs per Cell')
plt.xlabel('UMIs per Cell')
plt.ylabel('Density of cells')

plt.savefig(fig_path+'n_counts_total_infectfrac_dist.pdf')

In [None]:
rcParams['figure.figsize']=(3,3)
cats = ['<1%virus','1=<x<10%virus','10=<x<20%virus','20=<x<30%virus','>30%virus']
colors = ['dimgray','darkblue','blue','cornflowerblue','turquoise']

for col, inf_frac_cat in zip (colors,cats):
    df = adata_human_virus_df[adata_human_virus_df.infection_frac== inf_frac_cat]
    _ = sb.distplot(df['human_n_counts'],  kde=True, label= inf_frac_cat, hist=False,
           color=col)
    _.grid(False)
    sb.despine()
plt.legend(prop={'size': 20},bbox_to_anchor=(1.05, 1))
plt.title('Human UMIs per Cell')
plt.xlabel('Human UMIs per Cell')
plt.ylabel('Density of cells')

plt.savefig(fig_path+'n_counts_human_only_infectfrac_dist.pdf')

In [None]:
import scipy.stats as stats

In [None]:
# is there a correlation?
corr_coef,pval = stats.pearsonr(adata_human_virus.obs['n_counts'],adata_human_virus.obs['viral_transcript_frac'])
print("Correlation coefficient : ", corr_coef)
print("pval : ", pval)

In [None]:
# is there a correlation?
corr_coef,pval = stats.pearsonr(adata_human_virus.obs['human_n_counts'],adata_human_virus.obs['viral_transcript_frac'])
print("Correlation coefficient : ", corr_coef)
print("pval : ", pval)

## confirm that removal of low counts and low genes doesn't change these results

In [None]:
human_genes = [name for name in adata_human_virus.var_names if name.startswith('GRCh38_')]
adata_human_virus.obs['human_n_counts_postfilt'] = np.sum(adata_human_virus[:, human_genes].X, axis=1).A1
adata_human_virus_df = pd.DataFrame(adata_human_virus.obs)

In [None]:
rcParams['figure.figsize']=(3,3)
cats = ['<1%virus','1=<x<10%virus','10=<x<20%virus','20=<x<30%virus','>30%virus']
colors = ['dimgray','darkblue','blue','cornflowerblue','turquoise']

for col, inf_frac_cat in zip (colors,cats):
    df = adata_human_virus_df[adata_human_virus_df.infection_frac== inf_frac_cat]
    _ = sb.distplot(df['human_n_counts_postfilt'],  kde=True, label= inf_frac_cat, hist=False,
           color=col)
    _.grid(False)
    sb.despine()
plt.legend(prop={'size': 20},bbox_to_anchor=(1.05, 1))
plt.title('Human UMIs per Cell')
plt.xlabel('Human UMIs per Cell')
plt.ylabel('Density of cells')

#plt.savefig('human_n_counts_infectfrac_dist.pdf')

In [None]:
# is there a correlation?
corr_coef,pval = stats.pearsonr(adata_human_virus.obs['human_n_counts_postfilt'],adata_human_virus.obs['viral_transcript_frac'])
print("Correlation coefficient : ", corr_coef)
print("pval : ", pval)

In [None]:
sp.stats.ks_2samp(adata_human_virus_df[adata_human_virus_df.infection_frac== '<1%virus']['human_n_counts'],
                  adata_human_virus_df[adata_human_virus_df.infection_frac== '1=<x<10%virus']['human_n_counts'])

In [None]:
sp.stats.ks_2samp(adata_human_virus_df[adata_human_virus_df.infection_frac== '<1%virus']['human_n_counts'],
                  adata_human_virus_df[adata_human_virus_df.infection_frac== '10=<x<20%virus']['human_n_counts'])

In [None]:
sp.stats.ks_2samp(adata_human_virus_df[adata_human_virus_df.infection_frac== '<1%virus']['human_n_counts'],
                  adata_human_virus_df[adata_human_virus_df.infection_frac== '20=<x<30%virus']['human_n_counts'])

# Plot percent infection 

## Call Infected cells

In [None]:
##Add annotation for RSV viral transcripts. 
    #Call infected vs uninfected using raw counts.
    #add buffer region

col = 'viral_transcript_n_counts'
conditions = [adata_human_virus.obs[col] >= 40, 
              (adata_human_virus.obs[col] < 40) & (adata_human_virus.obs[col]>= 30),
              adata_human_virus.obs[col] <30]
infection_status = ["infected", "buffer", "uninfected"]
adata_human_virus.obs['infection_status'] = np.select(conditions, infection_status, default=np.nan)
adata_human_virus.obs['infection_status'].value_counts()

In [None]:
adata_human_virus_df = pd.DataFrame(adata_human_virus.obs)
adata_human_virus_df[['tp','multi_cond']] = adata_human_virus_df['new_multiseq_id'].str.split('_',expand=True)
adata_human_virus_df['tp'] = adata_human_virus_df["tp"]. str. replace("hr","")

In [None]:
adata_human_virus_df

In [None]:
categories_order = pd.DataFrame([0,4,8,12]).reset_index().set_index(0)
categories_order

In [None]:
total_cells_by_group = adata_human_virus_df.groupby(['treatment','tp']).count()
count_infected_by_group = adata_human_virus_df.groupby(['treatment','tp']).apply(lambda x: (x['infection_status'] == 'infected').sum())

total_cells_by_group 

In [None]:
percent_inf_df =(
    adata_human_virus_df.groupby(['treatment','tp'])
    .apply(lambda x: (x['infection_status'] == "infected").mean() * 100) #confirmed w/ unit test
    .reset_index(name="test_percent")
)
percent_inf_df

In [None]:
infected_hk = 100-0.061996
infected_hk

## Figure 1B

In [None]:
df_sorted = percent_inf_df.sort_values(by='tp')
custom_order = ['0','4','8','12']
df_sorted = df_sorted .set_index('tp').loc[custom_order].reset_index()
colors ={'RSV_infected':'Green','Heat_Killed_RSV' : 'Blue','Vehicle_Control':'Gray'}



plt.rcParams['figure.figsize'] = (4,4)
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 20.0
plt.rcParams['legend.fontsize'] = 20.0

for treatment, data in df_sorted.groupby('treatment'):
    plt.plot(data['tp'], data['test_percent'], marker='o', 
             linestyle='-', 
             color=colors[treatment], 
             label=treatment)
    
sb.despine()    
plt.legend(prop={'size': 12},bbox_to_anchor=(1.05, 1))
sb.despine()
plt.grid(False)
plt.xlabel('Timepoint (hrs)')
plt.ylabel('Percent Infection (%)')

#plt.savefig(fig_path+'lineplot_infectionpercent_by_treatment.pdf')

In [None]:
plt.rcParams['figure.figsize'] = (4,4)
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 20.0
plt.rcParams['legend.fontsize'] = 20.0

colors ={'Heat_Killed_RSV' : 'Blue', 'RSV_infected':'Green', 'Vehicle_Control':'Grey'}
sb.barplot(df_sorted, x="tp", y="test_percent", hue="treatment",)
plt.legend(prop={'size': 12},bbox_to_anchor=(1.05, 1))


sb.despine()    
plt.legend(prop={'size': 12},bbox_to_anchor=(1.05, 1))
sb.despine()
plt.grid(False)
plt.xlabel('Timepoint (hrs)')
plt.ylabel('Percent Infection (%)')

## Figure 1B

In [None]:
adata_human_virus_df['tp'] = pd.to_numeric(adata_human_virus_df['tp'])
adata_human_virus_df = adata_human_virus_df.sort_values(by='tp', ascending=True)

adata_human_virus_df_grouped = adata_human_virus_df.groupby(['tp','treatment']).size().unstack(fill_value=0)
adata_human_virus_df_grouped = np.log10(adata_human_virus_df_grouped)

ax = adata_human_virus_df_grouped.plot(kind='bar', stacked=True,
                                      color=colors
                                      )

sb.despine()    
plt.legend(prop={'size': 12},bbox_to_anchor=(1.05, 1))
sb.despine()
plt.grid(False)
plt.xlabel('Timepoint (hrs)')
plt.ylabel('Total Cell Number (log10)')
plt.xticks(rotation=0)

#plt.savefig(fig_path+'totalcellnumber_barplot_by_treatment.pdf')