# PCA Analysis Suite

Consolidated analysis notebook for *Decoding the physicochemical basis of taxonomy preferences in protein design models*.

## Analyses Covered

| Section | Paper Reference | Description |
|---------|----------------|-------------|
| **WT PCA** | Figure 3A-B | PCA on 4 feature sets (mixed, sequence, structure, pH) |
| **GAM Landscapes** | Figure 3C | Model preference surfaces in PC space |
| **Cosine Similarity** | Methods §Cosine Similarity | Likelihood vectors vs feature loading vectors |
| **Fine-tuning Validation** | Results §Fine-tuning | Vanilla vs AlkalineMPNN pH alignment |
| **Design Shift PCA** | Figure 4 | Designed sequences projected into WT PC space |
| **Quantitative Shifts** | Supplementary Table S12 | Centroid shifts, back-projection, model cosine similarity |

## Requirements
R packages: tidyverse, ggrepel, mgcv, patchwork, rlang, MASS, viridis

In [None]:
# Cell 2: Setup & Packages
packages <- c("tidyverse", "ggrepel", "mgcv", "patchwork", "rlang", "MASS", "viridis")
new_pkgs <- packages[!(packages %in% installed.packages()[, "Package"])]
if (length(new_pkgs)) install.packages(new_pkgs, quiet = TRUE)

suppressPackageStartupMessages({
  library(MASS)       # Load before tidyverse so dplyr::select is not masked
  library(tidyverse)
  library(ggrepel)
  library(mgcv)
  library(patchwork)
  library(rlang)
  library(viridis)
})

cat("Setup complete\n")

In [None]:
# Cell 3: Configuration — EDIT PATHS HERE

# Wild-type protein dataset (7843 proteins with model scores and features)
WT_CSV <- "/Users/lauradillon/PycharmProjects/inverse_fold/Cleaned_research_flow/0_Main_data/Final_Paper_Folder/Decoding_Bias_Dataset.csv"

# Design data: individual per-model CSVs with shift columns
DESIGN_DIR <- "design_data"  # directory containing per-model design CSVs
DESIGN_FILES <- list(
  "ProteinMPNN" = file.path(DESIGN_DIR, "ProteinMPNN_designs_with_WT_shifts.csv"),
  "ESM-IF"      = file.path(DESIGN_DIR, "ESMIF_designs_with_WT_shifts.csv"),
  "MIF"         = file.path(DESIGN_DIR, "MIF_designs_with_WT_shifts.csv"),
  "MIF-ST"      = file.path(DESIGN_DIR, "MIFST_designs_with_WT_shifts.csv"),
  "AlkalineMPNN"= file.path(DESIGN_DIR, "Alkaline_designs_with_WT_shifts.csv")
)

OUTPUT_DIR <- "pca_analysis_results"
dir.create(OUTPUT_DIR, showWarnings = FALSE, recursive = TRUE)

# Color palettes
DOMAIN_COLORS <- c(
  "Bacteria"  = "#1f77b4",
  "Eukaryota" = "#2ca02c",
  "Archaea"   = "#eb0920"
)

MODEL_COLORS <- c(
  "ProteinMPNN"  = "#1F77B4",
  "AlkalineMPNN" = "#E15759",
  "ESM-IF"       = "#FF7F0E",
  "MIF"          = "#9467BD",
  "MIF-ST"       = "#2CA02C"
)

# Model score columns in the WT dataset
SCORE_COLUMNS <- c(
  "proteinmpnn_score", "esmif_score", "mif_score",
  "mifst_score", "ESM2_15B_pppl_score", "carp_640M_score",
  "AlkalineMPNN_score"
)

cat("Configuration set.\n")
cat("Output directory:", OUTPUT_DIR, "\n")

In [None]:
# Cell 4: Helper Functions

# Invert structural features so higher = more compact/centralized
create_inverse_structural_features <- function(df) {
  if ("compactness" %in% names(df)) {
    df$radius_of_gyration <- df$compactness
    df$structural_compactness <- 1 / df$compactness
  }
  if ("avg_cb_distance" %in% names(df)) {
    df$avg_cb_distance_from_centroid <- df$avg_cb_distance
    df$centralization <- 1 / df$avg_cb_distance
  }
  df
}

# Unit vector (for cosine similarity)
.unit_vec <- function(v) {
  v <- as.numeric(v)
  n <- sqrt(sum(v^2))
  if (!is.finite(n) || n < 1e-12) stop("Zero-length vector")
  v / n
}

# Cosine similarity between two vectors
.cosine <- function(a, b) sum(.unit_vec(a) * .unit_vec(b))

# Vectorized cosine similarity (handles NAs)
.cosine2 <- function(dx1, dy1, dx2, dy2, eps = 1e-12) {
  n1 <- sqrt(dx1^2 + dy1^2)
  n2 <- sqrt(dx2^2 + dy2^2)
  ok <- (n1 > eps) & (n2 > eps)
  out <- rep(NA_real_, length(dx1))
  out[ok] <- (dx1[ok]*dx2[ok] + dy1[ok]*dy2[ok]) / (n1[ok]*n2[ok])
  pmin(1, pmax(-1, out))
}

# Safe numeric coercion
.as_num <- function(x) suppressWarnings(as.numeric(x))

# PC axis labels with variance %
pc_axis_labels <- function(explained_variance, pc_x = "PC1", pc_y = "PC2") {
  ix <- as.integer(gsub("PC", "", pc_x))
  iy <- as.integer(gsub("PC", "", pc_y))
  list(
    x = sprintf("%s (%.1f%%)", pc_x, 100 * explained_variance[ix]),
    y = sprintf("%s (%.1f%%)", pc_y, 100 * explained_variance[iy])
  )
}

# Symmetric axis limits for PC plots (centered on 0)
pc_symmetric_limits <- function(data, pc_x = "PC1", pc_y = "PC2", expand = 0.05) {
  xr <- range(data[[pc_x]], na.rm = TRUE)
  yr <- range(data[[pc_y]], na.rm = TRUE)
  mx <- max(abs(xr)) * (1 + expand)
  my <- max(abs(yr)) * (1 + expand)
  list(x = c(-mx, mx), y = c(-my, my))
}

# Bounded aspect ratio from variance explained
calculate_bounded_aspect <- function(explained_variance, pc_x = 1, pc_y = 2,
                                      lo = 0.4, hi = 2.5) {
  ratio <- explained_variance[pc_y] / explained_variance[pc_x]
  max(lo, min(hi, ratio))
}

cat("Helper functions loaded.\n")

In [None]:
# Cell 5: Feature Set Definitions

sequence_features <- c(
  "mw_per_residue", "isoelectric_point", "instability_index",
  "gravy", "sequence_length", "aromaticity"
)

structure_features <- c(
  "helix_sheet_contrast", "ordered_percent", "rco",
  "structural_compactness", "centralization"
)

mixed_features <- c(sequence_features, structure_features)

pH_features <- c(
  "buffer_capacity", "charge_per_residue",
  "acidic_residue_fraction", "basic_residue_fraction",
  "ionizable_residue_fraction"
)

