# Run MOFA analysis in Teaseq data for CD4 Naive cells
- CD4 naive cells ARI vs CON2

1. preprocess clinical lab, olink, scRNA, scATAC data
    - extract the sample pseudobulk level data
    - for atac-seq data, TF activivies infered from chromVAR is used
2. run MOFA
3. examine MOFA factor 1

# set up

In [None]:
# import libraries
quiet_library <- function(...) {
    suppressPackageStartupMessages(library(...))
}
quiet_library("tidyverse")
quiet_library("Matrix")
quiet_library("viridis")
quiet_library("scran")
quiet_library("scater")
quiet_library("MOFA2")
quiet_library("data.table")
quiet_library("jsonlite")
quiet_library("parallel")
quiet_library("Seurat")
quiet_library("ggpubr")


In [None]:
# define file path
fig_path <- "/home/jupyter/figures/preRA_teaseq/MOFA"
data_path <- "/home/jupyter/data/preRA_teaseq/EXP-00243"
meta_path <- "/home/jupyter/data/preRA_teaseq/meta_data"
output_path <- "/home/jupyter/data/preRA_teaseq/output_results/MOFA"
if (!dir.exists(fig_path)) (dir.create(fig_path, recursive = TRUE))
if (!dir.exists(output_path)) (dir.create(output_path, recursive = TRUE))
# define a project name
proj_name <- "PreRA_teaseq_MOFA_cd4na"


In [None]:
# define the color palette to be used
npg_color <- c(
    "#E64B35FF", "#4DBBD5FF", "#00A087FF", "#3C5488FF", "#F39B7FFF",
    "#8491B4FF", "#91D1C2FF", "#DC0000FF", "#7E6148FF", "#B09C85FF"
)
nejm_color <- c("#BC3C29FF", "#0072B5FF", "#E18727FF", "#20854EFF", "#7876B1FF", "#6F99ADFF", "#FFDC91FF", "#EE4C97FF")
jama_color <- c("#374E55FF", "#DF8F44FF", "#00A1D5FF", "#B24745FF", "#79AF97FF", "#6A6599FF", "#80796BFF")
jco_color <- c("#0073C2FF", "#EFC000FF", "#868686FF", "#CD534CFF", "#7AA6DCFF", "#003C67FF", "#8F7700FF")
cluster_colors <- c(
    "#DC050C", "#FB8072", "#1965B0", "#7BAFDE", "#882E72", "#B17BA6", "#FF7F00", "#FDB462", "#E7298A",
    "#E78AC3", "#33A02C", "#B2DF8A", "#55A1B1", "#8DD3C7", "#A6761D", "#E6AB02", "#7570B3", "#BEAED4", "#666666", "#999999",
    "#aa8282", "#d4b7b7", "#8600bf", "#ba5ce3", "#808000", "#aeae5c", "#1e90ff", "#00bfff", "#56ff0d", "#ffff00"
)
con_ari_colors <- c("#5AAA46", "#F59F00")
cluster_colors_ext <- colorRampPalette(cluster_colors)(36)
options(repr.plot.width = 20, repr.plot.height = 15)


In [None]:
# source the helper functions
source("/home/jupyter/github/Teaseq-analysis/scRNA_teaseq_ananlysis_helper_functions.r")


# downstream analysis

In [None]:
# load the trained model
model <- load_model(file.path(output_path, "mofa_preRA_cd4_rna_adt_olink_tf_pseudobulk_model_09202023.hdf5"))


In [None]:
plot_data_overview(model)
ggsave(file.path(fig_path, paste0(proj_name, "_data_overview.png")),
       width = 4, height = 4
)


In [None]:
# add metadata to the model
meta_data <- meta_data %>%
    filter(sample != "BR2024") %>%
    mutate(status = factor(recode(cohort, "at_risk" = "ARI", "healthy" = "CON2"),
        levels = c("CON2", "ARI")
    ))
meta_data


In [None]:
# add metadata to the model
samples_metadata(model) <- meta_data


