This scripts aggregate the mixed effect models

### Import


In [None]:
library(rstatix)
library(feather)
library(tidyverse)
library(rjson)
library(lme4)
library(reticulate)
library(ggeffects)
library(broom)
library(glue)
library(progress)
library(ggforce)
library(patchwork)
library(ggpubr)

knitr::opts_chunk$set(
  fig.show = "hold",
  fig.width = 7,
  fig.asp = 0.6
)


### Load data

In [None]:
data_merged <- readRDS('/mnt/datastore/Teris/CurrentBiology_2022/all_lme.rds) #load previous data

load peak and ramp score data

In [None]:
ramp_score <- read_csv('E:/in_vivo_vr/sarah_glm_202006/ramp_score_coeff_export.csv')

In [None]:
ramp_score2merge = ramp_score %>% filter(trial_type=="beaconed", ramp_region=="outbound")
data_merged$neuron = as.numeric(data_merged$neuron)
data_merged %>% inner_join(ramp_score2merge,by=c("session_id"="session_id","neuron"="cluster_id"))


In [None]:

### Calculate normalized firing rate
# getMaxFr <- function(data) {
#   max(data$firingRate)
# }
# 
# # Get max firing rate of each cells
# fr_max <- data_merged %>%
#   filter(trial_length_type=="All") %>%
#   mutate(max_fr = map(data,getMaxFr)) %>%
#   select(session_id, cluster_id, max_fr)
# 
# # Match back to the original 
# 
# data_merged_norm <- data_merged %>%
#   inner_join(fr_max, by=c("session_id","cluster_id")) %>%
#   select(-session_id) %>%
#   unnest(data) 

# saveRDS(data_merged_norm,'E:/in_vivo_vr/sarah_glm_202006/data_merged_norm.rds', compress = FALSE)
#  

In [None]:
data_merged_norm <- data_merged %>%
  select(-session_id) %>% 
  unnest(data)

### Histogram of trial length (unfiltered)

In [None]:
extractTrialLength <- function(data){
  return(data$trial_length[[1]])
}


d <- data_merged_norm %>%
  # filter(lm_result_outbound=='Positive') %>%
  filter(trial_length_type != 'All') %>%
  group_by(session_id, trial_length_type, trial_number) %>%
  nest() %>%
  mutate(trial_length = map_dbl(data, extractTrialLength))

d %>%
  ungroup() %>%
  mutate(trial_length_time = trial_length) %>%
  filter(trial_length_time<100) %>%
ggplot(aes(x=trial_length_time)) +
  geom_histogram(binwidth = 1 ) +
  facet_col(~trial_length_type)
  # theme_classic() +
  # labs(x='Time to reward (s)')

# ggsave('paper_figures/3A.pdf')


Count the number of ramp cells in each session

In [None]:
countCellType <- function(data) {
  # Count the number of different cell types
  data %>% group_by(neuron) %>%
    summarize(lm_result_outbound=first(lm_result_outbound)) %>%
    group_by(lm_result_outbound) %>%
    summarize(n=n())
}

sumRampCol <- function(celltype) {

  celltype[celltype$lm_result_outbound =='Positive',]$n +
    celltype[celltype$lm_result_outbound =='Negative',]$n

}


d <- d %>% mutate(celltype = map(data, countCellType)) %>%
  mutate(ramp_cell_number = map(celltype, sumRampCol))

d$ramp_cell_number <- as.numeric(d$ramp_cell_number)

In [None]:
# arrange the cells
d %>% group_by(session_id) %>%
  summarize(n=n(),rampN = first(ramp_cell_number)) %>% 
  mutate(selIdx = 2*n*rampN/(n+rampN)) %>%
  arrange(desc(n))


Plot trial time and firing-rate-time plot for a particular session

### Figure 3A and 3B
Histogram and firing rate over time

In [None]:

plot_trial_time <- function(session_id) {
  # plot histogram
d %>%
  ungroup() %>%
  filter(session_id == !!session_id ) %>%
  mutate(trial_length_time = trial_length) %>%
  filter(trial_length_time<100) %>%
  ggplot(aes(x=trial_length_time)) +
    geom_histogram(binwidth = 2 ) +
    labs(x='Time to reward (s)') +
    theme_minimal(base_size=16)
}

plot_fr_vs_time <- function(session_id, neurons=NULL){
    
#   plot firing rate- time plot
    sel_session <-  d %>%
    ungroup() %>%
    filter(session_id == !!session_id ) %>%
    filter(trial_length_type != 'Middle')
  
  sel_session %>%
    select(-session_id,-trial_length) %>% #remove duplicated columns
    unnest(data) -> sel_session_expand
    
  if (!is.null(neurons)){
      sel_session_expand <- sel_session_expand %>%
          filter(neuron %in% neurons)
  }
  
  sel_session_expand %>%
    filter(lm_result_outbound == "Positive" | lm_result_outbound == "Negative" ) %>% # only include ramp cell
  ggplot +
    geom_smooth(aes(x=time_relative_outbound, y=firingRate,color=trial_length_type)) +
    facet_wrap(~neuron,scale='free',ncol = 3)
  
  sel_session_expand %>%
    filter(lm_result_outbound == "Positive" | lm_result_outbound == "Negative" ) %>% # only include ramp cell
  ggplot +
    geom_smooth(aes(x=time_relative_outbound, y=firingRate,color=trial_length_type)) +
    facet_wrap(~neuron,scale='free',ncol = 2) +
    labs(x='Time from start of track (s)',y='Firing Rate(Hz)',color='Trial length') +
    theme_minimal(base_size=16) 

}


options(repr.plot.width=4, repr.plot.height=4)
plot_trial_time('M1_D31_2018-11-01_12-28-25')
ggsave('paper_figures/3A.pdf',width=4,height=4)

options(repr.plot.width=6, repr.plot.height=3)
plot_fr_vs_time('M1_D31_2018-11-01_12-28-25',c(7,9))
ggsave('paper_figures/3B.pdf',width=8,height=4)


### Figure 3C
Extrapolate firing rate at reward zone

In [None]:
getMaxFr <- function(data) {
  max(data$firingRate)
}

# Get max firing rate of each cells
fr_max <- data_merged %>%
  filter(trial_length_type=="All") %>%
  ungroup() %>%
  mutate(max_fr = map_dbl(data,getMaxFr)) %>%
  select(session_id, cluster_id, max_fr,-trial_length_type) 

norm_data <- function(data,max_fr){
    data$firingRate <- data$firingRate/max_fr
    return(data)
}

# Match back to the original 

data_merged_norm <- data_merged %>%
  inner_join(fr_max, by=c("session_id","cluster_id")) 

data_merged_norm <- data_merged_norm %>%
    mutate(data_norm = map2(data, max_fr, norm_data))


In [None]:
data_merged_filt <- data_merged_norm %>% filter(session_id=='M1_D31_2018-11-01_12-28-25', neuron==7) 
print(data_merged_filt)

In [None]:
time_model <- function(data){
#     print(names(data))
    tidy(lm(firingRate ~ time_relative_outbound, data=data))
}

predict_reward_firingRate <- function(intercept, slope, trial_time){
    slope*trial_time+intercept
}

find_reward_fringRate <- function(data){
    last(data$firingRate)
}

fitSlopePeak <- function(row){
    row %>% group_by(trial_number) %>%
    nest() %>%
    mutate(time_model = map(data, time_model)) %>%
    mutate(intercept=map_dbl(time_model, ~ .x[[1, 'estimate']])) %>% #get intercept
    mutate(slope = map_dbl(time_model, ~ .x[[2,'estimate']])) %>% #get slope
    mutate(trial_time = map_dbl(data, ~last(.x$time_relative_outbound))) %>%
    mutate(reward_fr = pmap_dbl(list(intercept,slope,trial_time),predict_reward_firingRate)) %>%
    mutate(final_reward_fr = map_dbl(data,find_reward_fringRate)) %>%
    group_by() %>%
    summarize(mean_intercept=mean(intercept), 
              mean_slope = mean(slope), 
              mean_reward_fr = mean(reward_fr),
             final_reward_fr = mean(final_reward_fr))
}

x2 <- data_merged_filt[1,]$data_norm[[1]]
fitSlopePeak(x2)

In [None]:
# get the predicted firing rate at the reward
data_merged_reward <- data_merged_norm %>% mutate(reward_fr_data = map(data_norm,fitSlopePeak)) %>%
    unnest_wider(reward_fr_data)

In [None]:
# Simplify data for analysis
data_merged_reward_sel <- data_merged_reward %>%
    select(neuron, session_id, trial_length_type, lm_result_outbound, final_reward_fr, starts_with('mean')) %>%
    mutate(lm_result_outbound = recode(lm_result_outbound, 'NoSlope'='Unclassified','None'='Unclassified'))

saveRDS(data_merged_reward_sel,'E:/in_vivo_vr/sarah_glm_202006/data_merged_reward_sel.rds', compress = FALSE)
print(data_merged_reward_sel,n=3,width=300)

In [None]:
data2plot <- data_merged_reward_sel %>%
    filter(trial_length_type %in% c('Long','Short')) 
data2plot$trial_length_type <- factor(data2plot$trial_length_type, levels=c('Long','Short'))
data2plot <- data2plot %>% mutate(cell_id = glue("{session_id}_{neuron}"))

In [None]:
options(repr.plot.width=12, repr.plot.height=8)

comp = list(c('Long','Short'))

p1 <- ggboxplot(data2plot, y='final_reward_fr',
          x='trial_length_type',facet.by='lm_result_outbound', id = 'cell_id',
          nrow=1, scales='free', xlab='Trial length', ylab='Normalized Reward firing rate',
              fill = "trial_length_type") +
    stat_compare_means(comparisons=comp,label = "p.signif", method='wilcox.test') +
    theme_minimal(base_size=16)

p2 <- ggboxplot(data2plot, y='mean_slope',
               fill = "trial_length_type",
          x='trial_length_type',facet.by='lm_result_outbound',
               nrow=1,scales='free',xlab='Trial length', 
               ylab='Slope',id='cell_id') +
    stat_compare_means(comparisons=comp,label = "p.signif", paired= TRUE,method='wilcox.test') +
        theme_minimal(base_size=16)

p1 / p2
ggsave('paper_figures/3C.pdf',width=12,height=8)
# ggboxplot(data2plot, y='mean_slope',color='trial_length_type',x='lm_result_outbound')