cat("Feature sets defined:\n")
cat("  Sequence:", length(sequence_features), "features\n")
cat("  Structure:", length(structure_features), "features\n")
cat("  Mixed:", length(mixed_features), "features\n")
cat("  pH:", length(pH_features), "features\n")

In [None]:
# Cell 6: Data Loading & Preprocessing

cat("Loading data from:", WT_CSV, "\n")
df <- read.csv(WT_CSV, stringsAsFactors = FALSE)
cat("Loaded", nrow(df), "proteins with", ncol(df), "columns\n")

# Create inverse structural features
df <- create_inverse_structural_features(df)

# Create derived structure columns if not present
if (all(c("helix_percent", "sheet_percent") %in% names(df))) {
  if (!"helix_sheet_contrast" %in% names(df)) {
    df$helix_sheet_contrast <- df$helix_percent - df$sheet_percent
  }
  if (!"ordered_percent" %in% names(df)) {
    df$ordered_percent <- df$helix_percent + df$sheet_percent
  }
}

# Verify score columns
available_scores <- SCORE_COLUMNS[SCORE_COLUMNS %in% names(df)]
cat("Available model score columns:", length(available_scores), "\n")
cat(" ", paste(available_scores, collapse = ", "), "\n")

# Verify feature columns
for (fs_name in c("mixed_features", "pH_features")) {
  fs <- get(fs_name)
  present <- fs[fs %in% names(df)]
  missing <- fs[!fs %in% names(df)]
  cat(sprintf("%s: %d/%d present", fs_name, length(present), length(fs)))
  if (length(missing)) cat(" (missing:", paste(missing, collapse = ", "), ")")
  cat("\n")
}

In [None]:
# Cell 7: Core PCA Function & Feature Contribution

perform_focused_pca <- function(df, features, anchor_var = c(PC1 = "sequence_length")) {
  cat("\n=== PCA ANALYSIS ===", "\n")
  cat("Using", length(features), "features\n")

  # Required metadata columns
  score_cols <- SCORE_COLUMNS[SCORE_COLUMNS %in% colnames(df)]
  meta_cols <- c("Entry", "domain", "species", "protein_name", "avg_plddt")
  meta_cols <- meta_cols[meta_cols %in% colnames(df)]

  # Prepare data
  keep_cols <- unique(c(features, score_cols, meta_cols))
  pca_data <- df[, intersect(keep_cols, names(df)), drop = FALSE]

  # Remove rows with NA/Inf in feature columns
  for (col in features) {
    if (col %in% names(pca_data)) {
      pca_data[[col]][is.infinite(pca_data[[col]])] <- NA
    }
  }
  initial_rows <- nrow(pca_data)
  pca_data <- pca_data[complete.cases(pca_data[, features, drop = FALSE]), ]
  cat("Clean data:", nrow(pca_data), "samples (removed", initial_rows - nrow(pca_data), "rows)\n")

  # Standardize and run PCA
  feature_matrix <- as.matrix(pca_data[, features, drop = FALSE])
  center_vals <- colMeans(feature_matrix)
  scale_vals <- apply(feature_matrix, 2, sd)
  scaled_matrix <- scale(feature_matrix, center = center_vals, scale = scale_vals)

  pca_result <- prcomp(scaled_matrix, center = FALSE, scale. = FALSE)

  # Explained variance
  var_explained <- pca_result$sdev^2 / sum(pca_result$sdev^2)
  cat(sprintf("Variance explained (PC1+PC2): %.3f\n", sum(var_explained[1:2])))

  # Sign-flip anchoring: ensure anchor variable loads positively on its PC
  for (pc_name in names(anchor_var)) {
    pc_idx <- as.integer(gsub("PC", "", pc_name))
    var_name <- anchor_var[[pc_name]]
    if (var_name %in% features && var_name %in% rownames(pca_result$rotation)) {
      if (pca_result$rotation[var_name, pc_idx] < 0) {
        pca_result$rotation[, pc_idx] <- -pca_result$rotation[, pc_idx]
        pca_result$x[, pc_idx] <- -pca_result$x[, pc_idx]
      }
    }
  }

  # Assemble PC scores with metadata
  pc_scores <- as.data.frame(pca_result$x[, 1:min(ncol(pca_result$x), 10)])
  result_data <- bind_cols(pca_data, pc_scores)

  # Loadings table
  loadings_df <- data.frame(
    Feature = features,
    PC1 = pca_result$rotation[features, 1],
    PC2 = pca_result$rotation[features, 2],
    stringsAsFactors = FALSE
  )
  if (ncol(pca_result$rotation) >= 3) {
    loadings_df$PC3 <- pca_result$rotation[features, 3]
  }

  list(
    data = result_data,
    loadings = loadings_df,
    explained_variance = var_explained,
    features = features,
    center = center_vals,
    scale = scale_vals,
    pca_object = pca_result
  )
}

# Feature contribution to PC variance (Methods equation)
calculate_feature_contributions <- function(pca_results, n_pcs = 2) {
  loadings <- pca_results$pca_object$rotation
  var_exp <- pca_results$explained_variance
  features <- pca_results$features

  contributions <- sapply(features, function(f) {
    sum(loadings[f, 1:n_pcs]^2 * var_exp[1:n_pcs])
  })

  tibble(
    feature = features,
    contribution = contributions,
    contribution_pct = 100 * contributions / sum(contributions)
  ) %>% arrange(desc(contribution))
}

cat("PCA functions loaded.\n")

In [None]:
# Cell 8: Plotting Functions

# Figure 3A: Domain-colored scatter in PC space
create_protein_scatter <- function(pca_results, pc_x = "PC1", pc_y = "PC2",
                                   show_density_contours = FALSE) {
  labs_xy <- pc_axis_labels(pca_results$explained_variance, pc_x, pc_y)
  ix <- as.integer(gsub("PC", "", pc_x))
  iy <- as.integer(gsub("PC", "", pc_y))
  ratio_yx <- pca_results$explained_variance[iy] / pca_results$explained_variance[ix]

  p <- ggplot(pca_results$data, aes(x = .data[[pc_x]], y = .data[[pc_y]],
                                     color = domain)) +
    geom_point(alpha = 0.5, size = 1.2) +
    scale_color_manual(values = DOMAIN_COLORS, name = "Domain") +
    labs(x = labs_xy$x, y = labs_xy$y,
         title = "Proteins in biophysical PC space") +
    theme_minimal(base_size = 12) +
    theme(legend.position = "bottom",
          panel.border = element_rect(color = "grey60", fill = NA),
          axis.title = element_text(face = "bold")) +
    coord_fixed(ratio = ratio_yx) +
    geom_hline(yintercept = 0, linetype = "dashed", color = "grey70") +
    geom_vline(xintercept = 0, linetype = "dashed", color = "grey70")

  if (show_density_contours) {
    p <- p + geom_density_2d(color = "grey40", alpha = 0.3)
  }
  p
}