In [None]:
# extract the facot
factors_values <- get_factors(model,
  factors = "all",
  as.data.frame = TRUE
) %>%
  as_tibble() %>%
  left_join(samples_metadata(model), by = "sample") %>%
  mutate(cohort = factor(cohort, levels = c("healthy", "at_risk")))
head(factors_values)


In [None]:
# check if the factors are correlated with each other
plot_factor_cor(model)


In [None]:
# Variance explained for every factor in per view and group
variance_exp <- model@cache$variance_explained$r2_per_factor[[1]] %>% as_tibble(rownames = "factor")
variance_exp %>% head()


In [None]:
# plot the total variance explained
total_variance <- calculate_variance_explained(model)$r2_total$single_group %>%
    as_tibble(rownames = "modality") %>%
    dplyr::rename("variance" = "value")
total_variance %>% ggbarplot(x = "modality", y = "variance", fill = "steelblue")


In [None]:
calculate_variance_explained(model)


In [None]:
# plot the variance explained by factor 1
f1_variance <- calculate_variance_explained(model)$r2_per_factor$single_group %>%
    as_tibble(rownames = "factor") %>%
    filter(factor == "Factor1") %>%
    pivot_longer(cols = -factor, names_to = "modality", values_to = "variance") %>%
    mutate(modality = factor(modality, levels = c("olink", "rna", "tf", "adt")))
f1_variance %>% ggbarplot(
    x = "modality", y = "variance", ylab = "", xlab = "", title = "% variance \nexplained by F1",
    fill = "modality", legend = "none", rotate = TRUE,
    palette = npg_color
)
ggsave(file.path(fig_path, paste0(proj_name, "_variance_decomposition_f1.pdf")), width = 3, height = 3)


In [None]:
# plot the variance in each modality explained by the mofa factors
p1 <- plot_variance_explained(model, x = "view", y = "factor") +
    scale_x_discrete(labels = c(
        "adt" = "ADT", "olink" = "OLINK",
        "rna" = "RNA", "tf" = "ATAC (tf)"
    )) +
    scale_y_discrete(labels = c(
        "Factor1" = "F1", "Factor2" = "F2", "Factor3" = "F3",
        "Factor4" = "F4", "Factor5" = "F5", "Factor6" = "F6"
    )) +
    theme(
        axis.text.x = element_text(
            size = 16, angle = 45, hjust = 1
        ),
        axis.text.y = element_text(
            size = 16
        )
    )
p1
ggsave(file.path(fig_path, paste0(proj_name, "_variance_decomposition_factor_heatmap.pdf")), width = 3, height = 3)


In [None]:
# plot the variaiance explained
p1 <- plot_variance_explained(model, x = "view", y = "factor")
p1 + ylab("Factors") + xlab("Modalities") + scale_x_discrete(labels = c("ADT", "Plasma\nProtein", "RNA", "TF")) +
    theme(
        axis.title = element_text(size = 24),
        axis.text.x = element_blank(),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank()
    )
ggsave(file.path(fig_path, paste0(proj_name, "_variance_decomposition_facoter_modality.pdf")), width = 4, height = 4)


In [None]:
plot_variance_explained(model, x = "group", y = "factor", plot_total = T)[[2]]
ggsave(file.path(fig_path, paste0(proj_name, "_total_variance_explained.pdf")), width = 4, height = 4)


In [None]:
# run glm ARI vs healthy
glm_test <- function(data, formula) {
    glm_res <- broom::tidy(stats::glm(as.formula(formula), data = data))
    return(glm_res)
}
stats_glm <- factors_values %>%
    mutate(status = factor(status, levels = c("CON2", "ARI"))) %>%
    group_by(factor) %>%
    group_modify(~ glm_test(.x, formula = "value ~ status + age"))
stats_glm %>%
    filter(term != "(Intercept)") %>%
    rstatix::adjust_pvalue(p.col = "p.value", method = "BH") %>%
    arrange(p.value.adj) %>%
    filter(term == "statusARI")
