# Import packages and data

In [None]:
from pathlib import Path
import re

import matplotlib as mpl
import matplotlib.pyplot as plt
from natsort import natsorted
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder

sns.set(font_scale=1, style="ticks", context="paper", palette="tab10")

In [None]:
adata=sc.read_h5ad('COVID_dataset_scvi_OSN.h5ad')

# Get OR info per cell

In [None]:
#add info about human ORs
#using annotation from Barnes et al. 2020
human_anno = pd.read_excel(
    "https://static-content.springer.com/esm/art%3A10.1186%2Fs12864-020-6583-3/MediaObjects/12864_2020_6583_MOESM2_ESM.xlsx",
    sheet_name=0,
)
human_anno.columns = human_anno.columns.map(lambda l: l.replace(" ", "_").lower())
bm_human = (
    human_anno[["gene_symbol", "gene_name", "chromosome", "strand", "ensembl_gene_id"]]
    .drop_duplicates()
    .reset_index(drop=True)
    .set_index("gene_symbol")
)
ct = (
    human_anno[human_anno.gene_symbol.isin(bm_human.index)]
    .groupby("gene_symbol")
    .transcript_biotype.value_counts()
    .unstack()
    .replace(np.nan, 0)
)

In [None]:
bm_human = bm_human.join(ct[["protein_coding", "unprocessed_pseudogene"]])
func_human = bm_human[["ensembl_gene_id"]].reset_index()
func_human.columns = ["gene", "Ens"]
is_olfr = adata.var.gene_ids.isin(bm_human.ensembl_gene_id)
adata.obs["olfr_sum"] = adata.layers["counts"][:, is_olfr].sum(1).A.flatten()
adata.obs["olfr_max"] = adata.layers["counts"][:, is_olfr].max(1).A.flatten()
adata.obs["olfr_norm_sum"] = adata.layers["norm"][:, is_olfr].sum(1).A.flatten()
adata.obs["olfr_norm_max"] = adata.layers["norm"][:, is_olfr].max(1).A.flatten()


In [None]:
OLFR_THRESH = 2
olfr_raw = adata.layers["counts"][:, is_olfr].A
has_any_olfr = (olfr_raw > OLFR_THRESH).any(1)
olfr_names = adata.var_names[is_olfr]
df_olfr_family = pd.DataFrame(
    [re.split("(OR[\\d]+)([A-Z])", s)[1:] for s in olfr_names],
    columns=["family", "sub_family", "family_number"],
    index=olfr_names,
)
df_olfr = pd.DataFrame(
    olfr_names[olfr_raw.argmax(1)], index=adata.obs_names, columns=["top_Olfr"]
)[has_any_olfr].merge(df_olfr_family, left_on="top_Olfr", right_index=True, how="left")

In [None]:
df = adata.obs.copy()
cat_cols = df.columns[df.dtypes == "category"]
df[cat_cols] = df[cat_cols].astype(str)
df_umap = pd.DataFrame(
    adata.obsm["X_umap"], index=adata.obs_names, columns=["UMAP1", "UMAP2"]
)
df = df.join(df_umap)
df["has_OR"] = has_any_olfr * 100

In [None]:
#Annotations for plotting
#For example
COVID = "covid"
CONTROL = "normosmic"
CONDS = (CONTROL, COVID)
N_COVID = "COVID"
N_CONTROL = "control"
N_CONDS = (N_CONTROL, N_COVID)
N_COLOR_MAP = dict(zip(N_CONDS, plt.cm.Set1.colors))
cond_mapping = dict(zip(CONDS, N_CONDS))
PATIENTS = [
    "COVID_1",
    "COVID_2",
    "Normosmic_1",
    "Normosmic_2"
]
is_patient = df.orig_ident.isin(PATIENTS)
df["cond2"] = df.cond.map(cond_mapping)
df_patient = df[is_patient].copy()
df_anno = (
    df.groupby(["cond", "orig_patients"], as_index=False)
    .total_counts.count()
    .set_index("orig_patients")
)

