In [2]:
source("utils/plot.R")

In [3]:
col_spec = cols(
    actual_frequency_at_selection=col_number(),
    sgv_selection_generation=col_number(),
    num_starting_lineages = col_number(),
    num_surviving_lineages = col_number(),
    frequency_at_selection = col_number(),
    swept_mutations = col_character(),
    swept_frequencies = col_character(),
    adaptive_mutation_rate = col_number(),
    selection_region_size = col_number()
)

sim_data <- read_tsv(snakemake@input[["sim_params"]], col_types=col_spec) %>%
    filter(regime != "neutral")

### Sweep timings

In [4]:
predicted_sel <- as.numeric(snakemake@params[["predicted_s"]])
predsel_min <- predicted_sel - predicted_sel*0.1
predsel_max <- predicted_sel + predicted_sel*0.1
predicted_f0 <- as.numeric(snakemake@params[["predicted_f0"]])

In [5]:
timings <- sim_data %>%
    mutate(fixtime = slim_generations - sgv_selection_generation) %>%
    filter(sweep_mode == "sgv (true)", selection_coefficient > predsel_min, selection_coefficient < predsel_max, actual_frequency_at_selection > predicted_f0/2) %>%
    select(fixtime, selection_coefficient)



In [6]:
cat("min 25% 50% mean 75% max", summary(timings$fixtime), file=snakemake@output[["timings_table"]], sep='\n')

In [7]:
sweeps <- sim_data %>%
    filter(sweep_mode %in% c("hard", "sgv (true)", "rnm (true)")) %>%
    select(log_selection_coefficient, slim_generations, sweep_mode, sgv_selection_generation) %>%
    mutate(
        fixtime=slim_generations - sgv_selection_generation,
        fixtime=coalesce(fixtime, slim_generations),
        sweep_mode = str_replace_all(sweep_mode, c("hard"="Hard sweep", "rnm \\(true\\)"="Recurrent mutation", "sgv \\(true\\)"="Standing variation")),
        selbin=cut(log_selection_coefficient, breaks=20, labels=FALSE)
    ) %>%
    group_by(sweep_mode, selbin) %>%
    summarize(mean_fixtime=mean(fixtime), max_fixtime=max(fixtime), min_fixtime=min(fixtime), selmean=mean(log_selection_coefficient))

In [8]:
num_gens <- summary(timings$fixtime)["Median"]

timing_fig <- ggplot(sweeps) +
    geom_vline(aes(xintercept=log10(predicted_sel)), linetype='dashed', colour='darkgrey') +
    geom_hline(aes(yintercept=num_gens), linetype='dashed', colour='darkgrey') +
    geom_text(aes(x=-2, y=num_gens), label=paste0('t=', num_gens, ' generations'), hjust=0, vjust=0, colour='darkgrey', nudge_y=0.1, nudge_x=0.1, fontface='italic') +
    geom_text(aes(x=log10(predicted_sel), y=1000), label=paste0('s=', predicted_sel), hjust=1, vjust=0, colour='darkgrey', fontface='italic', nudge_x=-0.05, nudge_y=0.1) +
    geom_ribbon(aes(x=selmean, ymin=min_fixtime, ymax=max_fixtime, fill=sweep_mode), alpha=0.2) +
    geom_line(aes(x=selmean, y=mean_fixtime, colour=sweep_mode), size=0.5) +
    geom_point(aes(x=log10(predicted_sel), y=num_gens), colour="#7570b3", size=2) +
    scale_x_continuous(label=function(x) round(10^x, 2), n.breaks=3) +
    scale_y_log10() +
    labs(
        x='Sel. coefficient',
        y='Generations to 80% frequency'
    ) +
    turkana_colour +
    turkana_fill +
    turkana_theme +
    theme(legend.title=element_blank())

turkana_save(snakemake@output[["timing"]], timing_fig, asp=1.618)

### Neural network learning curves