# Figure 3B: Feature loading biplot
create_feature_biplot <- function(pca_results, pc_x = "PC1", pc_y = "PC2",
                                   top_n_features = 8) {
  labs_xy <- pc_axis_labels(pca_results$explained_variance, pc_x, pc_y)
  loadings <- pca_results$loadings

  # Scale arrows for visibility
  arrow_scale <- max(abs(c(loadings[[pc_x]], loadings[[pc_y]]))) * 1.1
  loadings$x_scaled <- loadings[[pc_x]] / arrow_scale
  loadings$y_scaled <- loadings[[pc_y]] / arrow_scale
  loadings$magnitude <- sqrt(loadings$x_scaled^2 + loadings$y_scaled^2)

  top_loadings <- loadings %>% arrange(desc(magnitude)) %>% slice_head(n = top_n_features)

  # Pretty feature names
  pretty_name <- function(x) {
    x <- gsub("_", " ", x)
    x <- gsub("mw per residue", "MW/residue", x)
    x <- gsub("structural compactness", "compactness (1/Rg)", x)
    x <- gsub("centralization", "centralization (1/dCB)", x)
    tools::toTitleCase(x)
  }
  top_loadings$label <- sapply(top_loadings$Feature, pretty_name)

  ggplot() +
    geom_segment(data = top_loadings,
                 aes(x = 0, y = 0, xend = x_scaled, yend = y_scaled),
                 arrow = arrow(length = unit(0.15, "inches"), type = "closed"),
                 color = "steelblue", linewidth = 0.8, alpha = 0.8) +
    geom_text_repel(data = top_loadings,
                    aes(x = x_scaled, y = y_scaled, label = label),
                    size = 3.5, fontface = "bold", color = "grey20",
                    box.padding = 0.4, max.overlaps = 20) +
    labs(x = labs_xy$x, y = labs_xy$y, title = "Feature loadings") +
    theme_minimal(base_size = 12) +
    theme(panel.border = element_rect(color = "grey60", fill = NA),
          axis.title = element_text(face = "bold")) +
    coord_fixed() +
    geom_hline(yintercept = 0, linetype = "dashed", color = "grey70") +
    geom_vline(xintercept = 0, linetype = "dashed", color = "grey70")
}

# Scree plot
create_variance_plot <- function(pca_results, n_components = 6) {
  n <- min(n_components, length(pca_results$explained_variance))
  var_df <- tibble(
    PC = factor(paste0("PC", 1:n), levels = paste0("PC", 1:n)),
    Variance = 100 * pca_results$explained_variance[1:n],
    Cumulative = cumsum(100 * pca_results$explained_variance[1:n])
  )

  ggplot(var_df, aes(x = PC, y = Variance)) +
    geom_col(fill = "steelblue", alpha = 0.7) +
    geom_text(aes(label = sprintf("%.1f%%", Variance)), vjust = -0.3, size = 3.5) +
    geom_line(aes(y = Cumulative, group = 1), color = "red", linewidth = 0.8) +
    geom_point(aes(y = Cumulative), color = "red", size = 2) +
    labs(x = "Principal Component", y = "Variance Explained (%)",
         title = "Variance explained by principal components") +
    theme_minimal(base_size = 12)
}

# Figure 3C: GAM preference surface (matches original CLEAN_PCA-2 styling)
create_gam_contour_plot <- function(pca_results, score_col, score_label = NULL,
                                     pc_x = "PC1", pc_y = "PC2",
                                     grid_size = 120, n_contours = 10, k_basis = NULL) {
  if (is.null(score_label)) score_label <- score_col
  labs_xy <- pc_axis_labels(pca_results$explained_variance, pc_x, pc_y)

  data_clean <- pca_results$data %>%
    dplyr::select(all_of(c(pc_x, pc_y, score_col))) %>%
    drop_na()

  N <- nrow(data_clean)
  if (is.null(k_basis)) k_basis <- max(12, min(50, floor(N / 300)))

  form <- as.formula(paste0("`", score_col, "` ~ s(", pc_x, ", ", pc_y, ", k = ", k_basis, ")"))
  gam_fit <- gam(form, data = data_clean, method = "REML")

  # Symmetric axis limits centered on 0
  lims <- pc_symmetric_limits(data_clean, pc_x, pc_y, expand = 0.05)

  grid <- expand.grid(
    V1 = seq(lims$x[1], lims$x[2], length.out = grid_size),
    V2 = seq(lims$y[1], lims$y[2], length.out = grid_size)
  )
  names(grid) <- c(pc_x, pc_y)
  grid$pred <- predict(gam_fit, newdata = grid)

  # Bounded aspect ratio from variance explained
  aspect_ratio <- calculate_bounded_aspect(pca_results$explained_variance)

  ggplot(grid, aes(x = .data[[pc_x]], y = .data[[pc_y]], z = pred)) +
    geom_raster(aes(fill = pred), interpolate = TRUE) +
    geom_contour(color = "white", alpha = 0.5, linewidth = 0.3, bins = n_contours) +
    scale_fill_viridis_c(name = score_label, option = "viridis") +
    labs(title = NULL, x = labs_xy$x, y = labs_xy$y) +
    coord_cartesian(xlim = lims$x, ylim = lims$y, expand = FALSE) +
    theme_minimal(base_size = 20) +
    theme(
      panel.grid = element_blank(),
      axis.title = element_text(face = "bold", size = 20),
      axis.text = element_text(size = 20),
      aspect.ratio = 1 / aspect_ratio
    )
}

cat("Plotting functions loaded.\n")

In [None]:
# Cell 9: Run WT PCA (all 4 feature sets)

feature_sets <- list(
  "Mixed Features"     = mixed_features,
  "Sequence Features"  = sequence_features,
  "Structure Features" = structure_features,
  "pH Features"        = pH_features
)

all_results <- list()

for (set_name in names(feature_sets)) {
  cat("\n=== ANALYZING:", set_name, "===\n")

  features <- feature_sets[[set_name]]
  available <- intersect(features, colnames(df))

  if (length(available) < 3) {
    cat("Skipping", set_name, ": insufficient features\n")
    next
  }

  pca_results <- perform_focused_pca(df, available)

  # Feature contributions
  contributions <- calculate_feature_contributions(pca_results)
  cat("\nFeature contributions to total PC variance:\n")
  print(contributions, n = nrow(contributions))

  # Create plots
  p_scatter <- create_protein_scatter(pca_results)
  p_biplot <- create_feature_biplot(pca_results, top_n_features = min(8, length(available)))
  p_variance <- create_variance_plot(pca_results)

  # Store
  all_results[[set_name]] <- list(
    pca_results = pca_results,
    contributions = contributions,
    plots = list(scatter = p_scatter, biplot = p_biplot, variance = p_variance)
  )

  # Display
  print(p_scatter + p_biplot + plot_layout(widths = c(1.2, 1)))
  print(p_variance)

  # Save
  prefix <- gsub(" ", "_", set_name)
  ggsave(file.path(OUTPUT_DIR, paste0(prefix, "_scatter.png")), p_scatter,
         width = 8, height = 7, dpi = 300, bg = "white")
  ggsave(file.path(OUTPUT_DIR, paste0(prefix, "_biplot.png")), p_biplot,
         width = 7, height = 6, dpi = 300, bg = "white")

  # Save loadings table
  write_csv(pca_results$loadings,
            file.path(OUTPUT_DIR, paste0(prefix, "_loadings.csv")))
  write_csv(contributions,
            file.path(OUTPUT_DIR, paste0(prefix, "_contributions.csv")))
}

