In [None]:
import scanpy as sc
import anndata as ann
import numpy as np
import scipy as sp
import pandas as pd
import matplotlib.pyplot as plt
import glob
from matplotlib import rcParams
from matplotlib import colors
import logging

import seaborn as sb

sc.settings.verbosity = 3


plt.rcParams['figure.figsize']=(8,8) #rescale figures
sc.settings.verbosity = 3
sc.set_figure_params(dpi=200, dpi_save=300)
sc.logging.print_versions()

In [None]:
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)

In [None]:
# Set size for plots
sb.set_context(context='paper')

In [None]:
#set analysis version
version = "V1"
#set output files_path
output_files_path = "/Sunshine_DeRisi_RSV_files/"

fig_path = "/Sunshine_DeRisi_RSV_files/figures/"

In [None]:
sc.settings.figdir = fig_path

In [None]:
name = "2024_RSV_annotated_filtered_human_virus"###doublets are removed
preprocessed_path = output_files_path+version+'_'+name+'.h5ad'
adata_human_virus = sc.read_h5ad(preprocessed_path)

In [None]:
#id for read origin
RSV = [name for name in adata_human_virus.var_names if name.startswith('RSV')]
RSV_genome = [name for name in adata_human_virus.var_names if name.startswith('genome_RSV')]
human_genes = adata_human_virus.var_names.str.startswith('GRCh38_')
virus_genes = RSV + RSV_genome

## for downstream analysis of top Cas9 screening results

In [None]:
rsv_cas9_genes = ['TMEM165','SYS1','ARFRP1','HS6ST1','NDST1','TM9SF2','SLC39A9',
                  'HS3ST6','COG3','COG4','COG5','COG6','B3GAT3','B3GALT6','SLC35B2',
                  'EXT1','EXTL3','EXT2','B4GALT7','UNC50','ATP6V1B2','RAB4A']
cas9_gene_list = ['GRCh38_' + gene for gene in rsv_cas9_genes]

# determine what genes have at least 5 counts in at least 10 cells across this time point
# subsetting expressed genes for downstream analysis


adata_human_virus_cas9gene_subset = adata_human_virus[:,adata_human_virus.var_names.isin(cas9_gene_list)]
cell_ids = adata_human_virus_cas9gene_subset.obs.index
count_matrix = adata_human_virus_cas9gene_subset.X.toarray()
counts_df = pd.DataFrame(count_matrix, index = cell_ids, columns = adata_human_virus_cas9gene_subset.var_names)

cells_with_counts = (counts_df >= 5).sum(axis=0)
genes_above_10 = cells_with_counts[cells_with_counts > 10]
cas9_gene_subset = genes_above_10.index.tolist()
cas9_gene_subset

# Normalize, log1p,scale

In [None]:
sc.pp.normalize_per_cell(adata_human_virus)

adata_human_virus.obs['n_counts_norm'] = adata_human_virus.X.sum(1)
adata_human_virus.obs['n_counts_norm_log'] = np.log1p(adata_human_virus.obs['n_counts_norm'])

#Sum the number of human and viral transcripts per cell POST NORM
adata_human_virus.obs ['human_n_counts_norm'] = np.sum(adata_human_virus[:, human_genes].X, axis=1).A1

adata_human_virus.obs['viral_transcript_n_counts_norm'] = np.sum(adata_human_virus[:, RSV].X, axis=1).A1

#Sum the number of human and virus transcripts per cell and log transform (ln+1)
adata_human_virus.obs ['viral_transcript_n_counts_norm_log'] = np.log1p(np.sum(adata_human_virus[:, RSV].X, axis=1).A1)
adata_human_virus.obs ['human_n_counts_norm_log'] = np.log1p(np.sum(adata_human_virus[:, human_genes].X, axis=1).A1)

In [None]:
# filter genes not present within this time point
sc.pp.filter_genes(adata_human_virus, min_cells=3)
human_genes = adata_human_virus.var_names.str.startswith('GRCh38_')

In [None]:
#log1p
sc.pp.log1p(adata_human_virus)
logging.info('Log transforming data')
adata_human_virus.raw = adata_human_virus
logging.info('Saving log(counts)+1 in .raw')

# subset to only include infected/bystander cells for each time point 