# Make plots

In [None]:
#Plot OR expressing cells
has_OR = df_patient["has_OR"] > 0
xu, yu = df_patient[["UMAP1", "UMAP2"]].values.T

In [None]:
fig, ax = plt.subplots(figsize=(1.5, 1.5))
ax.scatter(xu[~has_OR], yu[~has_OR], lw=0, s=3, color="0.5", alpha=0.5, label="no OR")
ax.scatter(
    xu[has_OR], yu[has_OR], lw=0, s=3, color="k", alpha=0.85, label="OR-expressing"
)
ax.axis("off")
leg = ax.legend(
    frameon=True,
    loc="lower center",
    bbox_to_anchor=(0.5, 0.975),
    handletextpad=0,
    ncol=2,
    columnspacing=0,
    markerscale=2,
)
leg.get_frame().set_linewidth(0.5)

In [None]:
fig, ax = plt.subplots(figsize=(1.5, 1.5))

sns.barplot(
    data=df_patient,
    x="osn_clusters",
    y="has_OR",
    order=order,
    hue="cond2",
    palette="Set1",
    errwidth=0.75,
)
sns.despine()
ax.set_ylim(0, 100)
ax.legend(loc="upper left", bbox_to_anchor=(0, 1.05))
ax.set_xlabel("Cluster")
ax.set_ylabel("OR-expressing cells (%)")

In [None]:
#Plot OR expression
fig, ax = plt.subplots(figsize=(1.5, 1.5))
cbar_ax = fig.add_axes([0.95, 0.2, 0.05, 0.6])
df_patient["olfr_log"] = np.log1p(df_patient["olfr_norm_max"])
c = df_patient["olfr_log"]
vmax = np.percentile(c, 99)
im = ax.scatter(
    xu,
    yu,
    c=c,
    s=3,
    cmap="cmo.matter",
    lw=0,
    vmax=vmax,
)
cbar = plt.colorbar(im, cax=cbar_ax)
cbar.set_label("OR expression\n(log-normalized)")
cbar.solids.set_rasterized(True)
cbar.solids.set_edgecolor("face")
cbar.outline.set_visible(False)
ax.axis("off")

In [None]:
#OR family heatmap
df_patient_osn = df_patient[df_patient.osn_clusters.isin(("iOSN", "mOSN"))].join(
    df_olfr
)
df_family_counts = pd.crosstab(df_patient_osn.family, df_patient_osn.cond)
df_pivot_norm = df_family_counts / df_family_counts.sum(0) * 100
df_to_plot = df_pivot_norm[["normosmic", "covid"]]
df_to_plot.columns = ["control", "COVID"]
or_fam_order = natsorted(df_to_plot.index.tolist())
df_to_plot = df_to_plot.loc[or_fam_order]

In [None]:
fig, ax = plt.subplots(figsize=(1, 2))
cbar_ax = fig.add_axes([1, 0.2, 0.08, 0.6])
sns.heatmap(
    df_to_plot,
    cmap="viridis",
    cbar_kws={"label": "% of OSNs from OR family"},
    cbar_ax=cbar_ax,
    ax=ax,
)
ax.axvline(1, color="white", lw=0.75)
ax.set_ylabel("OR family")
ax.set_xlabel(None)
cbar_ax.xaxis.set_label_position("top")
cbar_ax.xaxis.set_ticks_position("top")
cbar_ax.xaxis.set_tick_params(pad=0)
ax.yaxis.set_tick_params(pad=1)