cat("\n=== WT PCA COMPLETE ===")
cat("\nAnalyses completed:", paste(names(all_results), collapse = ", "), "\n")

In [None]:
# Cell 10: GAM Preference Landscapes (Figure 3C)

# Score display names
score_labels <- c(
  proteinmpnn_score    = "ProteinMPNN",
  esmif_score          = "ESM-IF",
  mif_score            = "MIF",
  mifst_score          = "MIF-ST",
  ESM2_15B_pppl_score  = "ESM2-15B",
  carp_640M_score      = "CARP-640M",
  AlkalineMPNN_score   = "AlkalineMPNN"
)

# Generate GAM surfaces for each feature set
for (set_name in names(all_results)) {
  cat("\n=== GAM LANDSCAPES FOR:", set_name, "===\n")
  pca_res <- all_results[[set_name]]$pca_results
  scores_present <- intersect(available_scores, names(pca_res$data))

  gam_plots <- list()
  for (sc in scores_present) {
    label <- ifelse(sc %in% names(score_labels), score_labels[sc], sc)
    cat("  Fitting GAM for:", label, "\n")
    p <- create_gam_contour_plot(pca_res, sc, label)
    gam_plots[[sc]] <- p
  }

  # Display as grid
  if (length(gam_plots) > 0) {
    combined <- wrap_plots(gam_plots, ncol = min(3, length(gam_plots))) +
      plot_annotation(title = paste("GAM Preference Landscapes -", set_name),
                      theme = theme(plot.title = element_text(size = 14, face = "bold")))
    print(combined)

    prefix <- gsub(" ", "_", set_name)
    ggsave(file.path(OUTPUT_DIR, paste0(prefix, "_GAM_landscapes.png")), combined,
           width = 16, height = max(5, 5 * ceiling(length(gam_plots) / 3)),
           dpi = 300, bg = "white", limitsize = FALSE)
  }

  all_results[[set_name]]$gam_plots <- gam_plots
}

cat("\nGAM landscape generation complete.\n")

In [None]:
# Cell 11: Cosine Similarity — Likelihood Vectors vs Feature Directions

# Find GAM argmax preference vector for each model score
calculate_likelihood_vectors_gam <- function(pca_results, model_scores,
                                              grid_size = 60, k_basis = 12) {
  data_clean <- pca_results$data %>%
    dplyr::select(PC1, PC2, any_of(model_scores)) %>%
    drop_na()

  x_range <- range(data_clean$PC1)
  y_range <- range(data_clean$PC2)
  grid <- expand.grid(
    PC1 = seq(x_range[1], x_range[2], length.out = grid_size),
    PC2 = seq(y_range[1], y_range[2], length.out = grid_size)
  )

  results <- list()
  for (sc in intersect(model_scores, names(data_clean))) {
    form <- as.formula(paste0("`", sc, "` ~ s(PC1, PC2, k = ", k_basis, ")"))
    mod <- gam(form, data = data_clean, method = "REML")
    pred <- predict(mod, newdata = grid)
    best_idx <- which.max(pred)
    best_point <- unlist(grid[best_idx, c("PC1", "PC2")])
    best_vec <- .unit_vec(best_point)

    results[[sc]] <- list(
      model = mod,
      best_point = best_point,
      best_vec = best_vec,
      grid = grid,
      predictions = pred,
      Xp = predict(mod, newdata = grid, type = "lpmatrix"),
      beta = coef(mod),
      Vb = vcov(mod)
    )
    cat(sprintf("  %s: argmax at (%.3f, %.3f), unit vec (%.3f, %.3f)\n",
                sc, best_point[1], best_point[2], best_vec[1], best_vec[2]))
  }
  results
}

# Composite feature loading vector (e.g., pH features)
calculate_feature_directions <- function(pca_results, feature_groups) {
  loadings <- pca_results$loadings
  directions <- list()

  for (group_name in names(feature_groups)) {
    feats <- intersect(feature_groups[[group_name]], loadings$Feature)
    if (length(feats) == 0) next

    sub_load <- loadings %>% filter(Feature %in% feats)
    avg_vec <- c(PC1 = mean(sub_load$PC1), PC2 = mean(sub_load$PC2))
    directions[[group_name]] <- .unit_vec(avg_vec)
    cat(sprintf("  %s direction: PC1=%.3f, PC2=%.3f\n",
                group_name, directions[[group_name]][1], directions[[group_name]][2]))
  }
  directions
}

# Cosine similarity matrix
compare_likelihood_to_features <- function(likelihood_results, feature_directions) {
  results <- expand.grid(
    model = names(likelihood_results),
    feature_group = names(feature_directions),
    stringsAsFactors = FALSE
  )
  results$cosine <- mapply(function(m, f) {
    .cosine(likelihood_results[[m]]$best_vec, feature_directions[[f]])
  }, results$model, results$feature_group)

  results
}

# --- Run for each feature set ---
feature_group_defs <- list(
  pH = pH_features,
  compactness = c("rco", "structural_compactness", "centralization"),
  stability = c("instability_index", "gravy"),
  size = c("sequence_length", "mw_per_residue"),
  aromaticity = c("aromaticity"),
  secondary_structure = c("helix_sheet_contrast", "ordered_percent")
)