# factor 1 significant different between ARI and healthy


In [None]:
# plot factor1
factors_values %>%
    filter(factor == "Factor1") %>%
    ggpubr::ggboxplot(
        x = "status", y = "value", add = "jitter",
        color = "status", palette = con_ari_colors
    ) +
    # ggpubr::stat_compare_means() +
    NoLegend() + ggtitle("F1") +
    theme(
        plot.title = element_text(hjust = 0.5, size = 16),
        axis.text.x = element_text(size = 16),
        axis.title.y = element_text(size = 16),
        axis.text.y = element_text(size = 16)
    ) + xlab("") + ylab("Factor score")

ggsave(file.path(fig_path, paste0(proj_name, "_factor1_status.pdf")), width = 3, height = 3)


In [None]:
# plot other factors for comparison
factors_values %>%
  filter(factor != "Factor1") %>%
  ggpubr::ggviolin(
    x = "status", y = "value", add = "dotplot",
    color = "status", palette = con_ari_colors
  ) +
  # ggpubr::stat_compare_means() +ß
  NoLegend() + # ggtitle('Factor 1')+
  theme(plot.title = element_text(hjust = 0.5)) + xlab("") + ylab("Factor score") + facet_wrap(vars(factor))
ggsave(file.path(fig_path, paste0(proj_name, "_notsig_factor_status.pdf")), width = 4, height = 4)


## analyze Factor1

In [None]:
# plot the top features of foctor1
factor <- 1
p1 <- plot_weights(model,
  view = "rna",
  factor = factor,
  nfeatures = 10, # Number of features to highlight
  scale = T, # Scale weights from -1 to 1
  abs = F # Take the absolute value?
)
p2 <- plot_weights(model,
  view = "adt",
  factor = factor,
  nfeatures = 10, # Number of features to highlight
  scale = T, # Scale weights from -1 to 1
  abs = F # Take the absolute value?
)
p3 <- plot_weights(model,
  view = "olink",
  factor = factor,
  nfeatures = 10, # Number of features to highlight
  scale = T, # Scale weights from -1 to 1
  abs = F # Take the absolute value?
)
p4 <- plot_weights(model,
  view = "tf",
  factor = factor,
  nfeatures = 10, # Number of features to highlightß
  scale = T, # Scale weights from -1 to 1
  abs = F # Take the absolute value?
)
cowplot::plot_grid(p1, p2, p3, p4, nrow = 2)
ggsave(file.path(fig_path, paste0(proj_name, "_cd4na_factor1_weights.png")),
  width = 12, height = 8
)


In [None]:
# extract the weight of the models
weight <- get_weights(model, views = "all", as.data.frame = TRUE) %>%
    mutate(direction = if_else(value > 0, "up", "down"))
weight %>% head()


In [None]:
# save feature weights
weight %>% write_tsv(file.path(output_path, "Mofa_preRA_cd4na_09202023_factor_weights.tsv"))


In [None]:
# plot dot plot of top features
f1_top_weights <- weight %>%
  filter(factor == "Factor1") %>%
  group_by(view) %>%
  slice_max(order_by = value, n = 10) %>%
  arrange(desc(value)) %>%
  mutate(
    view = recode(view,
      "rna" = "RNA", "tf" = "Transcription Factor",
      "adt" = "Surface protein", "olink" = "plasma protein"
    ),
    feature = str_split(feature, pattern = "_", simplify = TRUE)[, 2]
  )
ggdotchart(f1_top_weights,
  x = "feature", y = "value",
  color = "view", # Color by groups
  palette = npg_color, # Custom color palette
  sorting = "descending", # Sort value in descending order
  rotate = TRUE, # Rotate vertically
  dot.size = 3, # Large dot size
  y.text.col = TRUE, # Color y text by groups
  ggtheme = theme_pubr() # ggplot2 theme
) +
  theme_cleveland()
ggsave(file.path(fig_path, paste0(proj_name, "_cd4na_factor1_weights_dotplot.pdf")),
  width = 6, height = 8
)