In [9]:
read_learning_curves <- function(filename) {
    result <- read_tsv(filename, col_types=cols()) %>%
    mutate(filename=basename(filename)) %>%
    separate('filename', sep='_', into=c('target', 'dataset', NA, 'replicate'))
    return(result)
}

In [10]:
learning0 <- bind_rows(lapply(list.files('output/model-fitting/', full.names=TRUE), read_learning_curves)) %>%
    filter(dataset == "codominant", target %in% c('log-sel-strength', 'sweep-mode', 'sgv-f0', 'sweep-age')) %>%
    pivot_longer(c('train_loss', 'valid_loss'), names_to='loss_name', values_to='loss') %>%
    mutate(
        target=target_factor(target),
        loss_name=str_replace_all(loss_name, c('valid_loss'='Validation loss', 'train_loss'='Training loss'))
    )

In [11]:
learning <- learning0 %>%
    group_by(target, epoch, loss_name) %>%
    summarize(loss=mean(loss))

In [12]:
learning_fig <- ggplot(filter(learning, epoch > 4)) +
    geom_line(aes(x=epoch, y=loss, colour=loss_name)) +
    facet_wrap(vars(target), scales='free_y') +
    turkana_colour +
    labs(
        x='Epoch',
        y='Loss'
    ) +
    turkana_theme +
    theme(
        legend.position='top',
        legend.title=element_blank(),
        legend.justification=c(0, 0),
        legend.box.spacing=unit(0.1, "cm"),
        panel.spacing=unit(0.2, "cm"),
        axis.title.y=element_text(hjust=1)
    )

turkana_save(snakemake@output[["learning_curves"]], learning_fig, asp=1.618)

### Selection strength validation

In [13]:
selstrength <- read_tsv(snakemake@input[["selstrength"]], col_types=cols()) %>%
    inner_join(sim_data, on="uuid") %>%
    mutate(sweep_mode=sweepmode_factor(sweep_mode))

In [14]:
selstrength_fig <- ggplot(selstrength) +
    geom_point(aes(x=true_log_selection_coefficient, y=predicted_log_selection_coefficient, colour=sweep_mode), size=0.5) +
    geom_abline(linetype='dashed') +
    facet_wrap(vars(sweep_mode)) +
    scale_x_continuous(labels=function(x){round(10**x, 2)}, n.breaks=3) +
    scale_y_continuous(labels=function(x){round(10**x, 2)}, n.breaks=3) +
    labs(
        x = "True sel. coefficient",
        y = "Predicted sel. coefficient"
    ) +
    turkana_colour +
    turkana_theme +
    theme(
        legend.position="none",
        axis.text.x=element_text(angle=45, hjust=1)
    )

turkana_save(snakemake@output[["selstrength"]], selstrength_fig, asp=2)

### Sweep mode validation

In [15]:
sweepmode_confmat <- read_tsv(snakemake@input[["sweepmode"]], col_types=cols()) %>%
    select(true_label, predicted_label) %>%
    table %>%
    as_tibble %>%
    mutate(
        true_label=sweepmode_factor_short(true_label),
        predicted_label=sweepmode_factor_short(predicted_label)
    ) %>%
    group_by(predicted_label) %>%
    mutate(
        percent=n/sum(n),
        percent_label=paste0(round(percent*100, 2), '%')
    )

In [16]:
confmat_fig <- ggplot(sweepmode_confmat) +
    geom_tile(aes(x=true_label, y=predicted_label, fill=percent)) +
    geom_text(aes(x=true_label, y=predicted_label, label=percent_label, colour=percent<0.5)) +
    scale_colour_manual(values=c('white', 'black')) +
    scale_y_discrete(limits=rev) +
    scale_fill_distiller(palette=3, direction=1) +
    turkana_theme +
    labs(x='True', y='Predicted') +
    theme(
        legend.position='none',
        panel.grid=element_blank(),
        panel.spacing=unit(0.3, "in")
    )

turkana_save(snakemake@output[["sweepmode"]], confmat_fig, width=3, asp=1)