for (set_name in names(all_results)) {
  cat("\n=== COSINE SIMILARITY ANALYSIS:", set_name, "===\n")
  pca_res <- all_results[[set_name]]$pca_results
  scores_present <- intersect(available_scores, names(pca_res$data))

  cat("Computing likelihood vectors...\n")
  lik_vecs <- calculate_likelihood_vectors_gam(pca_res, scores_present)

  cat("Computing feature directions...\n")
  feat_dirs <- calculate_feature_directions(pca_res, feature_group_defs)

  cosine_df <- compare_likelihood_to_features(lik_vecs, feat_dirs)

  # Heatmap
  p_heat <- ggplot(cosine_df, aes(x = feature_group, y = model, fill = cosine)) +
    geom_tile() +
    geom_text(aes(label = sprintf("%.2f", cosine)), color = "white", fontface = "bold", size = 3) +
    scale_fill_gradient2(limits = c(-1, 1), midpoint = 0, na.value = "grey90",
                         low = "#2166ac", high = "#b2182b") +
    labs(title = paste("Likelihood-Feature Alignment -", set_name),
         x = "Feature Group", y = "Model", fill = "Cosine") +
    theme_minimal(base_size = 12) +
    theme(axis.text.x = element_text(angle = 45, hjust = 1))
  print(p_heat)

  # Vector direction plot
  vec_data <- bind_rows(
    tibble(x = 0, y = 0,
           xend = sapply(lik_vecs, function(v) v$best_vec[1]),
           yend = sapply(lik_vecs, function(v) v$best_vec[2]),
           name = names(lik_vecs), type = "Model"),
    tibble(x = 0, y = 0,
           xend = sapply(feat_dirs, `[`, 1),
           yend = sapply(feat_dirs, `[`, 2),
           name = names(feat_dirs), type = "Feature")
  )

  p_vec <- ggplot(vec_data, aes(x = x, y = y, xend = xend, yend = yend, color = name)) +
    geom_segment(arrow = arrow(length = unit(0.2, "inches"), type = "closed"),
                 linewidth = 1.2, alpha = 0.8) +
    geom_text_repel(aes(x = xend, y = yend, label = name), size = 3) +
    facet_wrap(~ type) +
    labs(title = paste("Preference & Feature Vectors -", set_name)) +
    theme_minimal(base_size = 12) +
    coord_fixed(xlim = c(-1.2, 1.2), ylim = c(-1.2, 1.2)) +
    geom_hline(yintercept = 0, linetype = "dashed", color = "grey70") +
    geom_vline(xintercept = 0, linetype = "dashed", color = "grey70") +
    theme(legend.position = "none")
  print(p_vec)

  # Store & save
  all_results[[set_name]]$cosine_analysis <- list(
    likelihood_vectors = lik_vecs, feature_directions = feat_dirs,
    cosine_df = cosine_df
  )
  prefix <- gsub(" ", "_", set_name)
  write_csv(cosine_df, file.path(OUTPUT_DIR, paste0(prefix, "_cosine_similarities.csv")))
  ggsave(file.path(OUTPUT_DIR, paste0(prefix, "_cosine_heatmap.png")), p_heat,
         width = 10, height = 6, dpi = 300, bg = "white")
}

cat("\nCosine similarity analysis complete.\n")

In [None]:
# Cell 12: Fine-tuning Validation
# Compares vanilla ProteinMPNN vs AlkalineMPNN using pH loading vector alignment

validate_finetuning_effects <- function(pca_results, analysis_name = "pH Features",
                                         vanilla_score = "proteinmpnn_score",
                                         finetuned_score = "AlkalineMPNN_score",
                                         n_bootstrap = 300, grid_size = 60, k_basis = 12) {
  cat("\n=== FINE-TUNING VALIDATION ===", analysis_name, "\n")

  required_scores <- c(vanilla_score, finetuned_score)
  avail <- required_scores[required_scores %in% colnames(pca_results$data)]
  if (length(avail) < 2) {
    cat("Missing required score columns:", setdiff(required_scores, avail), "\n")
    return(NULL)
  }

  # pH loading vector
  avail_pH <- intersect(pH_features, pca_results$loadings$Feature)
  if (length(avail_pH) == 0) {
    cat("No pH features in loadings; using all features\n")
    load_use <- pca_results$loadings
  } else {
    load_use <- pca_results$loadings %>% filter(Feature %in% avail_pH)
  }
  ph_vector <- .unit_vec(c(PC1 = mean(load_use$PC1), PC2 = mean(load_use$PC2)))
  cat(sprintf("pH loading vector: PC1=%.3f, PC2=%.3f\n", ph_vector[1], ph_vector[2]))

  # Fit GAMs and find argmax
  data_clean <- pca_results$data %>%
    dplyr::select(PC1, PC2, all_of(avail)) %>% drop_na()
  if (nrow(data_clean) < 200) cat("Note: <200 rows; CIs may be wide.\n")

  grid <- expand.grid(
    PC1 = seq(min(data_clean$PC1), max(data_clean$PC1), length.out = grid_size),
    PC2 = seq(min(data_clean$PC2), max(data_clean$PC2), length.out = grid_size)
  )

  gam_fits <- list()
  for (sc in avail) {
    form <- as.formula(paste0("`", sc, "` ~ s(PC1, PC2, k = ", k_basis, ")"))
    mod <- gam(form, data = data_clean, method = "REML")
    Xp <- predict(mod, newdata = grid, type = "lpmatrix")
    beta <- coef(mod)
    Vb <- vcov(mod)
    pred_hat <- as.numeric(Xp %*% beta)
    best_idx <- which.max(pred_hat)
    best_vec <- .unit_vec(unlist(grid[best_idx, c("PC1", "PC2")]))
    gam_fits[[sc]] <- list(mod = mod, Xp = Xp, beta = beta, Vb = Vb, best_vec = best_vec)
    cat(sprintf("  %s likelihood vector: PC1=%.3f, PC2=%.3f\n", sc, best_vec[1], best_vec[2]))
  }

  # Point estimate cosine similarities
  cosine_sims <- sapply(gam_fits, function(fit) .cosine(ph_vector, fit$best_vec))
  cat("\nPoint estimate cosine similarities:\n")
  for (sc in names(cosine_sims)) {
    cat(sprintf("  %s: %.4f\n", sc, cosine_sims[sc]))
  }

  # Posterior simulation bootstrap
  cat("\nRunning posterior bootstrap (", n_bootstrap, "iterations)...\n")
  boot_sims <- matrix(NA_real_, nrow = n_bootstrap, ncol = length(avail))
  colnames(boot_sims) <- avail

  for (b in seq_len(n_bootstrap)) {
    for (sc in avail) {
      fit <- gam_fits[[sc]]
      beta_b <- MASS::mvrnorm(1, mu = fit$beta, Sigma = fit$Vb)
      pred_b <- as.numeric(fit$Xp %*% beta_b)
      idx_b <- which.max(pred_b)
      v_b <- tryCatch(.unit_vec(unlist(grid[idx_b, c("PC1", "PC2")])),
                      error = function(e) c(NA, NA))
      if (!any(is.na(v_b))) boot_sims[b, sc] <- .cosine(ph_vector, v_b)
    }
  }

  # Confidence intervals
  ci <- function(x) quantile(x, c(0.025, 0.975), na.rm = TRUE)
  conf_intervals <- lapply(avail, function(sc) ci(boot_sims[, sc]))
  names(conf_intervals) <- avail
  for (sc in avail) {
    cat(sprintf("  %s: 95%% CI [%.4f, %.4f]\n", sc,
                conf_intervals[[sc]][1], conf_intervals[[sc]][2]))
  }

  # Significance test
  diff_boot <- boot_sims[, finetuned_score] - boot_sims[, vanilla_score]
  p_value <- 2 * min(mean(diff_boot <= 0, na.rm = TRUE), mean(diff_boot >= 0, na.rm = TRUE))
  cat(sprintf("\nFine-tuned vs Vanilla difference: %.4f, p-value: %.4f\n",
              cosine_sims[finetuned_score] - cosine_sims[vanilla_score], p_value))

  # Plots
  sim_df <- data.frame(
    Model = c("Vanilla ProteinMPNN", "AlkalineMPNN"),
    Cosine = c(cosine_sims[vanilla_score], cosine_sims[finetuned_score]),
    CI_Lo = c(conf_intervals[[vanilla_score]][1], conf_intervals[[finetuned_score]][1]),
    CI_Hi = c(conf_intervals[[vanilla_score]][2], conf_intervals[[finetuned_score]][2])
  )

  p_bar <- ggplot(sim_df, aes(x = Model, y = Cosine, fill = Model)) +
    geom_col(alpha = 0.8, width = 0.6) +
    geom_errorbar(aes(ymin = CI_Lo, ymax = CI_Hi), width = 0.2, linewidth = 1) +
    scale_fill_manual(values = c("Vanilla ProteinMPNN" = "#e74c3c", "AlkalineMPNN" = "#3498db")) +
    labs(title = paste("pH Alignment:", analysis_name),
         subtitle = "Cosine similarity with pH loading vector (GAM argmax)",
         y = "Cosine Similarity", x = NULL) +
    theme_minimal(base_size = 12) +
    theme(legend.position = "none", axis.text.x = element_text(angle = 30, hjust = 1))

  var_x <- round(pca_results$explained_variance[1] * 100, 1)
  var_y <- round(pca_results$explained_variance[2] * 100, 1)
  vec_df <- data.frame(
    x = 0, y = 0,
    xend = c(gam_fits[[vanilla_score]]$best_vec[1],
             gam_fits[[finetuned_score]]$best_vec[1], ph_vector[1]),
    yend = c(gam_fits[[vanilla_score]]$best_vec[2],
             gam_fits[[finetuned_score]]$best_vec[2], ph_vector[2]),
    Label = c("Vanilla ProteinMPNN", "AlkalineMPNN", "pH Direction")
  )

  p_vec <- ggplot(vec_df, aes(x = x, y = y, xend = xend, yend = yend, color = Label)) +
    geom_segment(arrow = arrow(length = unit(0.3, "inches"), type = "closed"),
                 linewidth = 1.5, alpha = 0.8) +
    scale_color_manual(values = c("Vanilla ProteinMPNN" = "#e74c3c",
                                  "AlkalineMPNN" = "#3498db",
                                  "pH Direction" = "#9b59b6")) +
    labs(title = "Likelihood vs pH Direction Vectors",
         x = paste0("PC1 (", var_x, "%)"), y = paste0("PC2 (", var_y, "%)")) +
    theme_minimal(base_size = 12) +
    theme(legend.position = "bottom", aspect.ratio = 1) +
    coord_fixed(xlim = c(-1.2, 1.2), ylim = c(-1.2, 1.2)) +
    geom_hline(yintercept = 0, linetype = "dashed", color = "grey70") +
    geom_vline(xintercept = 0, linetype = "dashed", color = "grey70")

  print(p_bar + p_vec)
  ggsave(file.path(OUTPUT_DIR, paste0(gsub(" ", "_", analysis_name), "_finetuning_validation.png")),
         p_bar + p_vec, width = 14, height = 6, dpi = 300, bg = "white")

  list(cosine_sims = cosine_sims, conf_intervals = conf_intervals,
       p_value = p_value, ph_vector = ph_vector,
       plots = list(bar = p_bar, vectors = p_vec))
}