### check TF acitivity associated with factor 1

In [None]:
# plot dot plot of top features
f1_tf_weights <- weight %>%
    filter(factor == "Factor1" & view == "tf") %>%
    group_by(direction) %>%
    arrange(desc(value)) %>%
    mutate(`Transcription factors` = str_remove(if_else(feature %in% c("tf_NFATC3", "tf_NFATC2", "tf_NFATC1", "tf_NFATC4"),
        "NFATs", "Other TFs"
    ), "tf_")) %>%
    mutate(
        view = recode(view,
            "rna" = "RNA", "tf" = "Transcription Factor",
            "adt" = "Surface protein", "olink" = "plasma protein"
        ),
        feature = str_split(feature, pattern = "_", simplify = TRUE)[, 2]
    ) %>%
    mutate(
        `Transcription Factor Activity` = factor(if_else(direction == "up", "Enriched in ARI", "Enriched in Controls"),
            levels = c("Enriched in Controls", "Enriched in ARI")
        ),
        rank = rank(value, )
    )
f1_tf_weights %>% head()
# f1_tf_weights%>%ggplot(aes(x=value, y=))


#### plot fig 

In [None]:
# plot dot plot of top features
options(repr.plot.width = 5, repr.plot.height = 5)
f1_top_tf_weights <- weight %>%
  filter(factor == "Factor1" & view == "tf") %>%
  group_by(direction) %>%
  slice_max(order_by = abs(value), n = 15) %>%
  arrange(desc(value)) %>%
  mutate(`Transcription factors` = str_remove(if_else(feature %in% c("tf_NFATC3", "tf_NFATC2", "tf_NFATC1", "tf_NFATC4"),
    "NFATs", "Other TFs"
  ), "tf_")) %>%
  mutate(
    view = recode(view,
      "rna" = "RNA", "tf" = "Transcription Factor",
      "adt" = "Surface protein", "olink" = "plasma protein"
    ),
    feature = str_split(feature, pattern = "_", simplify = TRUE)[, 2]
  ) %>%
  mutate(`TF Activity` = factor(if_else(direction == "up", "Enriched in ARI", "Enriched in Controls"),
    levels = c("Enriched in Controls", "Enriched in ARI")
  ))

ggdotchart(f1_top_tf_weights,
  x = "feature", y = "value",
  color = "TF Activity", # Color by groups
  palette = con_ari_colors, # Custom color palette
  sorting = "descending", # Sort value in descending order
  rotate = TRUE, # Rotate vertically
  dot.size = 3, # Large dot size
  y.text.col = TRUE, # Color y text by groups
  # ggtheme = theme_bw(),                    # ggplot2 theme
  ylab = "Weight associated with Factor 1"
) +
  theme_cleveland() + geom_hline(yintercept = 0, linetype = 2) +
  theme(
    legend.position = "top", legend.box = "vertical", legend.margin = margin(),
    axis.text.y = element_text(color = "black")
  ) + guides(color = guide_legend(nrow = 2))
ggsave(file.path(fig_path, paste0(proj_name, "_cd4na_factor1_tf_weights_dotplot.pdf")),
  width = 5, height = 5
)


In [None]:
### plot fig S7D

In [None]:
p1 <- chromvar_tf %>%
    mutate(status = factor(if_else(str_detect(subject_id, "CU"), "ARI", "CON2"),
        levels = c("CON2", "ARI")
    )) %>%
    filter(tf %in% c("NFATC3", "NFATC1", "NFATC2", "FOXP3")) %>%
    ggpubr::ggboxplot(
        x = "status", y = "value", palette = con_ari_colors,
        ylab = "Chromvar Z scores", xlab = "",
        color = "status", add = "jitter"
    ) +
    ggpubr::stat_compare_means(
        method = "wilcox.test", label.y = 2.5,
        aes(label = paste0("p = ", after_stat(p.format)))
    ) +
    facet_wrap(vars(tf), ncol = 2)