In [None]:
#Calculate and plot transcriptomic distance
X_pca = adata.obsm["X_scVI_OSN"][is_patient]
dmat = metrics.pairwise_distances(X_pca, metric="correlation")
df_dmat = pd.DataFrame(
    dmat, index=adata.obs_names[is_patient], columns=adata.obs_names[is_patient]
)
df_corr_nan = df_dmat.copy()
df_corr_nan.index.name = "cell_1"
df_corr_nan.columns.name = "cell_2"
np.fill_diagonal(df_corr_nan.values, np.nan)
df_corr_stack = df_corr_nan.stack().rename("corr").reset_index()
df["merge_clust"] = df.osn_clusters.replace({"GBC": "GBC/INP", "INP": "GBC/INP"})
df_corr_stack = df_corr_stack.merge(
    df[["cond", "merge_clust"]], left_on="cell_1", right_index=True
).merge(
    df[["cond", "merge_clust"]],
    left_on="cell_2",
    right_index=True,
    suffixes=["_1", "_2"],
)

In [None]:
#median pairwise distance for cell pairs from each cluster
df_piv = (
    df_corr_stack.groupby(["cond_1", "merge_clust_1", "cond_2", "merge_clust_2"])[
        "corr"
    ]
    .median()
    .reset_index()
    .pivot(
        index=["merge_clust_1", "cond_1"],
        columns=["merge_clust_2", "cond_2"],
        values="corr",
    )
)
new_order = ["GBC/INP", "iOSN", "mOSN"]
df_pivot = df_piv[new_order].copy()

In [None]:
#plot
o_color = LabelEncoder().fit_transform(df_pivot.index.get_level_values(0))[
    :, np.newaxis
]
m_color = LabelEncoder().fit_transform(df_pivot.index.get_level_values(1))[
    np.newaxis, :
]
deltas = np.where(np.diff(o_color.flatten()) != 0)[0]
deltas = np.array([0, *deltas, len(o_color)])
width_ratios = [0.1, 1]
height_ratios = [0.1, 1]
fig = plt.figure(figsize=(sum(width_ratios), sum(height_ratios)))
gs = mpl.gridspec.GridSpec(
    len(height_ratios),
    len(width_ratios),
    fig,
    0.0,
    0.0,
    1,
    1,
    height_ratios=height_ratios,
    width_ratios=width_ratios,
    wspace=0,
    hspace=0,
)
dist_ax = fig.add_subplot(gs[1, 1], xticks=[], yticks=[], xlabel=None, frameon=False)
dist_im = dist_ax.imshow(df_pivot.values, vmin=0, vmax=2, cmap="PRGn")
cmap = mpl.colors.LinearSegmentedColormap.from_list(
    "cmap", ("#776eb8", "#f1ad0c", "#666666")
)
top_cmap = mpl.colors.LinearSegmentedColormap.from_list("cmap", plt.cm.Set1.colors[:2])

left_ax = fig.add_subplot(gs[1, 0], xticks=[], yticks=[], frameon=False)
left_ax.imshow(
    o_color,
    interpolation="none",
    cmap=cmap,
    aspect="auto",
    rasterized=True,
)
left_ax.set_yticks([0.5, 2.5, 4.5])
left_ax.set_yticklabels(new_order)
top_ax = fig.add_subplot(gs[0, 1], xticks=[], yticks=[], frameon=False)
top_ax.imshow(
    m_color,
    interpolation="none",
    cmap=top_cmap,
    aspect="auto",
    rasterized=True,
)
cb_ax = fig.add_axes([1.1, 0.09, 0.07, 0.75])
cbar = fig.colorbar(dist_im, cax=cb_ax)
cbar.outline.set_visible(False)
cbar.set_label(r"Transcriptome Distance")
cbar.set_ticks(mpl.ticker.MultipleLocator(0.5))
legend_TN = [mpl.patches.Patch(color=c, label=l) for l, c in N_COLOR_MAP.items()]
leg = top_ax.legend(
    handles=legend_TN,
    frameon=True,
    loc="lower center",
    bbox_to_anchor=(0.5, 1),
    ncol=2,
    columnspacing=0.5,
    handletextpad=0.2,
)
leg.get_frame().set_linewidth(0.5)