# Run on pH features (primary) and mixed features
ft_results <- list()
for (set_name in c("pH Features", "Mixed Features", "Sequence Features")) {
  if (set_name %in% names(all_results)) {
    ft_results[[set_name]] <- validate_finetuning_effects(
      all_results[[set_name]]$pca_results, analysis_name = set_name
    )
  }
}
cat("\nFine-tuning validation complete.\n")

In [None]:
# Cell 13: Load Design Data

cat("=== LOADING DESIGN DATA ===\n")

# Load and combine all per-model design CSVs
designs_list <- list()
for (model_name in names(DESIGN_FILES)) {
  fpath <- DESIGN_FILES[[model_name]]
  if (!file.exists(fpath)) {
    cat("  WARNING: File not found:", fpath, "\n")
    next
  }
  d <- read_csv(fpath, show_col_types = FALSE)

  # Standardize model column name
  if ("Model" %in% names(d)) {
    d <- d %>% rename(model = Model)
  }
  if (!"model" %in% names(d)) {
    d$model <- model_name
  }

  designs_list[[model_name]] <- d
  cat(sprintf("  %s: %d rows loaded\n", model_name, nrow(d)))
}

if (length(designs_list) == 0) {
  cat("ERROR: No design files found. Skipping design analysis.\n")
  cat("Please check DESIGN_DIR and DESIGN_FILES paths in Cell 3.\n")
} else {
  designs <- bind_rows(designs_list)
  designs <- designs %>%
    mutate(across(any_of(c("Entry", "design_id", "model", "domain")), as.character)) %>%
    mutate(across(ends_with("_shift"), ~ .as_num(.x)))

  cat("\nCombined designs:", nrow(designs), "total rows\n")
  cat("Models:", paste(unique(designs$model), collapse = ", "), "\n")
  cat("Unique proteins:", n_distinct(designs$Entry), "\n")

  # Identify available shift columns
  shift_cols <- grep("_shift$", names(designs), value = TRUE)
  cat("Shift columns:", length(shift_cols), "\n")
}

In [None]:
# Cell 14: Design Shift PCA — Project into WT Sequence-PCA Space (Figure 4)