p1
ggsave(file.path(fig_path, paste0(proj_name, "_NFATs_Chromvar_zscores.pdf")), width = 4, height = 4)


In [None]:
# plot f1 tf values
f1_tf_values <- weight %>%
  filter(view == "tf" & factor == "Factor1") %>%
  arrange(desc(value)) %>%
  mutate(
    rank = 1:length(value),
    label_feature = if_else(str_remove(feature, "tf_") %in% c("NFATC3", "NFATC2", "NFATC1", "NFATC4", "FOXP3", "BATF3", "EGR2"),
      feature, NA
    )
  )
f1_tf_values %>% ggplot(aes(x = value, y = rank, label = label_feature, color = label_feature)) +
  ggrepel::geom_text_repel(size = 6, nudge_x = 0.1) +
  geom_vline(xintercept = 0.4) +
  scale_y_reverse() +
  geom_point(size = 0.8) +
  theme_bw() +
  theme(
    legend.position = "none",
    axis.text.y = element_blank(),
    axis.ticks.y = element_blank()
  )
ggsave(
  file.path(fig_path, paste0(
    proj_name,
    "_cd4na_factor1_TF_weights_NFAT.pdf"
  )),
  width = 4, height = 4
)


In [None]:
# check the Natural cutoff for tf
p1 <- plot_weights(model,
  view = "tf",
  factor = 1,
  manual = c("tf_NFATC3", "tf_NFATC2", "tf_NFATC1", "tf_NFATC4"),
  text_size = 4,
  nfeatures = 20, # Number of features to highlight
  scale = T, # Scale weights from -1 to 1
  abs = F # Take the absolute value?
)

p1 + geom_vline(xintercept = 0.4) + ggrepel::geom_label_repel(aes(label = feature)) #+ scale_color_manual(values = cluster_colors)
ggsave(
  file.path(fig_path, paste0(
    proj_name,
    "_cd4na_factor1_TF_weights.pdf"
  )),
  width = 4, height = 4
)


In [None]:
factor1_tf <- weight %>%
    filter(factor == "Factor1" & view == "tf") %>%
    mutate(direction = if_else(value > 0, "healthy", "at-risk")) %>%
    arrange(desc(abs(value))) %>%
    filter(!str_detect(feature, "ENSG|DUX"))
# factor1_tf %>% group_by(direction) %>% slice_max(order_by = abs(value),n = 60)


In [None]:
plot_data_scatter(model,
     factor = 1,
     features = c(
          "tf_NFATC3", "tf_STAT5A", "tf_STAT3", "tf_JUND", "tf_BCL6", "tf_FOS", "tf_JUNB",
          "tf_BATF"
     ),
     view = "tf", color_by = "cohort"
)
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_tf_correlation_cohort_nfat_partners.png")),
     width = 8, height = 4
)


In [None]:
# get factor 1 and tf data
f1_tf_data <- get_data(model,
    as.data.frame = TRUE,
    view = "tf"
) %>%
    mutate(feature = str_remove(feature, "tf_")) %>%
    dplyr::rename("TF_activity" = "value") %>%
    left_join(filter(factors_values, factor == "Factor1") %>%
        dplyr::rename("Factor_score" = "value"), by = "sample")
f1_tf_data %>% head()


In [None]:
options(repr.plot.width = 5, repr.plot.height = 5)
f1_tf_data_nfat <- f1_tf_data %>% filter(feature %in% c("FOXP3", "NFATC1", "NFATC2", "NFATC3"))
ggpubr::ggscatter(f1_tf_data_nfat,
       x = "Factor_score", y = "TF_activity",
       xlab = "Factor1 score", ylab = "ChromVAR activity",
       color = "status", facet.by = "feature", palette = con_ari_colors,
       add = "reg.line", # Add regressin line
       add.params = list(color = "blue", fill = "lightgray"), # Customize reg. line
       conf.int = TRUE, # Add confidence interval
       cor.coef = TRUE, # Add correlation coefficient. see ?stat_cor
       cor.coeff.args = list(method = "spearman", label.x = -0.8, label.y = 1.5, label.sep = "\n"),
       ggtheme = theme_classic2()
) + scale_fill_manual(values = con_ari_colors) + theme(legend.position = "top")
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_tf_correlation_cohort_nfat_foxp3.pdf")),
       width = 5, height = 5
)