In [None]:
adata_human_virus_subset = adata_human_virus[((adata_human_virus.obs.infection_status != 'buffer')),:].copy()
adata_human_virus_subset

In [None]:
adata_human_virus_subset.obs.head()

# create df that includes genes of interest

In [None]:
adata_human_virus_subset.obs['treatment_infectionstatus'] = adata_human_virus_subset.obs['treatment'].astype(str)+'_'+adata_human_virus_subset.obs['infection_status'].astype(str)
adata_human_virus_subset.obs['treatment_infectionstatus'] = adata_human_virus_subset.obs['treatment_infectionstatus'].astype("category")

In [None]:
cols_of_interest = ['batch','new_multiseq_id','treatment','infection_status','treatment_infectionstatus']
adata_human_virus_subset_df = adata_human_virus_subset.obs[cols_of_interest].copy()
adata_human_virus_subset_df

In [None]:
cas9_expression_data = adata_human_virus_subset.raw[:, cas9_gene_list].X.toarray()

# Create a dataframe with expression data
cas9_gex_df = pd.DataFrame(cas9_expression_data, columns=cas9_gene_list, index=adata_human_virus_subset.obs_names)

In [None]:
gex_metadata_df = pd.concat([adata_human_virus_subset_df, cas9_gex_df], axis=1)

In [None]:
pd.crosstab(adata_human_virus_subset_df['new_multiseq_id'],adata_human_virus_subset_df['infection_status'])

In [None]:
#genes selected from above
cas9_gene_subset_re_ordered = ['GRCh38_TMEM165',
     'GRCh38_TM9SF2',
     'GRCh38_ARFRP1',
     'GRCh38_UNC50',
    'GRCh38_B4GALT7',
     'GRCh38_SLC35B2',
     'GRCh38_B3GAT3',
     'GRCh38_ATP6V1B2',
     'GRCh38_RAB4A',]

# Figure S5

In [None]:
gene_list = cas9_gene_subset_re_ordered
palette = {'infected': 'lightcoral', 'uninfected': 'gainsboro'}
order = ['0hr_VC','4hr_VC','8hr_VC','12hr_VC',
         '0hr_HK','4hr_HK','8hr_HK','12hr_HK',
         '0hr_RSV','4hr_RSV','8hr_RSV','12hr_RSV',
        ]

fig, axes = plt.subplots(nrows=len(gene_list), ncols=1,
                                   figsize=(16, 2 * len(gene_list)), sharex=True)

if len(gene_list) == 1:
    axes = [axes]
    
for i, gene in enumerate(gene_list):
    ax=axes[i]
    
    sb.violinplot(x='new_multiseq_id', y=gene, hue='infection_status',
                  data=gex_metadata_df,
                  split=True,
                  palette=palette,inner=None,
                  dodge=False,
                  alpha=0.1,
                  scale='count',
                  linewidth=0.5,
                  gap=5, order = order,saturation=0.9,
                  ax=ax
             #density_norm='count'
             )
    sb.stripplot(x='new_multiseq_id', y=gene, hue='infection_status', 
             data=gex_metadata_df,dodge=True, 
             marker='o', alpha=1, size=0.2,
             palette = palette, 
             order = order,
            edgecolor=['black'],
             linewidth=0.1,
            legend=False,
                ax=ax,
                 rasterized=True
                )

    ax.set_ylim(bottom=0)
    ax.set_ylabel(gene.replace('GRCh38_', '')) #plt.ylabel(gene.replace('GRCh38_', ''))
    ax.set_xlabel('')
    ax.set_xticks([])
    ax.legend_.remove() if ax.legend_ else None 
    

# Set x-label only on the last subplot
axes[-1].set_xlabel("Time Point - Treatment")
axes[-1].set_xticks(range(len(order)))
axes[-1].set_xticklabels(order, rotation=90)

# Common legend (optional): create outside of subplots
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, title='Infection Status', loc='upper right', bbox_to_anchor=(1.1, 1.0))

sb.despine()
plt.tight_layout(rect=[0, 0, 0.95, 1])  # Adjust for legend
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 12.0
plt.rcParams['legend.fontsize'] = 12.0

#plt.savefig(f"{fig_path}/cas9_subset_genes_combined_violin_plots.pdf",dpi=300, bbox_inches='tight')