if (!exists("designs")) {
  cat("No design data loaded. Skipping.\n")
} else {

cat("=== DESIGN SHIFT PCA (Figure 4) ===\n")

# 1) Fit WT sequence-PCA (6 features)
wt_seq_pca <- perform_focused_pca(df, sequence_features)
wt_seq_coords <- wt_seq_pca$data %>%
  transmute(Entry, domain, PC1_wt = PC1, PC2_wt = PC2)

# 2) Reconstruct absolute feature values for designs from WT + shift
reconstruct_design_raw <- function(designs_df, wt_df, raw_features, key = "Entry") {
  feats <- intersect(raw_features, names(wt_df))
  if (!length(feats)) return(designs_df)

  wt_base <- wt_df %>%
    dplyr::select(all_of(c(key, feats))) %>%
    rename_with(~ paste0(.x, "__WT"), .cols = -all_of(key))

  d <- designs_df %>% left_join(wt_base, by = key)

  for (f in feats) {
    wt_col <- .as_num(d[[paste0(f, "__WT")]])
    if (f == "sequence_length") { d[[f]] <- wt_col; next }

    raw_col <- if (f %in% names(d)) .as_num(d[[f]]) else rep(NA_real_, nrow(d))
    shift_nm <- paste0(f, "_shift")
    shift_col <- if (shift_nm %in% names(d)) .as_num(d[[shift_nm]]) else rep(0, nrow(d))
    d[[f]] <- dplyr::coalesce(raw_col, wt_col) + dplyr::coalesce(shift_col, 0)
  }
  d
}

designs_raw <- reconstruct_design_raw(designs, df, sequence_features)

# 3) Project designs into WT PC space
recon_feats <- intersect(sequence_features, names(designs_raw))
num_ok <- designs_raw %>%
  mutate(across(all_of(recon_feats), .as_num)) %>%
  filter(if_all(all_of(recon_feats), ~ is.finite(.x)))

X <- as.matrix(num_ok[, recon_feats, drop = FALSE])
Xs <- sweep(sweep(X, 2, wt_seq_pca$center[recon_feats], "-"),
            2, wt_seq_pca$scale[recon_feats], "/")
Xs[!is.finite(Xs)] <- 0
R <- wt_seq_pca$pca_object$rotation[recon_feats, 1:2, drop = FALSE]
Z <- Xs %*% R

designs_in_PCA <- num_ok %>%
  mutate(PC1 = Z[, 1], PC2 = Z[, 2]) %>%
  left_join(wt_seq_coords, by = c("Entry", "domain")) %>%
  filter(is.finite(PC1), is.finite(PC2))

cat("Projected", nrow(designs_in_PCA), "designs into WT sequence-PCA space\n")

# 4) Overall scatter by model
var_pc <- wt_seq_pca$explained_variance[1:2]
labs_xy <- pc_axis_labels(wt_seq_pca$explained_variance)
ratio_yx <- var_pc[2] / var_pc[1]

pal_model <- MODEL_COLORS[names(MODEL_COLORS) %in% unique(designs_in_PCA$model)]

p_all_models <- ggplot() +
  geom_point(data = wt_seq_pca$data, aes(PC1, PC2), color = "grey90", alpha = 0.35, size = 0.8) +
  geom_point(data = designs_in_PCA,
             aes(PC1, PC2, color = model, fill = model),
             size = 1.8, alpha = 0.7, shape = 21, stroke = 0.2) +
  scale_color_manual(values = pal_model) +
  scale_fill_manual(values = pal_model) +
  labs(x = labs_xy$x, y = labs_xy$y,
       title = "Designs in WT sequence-PCA space") +
  theme_minimal(base_size = 12) +
  theme(legend.position = "bottom",
        panel.border = element_rect(color = "grey60", fill = NA)) +
  coord_fixed(ratio = ratio_yx) +
  geom_hline(yintercept = 0, linetype = "dashed", color = "grey70") +
  geom_vline(xintercept = 0, linetype = "dashed", color = "grey70")
print(p_all_models)

# 5) Centroid shift arrows (Figure 4B left)
per_entry_means <- designs_in_PCA %>%
  group_by(Entry, model) %>%
  summarise(PC1_mean = mean(PC1), PC2_mean = mean(PC2),
            PC1_wt = first(PC1_wt), PC2_wt = first(PC2_wt), .groups = "drop") %>%
  mutate(dPC1 = PC1_mean - PC1_wt, dPC2 = PC2_mean - PC2_wt)

model_centroids <- per_entry_means %>%
  group_by(model) %>%
  summarise(dPC1 = mean(dPC1), dPC2 = mean(dPC2), .groups = "drop")

wt_mean <- wt_seq_pca$data %>%
  filter(Entry %in% unique(designs_in_PCA$Entry)) %>%
  summarise(PC1 = mean(PC1), PC2 = mean(PC2))

p_arrows <- ggplot() +
  geom_segment(data = model_centroids,
               aes(x = wt_mean$PC1, y = wt_mean$PC2,
                   xend = wt_mean$PC1 + dPC1, yend = wt_mean$PC2 + dPC2,
                   color = model),
               arrow = arrow(length = unit(0.2, "inches"), type = "closed"),
               linewidth = 1.2) +
  geom_point(data = wt_mean, aes(PC1, PC2), shape = 4, size = 4, stroke = 1.5) +
  scale_color_manual(values = pal_model) +
  labs(x = labs_xy$x, y = labs_xy$y,
       title = "Model centroid shifts (WT mean to design mean)") +
  theme_minimal(base_size = 12) +
  theme(legend.position = "bottom",
        panel.border = element_rect(color = "grey60", fill = NA)) +
  coord_fixed(ratio = ratio_yx) +
  geom_hline(yintercept = 0, linetype = "dashed", color = "grey70") +
  geom_vline(xintercept = 0, linetype = "dashed", color = "grey70")
print(p_arrows)

# 6) Per-protein scatter examples (4 representative proteins)
example_entries <- designs_in_PCA %>%
  distinct(Entry) %>%
  slice_head(n = 4) %>%
  pull(Entry)

per_prot_plots <- list()
for (eid in example_entries) {
  d_one <- designs_in_PCA %>% filter(Entry == eid)
  wt_pt <- d_one %>% distinct(PC1_wt, PC2_wt)

  per_prot_plots[[eid]] <- ggplot() +
    geom_point(data = designs_in_PCA, aes(PC1, PC2), color = "grey85", alpha = 0.3, size = 0.8) +
    geom_point(data = d_one, aes(PC1, PC2, color = model), size = 2, alpha = 0.85) +
    geom_point(data = wt_pt, aes(PC1_wt, PC2_wt), shape = 4, size = 3, stroke = 1) +
    scale_color_manual(values = pal_model) +
    labs(title = eid, x = labs_xy$x, y = labs_xy$y) +
    theme_minimal(base_size = 10) +
    theme(legend.position = "none") +
    coord_fixed(ratio = ratio_yx)
}

if (length(per_prot_plots) > 0) {
  print(wrap_plots(per_prot_plots, ncol = 2) +
    plot_annotation(title = "Per-protein designs in WT sequence-PCA",
                    subtitle = "X = wild-type position"))
}

# Save
ggsave(file.path(OUTPUT_DIR, "design_shift_overall_by_model.png"), p_all_models,
       width = 9, height = 7, dpi = 300, bg = "white")
ggsave(file.path(OUTPUT_DIR, "design_shift_centroid_arrows.png"), p_arrows,
       width = 8, height = 7, dpi = 300, bg = "white")

cat("Design shift PCA visualization complete.\n")
}

In [None]:
# Cell 15: Design Shift — Quantitative Analysis