### check RNA

In [None]:
factor1_rna <- weight %>%
    filter(factor == "Factor1" & view == "rna") %>%
    mutate(direction = if_else(value > 0, "healthy", "at-risk")) %>%
    arrange(desc(abs(value)))
factor1_rna_sel <- factor1_rna %>%
    group_by(direction) %>%
    slice_max(order_by = abs(value), n = 50)


In [None]:
p1 <- plot_data_scatter(model,
      factor = 1,
      features = c("rna_STIM1", "rna_STIM2"),
      view = "rna", color_by = "cohort"
)
p1
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_rna_correlation_Ca_channel.png")),
      width = 6, height = 3
)


In [None]:
# check il
il_rna <- factor1_rna %>%
       filter(str_detect(feature, "IL\\d|IFN.R|TGF|CD28|CD69") &
              abs(value) > 0.6 & feature != "rna_PPIL4")
il_rna
p1 <- plot_data_scatter(model,
       factor = 1,
       features = il_rna$feature %>% as.character(),
       view = "rna", color_by = "cohort"
)
p1
# il_rna <- factor1_rna %>%
#     filter(str_detect(feature, 'IL\\d|IFN.R|TGF|CD28|CD69')&
#            abs(value)>0.5)
# p2 <- plot_data_scatter(model, factor=1,
#                   features = il_rna$feature %>% as.character(),
#                         view='rna', color_by='cohort')
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_rna_correlation_cohort_il_receptor.png")),
       width = 8, height = 6
)


In [None]:
# check exhaustion
exh_rna <- factor1_rna %>%
       filter(str_detect(feature, "IL\\d|IFN.R|TGF|CD28|CD69") &
              abs(value) > 0.6 & feature != "rna_PPIL4")
il_rna
p1 <- plot_data_scatter(model,
       factor = 1,
       features = il_rna$feature %>% as.character(),
       view = "rna", color_by = "cohort"
)
p1
# il_rna <- factor1_rna %>%
#     filter(str_detect(feature, 'IL\\d|IFN.R|TGF|CD28|CD69')&
#            abs(value)>0.5)
# p2 <- plot_data_scatter(model, factor=1,
#                   features = il_rna$feature %>% as.character(),
#                         view='rna', color_by='cohort')
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_rna_correlation_cohort_il_receptor.png")),
       width = 8, height = 6
)


In [None]:
# plot ca2+ related gene expression
tcr_gene <- paste0("rna_", c(
       "VAV1", "GRB2", "GRAP2", "NFATC2", "NFAT5", "CBL",
       "CARD11", "LCK", "IKBKB", "CD4", "PIK3R1", "MALT1",
       "RAF1", "SOS1", "NCK2", "CD3G", "LCP2", "LAT", "NCK1", "CD3D", "CD3E", "RASGRP1", "MAP2K1",
       "PTPRC", "NFATC3", "MAP3K7", "PPP3CC", "PPP3CA", "GSK3B"
))
factor1_rna %>%
       filter(str_detect(feature, "PPP3CC|PLCG2|CALM1") | feature %in% tcr_gene) %>%
       filter(abs(value) > 0.6)

p1 <- plot_data_scatter(model,
       factor = 1,
       features = paste0("rna_", c("PPP3CC", "PLCG2", "PPP3CA", "LCK", "CD4", "CD3G", "CD3D")),
       view = "rna", color_by = "cohort"
)
p1
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_rna_correlation_cohort_ca_signaling_pos.png")),
       width = 8, height = 6
)
p2 <- plot_data_scatter(model,
       factor = 1,
       features = paste0("rna_", c("CABIN1", "CSNK1A1")),
       view = "rna", color_by = "cohort"
)
p2
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_rna_correlation_cohort_ca_signaling_neg.png")),
       width = 6, height = 3
)


In [None]:
# check NFAT gene expression
# factor1_rna %>%  filter(str_detect(feature,'NFAT|FOXP3'))

# c('rna_STAT3',  'rna_NFATC3', 'rna_BATF') %in% weight$feature
plot_data_scatter(model,
      factor = 1,
      features = factor1_rna %>% filter(str_detect(feature, "NFAT|FOXP3")) %>% pull(feature) %>% as.character(),
      view = "rna", color_by = "cohort"
)
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_rna_correlation_cohort_NFATs.png")),
      width = 8, height = 4
)


In [None]:
# change the row names for plotting
rownames(model@data$rna$single_group) <- rownames(model@data$rna$single_group) %>% str_remove("rna_")


In [None]:
# plot heatmap results for deg psudobulk
PlotDegHeatmap <- function(pseudo, genes, gene_meta = NULL, genes_hightlight = NULL,
                           assay = "normalized_counts", celltype_col,
                           celltype_colors = cluster_colors, batch_colors = cluster_colors_ext, scale = TRUE) {
    require("ComplexHeatmap")

    # set up the data matrix
    gex_matrix <- assay(pseudo, assay)[genes, ] %>% as.matrix()

    # set up the column annotation
    metadata <- colData(pseudo) %>% as.data.frame()
    # set up column annotation
    col_anno <- rowAnnotation(
        df = metadata %>% dplyr::select(c(status)) %>%
            as.data.frame(),
        col = list(status = c("CON2" = con_ari_colors[1], "ARI" = con_ari_colors[2])),
        annotation_legend_param =
            list(
                status = list(direction = "horizontal") # ,
                # title_gp = gpar(fontsize = 12),
                # label_gp = gpar(fontsize = 12),
                # grid_height = unit(1, "cm"),
                #    legend_height = unit(1, "cm"),
            )
    )
    # set up row annotation - gene labels at the right side
    if (!is.null(genes_hightlight)) {
        if (!all(genes_hightlight %in% genes)) {
            (stop("gene(s) to hightlight are not in the deg list."))
        } else {
            gene_index_tb <- tibble("index" = 1:length(rownames(gex_matrix)), gene = rownames(gex_matrix))
            gene_hightlight_index <- gene_index_tb %>%
                filter(gene %in% gene_hightlight) %>%
                pull(index)
            hightlight_labels <- gene_index_tb %>%
                filter(gene %in% gene_hightlight) %>%
                pull(gene)
            genename_anno <- rowAnnotation(foo = anno_mark(
                at = gene_hightlight_index,
                labels = hightlight_labels, gpar(fontsize = 12)
            ))
        }
    } else {
        (genename_anno <- NULL)
    }
    set.seed(1221)
    if (scale) (gex_matrix <- t(scale(t(gex_matrix))))

    p1 <- ComplexHeatmap::Heatmap(gex_matrix %>% t(),
        left_annotation = col_anno,
        cluster_rows = TRUE,
        col = colorRampPalette(c(nejm_color[2], "white", nejm_color[1]))(100),
        row_names_max_width = unit(10, "cm"),
        #    left_annotation = row_anno,
        right_annotation = genename_anno, #  column_names_rot = 45,
        show_column_names = TRUE, show_row_names = TRUE,
        column_names_gp = gpar(fontsize = 12),
        row_names_gp = gpar(fontsize = 12),
        heatmap_legend_param = list(
            title = "Scaled\nexpression", title_gp = gpar(fontsize = 12),
            legend_height = unit(6, "cm"), direction = "horizontal"
        )
    )
    return(p1)
}


In [None]:
colData(cd4na_rna_psudo)$cohort %>% unique()
colData(cd4na_rna_psudo)$status <- factor(if_else(colData(cd4na_rna_psudo)$cohort == "at_risk", "ARI", "CON2"),
    levels = c("CON2", "ARI")
)
colData(cd4na_rna_psudo)$status %>% unique()