if (!exists("designs_in_PCA")) {
  cat("No design PCA data. Skipping.\n")
} else {

cat("=== QUANTITATIVE DESIGN SHIFT ANALYSIS ===\n")

# 1) Per-model centroids and one-sample t-tests
per_model_tests <- per_entry_means %>%
  group_by(model) %>%
  summarise(
    n = n(),
    mean_dPC1 = mean(dPC1), sd_dPC1 = sd(dPC1),
    mean_dPC2 = mean(dPC2), sd_dPC2 = sd(dPC2),
    t_PC1 = tryCatch(t.test(dPC1)$statistic, error = function(e) NA),
    p_PC1 = tryCatch(t.test(dPC1)$p.value, error = function(e) NA),
    t_PC2 = tryCatch(t.test(dPC2)$statistic, error = function(e) NA),
    p_PC2 = tryCatch(t.test(dPC2)$p.value, error = function(e) NA),
    d_PC1 = mean(dPC1) / sd(dPC1),  # Cohen's d
    d_PC2 = mean(dPC2) / sd(dPC2),
    .groups = "drop"
  )

cat("\nPer-model centroid shift tests:\n")
print(per_model_tests)
write_csv(per_model_tests, file.path(OUTPUT_DIR, "design_shift_model_tests.csv"))

# 2) Back-project centroid shifts into feature space
R12 <- wt_seq_pca$pca_object$rotation[wt_seq_pca$features, 1:2, drop = FALSE]
scales <- wt_seq_pca$scale[wt_seq_pca$features]

model_centroids_pc <- per_entry_means %>%
  group_by(model) %>%
  summarise(dPC1_centroid = mean(dPC1), dPC2_centroid = mean(dPC2), .groups = "drop")

model_feature_deltas <- model_centroids_pc %>%
  rowwise() %>%
  mutate(bp = list({
    dz <- as.numeric(R12 %*% c(dPC1_centroid, dPC2_centroid))
    tibble(feature = wt_seq_pca$features, dz = dz, d_raw = dz * scales[feature])
  })) %>%
  unnest(bp) %>%
  ungroup() %>%
  arrange(model, desc(abs(dz)))

cat("\nBack-projected feature deltas (top per model):\n")
model_feature_deltas %>%
  group_by(model) %>%
  slice_head(n = 3) %>%
  print(n = 20)
write_csv(model_feature_deltas, file.path(OUTPUT_DIR, "design_shift_feature_deltas.csv"))

# 3) Per-feature t-tests with BH FDR correction
shift_feat_cols <- paste0(sequence_features, "_shift")
avail_shift <- intersect(shift_feat_cols, names(designs))

if (length(avail_shift) > 0) {
  feat_tests <- designs %>%
    dplyr::select(model, all_of(avail_shift)) %>%
    pivot_longer(cols = all_of(avail_shift), names_to = "feature", values_to = "shift") %>%
    group_by(model, feature) %>%
    summarise(
      n = sum(!is.na(shift)),
      mean_shift = mean(shift, na.rm = TRUE),
      sd_shift = sd(shift, na.rm = TRUE),
      t_stat = tryCatch(t.test(shift)$statistic, error = function(e) NA),
      p_value = tryCatch(t.test(shift)$p.value, error = function(e) NA),
      cohens_d = mean(shift, na.rm = TRUE) / sd(shift, na.rm = TRUE),
      .groups = "drop"
    ) %>%
    group_by(model) %>%
    mutate(p_adj = p.adjust(p_value, method = "BH")) %>%
    ungroup()

  cat("\nPer-feature shift tests (BH-corrected):\n")
  feat_tests %>% filter(p_adj < 0.05) %>% print(n = 30)
  write_csv(feat_tests, file.path(OUTPUT_DIR, "design_shift_feature_tests.csv"))
}

# 4) Cosine similarity between models (per-protein WT->design vectors)
compute_model_cosine_similarity <- function(pca_df) {
  mean_coords <- pca_df %>%
    group_by(Entry, model) %>%
    summarise(PC1_mean = mean(PC1), PC2_mean = mean(PC2),
              PC1_wt = first(PC1_wt), PC2_wt = first(PC2_wt), .groups = "drop") %>%
    mutate(dPC1 = PC1_mean - PC1_wt, dPC2 = PC2_mean - PC2_wt)

  keep <- mean_coords %>% count(Entry) %>% filter(n >= 2) %>% pull(Entry)
  mc <- mean_coords %>% filter(Entry %in% keep)

  pairs <- mc %>%
    group_by(Entry) %>%
    summarise(d = list(expand_grid(model_i = model, model_j = model)), .groups = "drop") %>%
    unnest(d) %>%
    filter(model_i < model_j) %>%
    left_join(mc %>% dplyr::select(Entry, model, dPC1, dPC2), by = c("Entry", "model_i" = "model")) %>%
    rename(dPC1_i = dPC1, dPC2_i = dPC2) %>%
    left_join(mc %>% dplyr::select(Entry, model, dPC1, dPC2), by = c("Entry", "model_j" = "model")) %>%
    rename(dPC1_j = dPC1, dPC2_j = dPC2) %>%
    mutate(cosine = .cosine2(dPC1_i, dPC2_i, dPC1_j, dPC2_j))

  summary <- pairs %>%
    group_by(model_i, model_j) %>%
    summarise(n = sum(!is.na(cosine)),
              mean_cos = mean(cosine, na.rm = TRUE),
              sd_cos = sd(cosine, na.rm = TRUE), .groups = "drop")

  list(pairs = pairs, summary = summary)
}

cos_result <- compute_model_cosine_similarity(designs_in_PCA)
cat("\nPairwise model cosine similarity (mean across proteins):\n")
print(cos_result$summary)

# Cosine heatmap
rev_df <- cos_result$summary %>% rename(model_j = model_i, model_i = model_j)
diag_models <- unique(c(cos_result$summary$model_i, cos_result$summary$model_j))
sq <- bind_rows(cos_result$summary, rev_df) %>%
  bind_rows(tibble(model_i = diag_models, model_j = diag_models,
                   n = NA_integer_, mean_cos = 1, sd_cos = NA_real_)) %>%
  distinct(model_i, model_j, .keep_all = TRUE)

p_cos_heat <- ggplot(sq, aes(model_i, model_j, fill = mean_cos)) +
  geom_tile() +
  geom_text(aes(label = ifelse(is.na(n), "", sprintf("%.2f", mean_cos))),
            color = "white", fontface = "bold", size = 3.5) +
  scale_fill_gradient2(limits = c(-1, 1), midpoint = 0) +
  labs(title = "Mean cosine similarity between model shift vectors",
       x = NULL, y = NULL, fill = "Cosine") +
  theme_minimal(base_size = 12) +
  theme(axis.text.x = element_text(angle = 45, hjust = 1)) +
  coord_equal()
print(p_cos_heat)

write_csv(cos_result$summary, file.path(OUTPUT_DIR, "model_cosine_summary.csv"))
write_csv(cos_result$pairs, file.path(OUTPUT_DIR, "model_cosine_per_protein.csv"))
ggsave(file.path(OUTPUT_DIR, "model_cosine_heatmap.png"), p_cos_heat,
       width = 7, height = 6, dpi = 300, bg = "white")

cat("\nQuantitative design shift analysis complete.\n")
}

In [None]:
# Cell 16: Save & Export

cat("=== SAVING ALL OUTPUTS ===\n")

# List all output files
output_files <- list.files(OUTPUT_DIR, recursive = TRUE, full.names = TRUE)
cat("Output files generated:\n")
for (f in output_files) {
  size_kb <- round(file.info(f)$size / 1024, 1)
  cat(sprintf("  %s (%s KB)\n", basename(f), size_kb))
}

# Zip for download (useful on Colab)
zip_path <- paste0(OUTPUT_DIR, ".zip")
if (length(output_files) > 0) {
  zip(zip_path, OUTPUT_DIR)
  cat("\nZipped to:", zip_path, "\n")
}

cat("\n=== ANALYSIS COMPLETE ===")