In [None]:
# specify tcr and calsium related genes
tcr_gene <- c(
    "VAV1", "GRB2", "GRAP2", "NFATC2", "NFAT5", "CBL",
    "CARD11", "LCK", "IKBKB", "CD4", "PIK3R1", "MALT1",
    "RAF1", "SOS1", "NCK2", "LCP2", "LAT", "NCK1", "RASGRP1", "MAP2K1",
    "PTPRC", "NFATC3", "MAP3K7", "PPP3CC", "PPP3CA", "GSK3B", "PPP3CC", "PLCG2", "PPP3CA", "CABIN1", "STIM1", "STIM2",
    "CSNK1A1", "LCK", "CD4", "NFATC1",
    "NFATC2IP"
) %>% unique()


In [None]:
# plot the gene heatmap in horizontal
pdf(file.path(fig_path, paste0(proj_name, "_facoter1_rna_tcr_ca_heatmap.pdf")),
    width = 6, height = 4
)
p1 <- PlotDegHeatmap(cd4na_rna_psudo, genes = tcr_gene)
draw(p1, heatmap_legend_side = "top")
dev.off()


In [None]:
options(repr.plot.width = 6, repr.plot.height = 4)
draw(p1, heatmap_legend_side = "top")


### check ADT

In [None]:
# check the Natural cutoff for tf
p1 <- plot_weights(model,
  view = "adt",
  factor = 1,
  nfeatures = 30, # Number of features to highlight
  scale = T, # Scale weights from -1 to 1
  abs = F # Take the absolute value?
)
p1 + geom_vline(xintercept = 0.5)
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_adts.png")),
  width = 5, height = 5
)


In [None]:
factor1_adt <- weight %>%
    filter(factor == "Factor1" & view == "adt") %>%
    mutate(direction = if_else(value > 0, "healthy", "at-risk")) %>%
    arrange(desc(abs(value)))
factor1_adt_sel <- factor1_adt %>%
    group_by(direction) %>%
    slice_max(order_by = abs(value), n = 25)
factor1_adt_sel %>%
    arrange(desc(value)) %>%
    head(20)


In [None]:
# check exhustion
p1 <- plot_data_scatter(model,
      factor = 1,
      features = c("adt_TIGIT", "adt_CD278", "adt_CD279"),
      view = "adt", color_by = "cohort"
)
# p2 <- plot_data_scatter(model, factor=1,
#                   features = c('rna_TIGIT', 'rna_HAVCR2', 'adt_PDCD1'),
#                         view='rna', color_by='cohort')
# cowplot::plot_grid(p1, p2, nrow = 2)
p1
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_adt_correlation_cohort_cd278_cd279_TIGIT.png")),
      width = 9, height = 3
)


In [None]:
plot_data_scatter(model,
       factor = 1, features = c("adt_CX3CR1", "adt_CD64", "adt_Ig-light-chain-k"),
       view = "adt", color_by = "cohort"
)
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_adt_correlation_cohort_CX3CR1_fc.png")),
       width = 8, height = 4
)


In [None]:
p1 <- plot_data_heatmap(model, # max.value = 3,
  view = "olink",
  factor = 1, main = "olink", fontsize_row = 6,
  features = 20, denoise = FALSE,
  cluster_rows = TRUE, cluster_cols = TRUE,
  show_rownames = TRUE, show_colnames = TRUE,
  scale = "row"
)
png(file.path(fig_path, paste0(proj_name, "_facoter1_olink_top20_heatmap.png")),
  units = "in", res = 300, width = 5, height = 5
)
print(p1)
dev.off()


In [None]:
plot_data_scatter(model,
       factor = 1, features = c("olink_IL17D"),
       view = "olink", color_by = "cohort"
)
ggsave(file.path(fig_path, paste0(proj_name, "_facoter1_olink_correlation_cohort_IL17D.pdf")),
       width = 5, height = 4
)
