In [None]:
---
title: "Multimorbidity Data Analyses: Aim 1"
output: html_notebook
---

#Aim 1: Characterize combinations of chronic diseases found in patients with multimorbidity among subgroups.
##Part 1: Descriptives

***Note: TGA was changed in Part 2b, so, when running any script that involves TGA, replace with TGA2, which uses the TGA_BN dataframe that was the most up to date version***

###The first step is to characterize the chronic diseases by comparing the control (non-multimorbidity/CG) and test (multimorbidity/TG) groups.
```{r}
#Call the packages and load the data
library(readr)
library(dplyr)
library(stats)
library(tidyr)
library(bnlearn)
library(viridisLite)
library(knitr)
library(stringr)
library(ggplot2)

datapath2 = normalizePath("C:/Users/taiqu/Box/01-TaiR-Dissertation-FALL2024/08-0325-RScripts")
```

## Count of Chronic Diseases in the the Combinations Data
```{r}
# Load necessary library
library(dplyr)

# Subset the second dataframe based on values in the reference dataframe
TGA_SCD <- TGA_back %>%
  filter(de_id_mrn %in% second_dataframe$de_id_mrn)

# View the new dataframe
head(TGA_SCD)

```



###Descriptive Statistics (Control:Nonmultimorbidity (CG2) & Test: Multimorbidity (TG2))
```{r}
# Function to calculate mode
calculate_mode = function(x) {
  uniq = unique(na.omit(x))  # Remove NAs and get unique values
  uniq[which.max(tabulate(match(x, uniq)))]  # Return the most frequent value
}

# Function to generate descriptive statistics
generate_descriptive_stats = function(df) {
  # Exclude "encdates" from analyses
  df_analysis = df[, !colnames(df) %in% "encdates"]

  # Separate numeric and categorical variables
  numeric_vars = df_analysis[sapply(df_analysis, is.numeric)]
  categorical_vars = df_analysis[sapply(df_analysis, is.factor) | sapply(df_analysis, is.character)]

  # Identify binary numeric variables and include only "Yes" (1s) in categorical stats
  binary_stats = lapply(names(numeric_vars), function(var) {
    x = numeric_vars[[var]]
    if (all(na.omit(x) %in% c(0, 1))) {
      # Calculate stats for "Yes" (1)
      count_1s = sum(x == 1, na.rm = TRUE)
      proportion_1s = mean(x == 1, na.rm = TRUE)
      data.frame(
        Variable = var,
        Category = "Yes",
        Count = count_1s,
        Proportion = round(proportion_1s, 3),
        stringsAsFactors = FALSE
      )
    } else {
      NULL  # Exclude non-binary variables
    }
  })
  binary_stats = do.call(rbind, binary_stats)  # Combine all binary stats into one data frame if any

  # Categorical statistics (excluding binary variables)
  categorical_stats = lapply(names(categorical_vars), function(var) {
    x = categorical_vars[[var]]
    if (!is.null(x)) {
      freq_table = table(x, useNA = "ifany")
      data.frame(
        Variable = var,
        Category = names(freq_table),
        Count = as.numeric(freq_table),
        Proportion = round(as.numeric(prop.table(freq_table)), 3),
        stringsAsFactors = FALSE
      )
    } else {
      NULL
    }
  })
  categorical_stats = do.call(rbind, categorical_stats)  # Combine all categorical stats into one data frame

  # Combine binary stats with categorical stats if binary stats exist
  combined_categorical_stats = if (!is.null(binary_stats)) {
    rbind(categorical_stats, binary_stats)
  } else {
    categorical_stats
  }

  # Numeric statistics
  numeric_stats = data.frame(
    Variable = names(numeric_vars),
    Missing = sapply(numeric_vars, function(x) sum(is.na(x))),
    Mean = sapply(numeric_vars, mean, na.rm = TRUE),
    Median = sapply(numeric_vars, median, na.rm = TRUE),
    Std_Dev = sapply(numeric_vars, sd, na.rm = TRUE),
    Mode = sapply(numeric_vars, calculate_mode),
    Min = sapply(numeric_vars, min, na.rm = TRUE),
    Max = sapply(numeric_vars, max, na.rm = TRUE),
    IQR = sapply(numeric_vars, function(x) {
      Q1 = quantile(x, 0.25, na.rm = TRUE)
      Q3 = quantile(x, 0.75, na.rm = TRUE)
      paste0(round(Q1, 2), "-", round(Q3, 2))  # Return the range as "Q1-Q3"
    })
  )

  # Return combined results
  list(numeric = numeric_stats, categorical = combined_categorical_stats)
}

# Format and print tables
print_stats = function(stats, dataset_name) {
  cat("\nDescriptive Statistics for", dataset_name, "Numeric Variables:\n")
  print(stats$numeric)

  cat("\nFrequency Distribution (Including Yes for Binary Variables) for", dataset_name, "Categorical Variables:\n")
  print(stats$categorical)
}

# Generate stats for CGA and TGA
CGA_stats = generate_descriptive_stats(CGA)
TGA_stats = generate_descriptive_stats(TGA)

# Print and save results
print_stats(CGA_stats, "CGA")
print_stats(TGA_stats, "TGA")

write.csv(CGA_stats$numeric, "CGA_numeric_stats.csv", row.names = FALSE)
write.csv(CGA_stats$categorical, "CGA_categorical_stats.csv", row.names = FALSE)

write.csv(TGA_stats$numeric, "TGA_numeric_stats.csv", row.names = FALSE)
write.csv(TGA_stats$categorical, "TGA_categorical_stats.csv", row.names = FALSE)
```
###Bar Plot: Age Groups
```{r}
# Load necessary library
library(ggplot2)

# Add a `group` column to differentiate the datasets
CGA$group = "Non-Multimorbidity"
TGA$group = "Multimorbidity"

# Combine the datasets into one for plotting
combined_data = rbind(
  data.frame(age_groups = CGA$age_groups, group = CGA$group),
  data.frame(age_groups = TGA$age_groups, group = TGA$group)
)

# Preprocess the data to calculate proportions
combined_data_summary = aggregate(
  count ~ group + age_groups,
  data = transform(combined_data, count = 1),
  FUN = length
)
combined_data_summary = transform(
  combined_data_summary,
  proportion = ave(count, group, FUN = function(x) x / sum(x))
)

# Overlay plot with y-axis ticks as raw proportions
plot = ggplot(combined_data_summary, aes(x = age_groups, y = proportion, fill = group)) +
  geom_bar(stat = "identity", position = "dodge", alpha = 1, color = NA) +
  geom_text(
  aes(label = round(proportion, 2)), # Display raw proportions rounded to 2 decimal places
  position = position_dodge(width = 1.2),
  vjust = -0.5,
  size = 3.5
  ) +
  scale_fill_manual(values = c("blue", "orange")) +
  labs(
    title = "Age Group Distribution: Patients With Multimorbidity vs Non-Multimorbidity",
    x = "Age Group (years)",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 50)
  ) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) # Default formatting for raw proportions

# Save the plot
ggsave("age_group_distribution.png", plot = plot, width = 8, height = 6, dpi = 300)
```


###Race Barplot
```{r}
# Load necessary library
library(ggplot2)

# Add a `group` column to differentiate the datasets
CGA$group = "Non-Multimorbidity"
TGA$group = "Multimorbidity"

# Combine the datasets into one for plotting
combined_data = rbind(
  data.frame(race_corrected = CGA$race_corrected, group = CGA$group),
  data.frame(race_corrected = TGA$race_corrected, group = TGA$group)
)

# Preprocess the data to calculate proportions
combined_data_summary = aggregate(
  count ~ group + race_corrected,
  data = transform(combined_data, count = 1),
  FUN = length
)
combined_data_summary = transform(
  combined_data_summary,
  proportion = ave(count, group, FUN = function(x) x / sum(x))
)

# Overlay plot with y-axis ticks as raw proportions
plot = ggplot(combined_data_summary, aes(x = race_corrected, y = proportion, fill = group)) +
  geom_bar(stat = "identity", position = "dodge", alpha = 1, color = NA) +
  geom_text(
  aes(label = round(proportion, 2)), # Display raw proportions rounded to 2 decimal places
  position = position_dodge(width = 1.2),
  vjust = -0.5,
  size = 3.5
  ) +
  scale_fill_manual(values = c("blue", "orange")) +
  labs(
    title = "Distribution of Race for Patients With Multimorbidity vs Non-Multimorbidity",
    x = "Race",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 50)
  ) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) # Default formatting for raw proportions

# Save the plot
ggsave("race_distribution.png", plot = plot, width = 8, height = 6, dpi = 300)
```

###Ethnicity Barplot
```{r}
# Add a `group` column to differentiate the datasets
CGA$group = "Non-Multimorbidity"
TGA$group = "Multimorbidity"

# Combine the datasets into one for plotting
combined_data = rbind(
  data.frame(eth_corrected = CGA$eth_corrected, group = CGA$group),
  data.frame(eth_corrected = TGA$eth_corrected, group = TGA$group)
)

# Preprocess the data to calculate proportions
combined_data_summary = aggregate(
  count ~ group + eth_corrected,
  data = transform(combined_data, count = 1),
  FUN = length
)
combined_data_summary = transform(
  combined_data_summary,
  proportion = ave(count, group, FUN = function(x) x / sum(x))
)

# Overlay plot with y-axis ticks as raw proportions
plot = ggplot(combined_data_summary, aes(x = eth_corrected, y = proportion, fill = group)) +
  geom_bar(stat = "identity", position = "dodge", alpha = 1, color = NA) +
  geom_text(
  aes(label = round(proportion, 2)), # Display raw proportions rounded to 2 decimal places
  position = position_dodge(width = 1.2),
  vjust = -0.5,
  size = 3.5
  ) +
  scale_fill_manual(values = c("blue", "orange")) +
  labs(
    title = "Distribution of Ethnicity for Patients With Multimorbidity vs Non-Multimorbidity",
    x = "Ethnicity",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 50)
  ) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) # Default formatting for raw proportions

# Save the plot
ggsave("ethnicity_distribution.png", plot = plot, width = 8, height = 6, dpi = 300)
```

###Sex Barplot
```{r}
# Combine the datasets into one for plotting
combined_data = rbind(
  data.frame(sex = CGA$sex, group = CGA$group),
  data.frame(sex = TGA$sex, group = TGA$group)
)

# Preprocess the data to calculate proportions
combined_data_summary = aggregate(
  count ~ group + sex,
  data = transform(combined_data, count = 1),
  FUN = length
)
combined_data_summary = transform(
  combined_data_summary,
  proportion = ave(count, group, FUN = function(x) x / sum(x))
)

# Overlay plot with y-axis ticks as raw proportions
plot = ggplot(combined_data_summary, aes(x = sex, y = proportion, fill = group)) +
  geom_bar(stat = "identity", position = "dodge", alpha = 1, color = NA) +
  geom_text(
  aes(label = round(proportion, 2)), # Display raw proportions rounded to 2 decimal places
  position = position_dodge(width = 1.2),
  vjust = -0.5,
  size = 3.5
  ) +
  scale_fill_manual(values = c("blue", "orange")) +
  labs(
    title = "Distribution of Sex Patients With Multimorbidity vs Non-Multimorbidity",
    x = "Sex",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 50)
  ) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) # Default formatting for raw proportions

# Save the plot
ggsave("sex_distribution.png", plot = plot, width = 8, height = 6, dpi = 300)
```

###Comorbidity Score & Van Index Barplots
```{r}
# Combine the datasets for analysis
combined_data = rbind(
  data.frame(variable = CGA$comorbidity_score, group = "Non-Multimorbidity", name = "comorbidity_score"),
  data.frame(variable = TGA$comorbidity_score, group = "Multimorbidity", name = "comorbidity_score"),
  data.frame(variable = CGA$van_index, group = "Non-Multimorbidity", name = "van_index"),
  data.frame(variable = TGA$van_index, group = "Multimorbidity", name = "van_index")
)

# Calculate proportions for each variable and group
combined_data_summary = combined_data %>%
  group_by(name, group, variable) %>%
  summarise(count = n(), .groups = "drop") %>% # Avoid grouped output
  mutate(proportion = count / sum(count))

# Define colors for groups
group_colors = c("Non-Multimorbidity" = "orange", "Multimorbidity" = "blue")

# Plot for comorbidity_score
plot_comorbidity = ggplot(
  filter(combined_data_summary, name == "comorbidity_score"),
  aes(x = as.factor(variable), y = proportion, fill = group)
) +
  geom_bar(stat = "identity", position = "dodge") +
  scale_fill_manual(values = group_colors) + # Apply custom colors
  labs(
    title = "Proportion of Comorbidity Scores: Multimorbidity vs Non-Multimorbidity",
    x = "Comorbidity Score",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal()

# Plot for van_index
plot_van_index = ggplot(
  filter(combined_data_summary, name == "van_index"),
  aes(x = as.factor(variable), y = proportion, fill = group)
) +
  geom_bar(stat = "identity", position = "dodge") +
  scale_fill_manual(values = group_colors) + # Apply custom colors
  labs(
    title = "Proportion of Van Indexes: Multimorbidity vs Non-Multimorbidity",
    x = "Van Index",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal()

# Save the plots
ggsave("comorbidity_score_proportion_plot.png", plot = plot_comorbidity, width = 8, height = 6, dpi = 300)
ggsave("van_index_proportion_plot.png", plot = plot_van_index, width = 8, height = 6, dpi = 300)
```

##Part 2: Correlations
###HeatMap of Chronic Disease Proportions between Groups
```{r}
#Get list of header names
colnames(TGA)

#Generate Heat Map
# Load necessary libraries
library(ggplot2)
library(reshape2)  # For reshaping data

# List of variables representing chronic diseases
disease_variables = c("AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF",
                      "CHRNLUNG", "COAG", "DEPRESS", "DM", "DMCX", "DRUG",
                      "HTN_C", "HYPOTHY", "LIVER", "LYMPH", "LYTES", "METS",
                      "NEURO", "OBESE", "PARA", "PERIVASC", "PSYCH", "PULMCIRC",
                      "RENLFAIL", "TUMOR", "ULCER", "VALVE", "WGHTLOSS", "ARRHYTH")

# Calculate proportions for Multimorbidity (TGA) and Non-Multimorbidity (CGA)
TGA_proportions = sapply(disease_variables, function(var) mean(TGA[[var]], na.rm = TRUE))
CGA_proportions = sapply(disease_variables, function(var) mean(CGA[[var]], na.rm = TRUE))

# Combine data into a single dataframe
combined_data = data.frame(
  disease = disease_variables,
  Multimorbidity = TGA_proportions,
  Non_Multimorbidity = CGA_proportions
)

# Reshape data for heatmap (long format)
melted_data = melt(combined_data, id.vars = "disease", variable.name = "group", value.name = "proportion")

# Create heatmap
heatmap_plot = ggplot(melted_data, aes(x = group, y = disease, fill = proportion)) +
  geom_tile(color = "white") +  # Create tiles
  geom_text(aes(label = round(proportion, 2)), size = 3.5) +  # Add proportion numbers
  scale_fill_gradient(low = "blue", high = "orange", name = "Proportion") +  # Gradient fill
  labs(
    title = "Comparison of Chronic Disease Proportions: Multimorbidity vs Non-Multimorbidity",
    x = "Group",
    y = "Chronic Diseases"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text.x = element_text(size = 10),
    axis.text.y = element_text(size = 10),
    legend.title = element_text(size = 10),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 10)
  )

# Save the heatmap
ggsave("chronic_disease_heatmap.png", plot = heatmap_plot, width = 10, height = 8, dpi = 300)
```

###Histograms of the Combinations of Chronic Diseases Present in Multimorbidity Group by type (dyad, etc.)
```{r}
# Load necessary libraries
library(ggplot2)

# Step 1: Categorize combinations by the number of "+" signs
TGA$combination_type = sapply(gregexpr("\\+", TGA$combinations), function(x) length(x[x > 0]))

# Label the categories
TGA$combination_category = ifelse(
  TGA$combination_type == 1, "Dyad",
  ifelse(TGA$combination_type == 2, "Triad",
  ifelse(TGA$combination_type == 3, "Tetrad",
  ifelse(TGA$combination_type >= 4, "Pentad", "Other")))
)

# Step 2: Calculate proportions for each unique combination
combination_counts = aggregate(TGA$combinations, by = list(category = TGA$combination_category, combination = TGA$combinations), FUN = length)
colnames(combination_counts) = c("combination_category", "combinations", "count")
combination_counts$proportion = combination_counts$count / nrow(TGA)

# Step 3: Split data into categories
dyads = subset(combination_counts, combination_category == "Dyad")
triads = subset(combination_counts, combination_category == "Triad")
tetrads_and_pentads = subset(combination_counts, combination_category %in% c("Tetrad", "Pentad"))

# Sort and select top 20 for dyads and triads
dyads_top20 = head(dyads[order(-dyads$proportion), ], 20)
triads_top20 = head(triads[order(-triads$proportion), ], 20)

# Combine tetrads and pentads
tetrads_and_pentads_top = tetrads_and_pentads[order(-tetrads_and_pentads$proportion), ]

# Step 4: Create Histograms

# Histogram for Dyads
histogram_dyads = ggplot(dyads_top20, aes(x = reorder(combinations, -proportion), y = proportion, fill = "Dyad")) +
  geom_bar(stat = "identity", color = "black", alpha = 0.8) +
  geom_text(aes(label = round(proportion, 2)), vjust = -0.5, size = 3.5) +  # Add values above bars
  scale_fill_manual(values = c("Dyad" = "blue")) +
  labs(
    title = "Top 20 Chronic Disease Dyad Combinations in Patients with Multimorbidity",
    x = "Chronic Disease Combinations",
    y = "Proportion"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text.x = element_text(size = 10, angle = 45, hjust = 1),
    legend.position = "none"
  )

# Histogram for Triads
histogram_triads = ggplot(triads_top20, aes(x = reorder(combinations, -proportion), y = proportion, fill = "Triad")) +
  geom_bar(stat = "identity", color = "black", alpha = 0.8) +
  geom_text(aes(label = round(proportion, 2)), vjust = -0.5, size = 3.5) +  # Add values above bars
  scale_fill_manual(values = c("Triad" = "orange")) +
  labs(
    title = "Top 20 Chronic Disease Triad Combinations in Patients with Multimorbidity",
    x = "Chronic Disease Combinations",
    y = "Proportion"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text.x = element_text(size = 10, angle = 45, hjust = 1),
    legend.position = "none"
  )

# Histogram for Tetrads and Pentads
histogram_tetrads_pentads = ggplot(tetrads_and_pentads_top, aes(x = reorder(combinations, -proportion), y = proportion, fill = combination_category)) +
  geom_bar(stat = "identity", color = "black", alpha = 0.8) +
  geom_text(aes(label = round(proportion, 2)), vjust = -0.5, size = 3.5) +  # Add values above bars
  scale_fill_manual(values = c("Tetrad" = "blue", "Pentad" = "orange")) +
  labs(
    title = "Top 20 Chronic Disease Tetrad & Pentad Combinations in Patients with Multimorbidity",
    x = "Chronic Disease Combinations",
    y = "Proportion"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text.x = element_text(size = 10, angle = 45, hjust = 1),
    legend.title = element_text(size = 10)
  )

# Save the histograms
ggsave("top20_dyads_histogram.png", plot = histogram_dyads, width = 14, height = 8, dpi = 300)
ggsave("top20_triads_histogram.png", plot = histogram_triads, width = 14, height = 8, dpi = 300)
ggsave("tetrads_pentads_histogram.png", plot = histogram_tetrads_pentads, width = 14, height = 8, dpi = 300)
```

###HOT! Correlations based on sub groups (POSSIBLE IMPLICATIONS FOR FUTURE ENDOMETRIOSIS RESEARCH!)
```{r}
# R Script for Chronic Disease Combination Analysis by Demographic Subgroups
# Focuses on combination frequencies across demographic groups

library(ggplot2)
library(reshape2)
library(dplyr)
library(gridExtra)

# List of demographic subgroup columns
demographic_vars = c("eth_corrected", "race_corrected", "age_groups", "sex")

# Function to generate heatmap showing combination frequencies by subgroup
generate_subgroup_heatmap = function(data, category, subgroup_var,
                                    top_n = 20, title = NULL) {

  # Default title if none provided
  if (is.null(title)) {
    title = paste0("Top ", top_n, " ", category, " combinations by ", subgroup_var)
  }

  # Filter data for the specific combination category
  subset_data = data[data$combo_cat == category, ]

  # Skip if no data for this category
  if (nrow(subset_data) == 0) {
    message(paste("No data for", category))
    return(NULL)
  }

  # Group by subgroup and combination, then count occurrences
  combo_counts = subset_data %>%
    group_by(!!sym(subgroup_var), combinations) %>%
    summarise(count = n(), .groups = "drop") %>%
    arrange(!!sym(subgroup_var), desc(count))

  # Get top N combinations overall to ensure consistent combinations across subgroups
  overall_top_combos = subset_data %>%
    group_by(combinations) %>%
    summarise(count = n(), .groups = "drop") %>%
    arrange(desc(count)) %>%
    slice_head(n = top_n) %>%
    pull(combinations)

  # Filter to only include top combinations
  combo_counts = combo_counts %>%
    filter(combinations %in% overall_top_combos)

  # Create a complete grid of all subgroup-combination pairs
  subgroups = unique(subset_data[[subgroup_var]])

  # Create empty matrix for the heatmap
  result_matrix = matrix(0,
                        nrow = length(subgroups),
                        ncol = length(overall_top_combos))
  rownames(result_matrix) = subgroups
  colnames(result_matrix) = overall_top_combos

  # Fill matrix with counts
  for (i in 1:nrow(combo_counts)) {
    row = combo_counts[i, ]
    sg = row[[subgroup_var]]
    combo = row$combinations
    count = row$count
    result_matrix[sg, combo] = count
  }

  # Calculate proportions within each subgroup
  # First get total counts per subgroup
  subgroup_totals = subset_data %>%
    group_by(!!sym(subgroup_var)) %>%
    summarise(total = n(), .groups = "drop")

  # Create proportions matrix
  prop_matrix = result_matrix
  for (sg in subgroups) {
    total = subgroup_totals$total[subgroup_totals[[subgroup_var]] == sg]
    if (total > 0) {
      prop_matrix[sg, ] = result_matrix[sg, ] / total
    }
  }

  # Melt matrix for ggplot
  melted_data = melt(prop_matrix)
  names(melted_data) = c(subgroup_var, "Combination", "Proportion")

  # Create heatmap
  p = ggplot(melted_data, aes_string(x = "Combination", y = subgroup_var, fill = "Proportion")) +
    geom_tile(color = "white", linewidth = 0.2) +
    scale_fill_gradient(low = "white", high = "darkred",
                       name = "Proportion",
                       limits = c(0, max(melted_data$Proportion)),
                       guide = guide_colorbar(title.position = "top")) +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1, size = 8),
          axis.text.y = element_text(size = 10),
          axis.title = element_blank(),
          panel.grid = element_blank(),
          panel.border = element_blank(),
          panel.background = element_blank(),
          plot.title = element_text(hjust = 0.5, size = 12, face = "bold"),
          plot.margin = margin(10, 10, 10, 10)) +
    ggtitle(title)

  return(p)
}

# Function to generate frequency distribution of combinations by subgroup
generate_frequency_barplot = function(data, category, subgroup_var, top_n = 10,
                                     title = NULL) {

  # Default title if none provided
  if (is.null(title)) {
    title = paste0("Frequency of top ", top_n, " ", category, " combinations by ", subgroup_var)
  }

  # Filter data for the specific combination category
  subset_data = data[data$combo_cat == category, ]

  # Get top N combinations overall
  overall_top_combos = subset_data %>%
    group_by(combinations) %>%
    summarise(count = n(), .groups = "drop") %>%
    arrange(desc(count)) %>%
    slice_head(n = top_n) %>%
    pull(combinations)

  # Filter to top combinations and count by subgroup
  plot_data = subset_data %>%
    filter(combinations %in% overall_top_combos) %>%
    group_by(!!sym(subgroup_var), combinations) %>%
    summarise(count = n(), .groups = "drop")

  # Convert to percentages within each subgroup
  plot_data = plot_data %>%
    group_by(!!sym(subgroup_var)) %>%
    mutate(percentage = count / sum(count) * 100) %>%
    ungroup()

  # Create barplot
  p = ggplot(plot_data, aes_string(x = "combinations", y = "percentage", fill = subgroup_var)) +
    geom_bar(stat = "identity", position = "dodge") +
    scale_fill_brewer(palette = "Set1") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1, size = 8),
          legend.title = element_blank(),
          plot.title = element_text(hjust = 0.5, size = 12, face = "bold")) +
    labs(title = title,
         x = "Disease Combination",
         y = "Percentage within Subgroup")

  return(p)
}

# Function to compare a specific combination across all demographic variables
compare_combo_across_demographics = function(data, combination, title = NULL) {

  # Default title
  if (is.null(title)) {
    title = paste0("Distribution of '", combination, "' across demographic subgroups")
  }

  # Filter data for the specific combination
  subset_data = data[data$combinations == combination, ]

  # Create empty list to store plots
  plots = list()

  # Create a plot for each demographic variable
  for (demo_var in demographic_vars) {
    # Count occurrences by subgroup
    counts = subset_data %>%
      group_by(!!sym(demo_var)) %>%
      summarise(count = n(), .groups = "drop")

    # Calculate percentages
    counts$percentage = counts$count / sum(counts$count) * 100

    # Create plot
    p = ggplot(counts, aes_string(x = demo_var, y = "percentage", fill = demo_var)) +
      geom_bar(stat = "identity") +
      scale_fill_brewer(palette = "Set2") +
      theme_minimal() +
      theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust = 1),
            legend.position = "none",
            plot.title = element_text(size = 10)) +
      labs(title = paste0("By ", demo_var),
           y = "Percentage",
           x = NULL)

    plots[[demo_var]] = p
  }

  # Arrange plots in a grid
  combined_plot = do.call(grid.arrange, c(plots, top = title, ncol = 2))

  return(combined_plot)
}

# Main function to analyze disease combinations by demographic subgroups
analyze_disease_combinations_by_demographics = function(df, output_dir = NULL) {

  # Create output directory if specified and doesn't exist
  if (!is.null(output_dir)) {
    if (!dir.exists(output_dir)) {
      dir.create(output_dir, recursive = TRUE)
    }
  }

  # Make sure combinations column is character type
  df$combinations = as.character(df$combinations)
  df$combo_cat = as.character(df$combo_cat)

  # Ensure demographic variables are factors
  for (var in demographic_vars) {
    if (var %in% names(df)) {
      df[[var]] = factor(df[[var]])
    } else {
      warning(paste("Variable", var, "not found in the dataset"))
    }
  }

  # Results storage
  all_plots = list()

  # Generate heatmaps for each demographic variable and combination category
  for (category in c("dyad", "triad", "tetrad", "pentad")) {
    for (demo_var in demographic_vars) {
      # Skip if demographic variable not in dataframe
      if (!(demo_var %in% names(df))) next

      # Generate heatmap
      plot_title = paste0("Top ", category, " combinations by ", demo_var)
      p = generate_subgroup_heatmap(df, category, demo_var, top_n = 15, title = plot_title)

      # Save plot if not NULL
      if (!is.null(p)) {
        plot_name = paste0(category, "_by_", demo_var)
        all_plots[[plot_name]] = p

        # Save to file if output directory specified
        if (!is.null(output_dir)) {
          ggsave(file.path(output_dir, paste0(plot_name, ".png")),
                p, width = 12, height = 8, dpi = 300)
        }
      }

      # Generate frequency barplot
      p_bar = generate_frequency_barplot(df, category, demo_var, top_n = 5)

      # Save plot if not NULL
      if (!is.null(p_bar)) {
        plot_name = paste0(category, "_freq_by_", demo_var)
        all_plots[[plot_name]] = p_bar

        # Save to file if output directory specified
        if (!is.null(output_dir)) {
          ggsave(file.path(output_dir, paste0(plot_name, ".png")),
                p_bar, width = 12, height = 8, dpi = 300)
        }
      }
    }
  }

  # Find top 5 most common combinations overall to analyze across demographics
  top_combos = df %>%
    group_by(combinations) %>%
    summarise(count = n(), .groups = "drop") %>%
    arrange(desc(count)) %>%
    slice_head(n = 5) %>%
    pull(combinations)

  # Compare top combinations across all demographic variables
  for (combo in top_combos) {
    p_demo = compare_combo_across_demographics(df, combo)

    # Save plot
    plot_name = paste0("combo_", gsub("\\+", "_", combo), "_demographics")
    all_plots[[plot_name]] = p_demo

    # Save to file if output directory specified
    if (!is.null(output_dir)) {
      ggsave(file.path(output_dir, paste0(plot_name, ".png")),
            p_demo, width = 10, height = 8, dpi = 300)
    }
  }

  # Return all plots
  return(all_plots)
}

plots = analyze_disease_combinations_by_demographics(TGA, output_dir = "output")
```

###Correlations based on sub groups
```{r}
# R Script for Chronic Disease Combination Analysis by Demographic Subgroups
# Analyzes combinations by "eth_corrected", "race_corrected", "age_group", "sex"

library(ggplot2)
library(reshape2)

# List of demographic subgroups
demo_vars = c("eth_corrected", "race_corrected", "age_groups", "sex")

# Correct Age Group Order
cor_age_order = c("18-29", "30-39", "40-49", "50-59", "60-69", "70-79", "80-89")
# Convert Age Group to a factor so that it is ordered correctly
TGA$age_groups = factor(TGA$age_groups, levels = cor_age_order)

# Function to generate heatmap showing frequency of combinations by subgroup
generate_subgroup_heatmap = function(data, category, subgroup_var,
                                     top_n = NULL, title = "Disease Combinations Heatmap") {

  # Filter data for the specific combination category
  subset_data = data[data$combo_cat == category, ]

  # Calculate frequency of each combination
  combo_counts = table(subset_data$combinations)

  # If top_n is specified, get only top combinations
  if (!is.null(top_n) && length(combo_counts) > top_n) {
    # Get top N most frequent combinations
    top_combos = names(sort(combo_counts, decreasing = TRUE)[1:min(top_n, length(combo_counts))])
    # Filter data to keep only rows with top combinations
    subset_data = subset_data[subset_data$combinations %in% top_combos, ]
  }

  # Get unique combinations after potential filtering
  unique_combos = unique(subset_data$combinations)

  # Get unique subgroup values
  subgroup_values = unique(subset_data[[subgroup_var]])

  # Create matrix to store counts for each combination by subgroup
  combo_matrix = matrix(0, nrow = length(subgroup_values), ncol = length(unique_combos))
  rownames(combo_matrix) = subgroup_values
  colnames(combo_matrix) = unique_combos

  # Count combinations for each subgroup
  for (i in 1:length(subgroup_values)) {
    sg = subgroup_values[i]
    sg_data = subset_data[subset_data[[subgroup_var]] == sg, ]

    for (j in 1:length(unique_combos)) {
      combo = unique_combos[j]
      combo_matrix[i, j] = sum(sg_data$combinations == combo)
    }
  }

  # Initialize prop_matrix as a copy of combo_matrix
  prop_matrix = combo_matrix

  # Calculate proportions based on total counts for each combination
  combo_total_counts = colSums(combo_matrix)  # Get total counts for each combination
  for (i in 1:nrow(combo_matrix)) {
    for (j in 1:ncol(combo_matrix)) {
      if (combo_total_counts[j] > 0) {
        prop_matrix[i, j] = combo_matrix[i, j] / combo_total_counts[j]  # Proportion by total
      }
    }
  }

  # Melt matrix for ggplot
  melted_data = melt(prop_matrix)
  names(melted_data) = c("Subgroup", "Combination", "Proportion")

  # Ensure melted_data preserves the age group factor levels
if (subgroup_var == "age_groups") {
  melted_data$Subgroup = factor(melted_data$Subgroup, levels = cor_age_order)
}

  # Create heatmap with values displayed
  p = ggplot(melted_data, aes(x = Combination, y = Subgroup, fill = Proportion)) +
    geom_tile(color = "black", linewidth = 0.1) +
    # Add text labels with proportion values
    geom_text(aes(label = sprintf("%.2f", Proportion)),
              color = ifelse(melted_data$Proportion > 0.5, "white", "black"),
              size = 3.0) +
    scale_fill_gradient(low = "blue", high = "orange",
                        name = "Proportion",
                        guide = guide_colorbar(title.position = "top")) +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1, size = 10),
          axis.text.y = element_text(size = 10),
          axis.title = element_blank(),
          panel.grid = element_blank(),
          panel.border = element_blank(),
          panel.background = element_blank(),
          plot.title = element_text(hjust = 0.5, size = 12, face = "bold"),
          plot.margin = margin(10, 10, 10, 10)) +
    ggtitle(title)

  return(p)
}

# Function to combine tetrads and pentads in one heatmap by subgroup
generate_combined_subgroup_heatmap = function(data, subgroup_var,
                                              title = "Tetrads and Pentads by Subgroup") {

  # Filter data for tetrads and pentads
  subset_data = data[data$combo_cat %in% c("tetrad", "pentad"), ]

  # Calculate frequency of each combination
  combo_counts = table(subset_data$combinations)

  # Get top 20 combinations overall
  if (length(combo_counts) > 20) {
    top_combos = names(sort(combo_counts, decreasing = TRUE)[1:20])
    # Filter data to keep only rows with top combinations
    subset_data = subset_data[subset_data$combinations %in% top_combos, ]
  }

  # Get unique combinations after filtering
  unique_combos = unique(subset_data$combinations)

  # Get unique subgroup values
  subgroup_values = unique(subset_data[[subgroup_var]])

  # Create matrix to store counts for each combination by subgroup
  combo_matrix = matrix(0, nrow = length(subgroup_values), ncol = length(unique_combos))
  rownames(combo_matrix) = subgroup_values
  colnames(combo_matrix) = unique_combos

  # Count combinations for each subgroup
  for (i in 1:length(subgroup_values)) {
    sg = subgroup_values[i]
    sg_data = subset_data[subset_data[[subgroup_var]] == sg, ]

    for (j in 1:length(unique_combos)) {
      combo = unique_combos[j]
      combo_matrix[i, j] = sum(sg_data$combinations == combo)
    }
  }

  # Initialize prop_matrix as a copy of combo_matrix
  prop_matrix = combo_matrix

  # Calculate proportions based on total counts for each combination
  combo_total_counts = colSums(combo_matrix)  # Get total counts for each combination
  for (i in 1:nrow(combo_matrix)) {
    for (j in 1:ncol(combo_matrix)) {
      if (combo_total_counts[j] > 0) {
        prop_matrix[i, j] = combo_matrix[i, j] / combo_total_counts[j]  # Proportion by total
      }
    }
  }

  # Melt matrix for ggplot
  melted_data = melt(prop_matrix)
  names(melted_data) = c("Subgroup", "Combination", "Proportion")

  # Ensure melted_data preserves the age group factor levels
if (subgroup_var == "age_groups") {
  melted_data$Subgroup = factor(melted_data$Subgroup, levels = cor_age_order)
}

  # Create heatmap with values displayed
  p = ggplot(melted_data, aes(x = Combination, y = Subgroup, fill = Proportion)) +
    geom_tile(color = "black", linewidth = 0.1) +
    # Add text labels with proportion values
    geom_text(aes(label = sprintf("%.2f", Proportion)),
              color = ifelse(melted_data$Proportion > 0.5, "white", "black"),
              size = 2.5) +
    scale_fill_gradient(low = "blue", high = "orange",
                        name = "Proportion",
                        guide = guide_colorbar(title.position = "top")) +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1, size = 8),
          axis.text.y = element_text(size = 10),
          axis.title = element_blank(),
          panel.grid = element_blank(),
          panel.border = element_blank(),
          panel.background = element_blank(),
          plot.title = element_text(hjust = 0.5, size = 12, face = "bold"),
          plot.margin = margin(10, 10, 10, 10)) +
    ggtitle(title)

  return(p)
}

# Main function - assign directly to your dataframe
generate_disease_subgroup_heatmaps = function(df,
                                              dyad_title = "Top 20 Disease Dyads",
                                              triad_title = "Top 20 Disease Triads",
                                              combined_title = "Disease Tetrads and Pentads") {

  # Make sure combinations column is character type
  df$combinations = as.character(df$combinations)
  df$combo_cat = as.character(df$combo_cat)

  # Check that demographic variables exist in the dataframe
  available_demo_vars = demo_vars[demo_vars %in% names(df)]

  if (length(available_demo_vars) == 0) {
    stop("None of the demographic variables found in the dataframe")
  }

  # Store all plots in a list
  all_plots = list()

  # Generate heatmaps for each demographic variable
  for (demo_var in available_demo_vars) {

    # Dyad heatmap by subgroup
    dyad_title_sg = paste0(dyad_title, " by ", demo_var)
    dyad_plot = generate_subgroup_heatmap(df, "dyad", demo_var, top_n = 20, title = dyad_title_sg)
    ggsave(paste0("dyad_heatmap_by_", demo_var, ".png"), dyad_plot, width = 14, height = 8, dpi = 300)
    all_plots[[paste0("dyad_", demo_var)]] = dyad_plot

    # Triad heatmap by subgroup
    triad_title_sg = paste0(triad_title, " by ", demo_var)
    triad_plot = generate_subgroup_heatmap(df, "triad", demo_var, top_n = 20, title = triad_title_sg)
    ggsave(paste0("triad_heatmap_by_", demo_var, ".png"), triad_plot, width = 14, height = 8, dpi = 300)
    all_plots[[paste0("triad_", demo_var)]] = triad_plot

    # Combined tetrad and pentad heatmap by subgroup
    combined_title_sg = paste0(combined_title, " by ", demo_var)
    combined_plot = generate_combined_subgroup_heatmap(df, demo_var, title = combined_title_sg)
    ggsave(paste0("tetrad_pentad_heatmap_by_", demo_var, ".png"), combined_plot, width = 16, height = 8, dpi = 300)
    all_plots[[paste0("combined_", demo_var)]] = combined_plot
  }

  return(all_plots)
}

# Use this line with your dataframe
plots = generate_disease_subgroup_heatmaps(TGA)
```


###These scripts are retained for future applications (if modified), but omitted because of lenghtiness and complexity of code for now.Some were ran as shorter chunks upstream of this section.
####OMIT: Descriptive Stats Boxplot Visualizations
```{r}
library(dplyr)
library(ggplot2)
library(tidyr)
library(stringr)

cat("Starting multimorbidity analysis...\n")

# First, add a check to ensure data is available
if(!exists("TGA_grouped") || !exists("CGA_grouped")) {
  stop("Error: TGA_grouped and/or CGA_grouped not found. Please load the datasets first.")
}

# Add standardized group identifiers
TGA_grouped$Group = "Multimorbidity"
CGA_grouped$Group = "Non-Multimorbidity"

# Check ethnicity variables
cat("Checking ethnicity variables...\n")
if("ethnicity" %in% names(TGA_grouped) && "ethnicity" %in% names(CGA_grouped)) {
  print(table(TGA_grouped$ethnicity, useNA = "ifany"))
  print(table(CGA_grouped$ethnicity, useNA = "ifany"))
} else {
  cat("Warning: ethnicity variable not found in one or both datasets\n")
}

# Find common columns between TGA_grouped and CGA_grouped for combined analyses
common_columns = intersect(names(TGA_grouped), names(CGA_grouped))
cat("Number of common columns:", length(common_columns), "\n")

# Create combined dataset with only common columns
TGA_common = TGA_grouped[, common_columns]
CGA_common = CGA_grouped[, common_columns]
combined_data = rbind(TGA_common, CGA_common)

# Save the combined dataset
write.csv(combined_data, "Combined_Multimorbidity_Data.csv", row.names = FALSE)
cat("Combined dataset saved to Combined_Multimorbidity_Data.csv\n")

# Reformat multimorbidity combination variables (dyads, triads, etc.) in TGA_grouped
# First, identify columns that contain periods (which indicate combinations)
combination_cols = names(TGA_grouped)[grepl("\\.", names(TGA_grouped))]
cat("Found", length(combination_cols), "combination columns\n")

# Create a copy of TGA_grouped for the transformed variable names
TGA_transformed = TGA_grouped

# Create empty vectors to store categorized columns
dyad_cols = c()
triad_cols = c()
tetrad_cols = c()
pentad_cols = c()

# Process the combination columns
for (col in combination_cols) {
  # Count the number of periods to determine the combination type
  num_periods = str_count(col, "\\.")

  # Add to appropriate category based on the type (dyad, triad, tetrad, pentad)
  if (num_periods == 1) {
    prefix = "dyad_"
    dyad_cols = c(dyad_cols, col)
  } else if (num_periods == 2) {
    prefix = "triad_"
    triad_cols = c(triad_cols, col)
  } else if (num_periods == 3) {
    prefix = "tetrad_"
    tetrad_cols = c(tetrad_cols, col)
  } else if (num_periods == 4) {
    prefix = "pentad_"
    pentad_cols = c(pentad_cols, col)
  } else {
    prefix = "combo_"  # Fallback for any other number
  }

  # Create a new column with the transformed name
  new_col_name = paste0(prefix, col)
  TGA_transformed[[new_col_name]] = TGA_grouped[[col]]
}

# Get prefixed column names for each category
dyads = paste0("dyad_", dyad_cols)
triads = paste0("triad_", triad_cols)
tetrads = paste0("tetrad_", tetrad_cols)
pentads = paste0("pentad_", pentad_cols)

cat("Number of dyads:", length(dyads), "\n")
cat("Number of triads:", length(triads), "\n")
cat("Number of tetrads:", length(tetrads), "\n")
cat("Number of pentads:", length(pentads), "\n")

# Create function to format display names by replacing dots with plus signs
format_display_name = function(col_name, prefix) {
  base_name = gsub(paste0("^", prefix, "_"), "", col_name)  # Remove prefix
  display_name = gsub("\\.", "+", base_name)  # Replace dots with plus
  return(display_name)
}

# Save the transformed dataset
write.csv(TGA_transformed, "TGA_transformed.csv", row.names = FALSE)
cat("Transformed dataset saved to TGA_transformed.csv\n")

# Set up device for displaying plots
pdf("all_plots.pdf")  # This will capture all plots in a PDF

# VISUALIZATIONS

cat("Creating visualizations...\n")

# 1. Age and Sex Distribution by Multimorbidity Status
# First check if sex variables exist
if("s_male" %in% names(combined_data) && "s_female" %in% names(combined_data) && "age_group" %in% names(combined_data)) {
  cat("Creating Age and Sex Distribution plot...\n")

  # Create sex variable
  combined_sex_data = combined_data
  combined_sex_data$Sex = "Unknown"
  combined_sex_data$Sex[combined_data$s_male == 1] = "Male"
  combined_sex_data$Sex[combined_data$s_female == 1] = "Female"
  combined_sex_data = subset(combined_sex_data, Sex != "Unknown")

  # Check if we have data after filtering
  if(nrow(combined_sex_data) > 0) {
    # Create plot
    plot1 = ggplot(combined_sex_data, aes(x = age_group, fill = Sex)) +
      geom_bar(position = "dodge") +
      facet_wrap(~ Group) +
      scale_fill_manual(values = c("Male" = "blue", "Female" = "orange")) +
      labs(title = "Age and Sex Distribution by Multimorbidity Status",
           x = "Age Group (High/Low)",
           y = "Count",
           fill = "Sex") +
      theme_minimal() +
      theme(axis.text.x = element_text(angle = 45, hjust = 1))

    print(plot1)  # Explicitly print the plot

    # Save separately
    png("age_sex_distribution.png", width = 800, height = 600)
    print(plot1)
    dev.off()

    cat("Age and Sex Distribution plot saved\n")
  } else {
    cat("Warning: No valid data for Age and Sex Distribution plot\n")
  }
} else {
  cat("Warning: Missing variables for Age and Sex Distribution plot\n")
}

# 2. Length of Stay Distribution by Multimorbidity Status
if("los_group" %in% names(combined_data)) {
  cat("Creating Length of Stay Distribution plot...\n")

  plot2 = ggplot(combined_data, aes(x = los_group, fill = Group)) +
    geom_bar(position = "dodge") +
    scale_fill_manual(values = c("Non-Multimorbidity" = "blue", "Multimorbidity" = "orange")) +
    labs(title = "Length of Stay Distribution by Multimorbidity Status",
         x = "Length of Stay (High/Low)",
         y = "Count",
         fill = "Group") +
    theme_minimal()

  print(plot2)  # Explicitly print the plot

  # Save separately
  png("los_distribution.png", width = 800, height = 600)
  print(plot2)
  dev.off()

  cat("Length of Stay Distribution plot saved\n")
} else {
  cat("Warning: Missing los_group variable for Length of Stay Distribution plot\n")
}

# 3. Total Charges Distribution by Multimorbidity Status
if("age_group" %in% names(combined_data) && "charge_group" %in% names(combined_data)) {
  cat("Creating Total Charges Distribution plot...\n")

  plot3 = ggplot(combined_data, aes(x = age_group, fill = charge_group)) +
    geom_bar(position = "stack") +
    facet_wrap(~ Group) +
    scale_fill_manual(values = c("Low" = "blue", "High" = "orange")) +
    labs(title = "Total Charges by Age Group and Multimorbidity Status",
         x = "Age Group (High/Low)",
         y = "Count",
         fill = "Charge Group") +
    theme_minimal()

  print(plot3)  # Explicitly print the plot

  # Save separately
  png("charges_distribution.png", width = 800, height = 600)
  print(plot3)
  dev.off()

  cat("Total Charges Distribution plot saved\n")
} else {
  cat("Warning: Missing variables for Total Charges Distribution plot\n")
}

# 4. Ethnicity Distribution by Multimorbidity Status
if("ethnicity" %in% names(combined_data)) {
  cat("Creating Ethnicity Distribution plot...\n")

  # Check if ethnicity has valid values
  if(length(unique(na.omit(combined_data$ethnicity))) > 0) {
    plot4 = ggplot(combined_data, aes(x = ethnicity, fill = Group)) +
      geom_bar(position = "dodge") +
      scale_fill_manual(values = c("Non-Multimorbidity" = "blue", "Multimorbidity" = "orange")) +
      labs(title = "Ethnicity Distribution by Multimorbidity Status",
           x = "Ethnicity",
           y = "Count",
           fill = "Group") +
      theme_minimal() +
      theme(axis.text.x = element_text(angle = 45, hjust = 1))

    print(plot4)  # Explicitly print the plot

    # Save separately
    png("ethnicity_distribution.png", width = 800, height = 600)
    print(plot4)
    dev.off()

    cat("Ethnicity Distribution plot saved\n")
  } else {
    cat("Warning: Ethnicity variable has no valid values\n")
  }
} else {
  cat("Warning: Missing ethnicity variable for Ethnicity Distribution plot\n")
}

# Analysis of multimorbidity combinations (dyads, triads, etc.)
# These are only available in the TGA_grouped dataset

# 5. Top Dyad Combinations
if(length(dyads) > 0) {
  cat("Analyzing dyad combinations...\n")

  dyad_summary = data.frame(Dyad = character(0), Frequency = numeric(0), Display = character(0))

  for (col in dyads) {
    # Count frequency of this dyad
    frequency = sum(TGA_transformed[[col]] == 1, na.rm = TRUE)

    # Create a display name
    display = format_display_name(col, "dyad")

    # Add to summary dataframe
    dyad_summary = rbind(dyad_summary,
                         data.frame(Dyad = col,
                                   Frequency = frequency,
                                   Display = display))
  }

  # Sort by frequency and get top 10
  dyad_summary = dyad_summary[order(-dyad_summary$Frequency), ]
  top_dyads = head(dyad_summary, 10)

  if(nrow(top_dyads) > 0) {
    # Plot top dyads
    plot5 = ggplot(top_dyads, aes(x = reorder(Display, Frequency), y = Frequency)) +
      geom_bar(stat = "identity", fill = "blue") +
      coord_flip() +
      labs(title = "Top 10 Dyad Combinations in Multimorbidity Patients",
           subtitle = "Two chronic conditions occurring together",
           x = "Dyad Combination",
           y = "Frequency") +
      theme_minimal()

    print(plot5)  # Explicitly print the plot

    # Save separately
    png("top_dyads.png", width = 800, height = 800)
    print(plot5)
    dev.off()

    cat("Top Dyads plot saved\n")
  } else {
    cat("Warning: No valid dyad combinations found\n")
  }
} else {
  cat("Warning: No dyad columns found\n")
}

# 6. Top Triad Combinations
if(length(triads) > 0) {
  cat("Analyzing triad combinations...\n")

  triad_summary = data.frame(Triad = character(0), Frequency = numeric(0), Display = character(0))

  for (col in triads) {
    # Count frequency of this triad
    frequency = sum(TGA_transformed[[col]] == 1, na.rm = TRUE)

    # Create a display name
    display = format_display_name(col, "triad")

    # Add to summary dataframe
    triad_summary = rbind(triad_summary,
                         data.frame(Triad = col,
                                   Frequency = frequency,
                                   Display = display))
  }

  # Sort by frequency and get top 10
  triad_summary = triad_summary[order(-triad_summary$Frequency), ]
  top_triads = head(triad_summary, 10)

  if(nrow(top_triads) > 0) {
    # Plot top triads
    plot6 = ggplot(top_triads, aes(x = reorder(Display, Frequency), y = Frequency)) +
      geom_bar(stat = "identity", fill = "orange") +
      coord_flip() +
      labs(title = "Top 10 Triad Combinations in Multimorbidity Patients",
           subtitle = "Three chronic conditions occurring together",
           x = "Triad Combination",
           y = "Frequency") +
      theme_minimal()

    print(plot6)  # Explicitly print the plot

    # Save separately
    png("top_triads.png", width = 800, height = 800)
    print(plot6)
    dev.off()

    cat("Top Triads plot saved\n")
  } else {
    cat("Warning: No valid triad combinations found\n")
  }
} else {
  cat("Warning: No triad columns found\n")
}

# 7. All Tetrad Combinations
if(length(tetrads) > 0) {
  cat("Analyzing tetrad combinations...\n")

  tetrad_summary = data.frame(Tetrad = character(0), Frequency = numeric(0), Display = character(0))

  for (col in tetrads) {
    # Count frequency of this tetrad
    frequency = sum(TGA_transformed[[col]] == 1, na.rm = TRUE)

    # Create a display name
    display = format_display_name(col, "tetrad")

    # Add to summary dataframe
    tetrad_summary = rbind(tetrad_summary,
                          data.frame(Tetrad = col,
                                    Frequency = frequency,
                                    Display = display))
  }

  # Sort by frequency
  tetrad_summary = tetrad_summary[order(-tetrad_summary$Frequency), ]

  if(nrow(tetrad_summary) > 0) {
    # Plot all tetrads
    plot7 = ggplot(tetrad_summary, aes(x = reorder(Display, Frequency), y = Frequency)) +
      geom_bar(stat = "identity", fill = "#56B4E9") +
      coord_flip() +
      labs(title = "All Tetrad Combinations in Multimorbidity Patients",
           subtitle = "Four chronic conditions occurring together",
           x = "Tetrad Combination",
           y = "Frequency") +
      theme_minimal()

    print(plot7)  # Explicitly print the plot

    # Save separately
    png("tetrad_combinations.png", width = 1000, height = 800)
    print(plot7)
    dev.off()

    cat("Tetrad Combinations plot saved\n")
  } else {
    cat("Warning: No valid tetrad combinations found\n")
  }
} else {
  cat("Warning: No tetrad columns found\n")
}

# 8. Pentad Combinations (likely just one)
if(length(pentads) > 0) {
  cat("Analyzing pentad combinations...\n")

  pentad_summary = data.frame(Pentad = character(0), Frequency = numeric(0), Display = character(0))

  for (col in pentads) {
    # Count frequency of this pentad
    frequency = sum(TGA_transformed[[col]] == 1, na.rm = TRUE)

    # Create a display name
    display = format_display_name(col, "pentad")

    # Add to summary dataframe
    pentad_summary = rbind(pentad_summary,
                          data.frame(Pentad = col,
                                    Frequency = frequency,
                                    Display = display))
  }

  # Sort by frequency
  pentad_summary = pentad_summary[order(-pentad_summary$Frequency), ]

  if(nrow(pentad_summary) > 0) {
    # Plot pentads
    plot8 = ggplot(pentad_summary, aes(x = reorder(Display, Frequency), y = Frequency)) +
      geom_bar(stat = "identity", fill = "#009E73") +
      coord_flip() +
      labs(title = "Pentad Combinations in Multimorbidity Patients",
           subtitle = "Five chronic conditions occurring together",
           x = "Pentad Combination",
           y = "Frequency") +
      theme_minimal()

    print(plot8)  # Explicitly print the plot

    # Save separately
    png("pentad_combinations.png", width = 1000, height = 800)
    print(plot8)
    dev.off()

    cat("Pentad Combinations plot saved\n")
  } else {
    cat("Warning: No valid pentad combinations found\n")
  }
} else {
  cat("Warning: No pentad columns found\n")
}

# 9. Tetrad and Pentad Combinations Combined
if(length(tetrads) > 0 || length(pentads) > 0) {
  cat("Creating combined tetrad and pentad visualization...\n")

  # Combine tetrad and pentad data
  higher_summary = data.frame(Combination = character(0), Type = character(0),
                             Frequency = numeric(0), Display = character(0))

  if(length(tetrads) > 0 && nrow(tetrad_summary) > 0) {
    tetrad_data = data.frame(
      Combination = tetrad_summary$Tetrad,
      Type = rep("Tetrad", nrow(tetrad_summary)),
      Frequency = tetrad_summary$Frequency,
      Display = tetrad_summary$Display
    )
    higher_summary = rbind(higher_summary, tetrad_data)
  }

  if(length(pentads) > 0 && nrow(pentad_summary) > 0) {
    pentad_data = data.frame(
      Combination = pentad_summary$Pentad,
      Type = rep("Pentad", nrow(pentad_summary)),
      Frequency = pentad_summary$Frequency,
      Display = pentad_summary$Display
    )
    higher_summary = rbind(higher_summary, pentad_data)
  }

  if(nrow(higher_summary) > 0) {
    # Plot combined higher-order combinations
    plot9 = ggplot(higher_summary, aes(x = reorder(Display, Frequency), y = Frequency, fill = Type)) +
      geom_bar(stat = "identity") +
      coord_flip() +
      scale_fill_manual(values = c("Tetrad" = "#56B4E9", "Pentad" = "#009E73")) +
      labs(title = "Tetrad and Pentad Combinations in Multimorbidity Patients",
           subtitle = "Four and five chronic conditions occurring together",
           x = "Combination",
           y = "Frequency",
           fill = "Type") +
      theme_minimal()

    print(plot9)  # Explicitly print the plot

    # Save separately
    png("tetrad_pentad_combinations.png", width = 1000, height = 800)
    print(plot9)
    dev.off()

    cat("Tetrad and Pentad Combinations plot saved\n")
  } else {
    cat("Warning: No valid higher-order combinations found\n")
  }
}

# 10. Dyads by Sex Distribution
if(length(dyads) > 0 && nrow(dyad_summary) > 0) {
  cat("Analyzing dyads by sex...\n")

  # Check if sex variables exist
  if("s_male" %in% names(TGA_transformed) && "s_female" %in% names(TGA_transformed)) {
    # Create sex variable
    TGA_sex_data = TGA_transformed
    TGA_sex_data$Sex = "Unknown"
    TGA_sex_data$Sex[TGA_transformed$s_male == 1] = "Male"
    TGA_sex_data$Sex[TGA_transformed$s_female == 1] = "Female"
    TGA_sex_data = subset(TGA_sex_data, Sex != "Unknown")

    # Get top 5 most frequent dyads
    top5_dyads = head(dyad_summary$Dyad, 5)
    top5_displays = head(dyad_summary$Display, 5)

    if(length(top5_dyads) > 0) {
      # Create a data frame for plotting dyads by sex
      dyad_sex_data = data.frame(Sex = character(0), Dyad = character(0),
                               Frequency = numeric(0), Display = character(0))

      for (i in 1:length(top5_dyads)) {
        dyad = top5_dyads[i]
        display = top5_displays[i]

        # Count for males
        male_count = sum(TGA_sex_data[[dyad]][TGA_sex_data$Sex == "Male"] == 1, na.rm = TRUE)
        dyad_sex_data = rbind(dyad_sex_data,
                             data.frame(Sex = "Male", Dyad = dyad,
                                       Frequency = male_count, Display = display))

        # Count for females
        female_count = sum(TGA_sex_data[[dyad]][TGA_sex_data$Sex == "Female"] == 1, na.rm = TRUE)
        dyad_sex_data = rbind(dyad_sex_data,
                             data.frame(Sex = "Female", Dyad = dyad,
                                       Frequency = female_count, Display = display))
      }

      # Plot dyads by sex
      plot10 = ggplot(dyad_sex_data, aes(x = Display, y = Frequency, fill = Sex)) +
        geom_bar(stat = "identity", position = "dodge") +
        scale_fill_manual(values = c("Male" = "blue", "Female" = "orange")) +
        labs(title = "Top 5 Dyad Combinations by Sex",
             subtitle = "Distribution of most common chronic disease pairs between males and females",
             x = "Dyad Combination",
             y = "Frequency",
             fill = "Sex") +
        theme_minimal() +
        theme(axis.text.x = element_text(angle = 45, hjust = 1))

      print(plot10)  # Explicitly print the plot

      # Save separately
      png("dyads_by_sex.png", width = 800, height = 600)
      print(plot10)
      dev.off()

      cat("Dyads by Sex plot saved\n")
    } else {
      cat("Warning: No top dyads found for sex analysis\n")
    }
  } else {
    cat("Warning: Sex variables not found for dyad analysis\n")
  }
} else {
  cat("Warning: No dyad summary data available for sex analysis\n")
}

# 11. Dyads by Ethnicity
if(length(dyads) > 0 && nrow(dyad_summary) > 0 && "ethnicity" %in% names(TGA_transformed)) {
  cat("Analyzing dyads by ethnicity...\n")

  # Check if ethnicity has valid values
  if(length(unique(na.omit(TGA_transformed$ethnicity))) > 0) {
    # Get top 5 most frequent dyads
    top5_dyads = head(dyad_summary$Dyad, 5)
    top5_displays = head(dyad_summary$Display, 5)

    if(length(top5_dyads) > 0) {
      # Create a data frame for plotting dyads by ethnicity
      dyad_ethnicity_data = data.frame(Ethnicity = character(0), Dyad = character(0),
                                     Frequency = numeric(0), Display = character(0))

      for (i in 1:length(top5_dyads)) {
        dyad = top5_dyads[i]
        display = top5_displays[i]

        # Get unique ethnicity values
        ethnicities = unique(na.omit(TGA_transformed$ethnicity))

        for (eth in ethnicities) {
          # Count frequency for this ethnicity
          eth_count = sum(TGA_transformed[[dyad]][TGA_transformed$ethnicity == eth] == 1, na.rm = TRUE)
          dyad_ethnicity_data = rbind(dyad_ethnicity_data,
                                    data.frame(Ethnicity = eth, Dyad = dyad,
                                              Frequency = eth_count, Display = display))
        }
      }

      # Plot dyads by ethnicity
      plot11 = ggplot(dyad_ethnicity_data, aes(x = Display, y = Frequency, fill = Ethnicity)) +
        geom_bar(stat = "identity", position = "dodge") +
        scale_fill_manual(values = c("Hispanic/LatinX/Spanish" = "orange",
                                     "Non-Hispanic/LatinX/Spanish" = "blue")) +
        labs(title = "Top 5 Dyad Combinations by Ethnicity",
             subtitle = "Distribution of most common chronic disease pairs across ethnic groups",
             x = "Dyad Combination",
             y = "Frequency",
             fill = "Ethnicity") +
        theme_minimal() +
        theme(axis.text.x = element_text(angle = 45, hjust = 1))

      print(plot11)  # Explicitly print the plot

      # Save separately
      png("dyads_by_ethnicity.png", width = 800, height = 600)
      print(plot11)
      dev.off()

      cat("Dyads by Ethnicity plot saved\n")
    } else {
      cat("Warning: No top dyads found for ethnicity analysis\n")
    }
  } else {
    cat("Warning: Ethnicity variable has no valid values\n")
  }
} else {
  cat("Warning: Missing data for dyad ethnicity analysis\n")
}

# 12. Age and Length of Stay Relationship
if("age_group" %in% names(combined_data) && "los_group" %in% names(combined_data)) {
  cat("Creating Age and Length of Stay relationship plot...\n")

  plot12 = ggplot(combined_data, aes(x = age_group, y = los_group)) +
    geom_count(aes(color = ..n.., size = ..n..)) +
    scale_color_gradient(low = "blue", high = "orange") +
    facet_wrap(~ Group) +
    labs(title = "Relationship Between Age and Length of Stay",
         subtitle = "Distribution of patients by age group and length of stay category",
         x = "Age Group (High/Low)",
         y = "Length of Stay (High/Low)",
         size = "Count",
         color = "Count") +
    theme_minimal()

  print(plot12)  # Explicitly print the plot

  # Save separately
  png("age_los_relationship.png", width = 800, height = 600)
  print(plot12)
  dev.off()

  cat("Age and Length of Stay Relationship plot saved\n")
} else {
  cat("Warning: Missing variables for Age and Length of Stay relationship plot\n")
}

# 13. Generate summary statistics for numerical variables
if(all(c("age", "los_days", "tot_charge") %in% names(combined_data))) {
  cat("Generating summary statistics...\n")

  summary_stats = aggregate(
    cbind(age, los_days, tot_charge) ~ Group,
    data = combined_data,
    FUN = function(x) c(
      n = length(x),
      mean = mean(x, na.rm = TRUE),
      median = median(x, na.rm = TRUE)
    )
  )

  print(summary_stats)

  # Write summary statistics to CSV
  write.csv(summary_stats, "multimorbidity_summary_stats.csv", row.names = FALSE)
  cat("Summary statistics saved to multimorbidity_summary_stats.csv\n")
} else {
  cat("Warning: Missing numerical variables for summary statistics\n")
}

# If tetrads or pentads exist, summarize their frequencies by patient characteristics
if(length(tetrads) > 0 | length(pentads) > 0) {
  cat("Analyzing higher-order combinations...\n")

  higher_order_cols = c(tetrads, pentads)

  # Create a new column indicating if a patient has any higher-order combination
  TGA_transformed$has_higher_order = 0

  for (col in higher_order_cols) {
    TGA_transformed$has_higher_order =
      ifelse(TGA_transformed[[col]] == 1, 1, TGA_transformed$has_higher_order)
  }

  # Count how many patients have at least one higher-order combination
  patients_with_higher_orders = table(TGA_transformed$has_higher_order)
  percent_with_higher_orders = round(prop.table(patients_with_higher_orders) * 100, 1)

  cat("Patients with higher-order combinations:",
      patients_with_higher_orders[2],
      "(", percent_with_higher_orders[2], "%)\n")

  # Summarize characteristics of patients with higher-order combinations
  if(all(c("age", "los_days", "tot_charge") %in% names(TGA_transformed))) {
    patient_characteristics = aggregate(
      cbind(age, los_days, tot_charge) ~ has_higher_order,
      data = TGA_transformed,
      FUN = function(x) c(
        n = length(x),
        mean = mean(x, na.rm = TRUE),
        median = median(x, na.rm = TRUE)
      )
    )

    print(patient_characteristics)

    # Write this additional summary to CSV
    write.csv(patient_characteristics, "higher_order_patient_characteristics.csv", row.names = FALSE)
    cat("Higher-order patient characteristics saved to higher_order_patient_characteristics.csv\n")
  } else {
    cat("Warning: Missing numerical variables for patient characteristics\n")
  }
}

# Create a reference mapping of all combinations
cat("Creating combination mapping reference...\n")

combination_mapping = data.frame(Type = character(0), Original = character(0),
                                Display = character(0), Frequency = numeric(0))

# Add dyads if they exist
if(exists("dyad_summary") && nrow(dyad_summary) > 0) {
  dyad_mapping = data.frame(
    Type = "Dyad",
    Original = dyad_summary$Dyad,
    Display = dyad_summary$Display,
    Frequency = dyad_summary$Frequency
  )
  combination_mapping = rbind(combination_mapping, dyad_mapping)
}

# Add triads if they exist
if(exists("triad_summary") && nrow(triad_summary) > 0) {
  triad_mapping = data.frame(
    Type = "Triad",
    Original = triad_summary$Triad,
    Display = triad_summary$Display,
    Frequency = triad_summary$Frequency
  )
  combination_mapping = rbind(combination_mapping, triad_mapping)
}

# Add tetrads if they exist
if(exists("tetrad_summary") && nrow(tetrad_summary) > 0) {
  tetrad_mapping = data.frame(
    Type = "Tetrad",
    Original = tetrad_summary$Tetrad,
    Display = tetrad_summary$Display,
    Frequency = tetrad_summary$Frequency
  )
  combination_mapping = rbind(combination_mapping, tetrad_mapping)
}

# Add pentads if they exist
if(exists("pentad_summary") && nrow(pentad_summary) > 0) {
  pentad_mapping = data.frame(
    Type = "Pentad",
    Original = pentad_summary$Pentad,
    Display = pentad_summary$Display,
    Frequency = pentad_summary$Frequency
  )
  combination_mapping = rbind(combination_mapping, pentad_mapping)
}

# Save the mapping for reference
if(nrow(combination_mapping) > 0) {
  write.csv(combination_mapping, "multimorbidity_combination_mapping.csv", row.names = FALSE)
  cat("Combination mapping saved to multimorbidity_combination_mapping.csv\n")
} else {
  cat("Warning: No combination mapping data generated\n")
}

# Close the PDF device that contains all plots
dev.off()
cat("All plots saved to all_plots.pdf\n")

cat("Analysis complete!\n")
```




####OMIT:  Plot of Hospitalization Outcomes (LOS, Charges, Mortality: Alive)
```{r}
# Combine the datasets for analysis
combined_data = rbind(
  data.frame(outcome = CGA$los_days, group = "Non-Multimorbidity", variable = "LOS Days"),
  data.frame(outcome = TGA$los_days, group = "Multimorbidity", variable = "LOS Days"),
  data.frame(outcome = CGA$tot_charge, group = "Non-Multimorbidity", variable = "Total Charges"),
  data.frame(outcome = TGA$tot_charge, group = "Multimorbidity", variable = "Total Charges"),
  data.frame(outcome = ifelse(CGA$alive == "Y", 1, 0), group = "Non-Multimorbidity", variable = "Alive"),
  data.frame(outcome = ifelse(TGA$alive == "Y", 1, 0), group = "Multimorbidity", variable = "Alive")
)

# Preprocess the data: calculate means and standard deviations
summary_data = aggregate(
  outcome ~ group + variable,
  data = combined_data,
  FUN = function(x) c(mean = mean(x), sd = sd(x))
)

# Split aggregated columns into separate columns
summary_data = do.call(data.frame, summary_data)
names(summary_data)[names(summary_data) == "outcome.mean"] = "mean"
names(summary_data)[names(summary_data) == "outcome.sd"] = "sd"

# Define colors for groups
group_colors = c("Non-Multimorbidity" = "orange", "Multimorbidity" = "blue")

# Create a bar plot for each variable
create_bar_plot <- function(variable_name, y_breaks = NULL, y_labels = NULL, y_title = "Mean Value (with SD as Error Bars)") {
  plot_data = summary_data[summary_data$variable == variable_name, ]

  ggplot(plot_data, aes(x = group, y = mean, fill = group)) +
    geom_bar(stat = "identity", position = "dodge", color = "black") +
    geom_errorbar(
      aes(ymin = mean - sd, ymax = mean + sd),
      position = position_dodge(width = 0.9),
      width = 0.2,
      color = "black"
    ) +
    geom_text(
      aes(label = round(mean, 2)), # Show mean values directly on the bars
      position = position_dodge(width = 0.9),
      vjust = -0.5
    ) +
    scale_fill_manual(values = group_colors) + # Apply custom colors
    scale_y_continuous(breaks = y_breaks, labels = y_labels) + # Apply custom y-axis breaks and labels
    labs(
      title = paste(variable_name, "by Group"),
      x = "Group",
      y = y_title
    ) +
    theme_minimal() +
    theme(
      plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
      axis.text = element_text(size = 10),
      legend.position = "none"
    )
}

# Create and save plots for each variable
plot_los_days = create_bar_plot(
  "LOS Days",
  y_breaks = seq(0, max(summary_data$mean[summary_data$variable == "LOS Days"]) + 1, 1), # Adjusted y-axis
  y_labels = NULL,
  y_title = "Mean LOS Days (in days)"
)
plot_tot_charge = create_bar_plot(
  "Total Charges",
  y_breaks = seq(0, max(summary_data$mean[summary_data$variable == "Total Charges"]) + 50000, 50000), # 50,000 increments
  y_labels = scales::comma_format(),
  y_title = "Mean Total Charges (in $)"
)
plot_alive = create_bar_plot(
  "Alive",
  y_breaks = seq(0, 100, 10), # Alive data as percentages
  y_labels = scales::percent_format(scale = 1), # Format as percentages
  y_title = "Percentage Alive (in %)"
)

# Save each plot
ggsave("los_days_bar_plot_fixed.png", plot = plot_los_days, width = 8, height = 6, dpi = 300)
ggsave("total_charges_bar_plot_fixed.png", plot = plot_tot_charge, width = 8, height = 6, dpi = 300)
ggsave("alive_bar_plot_fixed.png", plot = plot_alive, width = 8, height = 6, dpi = 300)
```
####OMIT: Create visualizations to compare the high/low variables(los, age, charges) between the control and test groups using colorblind-friendly imagery
```{r}
# Load required libraries
library(ggplot2)
library(dplyr)
library(tidyr)
library(gridExtra)

# Function to calculate medians for a variable
calculate_medians <- function(CGA, TGA, variable) {
  return(data.frame(
    variable = variable,
    non_multimorbidity_median = median(CGA[[variable]], na.rm = TRUE),
    multimorbidity_median = median(TGA[[variable]], na.rm = TRUE)
  ))
}

# Combine data preparation into a single step
prepare_combined_data <- function(CGA, TGA) {
  CGA$group <- "Non-Multimorbidity"  # Label CGA data
  TGA$group <- "Multimorbidity"      # Label TGA data

  # Ensure binary variables are factors with labels
  CGA$los_days_hilo <- factor(CGA$los_days_hilo, levels = c(0, 1), labels = c("Low", "High"))
  CGA$age_hilo <- factor(CGA$age_hilo, levels = c(0, 1), labels = c("Low", "High"))
  CGA$tot_charge_hilo <- factor(CGA$tot_charge_hilo, levels = c(0, 1), labels = c("Low", "High"))

  TGA$los_days_hilo <- factor(TGA$los_days_hilo, levels = c(0, 1), labels = c("Low", "High"))
  TGA$age_hilo <- factor(TGA$age_hilo, levels = c(0, 1), labels = c("Low", "High"))
  TGA$tot_charge_hilo <- factor(TGA$tot_charge_hilo, levels = c(0, 1), labels = c("Low", "High"))

  selected_cols <- c("los_days", "age", "tot_charge", "los_days_hilo", "age_hilo", "tot_charge_hilo", "group")
  combined <- rbind(select(CGA, all_of(selected_cols)), select(TGA, all_of(selected_cols)))
  return(combined)
}

# Function to create grouped bar plots
create_grouped_bar <- function(data, var_name, title, subtitle, x_label, y_label, colors) {
  count_data <- data %>%
    count(group, !!sym(var_name)) %>%
    group_by(group) %>%
    mutate(percentage = n / sum(n) * 100)

  ggplot(count_data, aes(x = group, y = percentage, fill = !!sym(var_name))) +
    geom_bar(stat = "identity", position = "dodge", width = 0.7) +
    geom_text(aes(label = sprintf("%.1f%%", percentage)),
              position = position_dodge(width = 0.7),
              vjust = -0.5, size = 3.5) +
    scale_fill_manual(values = colors) +
    labs(title = title, subtitle = subtitle, x = x_label, y = y_label, fill = var_name) +
    theme_minimal() +
    theme(legend.position = "bottom",
          plot.title = element_text(hjust = 0.5, size = 12, face = "bold"),
          plot.subtitle = element_text(hjust = 0.5, size = 10, face = "italic"),
          axis.text = element_text(size = 10),
          legend.title = element_text(size = 10))
}

# Main workflow
# Combine and prepare data
combined_data <- prepare_combined_data(CGA, TGA)

# Define colorblind-friendly colors
colors <- c("blue", "orange")

# Calculate medians
median_values <- do.call(rbind, lapply(c("los_days", "age", "tot_charge"),
                                       calculate_medians, CGA = CGA, TGA = TGA))

# Create individual plots
los_plot <- create_grouped_bar(
  combined_data, "los_days_hilo",
  "Length of Stay (High vs Low)",
  sprintf("Medians: Non-Multimorbidity = %.1f, Multimorbidity = %.1f",
          median_values$non_multimorbidity_median[1], median_values$multimorbidity_median[1]),
  "", "Percentage (%)", colors
)

age_plot <- create_grouped_bar(
  combined_data, "age_hilo",
  "Age (High vs Low)",
  sprintf("Medians: Non-Multimorbidity = %.1f, Multimorbidity = %.1f",
          median_values$non_multimorbidity_median[2], median_values$multimorbidity_median[2]),
  "", "Percentage (%)", colors
)

charge_plot <- create_grouped_bar(
  combined_data, "tot_charge_hilo",
  "Total Charge (High vs Low)",
  sprintf("Medians: Non-Multimorbidity = %.1f, Multimorbidity = %.1f",
          median_values$non_multimorbidity_median[3], median_values$multimorbidity_median[3]),
  "", "Percentage (%)", colors
)

# Arrange plots in a grid
combined_plot <- grid.arrange(
  los_plot, age_plot, charge_plot,
  ncol = 3,
  top = grid::textGrob("Comparison of High vs Low Variables Between Multimorbidity Groups",
                       gp = grid::gpar(fontsize = 14, font = 2))
)
```
####OMIT: Transform ethnicity: very messy, need to clean for classification as either Hispanic/Latinx or Non-Hispanic/Latinx
```{r}
table(TGA$ethnicity)

# Script to transform ethnicity column in both CGA and TGA dataframes

# Create ethnicity mapping for CGA
cga_ethnicity_mapping <- c(
  "Non-Hispanic or Latino" = "Non-Hispanic/LatinX/Spanish",
  "Hispanic or Latino" = "Hispanic/LatinX/Spanish",
  "White: Not Specified/Unknown" = "Non-Hispanic/LatinX/Spanish",
  "White: Other" = "Non-Hispanic/LatinX/Spanish",
  "Black or African American: Not Specified/Unknown" = "Non-Hispanic/LatinX/Spanish",
  "Unknown" = "Non-Hispanic/LatinX/Spanish",
  "Black or African American: Other" = "Non-Hispanic/LatinX/Spanish",
  "African American" = "Non-Hispanic/LatinX/Spanish",
  "European Descent" = "Non-Hispanic/LatinX/Spanish",
  "Hispanic/Latino: Not Specified/Unknown" = "Hispanic/LatinX/Spanish",
  "Asian: Other" = "Non-Hispanic/LatinX/Spanish",
  "Asian Indian/Indian Sub-Continent" = "Non-Hispanic/LatinX/Spanish",
  "Asian: Not Specified/Unknown" = "Non-Hispanic/LatinX/Spanish",
  "Hispanic/Latino: Other" = "Hispanic/LatinX/Spanish",
  "NULL" = "Non-Hispanic/LatinX/Spanish",
  "American Indian" = "Non-Hispanic/LatinX/Spanish",
  "Arab or Middle Eastern" = "Non-Hispanic/LatinX/Spanish",
  "Filipino" = "Non-Hispanic/LatinX/Spanish",
  "Mexican" = "Hispanic/LatinX/Spanish",
  "Vietnamese" = "Non-Hispanic/LatinX/Spanish",
  "American Indian or Alaska Native: Not Specified/Unknown" = "Non-Hispanic/LatinX/Spanish",
  "African (Continental)" = "Non-Hispanic/LatinX/Spanish",
  "American Indian or Alaska Native: Other" = "Non-Hispanic/LatinX/Spanish",
  "Chinese" = "Non-Hispanic/LatinX/Spanish",
  "Native Hawaiian or Other Pacific Islander: Other" = "Non-Hispanic/LatinX/Spanish",
  "Puerto Rican (Island)" = "Hispanic/LatinX/Spanish",
  "Puerto Rican (Mainland)" = "Hispanic/LatinX/Spanish",
  "West Indian" = "Non-Hispanic/LatinX/Spanish",
  "Haitian" = "Non-Hispanic/LatinX/Spanish",
  "Korean" = "Non-Hispanic/LatinX/Spanish",
  "Native Hawaiian or Other Pacific Islander: Not Specified/Unknown" = "Non-Hispanic/LatinX/Spanish",
  "Cuban" = "Hispanic/LatinX/Spanish",
  "Guamanian or Chamorro" = "Non-Hispanic/LatinX/Spanish",
  "North African (Non-Black)" = "Non-Hispanic/LatinX/Spanish",
  "Samoan" = "Non-Hispanic/LatinX/Spanish"
)

# Create ethnicity mapping for TGA
tga_ethnicity_mapping <- c(
  "Cuban" = "Hispanic/LatinX/Spanish",
  "Puerto Rican (Island)" = "Hispanic/LatinX/Spanish",
  "Hispanic/Latino: Other" = "Hispanic/LatinX/Spanish",
  "Mexican" = "Hispanic/LatinX/Spanish",
  "Hispanic/Latino: Not Specified/Unknown" = "Hispanic/LatinX/Spanish",
  "Hispanic or Latino" = "Hispanic/LatinX/Spanish",
  "Alaska Indian" = "Non-Hispanic/LatinX/Spanish",
  "NULL" = "Non-Hispanic/LatinX/Spanish",
  "Vietnamese" = "Non-Hispanic/LatinX/Spanish",
  "West Indian" = "Non-Hispanic/LatinX/Spanish",
  "American Indian or Alaska Native: Not Specified/Unknown" = "Non-Hispanic/LatinX/Spanish",
  "American Indian or Alaska Native: Other" = "Non-Hispanic/LatinX/Spanish",
  "Native Hawaiian or Other Pacific Islander: Not Specified/Unknown" = "Non-Hispanic/LatinX/Spanish",
  "African (Continental)" = "Non-Hispanic/LatinX/Spanish",
  "Korean" = "Non-Hispanic/LatinX/Spanish",
  "Arab or Middle Eastern" = "Non-Hispanic/LatinX/Spanish",
  "Asian Indian/Indian Sub-Continent" = "Non-Hispanic/LatinX/Spanish",
  "Chinese" = "Non-Hispanic/LatinX/Spanish",
  "Filipino" = "Non-Hispanic/LatinX/Spanish",
  "North African (Non-Black)" = "Non-Hispanic/LatinX/Spanish",
  "American Indian" = "Non-Hispanic/LatinX/Spanish",
  "Asian: Not Specified/Unknown" = "Non-Hispanic/LatinX/Spanish",
  "Native Hawaiian or Other Pacific Islander: Other" = "Non-Hispanic/LatinX/Spanish",
  "Asian: Other" = "Non-Hispanic/LatinX/Spanish",
  "European Descent" = "Non-Hispanic/LatinX/Spanish",
  "Unknown" = "Non-Hispanic/LatinX/Spanish",
  "African American" = "Non-Hispanic/LatinX/Spanish",
  "Black or African American: Other" = "Non-Hispanic/LatinX/Spanish",
  "Black or African American: Not Specified/Unknown" = "Non-Hispanic/LatinX/Spanish",
  "White: Other" = "Non-Hispanic/LatinX/Spanish",
  "White: Not Specified/Unknown" = "Non-Hispanic/LatinX/Spanish",
  "Non-Hispanic or Latino" = "Non-Hispanic/LatinX/Spanish"
)

# Function to clean ethnicity column (update in place)
clean_ethnicity <- function(df, mapping) {
  # Make a copy of the original ethnicity column (optional)
  df$ethnicity_original <- df$ethnicity

  # Transform the ethnicity column in place
  df$ethnicity <- sapply(df$ethnicity, function(x) {
    # Handle NA values
    if (is.na(x)) {
      return(NA)
    }

    # Check for "Hispanic or Latino" with possible leading/trailing spaces
    if (grepl("^\\s*Hispanic\\s+or\\s+Latino\\s*$", x, ignore.case = TRUE)) {
      return("Hispanic/LatinX/Spanish")
    }

    # Match the ethnicity to the mapping
    if (x %in% names(mapping)) {
      return(mapping[x])
    } else {
      # Return original value if not found in mapping
      cat("Warning: Unmapped ethnicity value found:", x, "\n")
      return(x)
    }
  })

  # Convert ethnicity to factor
  df$ethnicity <- factor(df$ethnicity,
                        levels = c("Hispanic/LatinX/Spanish", "Non-Hispanic/LatinX/Spanish"))

  # Return the updated dataframe
  return(df)
}

# Update ethnicity in the original dataframes
# This modifies CGA_grouped and TGA_grouped in place
CGA_grouped <- clean_ethnicity(CGA_grouped, cga_ethnicity_mapping)
TGA_grouped <- clean_ethnicity(TGA_grouped, tga_ethnicity_mapping)

# Print summary of the transformed ethnicity columns
cat("\nSummary of transformed ethnicity values in CGA_grouped:\n")
print(table(CGA_grouped$ethnicity, useNA = "ifany"))

cat("\nSummary of transformed ethnicity values in TGA_grouped:\n")
print(table(TGA_grouped$ethnicity, useNA = "ifany"))

# Save the updated original dataframes
write.csv(CGA_grouped, "CGA_grouped.csv", row.names = FALSE)
write.csv(TGA_grouped, "TGA_grouped.csv", row.names = FALSE)

cat("\nUpdated dataframes saved as 'CGA_grouped.csv' and 'TGA_grouped.csv'\n")

# Check for any unmapped values in CGA
if ("ethnicity_original" %in% colnames(CGA_grouped)) {
  unmapped_cga <- setdiff(unique(CGA_grouped$ethnicity_original), names(cga_ethnicity_mapping))
  if (length(unmapped_cga) > 0) {
    cat("\nUnmapped ethnicity values found in CGA_grouped:\n")
    print(unmapped_cga)
  } else {
    cat("\nAll ethnicity values in CGA_grouped were successfully mapped.\n")
  }
}

# Check for any unmapped values in TGA
if ("ethnicity_original" %in% colnames(TGA_grouped)) {
  unmapped_tga <- setdiff(unique(TGA_grouped$ethnicity_original), names(tga_ethnicity_mapping))
  if (length(unmapped_tga) > 0) {
    cat("\nUnmapped ethnicity values found in TGA_grouped:\n")
    print(unmapped_tga)
  } else {
    cat("\nAll ethnicity values in TGA_grouped were successfully mapped.\n")
  }
}

# Create a visualization of ethnicity distributions
library(ggplot2)

# Function to create ethnicity distribution plot
create_ethnicity_plot <- function(cga_df, tga_df) {
  # Calculate percentages for CGA
  cga_counts <- table(cga_df$ethnicity)
  cga_percentages <- prop.table(cga_counts) * 100
  cga_data <- data.frame(
    Group = "Non-Multimorbidity",
    Ethnicity = names(cga_percentages),
    Percentage = as.numeric(cga_percentages)
  )

  # Calculate percentages for TGA
  tga_counts <- table(tga_df$ethnicity)
  tga_percentages <- prop.table(tga_counts) * 100
  tga_data <- data.frame(
    Group = "Multimorbidity",
    Ethnicity = names(tga_percentages),
    Percentage = as.numeric(tga_percentages)
  )

  # Combine data
  combined_data <- rbind(cga_data, tga_data)

  # Create the plot
  ggplot(combined_data, aes(x = Group, y = Percentage, fill = Ethnicity)) +
    geom_bar(stat = "identity", position = "dodge", width = 0.7) +
    geom_text(aes(label = sprintf("%.1f%%", Percentage)),
              position = position_dodge(width = 0.7),
              vjust = -0.5, size = 3.5) +
    scale_fill_manual(values = c("blue", "orange")) +  # Blue and Orange
    labs(title = "Ethnicity Distribution by Group",
         x = "",
         y = "Percentage (%)") +
    theme_minimal() +
    theme(legend.position = "bottom",
          plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
          axis.text = element_text(size = 10))
}

# Create and save the ethnicity distribution plot
ethnicity_plot <- create_ethnicity_plot(CGA_grouped, TGA_grouped)
ggsave("ethnicity_distribution.png", ethnicity_plot, width = 8, height = 6, dpi = 300)

cat("\nEthnicity distribution plot saved as 'ethnicity_distribution.png'\n")
```
####OMIT: Create LOS and AGE groupings
```{r}
# CREATE AGE AND LENGTH OF STAY TIER GROUPS IN DATAFRAMES

# First, check if the dataframes exist
if(!exists("TGA_grouped") || !exists("CGA_grouped")) {
  stop("Error: TGA_grouped and/or CGA_grouped not found. Please load the datasets first.")
}

cat("Creating tier groups in TGA_grouped and CGA_grouped dataframes...\n")

# Create Age Group tiers
if("age" %in% names(TGA_grouped)) {
  TGA_grouped$age_group = cut(TGA_grouped$age,
                           breaks = c(18, 29, 39, 49, 59, 69, 79, 89),
                           labels = c("18-29", "30-39", "40-49", "50-59",
                                     "60-69", "70-79", "80-89"),
                           right = FALSE)
  cat("Created age_group in TGA_grouped with tiers: 18-29, 30-39, 40-49, 50-59, 60-69, 70-79, 80-89\n")
} else {
  cat("Warning: 'age' variable not found in TGA_grouped\n")
}

if("age" %in% names(CGA_grouped)) {
  CGA_grouped$age_group = cut(CGA_grouped$age,
                           breaks = c(18, 29, 39, 49, 59, 69, 79, 89),
                           labels = c("18-29", "30-39", "40-49", "50-59",
                                     "60-69", "70-79", "80-89"),
                           right = FALSE)
  cat("Created age_group in CGA_grouped with tiers: 18-29, 30-39, 40-49, 50-59, 60-69, 70-79, 80-89\n")
} else {
  cat("Warning: 'age' variable not found in CGA_grouped\n")
}

# Create Length of Stay Group tiers
if("los_days" %in% names(TGA_grouped)) {
  TGA_grouped$los_group = cut(TGA_grouped$los_days,
                           breaks = c(-Inf, 0, 3, 7, 14, 21, 28, Inf),
                           labels = c("<1 day", "1-3 days", "4-7 days", "8-14 days",
                                     "15-21 days", "22-28 days", "29+ days"),
                           right = TRUE)
  cat("Created los_group in TGA_grouped with tiers: <1 day, 1-3 days, 4-7 days, 8-14 days, 15-21 days, 22-28 days, 29+ days\n")
} else {
  cat("Warning: 'los_days' variable not found in TGA_grouped\n")
}

if("los_days" %in% names(CGA_grouped)) {
  CGA_grouped$los_group = cut(CGA_grouped$los_days,
                           breaks = c(-Inf, 0, 3, 7, 14, 21, 28, Inf),
                           labels = c("<1 day", "1-3 days", "4-7 days", "8-14 days",
                                     "15-21 days", "22-28 days", "29+ days"),
                           right = TRUE)
  cat("Created los_group in CGA_grouped with tiers: <1 day, 1-3 days, 4-7 days, 8-14 days, 15-21 days, 22-28 days, 29+ days\n")
} else {
  cat("Warning: 'los_days' variable not found in CGA_grouped\n")
}

# Create Charge Group tiers
if("tot_charge" %in% names(TGA_grouped)) {
  # Calculate quintiles for charges using both datasets for consistent breaks
  combined_charges = c(TGA_grouped$tot_charge, CGA_grouped$tot_charge)
  charge_breaks = quantile(combined_charges, probs = seq(0, 1, 0.2), na.rm = TRUE)

  TGA_grouped$charge_group = cut(TGA_grouped$tot_charge,
                              breaks = charge_breaks,
                              labels = c("Q1 (Lowest)", "Q2", "Q3", "Q4", "Q5 (Highest)"),
                              include.lowest = TRUE)
  cat("Created charge_group in TGA_grouped with quintiles: Q1 (Lowest), Q2, Q3, Q4, Q5 (Highest)\n")
} else {
  cat("Warning: 'tot_charge' variable not found in TGA_grouped\n")
}

if("tot_charge" %in% names(CGA_grouped)) {
  # Use the same breaks for both datasets
  if(exists("charge_breaks")) {
    CGA_grouped$charge_group = cut(CGA_grouped$tot_charge,
                                breaks = charge_breaks,
                                labels = c("Q1 (Lowest)", "Q2", "Q3", "Q4", "Q5 (Highest)"),
                                include.lowest = TRUE)
  } else {
    # If we didn't calculate breaks above (no tot_charge in TGA_grouped)
    charge_breaks = quantile(CGA_grouped$tot_charge, probs = seq(0, 1, 0.2), na.rm = TRUE)
    CGA_grouped$charge_group = cut(CGA_grouped$tot_charge,
                                breaks = charge_breaks,
                                labels = c("Q1 (Lowest)", "Q2", "Q3", "Q4", "Q5 (Highest)"),
                                include.lowest = TRUE)
  }
  cat("Created charge_group in CGA_grouped with quintiles: Q1 (Lowest), Q2, Q3, Q4, Q5 (Highest)\n")
} else {
  cat("Warning: 'tot_charge' variable not found in CGA_grouped\n")
}

# Create Race variable by combining race flags
race_vars = c("r_white", "r_black", "r_asian", "r_amind_alnat", "r_nahaw_opacisl", "r_oth")

# For TGA_grouped
if(any(race_vars %in% names(TGA_grouped))) {
  TGA_grouped$race = "Unknown"
  if("r_white" %in% names(TGA_grouped))
    TGA_grouped$race[TGA_grouped$r_white == 1] = "White"
  if("r_black" %in% names(TGA_grouped))
    TGA_grouped$race[TGA_grouped$r_black == 1] = "Black/African American"
  if("r_asian" %in% names(TGA_grouped))
    TGA_grouped$race[TGA_grouped$r_asian == 1] = "Asian"
  if("r_amind_alnat" %in% names(TGA_grouped))
    TGA_grouped$race[TGA_grouped$r_amind_alnat == 1] = "American Indian/Alaska Native"
  if("r_nahaw_opacisl" %in% names(TGA_grouped))
    TGA_grouped$race[TGA_grouped$r_nahaw_opacisl == 1] = "Native Hawaiian/Pacific Islander"
  if("r_oth" %in% names(TGA_grouped))
    TGA_grouped$race[TGA_grouped$r_oth == 1] = "Other"
  cat("Created Race variable in TGA_grouped\n")
} else {
  cat("Warning: No race variables found in TGA_grouped\n")
}

# For CGA_grouped
if(any(race_vars %in% names(CGA_grouped))) {
  CGA_grouped$race = "Unknown"
  if("r_white" %in% names(CGA_grouped))
    CGA_grouped$race[CGA_grouped$r_white == 1] = "White"
  if("r_black" %in% names(CGA_grouped))
    CGA_grouped$race[CGA_grouped$r_black == 1] = "Black/African American"
  if("r_asian" %in% names(CGA_grouped))
    CGA_grouped$race[CGA_grouped$r_asian == 1] = "Asian"
  if("r_amind_alnat" %in% names(CGA_grouped))
    CGA_grouped$race[CGA_grouped$r_amind_alnat == 1] = "American Indian/Alaska Native"
  if("r_nahaw_opacisl" %in% names(CGA_grouped))
    CGA_grouped$race[CGA_grouped$r_nahaw_opacisl == 1] = "Native Hawaiian/Pacific Islander"
  if("r_oth" %in% names(CGA_grouped))
    CGA_grouped$race[CGA_grouped$r_oth == 1] = "Other"
  cat("Created Race variable in CGA_grouped\n")
} else {
  cat("Warning: No race variables found in CGA_grouped\n")
}

# Create Sex variable
if(all(c("M", "F") %in% names(TGA_grouped))) {
  TGA_grouped$gender = "Unknown"
  TGA_grouped$gender[TGA_grouped$M == 1] = "Male"
  TGA_grouped$Sex[TGA_grouped$s_female == 1] = "Female"
  cat("Created Sex variable in TGA_grouped\n")
} else {
  cat("Warning: Sex flag variables not found in TGA_grouped\n")
}

if(all(c("s_male", "s_female") %in% names(CGA_grouped))) {
  CGA_grouped$Sex = "Unknown"
  CGA_grouped$Sex[CGA_grouped$s_male == 1] = "Male"
  CGA_grouped$Sex[CGA_grouped$s_female == 1] = "Female"
  cat("Created Sex variable in CGA_grouped\n")
} else {
  cat("Warning: Sex flag variables not found in CGA_grouped\n")
}

# Copy the newly created variables to TGA_transformed if it exists
if(exists("TGA_transformed")) {
  # Copy age_group
  if("age_group" %in% names(TGA_grouped)) {
    TGA_transformed$age_group = TGA_grouped$age_group
    cat("Copied age_group to TGA_transformed\n")
  }

  # Copy los_group
  if("los_group" %in% names(TGA_grouped)) {
    TGA_transformed$los_group = TGA_grouped$los_group
    cat("Copied los_group to TGA_transformed\n")
  }

  # Copy charge_group
  if("charge_group" %in% names(TGA_grouped)) {
    TGA_transformed$charge_group = TGA_grouped$charge_group
    cat("Copied charge_group to TGA_transformed\n")
  }

  # Copy Race
  if("Race" %in% names(TGA_grouped)) {
    TGA_transformed$Race = TGA_grouped$Race
    cat("Copied Race to TGA_transformed\n")
  }

  # Copy Sex
  if("Sex" %in% names(TGA_grouped)) {
    TGA_transformed$Sex = TGA_grouped$Sex
    cat("Copied Sex to TGA_transformed\n")
  }
}

cat("\nTier groups creation complete!\n")
```


####OMIT: Descriptive Statistics Heatmap Visualizations: Part A
```{r}
# COMPREHENSIVE MULTIMORBIDITY COMBINATION HEATMAPS
# Analyzes dyads, triads, tetrads, and pentads across various demographic variables
library(ggplot2)

# Define a function to create and save heatmaps
create_heatmap = function(data, x_var, y_var, fill_var, title, filename,
                         x_label = NULL, y_label = NULL, angle = 45,
                         low_color = "blue", high_color = "orange") {
  if(is.null(x_label)) x_label = x_var
  if(is.null(y_label)) y_label = y_var

  cat("Creating heatmap:", title, "\n")

  # Create the heatmap
  heatmap_plot = ggplot(data, aes_string(x = x_var, y = y_var, fill = fill_var)) +
    geom_tile(color = "white") +
    scale_fill_gradient(low = low_color, high = high_color) +
    labs(title = title,
         x = x_label,
         y = y_label,
         fill = "Proportion (%)") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = angle, hjust = 1))

  # Display the plot
  print(heatmap_plot)

  # Save the plot
  png(filename, width = 1200, height = 800)
  print(heatmap_plot)
  dev.off()

  cat("Heatmap saved as:", filename, "\n")

  return(heatmap_plot)
}

# Check and standardize variable names
age_var = ifelse("age_group" %in% names(TGA_transformed), "age_group",
                ifelse("AgeGroup" %in% names(TGA_transformed), "AgeGroup", NA))

los_var = ifelse("los_group" %in% names(TGA_transformed), "los_group",
                ifelse("LOSGroup" %in% names(TGA_transformed), "LOSGroup", NA))

charge_var = ifelse("charge_group" %in% names(TGA_transformed), "charge_group",
                   ifelse("ChargeGroup" %in% names(TGA_transformed), "ChargeGroup", NA))

# Create age/LOS/charge groups if missing but source variables exist
if(is.na(age_var) && "age" %in% names(TGA_transformed)) {
  TGA_transformed$age_group = cut(TGA_transformed$age,
                                breaks = c(18, 29, 39, 49, 59, 69, 79, 89, Inf),
                                labels = c("18-29", "30-39", "40-49", "50-59",
                                          "60-69", "70-79", "80-89", "90+"),
                                right = FALSE)
  age_var = "age_group"
  cat("Created age_group variable\n")
}

if(is.na(los_var) && "los_days" %in% names(TGA_transformed)) {
  TGA_transformed$los_group = cut(TGA_transformed$los_days,
                                breaks = c(-Inf, 0, 3, 7, 14, 21, 28, Inf),
                                labels = c("<1 day", "1-3 days", "4-7 days", "8-14 days",
                                          "15-21 days", "22-28 days", "29+ days"),
                                right = TRUE)
  los_var = "los_group"
  cat("Created los_group variable\n")
}

if(is.na(charge_var) && "tot_charge" %in% names(TGA_transformed)) {
  charge_breaks = quantile(TGA_transformed$tot_charge, probs = seq(0, 1, 0.2), na.rm = TRUE)
  TGA_transformed$charge_group = cut(TGA_transformed$tot_charge,
                                   breaks = charge_breaks,
                                   labels = c("Q1 (Lowest)", "Q2", "Q3", "Q4", "Q5 (Highest)"),
                                   include.lowest = TRUE)
  charge_var = "charge_group"
  cat("Created charge_group variable\n")
}

# Check for gender variable and standardize to Sex
if("gender" %in% names(TGA_transformed)) {
  # Use the existing gender variable but label as Sex
  TGA_transformed$Sex = TGA_transformed$gender

  # Map M/F to Male/Female for better readability
  TGA_transformed$Sex[TGA_transformed$gender == "M"] = "Male"
  TGA_transformed$Sex[TGA_transformed$gender == "F"] = "Female"

  cat("Created Sex variable from gender (M -> Male, F -> Female)\n")
  cat("Sex distribution:\n")
  print(table(TGA_transformed$Sex))
} else if(!"Sex" %in% names(TGA_transformed) && all(c("s_male", "s_female") %in% names(TGA_transformed))) {
  # Create Sex from binary indicators if gender not available
  TGA_transformed$Sex = "Unknown"
  TGA_transformed$Sex[TGA_transformed$s_male == 1] = "Male"
  TGA_transformed$Sex[TGA_transformed$s_female == 1] = "Female"
  cat("Created Sex variable from s_male/s_female indicators\n")
}

# Check for race variable and transform to shorter categories
if(!"Race_Short" %in% names(TGA_transformed)) {
  if("Race" %in% names(TGA_transformed) || "race" %in% names(TGA_transformed)) {
    # Determine which race variable exists
    race_var = ifelse("race" %in% names(TGA_transformed), "race", "Race")

    # Create a shorter version of race categories
    TGA_transformed$Race_Short = TGA_transformed[[race_var]]

    # Transform to shorter versions
    race_mapping = c(
      "American Indian or Alaskan Native" = "American Indian",
      "Asian" = "Asian",
      "Black or African American" = "Black",
      "Native Hawaiian or Other Pacific Islander" = "Pacific Islander",
      "NULL" = "Unknown",
      "Other" = "Other",
      "Unknown (for use if patient refuses or fails to disclose)" = "Unknown",
      "White" = "White"
    )

    # Apply mapping
    for(long_name in names(race_mapping)) {
      TGA_transformed$Race_Short[TGA_transformed[[race_var]] == long_name] = race_mapping[long_name]
    }

    cat("Created shortened race categories:\n")
    print(table(TGA_transformed$Race_Short))
  } else if(any(c("r_white", "r_black", "r_asian", "r_amind_alnat", "r_nahaw_opacisl", "r_oth") %in% names(TGA_transformed))) {
    # Create Race from binary flags if no race variable exists
    TGA_transformed$Race_Short = "Unknown"
    if("r_white" %in% names(TGA_transformed))
      TGA_transformed$Race_Short[TGA_transformed$r_white == 1] = "White"
    if("r_black" %in% names(TGA_transformed))
      TGA_transformed$Race_Short[TGA_transformed$r_black == 1] = "Black"
    if("r_asian" %in% names(TGA_transformed))
      TGA_transformed$Race_Short[TGA_transformed$r_asian == 1] = "Asian"
    if("r_amind_alnat" %in% names(TGA_transformed))
      TGA_transformed$Race_Short[TGA_transformed$r_amind_alnat == 1] = "American Indian"
    if("r_nahaw_opacisl" %in% names(TGA_transformed))
      TGA_transformed$Race_Short[TGA_transformed$r_nahaw_opacisl == 1] = "Pacific Islander"
    if("r_oth" %in% names(TGA_transformed))
      TGA_transformed$Race_Short[TGA_transformed$r_oth == 1] = "Other"
    cat("Created Race_Short variable from race flags\n")
  }
}

# Format display names for combinations - removing the prefixes completely
format_display_name = function(col_name, prefix) {
  # Extract just the combination part, removing any prefix
  if(grepl("_", col_name)) {
    base_name = strsplit(col_name, "_")[[1]][2]  # Get part after first underscore
  } else {
    base_name = col_name  # No underscore, use as is
  }

  # Replace dots with plus signs for readability
  display_name = gsub("\\.", "+", base_name)

  return(display_name)
}

# Create variable mappings for processing
var_types = list(
  tot_charge = charge_var,
  ethnicity = "ethnicity",
  age_group = age_var,
  los_group = los_var,
  Race = ifelse("Race_Short" %in% names(TGA_transformed), "Race_Short", "Race"),
  Sex = "Sex",
  alive = "alive"
)

var_names = list(
  tot_charge = "Charge Group",
  ethnicity = "Ethnicity",
  age_group = "Age Group",
  los_group = "Length of Stay",
  Race = "Race",
  Sex = "Sex",
  alive = "Alive"
)

# Get the top 20 dyads and triads, and all tetrads and pentads
get_top_combos = function(combo_list, n = 20) {
  if(length(combo_list) == 0) return(c())

  counts = sapply(combo_list, function(col) sum(TGA_transformed[[col]] == 1, na.rm = TRUE))
  sorted_indices = order(counts, decreasing = TRUE)
  return(combo_list[sorted_indices[1:min(n, length(combo_list))]])
}

cat("\nIdentifying combinations for analysis...\n")
top20_dyads = get_top_combos(dyads, 20)
top20_triads = get_top_combos(triads, 20)
all_tetrads = tetrads
all_pentads = pentads

cat("Selected top 20 dyads:", length(top20_dyads), "\n")
cat("Selected top 20 triads:", length(top20_triads), "\n")
cat("All tetrads:", length(all_tetrads), "\n")
cat("All pentads:", length(all_pentads), "\n")

# Helper function to process combinations by grouping variable
process_combos_by_group = function(combos, group_var, var_name, filename_prefix, title_prefix) {
  if(length(combos) == 0 || is.na(group_var) || !group_var %in% names(TGA_transformed)) {
    return(NULL)
  }

  result_data = data.frame()
  groups = unique(na.omit(TGA_transformed[[group_var]]))

  # For each combination
  for(combo in combos) {
    display = format_display_name(combo, sub("_.*$", "", combo))

    # For each group
    for(grp in groups) {
      # Subset
      subset_data = subset(TGA_transformed, TGA_transformed[[group_var]] == grp)

      if(nrow(subset_data) > 0) {
        # Count occurrences
        count = sum(subset_data[[combo]] == 1, na.rm = TRUE)
        # Calculate proportion
        proportion = (count / nrow(subset_data)) * 100

        # Add to results
        result_data = rbind(result_data,
                          data.frame(
                            Group = grp,
                            Combo = combo,
                            Display = display,
                            Count = count,
                            Total = nrow(subset_data),
                            Proportion = proportion
                          ))
      }
    }
  }

  if(nrow(result_data) > 0) {
    # Create the heatmap
    heatmap = create_heatmap(
      result_data,
      "Display",
      "Group",
      "Proportion",
      paste(title_prefix, var_name),
      paste0(filename_prefix, "_", tolower(gsub("[^a-zA-Z0-9]", "_", var_name)), "_proportion_heatmap.png"),
      "Combinations",
      var_name
    )
    return(result_data)
  }
  return(NULL)
}

# Process chronic diseases
process_disease_heatmap = function(combos, combo_type) {
  chronic_diseases = c("aids", "alcohol", "anemdef", "arrhyth", "arthrh", "arth", "bldloss",
                      "chf", "chrnlung", "coag", "depress", "dm", "dmcx", "drug",
                      "htn_c", "hypothy", "liver", "lymph", "lytes", "mets", "neuro",
                      "obese", "para", "perivasc", "psych", "pulmcirc", "renlfail",
                      "tumor", "ulcer", "valve", "wghtloss")

  valid_diseases = chronic_diseases[chronic_diseases %in% names(TGA_transformed)]

  if(length(combos) > 0 && length(valid_diseases) > 0) {
    combo_disease_data = data.frame(Disease = character(0), Combo = character(0),
                                  Display = character(0), Count = numeric(0),
                                  Total = numeric(0), Proportion = numeric(0))

    for(combo in combos) {
      display = format_display_name(combo, combo_type)

      for(disease in valid_diseases) {
        disease_subset = subset(TGA_transformed, TGA_transformed[[disease]] == 1)
        if(nrow(disease_subset) > 0) {
          count = sum(disease_subset[[combo]] == 1, na.rm = TRUE)
          proportion = (count / nrow(disease_subset)) * 100

          combo_disease_data = rbind(combo_disease_data,
                                   data.frame(Disease = disease,
                                             Combo = combo,
                                             Display = display,
                                             Count = count,
                                             Total = nrow(disease_subset),
                                             Proportion = proportion))
        }
      }
    }

    if(nrow(combo_disease_data) > 0) {
      # Format disease names
      combo_disease_data$DiseaseName = sapply(combo_disease_data$Disease, function(d) {
        readable = gsub("_", " ", d)
        readable = gsub("\\.", " ", readable)
        readable = paste0(toupper(substr(readable, 1, 1)), substr(readable, 2, nchar(readable)))
        return(readable)
      })

      create_heatmap(
        combo_disease_data,
        "Display",
        "DiseaseName",
        "Proportion",
        paste("Proportion of Patients with Combinations by Chronic Disease"),
        paste0(combo_type, "_disease_proportion_heatmap.png"),
        "Combinations",
        "Chronic Disease",
        angle = 45
      )
      return(combo_disease_data)
    }
  }
  return(NULL)
}

# Generate heatmaps for all variables and combination types
cat("\n--- GENERATING HEATMAPS ---\n")

# Process each demographic variable
for(var in names(var_types)) {
  group_var = var_types[[var]]
  var_name = var_names[[var]]

  # Skip if variable doesn't exist or is NA
  if(is.na(group_var) || !group_var %in% names(TGA_transformed)) {
    cat("\nSkipping", var_name, "- variable not found in data\n")
    next
  }

  cat("\nProcessing", var_name, "heatmaps...\n")

  # 1. Dyads
  cat("  Generating dyad proportions by", var_name, "...\n")
  dyad_data = process_combos_by_group(top20_dyads, group_var, var_name, "dyad",
                                     "Proportion of Patients with Combinations by")

  # 2. Triads
  cat("  Generating triad proportions by", var_name, "...\n")
  triad_data = process_combos_by_group(top20_triads, group_var, var_name, "triad",
                                      "Proportion of Patients with Combinations by")

  # 3. Tetrads
  cat("  Generating tetrad proportions by", var_name, "...\n")
  tetrad_data = process_combos_by_group(all_tetrads, group_var, var_name, "tetrad",
                                       "Proportion of Patients with Combinations by")

  # 4. Pentads
  cat("  Generating pentad proportions by", var_name, "...\n")
  pentad_data = process_combos_by_group(all_pentads, group_var, var_name, "pentad",
                                       "Proportion of Patients with Combinations by")

  # 5. Combined tetrad and pentad visualization
  if(length(all_tetrads) > 0 || length(all_pentads) > 0) {
    # Combine tetrad and pentad data if available
    higher_data = NULL
    if(!is.null(tetrad_data) && !is.null(pentad_data)) {
      tetrad_data$Type = "Tetrad"
      pentad_data$Type = "Pentad"
      higher_data = rbind(tetrad_data, pentad_data)
    } else if(!is.null(tetrad_data)) {
      tetrad_data$Type = "Tetrad"
      higher_data = tetrad_data
    } else if(!is.null(pentad_data)) {
      pentad_data$Type = "Pentad"
      higher_data = pentad_data
    }

    if(!is.null(higher_data)) {
      # Create a faceted version that separates tetrads and pentads
      facet_plot = ggplot(higher_data, aes(x = Display, y = Group, fill = Proportion)) +
        geom_tile(color = "white") +
        facet_wrap(~ Type, scales = "free_x") +  # Separate by type
        scale_fill_gradient(low = "blue", high = "orange") +
        labs(title = paste("Proportion of Patients with Combinations by", var_name),
             x = "Combinations",
             y = var_name,
             fill = "Proportion (%)") +
        theme_minimal() +
        theme(axis.text.x = element_text(angle = 90, hjust = 1, size = 8))

      print(facet_plot)

      # Save the faceted plot
      png(paste0("faceted_higher_order_",
                tolower(gsub("[^a-zA-Z0-9]", "_", var_name)),
                "_proportion_heatmap.png"),
          width = 1600, height = 1000)
      print(facet_plot)
      dev.off()

      cat("  Created faceted higher-order combinations plot\n")
    }
  }
}

# Process chronic disease heatmaps
cat("\nProcessing chronic disease heatmaps...\n")
dyad_disease_data = process_disease_heatmap(top20_dyads, "dyad")
triad_disease_data = process_disease_heatmap(top20_triads, "triad")
tetrad_disease_data = process_disease_heatmap(all_tetrads, "tetrad")
pentad_disease_data = process_disease_heatmap(all_pentads, "pentad")

# Generate summary statistics
cat("\nGenerating summary statistics...\n")

get_summary_stats = function(combos, combo_type) {
  summary_data = data.frame(
    Combination = character(0),
    Type = character(0),
    Count = numeric(0),
    Proportion = numeric(0),
    AgeDistribution = character(0),
    SexDistribution = character(0),
    LOSAverage = numeric(0),
    ChargeAverage = numeric(0)
  )

  if(length(combos) == 0) return(summary_data)

  for(combo in combos) {
    display = format_display_name(combo, combo_type)
    patients = subset(TGA_transformed, TGA_transformed[[combo]] == 1)

    if(nrow(patients) > 0) {
      count = nrow(patients)
      proportion = (count / nrow(TGA_transformed)) * 100

      # Age distribution
      age_dist = "Not available"
      if(!is.na(age_var) && age_var %in% names(patients)) {
        age_counts = table(patients[[age_var]])
        if(length(age_counts) > 0) {
          age_props = round(prop.table(age_counts) * 100, 1)
          age_dist = paste(names(age_counts)[which.max(age_counts)],
                          " (", max(age_props), "%)", sep="")
        }
      }

      # Sex distribution
      sex_dist = "Not available"
      if("Sex" %in% names(patients)) {
        sex_counts = table(patients$Sex)
        if(length(sex_counts) > 0) {
          sex_props = round(prop.table(sex_counts) * 100, 1)
          sex_dist = paste(names(sex_counts)[which.max(sex_counts)],
                          " (", max(sex_props), "%)", sep="")
        }
      }

      # LOS average
      los_avg = NA
      los_vars = c("los_days", "los")
      for(var in los_vars) {
        if(var %in% names(patients)) {
          los_avg = mean(patients[[var]], na.rm = TRUE)
          break
        }
      }

      # Charge average
      charge_avg = NA
      charge_vars = c("tot_charge", "total_charge", "charge")
      for(var in charge_vars) {
        if(var %in% names(patients)) {
          charge_avg = mean(patients[[var]], na.rm = TRUE)
          break
        }
      }

      # Add to summary
      summary_data = rbind(summary_data,
                          data.frame(
                            Combination = display,
                            Type = toupper(substr(combo_type, 1, 1)),
                            Count = count,
                            Proportion = proportion,
                            AgeDistribution = age_dist,
                            SexDistribution = sex_dist,
                            LOSAverage = los_avg,
                            ChargeAverage = charge_avg
                          ))
    }
  }

  return(summary_data)
}

# Get summary for each combination type
dyad_summary = get_summary_stats(top20_dyads, "dyad")
triad_summary = get_summary_stats(top20_triads, "triad")
tetrad_summary = get_summary_stats(all_tetrads, "tetrad")
pentad_summary = get_summary_stats(all_pentads, "pentad")

# Combine all summaries
combined_summary = rbind(dyad_summary, triad_summary, tetrad_summary, pentad_summary)
if(nrow(combined_summary) > 0) {
  combined_summary = combined_summary[order(-combined_summary$Count),]

  write.csv(combined_summary, "multimorbidity_combinations_summary.csv", row.names = FALSE)
  cat("Combined summary statistics saved to multimorbidity_combinations_summary.csv\n")

  # Also create separate files for each type
  if(nrow(dyad_summary) > 0)
    write.csv(dyad_summary, "dyad_combinations_summary.csv", row.names = FALSE)
  if(nrow(triad_summary) > 0)
    write.csv(triad_summary, "triad_combinations_summary.csv", row.names = FALSE)
  if(nrow(tetrad_summary) > 0)
    write.csv(tetrad_summary, "tetrad_combinations_summary.csv", row.names = FALSE)
  if(nrow(pentad_summary) > 0)
    write.csv(pentad_summary, "pentad_combinations_summary.csv", row.names = FALSE)

  # Also create a combined higher-order summary
  higher_summary = rbind(tetrad_summary, pentad_summary)
  if(nrow(higher_summary) > 0) {
    higher_summary = higher_summary[order(-higher_summary$Count),]
    write.csv(higher_summary, "higher_order_combinations_summary.csv", row.names = FALSE)
    cat("Higher-order combinations summary saved\n")
  }
}

cat("\nMultimorbidity combination heatmap analysis complete!\n")
```
####OMIT: Descriptive Statistics Heatmap Visualizations: Part B (single chronic diseases)
```{r}
# CHRONIC DISEASE HEATMAPS BY DEMOGRAPHIC VARIABLES
# Creates proportion-based heatmaps showing the prevalence of individual chronic diseases
# across various demographic and clinical variables

library(ggplot2)
library(reshape2)  # For melt function

# Define the function to create and save heatmaps
create_heatmap = function(data, x_var, y_var, fill_var, title, filename,
                         x_label = NULL, y_label = NULL, angle = 45,
                         low_color = "blue", high_color = "orange") {
  if(is.null(x_label)) x_label = x_var
  if(is.null(y_label)) y_label = y_var

  cat("Creating heatmap:", title, "\n")

  # Create the heatmap
  heatmap_plot = ggplot(data, aes_string(x = x_var, y = y_var, fill = fill_var)) +
    geom_tile(color = "white") +
    scale_fill_gradient(low = low_color, high = high_color) +
    labs(title = title,
         x = x_label,
         y = y_label,
         fill = "Proportion (%)") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = angle, hjust = 1))

  # Display the plot
  print(heatmap_plot)

  # Save the plot
  png(filename, width = 1200, height = 800)
  print(heatmap_plot)
  dev.off()

  cat("Heatmap saved as:", filename, "\n")

  return(heatmap_plot)
}

# List of chronic diseases with exact variable names
chronic_diseases = c(
  "AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF", "CHRNLUNG",
  "COAG", "DEPRESS", "DM", "DMCX", "DRUG", "HTN_C", "HYPOTHY",
  "LIVER", "LYMPH", "LYTES", "METS", "NEURO", "OBESE", "PARA",
  "PERIVASC", "PSYCH", "PULMCIRC", "RENLFAIL", "TUMOR", "ULCER",
  "VALVE", "WGHTLOSS", "ARRHYTH"
)

# Check which chronic diseases exist in the data
existing_diseases = chronic_diseases[chronic_diseases %in% names(TGA_transformed)]
cat("Found", length(existing_diseases), "of", length(chronic_diseases), "chronic diseases in data\n")
if(length(existing_diseases) < length(chronic_diseases)) {
  missing_diseases = setdiff(chronic_diseases, existing_diseases)
  cat("Missing diseases:", paste(missing_diseases, collapse=", "), "\n")
}

# Variables to analyze against
demographic_vars = list(
  age_group = "Age Group",
  los_group = "Length of Stay",
  charge_group = "Charge Group",
  Race_Short = "Race",
  Sex = "Sex",
  ethnicity = "Ethnicity",
  alive = "Alive"
)

# Check which demographic variables exist
existing_demo_vars = names(demographic_vars)[names(demographic_vars) %in% names(TGA_transformed)]
cat("Found", length(existing_demo_vars), "of", length(demographic_vars), "demographic variables\n")

# Function to generate data for one demographic variable and all diseases
generate_data = function(demo_var, demo_name) {
  result = data.frame(
    DemoGroup = character(0),
    Disease = character(0),
    Count = numeric(0),
    Total = numeric(0),
    Proportion = numeric(0)
  )

  # Skip if variable doesn't exist
  if(!demo_var %in% names(TGA_transformed)) {
    cat("Warning: Variable", demo_var, "not found in data\n")
    return(result)
  }

  # Get unique values
  demo_values = unique(na.omit(TGA_transformed[[demo_var]]))

  # Process each disease and demographic value
  for(disease in existing_diseases) {
    for(demo_val in demo_values) {
      # Get subset
      subset_data = subset(TGA_transformed, TGA_transformed[[demo_var]] == demo_val)

      if(nrow(subset_data) > 0) {
        # Count occurrences
        count = sum(subset_data[[disease]] == 1, na.rm = TRUE)
        proportion = (count / nrow(subset_data)) * 100

        # Add to results
        result = rbind(result,
                      data.frame(
                        DemoGroup = as.character(demo_val),
                        Disease = disease,
                        Count = count,
                        Total = nrow(subset_data),
                        Proportion = proportion
                      ))
      }
    }
  }

  return(result)
}

# Process each demographic variable
all_data = list()  # Store all data for comprehensive view

for(demo_var in existing_demo_vars) {
  demo_name = demographic_vars[[demo_var]]
  cat("\nProcessing chronic diseases by", demo_name, "...\n")

  # Generate data
  disease_data = generate_data(demo_var, demo_name)

  if(nrow(disease_data) > 0) {
    # Store for later use in comprehensive view
    all_data[[demo_var]] = disease_data

    # Create heatmap
    heatmap = create_heatmap(
      disease_data,
      "Disease",
      "DemoGroup",
      "Proportion",
      paste("Proportion of Patients with Chronic Diseases by", demo_name),
      paste0("chronic_disease_by_", tolower(gsub(" ", "_", demo_name)), "_heatmap.png"),
      "Chronic Disease",
      demo_name,
      angle = 90  # Rotate disease names 90 degrees for better fit
    )

    # Calculate average proportions
    avg_props = aggregate(Proportion ~ Disease, data = disease_data, FUN = mean)
    avg_props = avg_props[order(-avg_props$Proportion),]
    cat("Average proportions of chronic diseases across", demo_name, "groups:\n")
    print(head(avg_props, 10))  # Show top 10

    # Calculate demographic groups with highest disease burden
    demo_burden = aggregate(Proportion ~ DemoGroup, data = disease_data, FUN = sum)
    demo_burden = demo_burden[order(-demo_burden$Proportion),]
    cat("Total disease burden by", demo_name, "groups:\n")
    print(demo_burden)
  } else {
    cat("No data available for", demo_name, "\n")
  }
}

# Create a comprehensive dataset
comprehensive_data = data.frame(
  Disease = character(0),
  DemoVariable = character(0),
  DemoGroup = character(0),
  Proportion = numeric(0)
)

for(demo_var in names(all_data)) {
  demo_name = demographic_vars[[demo_var]]
  temp_data = all_data[[demo_var]]

  # Add to comprehensive dataset
  comprehensive_data = rbind(comprehensive_data,
                           data.frame(
                             Disease = temp_data$Disease,
                             DemoVariable = demo_name,
                             DemoGroup = temp_data$DemoGroup,
                             Proportion = temp_data$Proportion
                           ))
}

if(nrow(comprehensive_data) > 0) {
  # Create faceted plot showing diseases by all demographic variables
  facet_plot = ggplot(comprehensive_data,
                     aes(x = Disease, y = DemoGroup, fill = Proportion)) +
    geom_tile(color = "white") +
    facet_grid(DemoVariable ~ ., scales = "free_y", space = "free_y") +
    scale_fill_gradient(low = "blue", high = "orange") +
    labs(title = "Comprehensive View of Chronic Disease Prevalence by Patient Characteristics",
         x = "Chronic Disease",
         y = "Demographic Groups",
         fill = "Proportion (%)") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5, size = 8),
          strip.text.y = element_text(angle = 0))

  print(facet_plot)

  # Save the comprehensive plot
  png("comprehensive_chronic_disease_prevalence.png", width = 2000, height = 1500, res = 150)
  print(facet_plot)
  dev.off()

  cat("Comprehensive disease prevalence plot saved\n")

  # Create a disease correlation matrix
  cat("\nCalculating disease correlation matrix...\n")

  disease_matrix = matrix(NA, nrow = length(existing_diseases), ncol = length(existing_diseases))
  rownames(disease_matrix) = colnames(disease_matrix) = existing_diseases

  for(i in 1:length(existing_diseases)) {
    for(j in 1:length(existing_diseases)) {
      if(i == j) {
        disease_matrix[i,j] = 1  # Perfect correlation with self
      } else {
        # Calculate correlation between diseases
        corr = cor(TGA_transformed[[existing_diseases[i]]],
                   TGA_transformed[[existing_diseases[j]]],
                   use = "pairwise.complete.obs")
        disease_matrix[i,j] = corr
      }
    }
  }

  # Create correlation heatmap
  disease_corr_data = melt(disease_matrix)
  colnames(disease_corr_data) = c("Disease1", "Disease2", "Correlation")

  # Create correlation heatmap
  corr_plot = ggplot(disease_corr_data,
                    aes(x = Disease1, y = Disease2, fill = Correlation)) +
    geom_tile() +
    scale_fill_gradient2(low = "blue", mid = "white", high = "orange", midpoint = 0,
                        limits = c(-1, 1)) +
    labs(title = "Correlation Matrix of Chronic Diseases",
         x = "", y = "", fill = "Correlation") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5, size = 7),
          axis.text.y = element_text(size = 7))

  print(corr_plot)

  # Save correlation matrix
  png("chronic_disease_correlation_matrix.png", width = 1500, height = 1500, res = 150)
  print(corr_plot)
  dev.off()

  cat("Disease correlation matrix saved\n")
}

cat("\nChronic disease analysis complete!\n")
```

####OMIT: Stratified Demographics for Combinations
```{r}
#Here we look at the demographic composition of the various types of chronic diseases via visual explorations rather than tables.

# R Script to Visualize Demographic Composition of Chronic Disease Combinations
# This script uses blue and orange color scheme consistently across all visualizations

# Load necessary packages
library(ggplot2)
library(patchwork) # For combining multiple plots
library(scales) # For percentage scales

# Function to identify chronic disease combinations
identify_disease_combinations = function(data, disease_columns) {
  # Create a new column counting the number of chronic diseases per patient
  data$disease_count <- rowSums(data[, disease_columns], na.rm = TRUE)

  # Create disease combination strings
  data$disease_combination <- apply(data[, disease_columns], 1, function(row) {
    diseases <- names(row[row == 1])
    if(length(diseases) == 0) return("None")
    paste(diseases, collapse = ", ")
  })

  # Create disease count category
  data$disease_count_category <- "Unknown"
  data$disease_count_category[data$disease_count == 0] <- "No chronic diseases"
  data$disease_count_category[data$disease_count == 1] <- "Single chronic disease"
  data$disease_count_category[data$disease_count == 2] <- "2 chronic diseases"
  data$disease_count_category[data$disease_count == 3] <- "3 chronic diseases"
  data$disease_count_category[data$disease_count >= 4] <- "4+ chronic diseases"

  # Convert to ordered factor for better plotting
  data$disease_count_category <- factor(data$disease_count_category,
                                      levels = c("No chronic diseases",
                                                "Single chronic disease",
                                                "2 chronic diseases",
                                                "3 chronic diseases",
                                                "4+ chronic diseases",
                                                "Unknown"))

  return(data)
}

# Function to visualize disease count distribution
plot_disease_count_distribution <- function(data) {
  # Create a summary of disease counts
  count_summary <- aggregate(rep(1, nrow(data)), by = list(disease_count_category = data$disease_count_category), FUN = sum)
  names(count_summary)[2] <- "n"
  count_summary$percentage <- count_summary$n / sum(count_summary$n) * 100
  count_summary$label <- paste0(round(count_summary$percentage, 1), "%")

  # Order factors
  count_summary$disease_count_category <- factor(count_summary$disease_count_category,
                                   levels = c("No chronic diseases",
                                             "Single chronic disease",
                                             "2 chronic diseases",
                                             "3 chronic diseases",
                                             "4+ chronic diseases"))

  # Create the plot with blue bars
  p <- ggplot(count_summary, aes(x = disease_count_category, y = n)) +
    geom_bar(stat = "identity", fill = "blue") + # Blue color
    geom_text(aes(label = label), vjust = -0.5) +
    labs(title = "Distribution of Chronic Disease Counts",
         x = "",
         y = "Number of Patients") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1))

  return(p)
}

# Function to visualize top disease combinations
plot_top_disease_combinations <- function(data, top_n = 10) {
  # Filter data with at least one disease
  filtered_data <- data[data$disease_count > 0, ]

  # Get top combinations
  combination_counts <- aggregate(rep(1, nrow(filtered_data)),
                               by = list(disease_combination = filtered_data$disease_combination),
                               FUN = sum)
  names(combination_counts)[2] <- "n"

  # Sort and get top N
  combination_counts <- combination_counts[order(-combination_counts$n), ]
  top_combinations <- head(combination_counts, top_n)

  # Reorder factor levels
  top_combinations$disease_combination <- factor(top_combinations$disease_combination,
                                            levels = rev(top_combinations$disease_combination))

  # Create the plot with orange bars
  p <- ggplot(top_combinations, aes(x = disease_combination, y = n)) +
    geom_col(fill = "orange") + # Orange color
    coord_flip() +
    labs(title = paste0("Top ", top_n, " Chronic Disease Combinations"),
         x = "",
         y = "Number of Patients") +
    theme_minimal()

  return(p)
}

# Function to visualize demographic breakdown by disease count
plot_demographic_by_disease_count <- function(data, demographic_column) {
  # Create summary data
  summary_data <- aggregate(rep(1, nrow(data)),
                         by = list(
                           disease_count_category = data$disease_count_category,
                           demographic = data[, demographic_column]
                         ),
                         FUN = sum)
  names(summary_data)[3] <- "count"

  # Calculate percentages
  totals <- aggregate(summary_data$count,
                   by = list(disease_count_category = summary_data$disease_count_category),
                   FUN = sum)
  names(totals)[2] <- "total"

  summary_data <- merge(summary_data, totals, by = "disease_count_category")
  summary_data$percentage <- summary_data$count / summary_data$total * 100

  # Remove NAs
  summary_data <- summary_data[!is.na(summary_data$demographic), ]

  # Order factors
  summary_data$disease_count_category <- factor(summary_data$disease_count_category,
                                           levels = c("No chronic diseases",
                                                     "Single chronic disease",
                                                     "2 chronic diseases",
                                                     "3 chronic diseases",
                                                     "4+ chronic diseases"))

  # Create a blue-orange color scale
  blue_orange_palette <- c("blue", "orange")  # Adding more shades if needed for many categories
  if(length(unique(summary_data$demographic)) > 2) {
    blue_orange_palette <- colorRampPalette(c("blue", "orange"))(length(unique(summary_data$demographic)))
  }

  # Create the plot
  p <- ggplot(summary_data, aes(x = disease_count_category,
                              y = percentage,
                              fill = factor(demographic))) +
    geom_bar(stat = "identity", position = "fill") +
    scale_y_continuous(labels = percent_format()) +
    scale_fill_manual(values = blue_orange_palette) +
    labs(title = paste0("Distribution of ", demographic_column, " by Disease Count"),
         x = "",
         y = "Percentage",
         fill = demographic_column) +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1))

  return(p)
}

# Function to visualize specific gender with multiple diseases
plot_gender_with_multiple_diseases <- function(data, gender_value, gender_column) {
  # Filter for the specified gender with multiple diseases
  gender_data <- data[data[, gender_column] == gender_value & data$disease_count >= 2, ]

  # Summarize data
  gender_summary <- aggregate(rep(1, nrow(gender_data)),
                           by = list(disease_count_category = gender_data$disease_count_category),
                           FUN = sum)
  names(gender_summary)[2] <- "count"

  # Calculate percentages
  gender_summary$percentage <- gender_summary$count / sum(gender_summary$count) * 100
  gender_summary$label <- paste0(round(gender_summary$percentage, 1), "%")

  # Create a blue-orange color scale for the pie chart
  blue_orange_palette <- c("blue", "#6BAED6", "orange")  # Blue, light blue, orange for 3 categories
  if(nrow(gender_summary) > 3) {
    blue_orange_palette <- colorRampPalette(c("blue", "orange"))(nrow(gender_summary))
  }

  # Create the plot
  p <- ggplot(gender_summary, aes(x = "", y = count, fill = disease_count_category)) +
    geom_bar(stat = "identity", width = 1) +
    coord_polar("y", start = 0) +
    geom_text(aes(label = label),
              position = position_stack(vjust = 0.5)) +
    scale_fill_manual(values = blue_orange_palette) +
    labs(title = paste0("Distribution of ", gender_value, "s with Multiple Chronic Diseases"),
         fill = "Disease Count") +
    theme_void() +
    theme(legend.position = "right")

  return(p)
}

# Function to create heatmap of demographics vs disease count
plot_demographic_heatmap <- function(data, demographic_column) {
  # Create summary data
  heatmap_data <- aggregate(rep(1, nrow(data)),
                         by = list(
                           disease_count_category = data$disease_count_category,
                           demographic = data[, demographic_column]
                         ),
                         FUN = sum)
  names(heatmap_data)[3] <- "n"

  # Calculate percentages by demographic group
  demo_totals <- aggregate(heatmap_data$n,
                        by = list(demographic = heatmap_data$demographic),
                        FUN = sum)
  names(demo_totals)[2] <- "demo_total"

  heatmap_data <- merge(heatmap_data, demo_totals, by = "demographic")
  heatmap_data$percentage <- heatmap_data$n / heatmap_data$demo_total * 100

  # Remove NAs
  heatmap_data <- heatmap_data[!is.na(heatmap_data$demographic), ]

  # Order factors
  heatmap_data$disease_count_category <- factor(heatmap_data$disease_count_category,
                                           levels = c("No chronic diseases",
                                                     "Single chronic disease",
                                                     "2 chronic diseases",
                                                     "3 chronic diseases",
                                                     "4+ chronic diseases"))

  # Create the plot with blue to orange color scale
  p <- ggplot(heatmap_data,
             aes(x = disease_count_category,
                 y = demographic,
                 fill = percentage)) +
    geom_tile(color = "white") +
    scale_fill_gradient(low = "blue", high = "orange") +  # Blue to orange gradient
    geom_text(aes(label = round(percentage, 1)), color = "black", size = 3) +
    labs(title = paste0(demographic_column, " Distribution by Disease Count (%)"),
         x = "",
         y = "",
         fill = "Percentage") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1))

  return(p)
}

# Function to visualize age-specific disease patterns (if age is a numeric variable)
plot_age_disease_patterns <- function(data, age_column) {
  # Check if age is numeric
  if(!is.numeric(data[, age_column])) {
    return(NULL)  # Skip if not numeric
  }

  # Create boxplot of age by disease count
  p <- ggplot(data, aes(x = disease_count_category, y = data[, age_column], fill = disease_count_category)) +
    geom_boxplot() +
    scale_fill_manual(values = colorRampPalette(c("blue", "orange"))(length(unique(data$disease_count_category)))) +
    labs(title = paste0("Distribution of ", age_column, " by Disease Count"),
         x = "",
         y = age_column) +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1),
          legend.position = "none")

  return(p)
}

# Function to create plots for each demographic
create_demographic_plots <- function(data, demographic_columns) {
  plots <- list()

  # Create plots for each demographic
  for (demo in demographic_columns) {
    plots[[paste0(demo, "_by_disease")]] <- plot_demographic_by_disease_count(data, demo)
    plots[[paste0(demo, "_heatmap")]] <- plot_demographic_heatmap(data, demo)

    # Add age-specific plots if the column is age-related
    if(grepl("age", demo, ignore.case = TRUE) && is.numeric(data[, demo])) {
      plots[[paste0(demo, "_boxplot")]] <- plot_age_disease_patterns(data, demo)
    }
  }

  return(plots)
}

# Function to create a comparative plot for disease counts across a demographic
plot_disease_count_comparative <- function(data, demographic_column) {
  # Create summary data
  comparative_data <- aggregate(disease_count ~ data[, demographic_column], data = data, FUN = mean)
  names(comparative_data)[1] <- "demographic"

  # Sort by average disease count
  comparative_data <- comparative_data[order(-comparative_data$disease_count), ]

  # Reorder factor for plotting
  comparative_data$demographic <- factor(comparative_data$demographic,
                                      levels = comparative_data$demographic)

  # Create the plot alternating blue and orange
  colors <- rep(c("blue", "orange"), length.out = nrow(comparative_data))

  p <- ggplot(comparative_data, aes(x = demographic, y = disease_count, fill = demographic)) +
    geom_bar(stat = "identity") +
    scale_fill_manual(values = colors) +
    labs(title = paste0("Average Number of Chronic Diseases by ", demographic_column),
         x = demographic_column,
         y = "Average Number of Diseases") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1),
          legend.position = "none")

  return(p)
}

# Main analysis function
visualize_chronic_diseases <- function(data, disease_columns, demographic_columns, gender_column,
                                     gender_values = c("Female", "Male")) {
  # Identify disease combinations
  data <- identify_disease_combinations(data, disease_columns)

  # Create basic plots
  plots <- list()
  plots[["disease_count"]] <- plot_disease_count_distribution(data)
  plots[["top_combinations"]] <- plot_top_disease_combinations(data)

  # Create demographic breakdown plots
  demographic_plots <- create_demographic_plots(data, demographic_columns)
  plots <- c(plots, demographic_plots)

  # Create gender-specific plots
  for (gender in gender_values) {
    plots[[paste0(gender, "_multiple_diseases")]] <-
      plot_gender_with_multiple_diseases(data, gender, gender_column)
  }

  # Create comparative plots
  for (demo in demographic_columns) {
    # Skip if the demographic has too many categories (arbitrary threshold)
    if(length(unique(data[, demo])) <= 15 && !is.numeric(data[, demo])) {
      plots[[paste0(demo, "_comparative")]] <- plot_disease_count_comparative(data, demo)
    }
  }

  # Return the plots and processed data
  return(list(
    plots = plots,
    processed_data = data
  ))
}

# Function to display multiple plots together
display_plots <- function(plot_list, ncol = 2) {
  # Combine plots using patchwork
  combined_plot <- wrap_plots(plot_list, ncol = ncol)

  # Display the combined plot
  print(combined_plot)

  # Return the combined plot
  return(combined_plot)
}

# Function to save plots
save_plots <- function(plot_list, output_dir = ".", width = 10, height = 8) {
  # Create output directory if it doesn't exist
  if (!dir.exists(output_dir)) {
    dir.create(output_dir, recursive = TRUE)
  }

  # Save individual plots
  for (name in names(plot_list)) {
    if(!is.null(plot_list[[name]])) {  # Skip NULL plots
      ggsave(
        filename = file.path(output_dir, paste0(name, ".png")),
        plot = plot_list[[name]],
        width = width,
        height = height
      )
    }
  }

  # Create and save combined plots
  demographic_plots <- plot_list[grep("_by_disease$", names(plot_list))]
  if(length(demographic_plots) > 0) {
    combined_demo_plots <- wrap_plots(
      demographic_plots,
      ncol = 2
    )

    ggsave(
      filename = file.path(output_dir, "combined_demographic_plots.png"),
      plot = combined_demo_plots,
      width = width * 1.5,
      height = height * 1.5
    )
  }

  heatmap_plots <- plot_list[grep("_heatmap$", names(plot_list))]
  if(length(heatmap_plots) > 0) {
    combined_heatmaps <- wrap_plots(
      heatmap_plots,
      ncol = 2
    )

    ggsave(
      filename = file.path(output_dir, "combined_heatmaps.png"),
      plot = combined_heatmaps,
      width = width * 1.5,
      height = height * 1.5
    )
  }
}


# Define your chronic disease columns (assuming binary 0/1 or TRUE/FALSE coding)
disease_columns <- c("diabetes", "hypertension", "heart_disease", "kidney_disease", "lung_disease",
                    "liver_disease", "hiv_aids", "cancer", "autoimmune", "mental_health")

# Define demographic columns
demographic_columns <- c("age_group", "race", "gender", "ethnicity")

# Specify which column represents gender
gender_column <- "gender"

# Run the visualization
results <- visualize_chronic_diseases(TGA_grouped, disease_columns, demographic_columns, gender_column)

# Display the plots
display_plots(list(
  results$plots[["disease_count"]],
  results$plots[["top_combinations"]]
))

# Display demographic plots
display_plots(list(
  results$plots[["gender_by_disease"]],
  results$plots[["race_by_disease"]],
  results$plots[["ethnicity_by_disease"]],
  results$plots[["age_group_by_disease"]]
))

# Display heatmaps
display_plots(list(
  results$plots[["gender_heatmap"]],
  results$plots[["race_heatmap"]]
))

# Display female-specific plots
display_plots(list(
  results$plots[["Female_multiple_diseases"]],
  results$plots[["gender_comparative"]]
))

# Save plots to output directory
save_plots(results$plots, "chronic_disease_plots")

# Print a summary
cat("Visualization complete. Plots saved to 'chronic_disease_plots' directory.\n")
cat("Analyzed", nrow(results$processed_data), "patients with", length(disease_columns), "possible chronic diseases.\n")
cat("Found", sum(results$processed_data$disease_count > 0), "patients with at least one chronic disease.\n")
```
##Part 2b: Tetrachoric Correlation Matrix for Each Chronic Disease against others (singularly) (ChatGPT)
```{r}
# Load necessary libraries
library(psych)
library(ggplot2)
library(reshape2)
library(gridExtra) # For combining plots

# Function to calculate tetrachoric correlations and save heatmap as PNG
generate_and_save_heatmap <- function(data, title, file_name) {
  # Specify columns for correlation
  columns <- c("AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF", "CHRNLUNG",
               "COAG", "DEPRESS", "DM", "DMCX", "DRUG", "HTN_C", "HYPOTHY", "LIVER",
               "LYMPH", "LYTES", "METS", "NEURO", "OBESE", "PARA", "PERIVASC", "PSYCH",
               "PULMCIRC", "RENLFAIL", "TUMOR", "ULCER", "VALVE", "WGHTLOSS", "ARRHYTH")

  # Alphabetize variables
  columns <- sort(columns)

  # Select specified columns and handle NAs
  data_subset <- data[columns]
  data_subset[is.na(data_subset)] <- NA

  # Compute tetrachoric correlations
  corr_matrix <- tetrachoric(data_subset)$rho

  # Melt correlation matrix for plotting
  melted_corr <- melt(corr_matrix)
  colnames(melted_corr) <- c("Var1", "Var2", "Correlation")

  # Create heatmap
  heatmap_plot <- ggplot(melted_corr, aes(x = Var1, y = Var2, fill = Correlation)) +
    geom_tile(color = "white") +
    scale_fill_gradient2(low = "blue", high = "orange", mid = "white", midpoint = 0,
                         limit = c(-1, 1), space = "Lab", name = "Correlation") +
    geom_text(aes(label = sprintf("%.2f", Correlation),
                  fontface = ifelse(abs(Correlation) > 0.5, "bold", "plain")), size = 3.5) + # Adjust text size
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1, size = 10),
          axis.text.y = element_text(size = 10),
          plot.title = element_text(size = 14, face = "bold"),
          plot.margin = margin(10, 10, 10, 10)) +
    labs(title = title, x = "", y = "") +
    coord_fixed(ratio = 0.8)  # Adjust tile aspect ratio

  # Save heatmap as PNG
  ggsave(file_name, plot = heatmap_plot, width = 14, height = 10)

  # Return the heatmap plot object (useful for combining plots later)
  return(heatmap_plot)
}

# Save individual heatmaps
heatmap_CGA <- generate_and_save_heatmap(CGA, "Correlation Heatmap Between Chronic Diseases in Patients with Non-Multimorbidity", "CGA_Heatmap_tetrachoric.png")
heatmap_TGA <- generate_and_save_heatmap(TGA, "Correlation Heatmap Between Chronic Diseases in Patients with Multimorbidity", "TGA_Heatmap_tetrachoric.png")

# Combine heatmaps for side-by-side comparison
combined_plot <- grid.arrange(heatmap_CGA, heatmap_TGA, ncol = 2)

# Save combined heatmap
ggsave("combined_heatmap_fixed_width.png", combined_plot, width = 26, height = 10)


```
##Part 2c: Correlation Matrix for Combinations vs Chronic Diseases
```{r}
# Create new dataframe from the most recent due to adjustments caused by Part2b
TGA2 = TGA_BN

# Load necessary libraries
library(dplyr)
library(psych)
library(reshape2)
library(ggplot2)

# Function to identify the top combinations per category
identify_top_combinations = function(data, combination_col, category_col, top_n) {
  data[[combination_col]] = as.character(data[[combination_col]])
  summary_table = aggregate(list(Frequency = data[[combination_col]]),
                             by = list(Category = data[[category_col]], Combination = data[[combination_col]]),
                             FUN = length)
  summary_table = summary_table[order(summary_table$Frequency, decreasing = TRUE), ]
  top_combinations = tapply(summary_table$Combination, summary_table$Category, head, n = top_n)
  return(top_combinations)
}

# Improved binarize_top_combinations function - returns ONLY the binary combination columns
binarize_top_combinations = function(data, combination_col, top_combinations) {
  result_data = data.frame(row.names = rownames(data))
  data[[combination_col]] = as.character(data[[combination_col]])

  # Add binary columns for each combination
  for (combination in top_combinations) {
    adjusted_combination = ifelse(grepl("\\+", combination), combination, gsub("\\.", "+", combination))
    column_name = gsub("\\.", "+", make.names(adjusted_combination))
    result_data[[column_name]] = as.integer(data[[combination_col]] == combination)
  }

  return(result_data)
}

# Improved create_tga3 function - now only returns disease and combination columns
create_tga3 = function(data, combination_col, category_col, disease_cols, category, top_n) {
  # Get top combinations for the category
  top_combinations = identify_top_combinations(data, combination_col, category_col, top_n)
  top_combinations = unlist(top_combinations[category])
  if (is.null(top_combinations)) stop(paste("No top combinations found for category:", category))

  # Create binary columns for the top combinations
  combinations_df = binarize_top_combinations(data, combination_col, top_combinations)

  # Create a dataframe for disease columns
  disease_cols_in_data = intersect(disease_cols, colnames(data))
  disease_data = data[, disease_cols_in_data, drop = FALSE]
  disease_data = as.data.frame(lapply(disease_data, function(x) as.integer(!is.na(x) & x != 0)))

  # Combine the disease columns with the combination columns
  TGA3 = cbind(disease_data, combinations_df)

  # Clean data before returning
  TGA3 = clean_data(TGA3)

  return(TGA3)
}

# Function to compute tetrachoric correlations with confidence intervals
compute_tetrachoric_with_ci = function(data) {
  # Try to compute tetrachoric correlations with standard errors
  tetra_result = tryCatch({
    tetrachoric(data, smooth=TRUE)  # Add smoothing to help with non-positive definite matrix
  }, error = function(e) {
    warning(paste("Error in tetrachoric calculation:", e$message))
    # Return a minimal valid result
    list(rho = matrix(NA, ncol=ncol(data), nrow=ncol(data),
                     dimnames=list(colnames(data), colnames(data))))
  })

  corr_matrix = tetra_result$rho

  # Initialize SE matrix with same dimensions as correlation matrix
  se_matrix = matrix(NA, nrow=nrow(corr_matrix), ncol=ncol(corr_matrix))
  rownames(se_matrix) = rownames(corr_matrix)
  colnames(se_matrix) = colnames(corr_matrix)

  # If SE exists and has correct dimensions, use it
  if (!is.null(tetra_result$SE) && identical(dim(tetra_result$SE), dim(corr_matrix))) {
    se_matrix = tetra_result$SE
  } else {
    warning("Standard errors couldn't be computed for all correlations.")
  }

  list(
    correlation = corr_matrix,
    SE = se_matrix,
    ci_lower = corr_matrix - 1.96 * se_matrix,
    ci_upper = corr_matrix + 1.96 * se_matrix
  )
}

# Modified function to highlight strong correlations instead of significant ones
generate_heatmap_with_borders = function(corr_result, singles, combinations, title, file_name,
                                        threshold = 0.5) {
  corr_matrix = corr_result$correlation

  # Filter to keep only valid singles and combinations
  valid_singles = intersect(singles, rownames(corr_matrix))
  valid_combinations = intersect(combinations, colnames(corr_matrix))

  if (length(valid_singles) == 0 || length(valid_combinations) == 0)
    stop("No valid rows or columns for subsetting. Please verify data.")

  # Subset the correlation matrix to include only singles vs combinations
  corr_matrix = corr_matrix[valid_singles, valid_combinations, drop = FALSE]

  # Melt the matrix for plotting
  melted_corr = melt(corr_matrix, varnames = c("Var1", "Var2"), value.name = "Correlation")

  # Mark strong correlations (absolute value above threshold)
  melted_corr$Strong = abs(melted_corr$Correlation) >= threshold

  # Count number of strong correlations
  num_strong = sum(melted_corr$Strong, na.rm = TRUE)
  cat("Number of strong correlations (|r| ≥", threshold, "):", num_strong, "\n")

  # Round correlations for display
  melted_corr$Correlation_Display = round(melted_corr$Correlation, 2)

  # Create the base plot
  p = ggplot(melted_corr, aes(Var2, Var1)) +
    geom_tile(aes(fill = Correlation)) +
    geom_text(aes(label = Correlation_Display), color = "black") +
    scale_fill_gradient2(low = "blue", high = "orange", mid = "white", midpoint = 0) +
    labs(title = title, x = "Combinations", y = "Singles") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1))

  # Only add borders if there are strong correlations
  if (num_strong > 0) {
    # Create a new dataframe with just the strong correlations
    strong_data = melted_corr[melted_corr$Strong, ]

    # Convert factors to numeric positions for rectangle coordinates
    strong_data$x = as.numeric(factor(strong_data$Var2, levels = levels(factor(melted_corr$Var2))))
    strong_data$y = as.numeric(factor(strong_data$Var1, levels = levels(factor(melted_corr$Var1))))

    # Add rectangle borders for strong correlation cells
    p = p +
      geom_rect(data = strong_data,
                aes(xmin = x - 0.5, xmax = x + 0.5,
                    ymin = y - 0.5, ymax = y + 0.5),
                fill = NA, color = "yellow", linewidth = 2, inherit.aes = FALSE)

    # Add explanation subtitle
    p = p + labs(subtitle = paste0(num_strong,
                                 " cells with yellow borders indicate strong correlations (|r| ≥ ",
                                 threshold, ")"))
  } else {
    # If no strong correlations, add a note
    p = p + labs(subtitle = paste0("No correlations with |r| ≥ ", threshold, " found"))
  }

  # Save the plot
  ggsave(file_name, plot = p, width = 12, height = 8)

  return(p)
}

# Clean data by removing columns with no variance
clean_data = function(data) {
  data[, sapply(data, function(col) length(unique(col)) > 1), drop = FALSE]
}

# Function that uses the original create_tga3 without creating duplicates
run_analysis_by_category = function(data, combination_col, category_col, disease_cols, top_n, output_prefix,
                                   threshold = 0.5) {
  # Make a copy of the original data to avoid modifying it
  analysis_data = data

  # Convert factor to character if needed to prevent numeric conversion
  if (is.factor(analysis_data[[category_col]])) {
    analysis_data[[category_col]] = as.character(analysis_data[[category_col]])
  }

  # Store original categories before grouping
  original_categories = unique(analysis_data[[category_col]])

  # Apply the category grouping - this modifies the column for processing
  # but we're working on a copy of the data
  analysis_data[[category_col]] = ifelse(analysis_data[[category_col]] %in% c("Dyad", "Triad"),
                                       analysis_data[[category_col]],
                                       "Tetrad/Pentad")

  # Process each category
  categories = unique(analysis_data[[category_col]])
  for (category in categories) {
    cat("\nProcessing category:", category, "\n")

    # Create TGA3 with only the disease columns and top combinations for this category
    # We're using the modified analysis_data but with the original function
    TGA3 = create_tga3(analysis_data, combination_col, category_col, disease_cols, category, top_n)

    # Verify that we only have disease columns and combination columns
    combinations = setdiff(colnames(TGA3), disease_cols)
    singles = intersect(colnames(TGA3), disease_cols)

    cat("Analysis for category:", category, "\n")
    cat("Number of disease columns:", length(singles), "\n")
    cat("Number of combination columns:", length(combinations), "\n")

    # Compute tetrachoric correlations
    corr_result = compute_tetrachoric_with_ci(TGA3)

    # Generate heatmap using the category name in the title
    tryCatch({
      generate_heatmap_with_borders(
        corr_result,
        singles = singles,
        combinations = combinations,
        title = paste(category, "Tetrachoric Correlations of Chronic Diseases vs Combinations"),
        file_name = paste0(output_prefix, "_", gsub("[^a-zA-Z0-9]", "_", category), "_single_vs_combinations_heatmap.png"),
        threshold = threshold
      )
      cat("Heatmap created successfully for category:", category, "\n")
    }, error = function(e) {
      warning(paste("Error generating heatmap for category", category, ":", e$message))
    })
  }
}

# Parameters
combination_col = "combinations"
category_col = "combination_category"
disease_cols = c("AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF", "CHRNLUNG", "COAG", "DEPRESS",
                  "DM", "DMCX", "DRUG", "HTN_C", "HYPOTHY", "LIVER", "LYTES", "METS", "NEURO", "OBESE",
                  "PARA", "PERIVASC", "PSYCH", "PULMCIRC", "RENLFAIL", "TUMOR", "ULCER", "VALVE", "WGHTLOSS",
                  "ARRHYTH")
top_n = 20
output_prefix = "TGA3"

# Run the analysis
run_analysis_by_category(TGA2, combination_col, category_col, disease_cols, top_n, output_prefix)

```


##Part 3: Mann-Whitney U
###Mann Whitney U Test for nonparametric data (differences in hospitalization outcomes by subgroups, age, race, sex, ethnicity, chronic diseases)
```{r}
# Full Script: Group Comparison Analysis for Multimorbidity Project

CG_MW = CGA_MWU
TG_MW = TGA_MWU

# Load required libraries
library(stats)
library(graphics)
library(grDevices)
library(utils)

# P-value to significance stars
pval_to_star <- function(p) {
  if (p < 0.0001) return("****")
  else if (p < 0.001) return("***")
  else if (p < 0.01) return("**")
  else if (p < 0.05) return("*")
  else return("")
}

# Main comparison function
compare_subgroups <- function(CG_MW, TG_MW, outcomes, demo_vars, diseases_vars, clinical_vars,
                              score_vars = c("comorbidity_score", "van_index"),
                              group_label = c("Non-MM", "MM")) {
  library(dplyr)
  results <- data.frame()

  common_cols <- intersect(names(CG_MW), names(TG_MW))
  CG_MW <- CG_MW[, common_cols]
  TG_MW <- TG_MW[, common_cols]
  CG_MW$Group <- group_label[1]
  TG_MW$Group <- group_label[2]
  df_all <- rbind(CG_MW, TG_MW)

  all_vars <- unique(c(demo_vars, diseases_vars, clinical_vars, score_vars))
  age_vars <- c("age", "age_groups", "age_hilo")

  assign_category <- function(var) {
    if (var %in% age_vars) return("Age")
    if (var %in% demo_vars) return("Demographics")
    if (var %in% clinical_vars) return("Clinical")
    if (var %in% score_vars) return("Scores")
    return("Other")
  }

  for (outcome in outcomes) {
    df_all[[outcome]] <- as.numeric(df_all[[outcome]])
    for (var in all_vars) {
      if (!(var %in% names(df_all))) next
      var_data <- df_all[[var]]
      var_type <- if (length(unique(na.omit(var_data))) <= 2 && all(unique(na.omit(var_data)) %in% c(0,1))) "binary"
                  else if (is.factor(var_data) || is.character(var_data)) "categorical"
                  else if (length(unique(na.omit(var_data))) <= 10) "discrete"
                  else "continuous"
      category <- assign_category(var)

      levels_to_check <- if (var_type == "continuous") {
  "all"
} else if (var_type == "binary") {
  c(1)  # Only look at present cases
} else {
  unique(na.omit(var_data))
}

            for (level in levels_to_check) {
        df_sub <- if (level == "all") df_all else df_all[df_all[[var]] == level, ]
        if (nrow(df_sub) < 10 || length(unique(df_sub$Group)) < 2) next
        cg <- df_sub[df_sub$Group == group_label[1], ]
        tg <- df_sub[df_sub$Group == group_label[2], ]
        cg_out <- na.omit(as.numeric(cg[[outcome]]))
        tg_out <- na.omit(as.numeric(tg[[outcome]]))
        if (length(cg_out) < 5 || length(tg_out) < 5) {
          message("Skipping ", var, " (", level, ") — insufficient data for outcome: ", outcome)
          next
        }

        is_binary_outcome <- all(na.omit(unique(c(cg_out, tg_out))) %in% c(0, 1))

        level_label <- if (var_type == "binary" && level == 1) "Present" else as.character(level)

        row <- list(
          Outcome = outcome, Subgroup = var, Level = level_label, Category = category,
          Control_n = length(cg_out), Test_n = length(tg_out),
          Median_NonMM = median(cg_out, na.rm = TRUE),
          Median_MM = median(tg_out, na.rm = TRUE),
          IQR_NonMM = IQR(cg_out, na.rm = TRUE),
          IQR_MM = IQR(tg_out, na.rm = TRUE)
        )

        if (is_binary_outcome) {
          tbl <- matrix(c(
            sum(cg_out == 1), sum(cg_out == 0),
            sum(tg_out == 1), sum(tg_out == 0)
          ), nrow = 2)
          if (any(tbl == 0)) next
          test <- fisher.test(tbl)
          row$Test_Type <- "Fisher"
          row$Statistic <- NA
          row$P_Value <- test$p.value
          row$Z_Score <- NA
          row$CI_Lower <- test$conf.int[1]
          row$CI_Upper <- test$conf.int[2]
          row$Difference <- 100 * (mean(tg_out, na.rm = TRUE) - mean(cg_out, na.rm = TRUE))
        } else {
          test <- tryCatch({
            wilcox.test(tg_out, cg_out, conf.int = TRUE)
          }, error = function(e) NULL)
          if (is.null(test)) next
          n1 <- length(cg_out)
          n2 <- length(tg_out)
          U <- as.numeric(test$statistic)
          mu <- n1 * n2 / 2
          sigma <- sqrt((n1 * n2 * (n1 + n2 + 1)) / 12)
          z_score <- (U - mu) / sigma

          row$Test_Type <- "Mann-Whitney"
          row$Statistic <- U
          row$P_Value <- test$p.value
          row$Z_Score <- z_score
          row$CI_Lower <- test$conf.int[1]
          row$CI_Upper <- test$conf.int[2]
          row$Difference <- median(tg_out) - median(cg_out)
        }

        results <- rbind(results, as.data.frame(row, stringsAsFactors = FALSE))
      }
    }
  }

  if (nrow(results) > 0) {
    results$Adjusted_P <- p.adjust(results$P_Value, method = "BH")
    results$P_Stars <- sapply(results$Adjusted_P, function(p) {
      if (p < 0.001) "***"
      else if (p < 0.01) "**"
      else if (p < 0.05) "*"
      else ""
    })
    write.csv(results, "hybrid_group_comparison_results.csv", row.names = FALSE)
  }

  return(results)
}

create_subgroup_heatmaps <- function(results_df, output_dir = "subgroup_heatmaps", p_threshold = 0.05) {
  dir.create(output_dir, showWarnings = FALSE)
  all_outcomes <- unique(results_df$Outcome)
  unique_subgroups <- unique(results_df$Subgroup)

  for (subgroup in unique_subgroups) {
    df_sub <- results_df %>%
      filter(Subgroup == subgroup, Adjusted_P < p_threshold) %>%
      mutate(
        Level = as.character(Level),
        Outcome = factor(Outcome, levels = all_outcomes),
        log10P = -log10(Adjusted_P),
        Label = stringr::str_wrap(paste0(
          P_Stars, "\n",
          "Med: ", round(Median_NonMM, 1), "/", round(Median_MM, 1), "\n",
          "IQR: ", round(IQR_NonMM, 1), "/", round(IQR_MM, 1), "\n",
          "CI: [", round(CI_Lower, 1), ", ", round(CI_Upper, 1), "]"
        ), width = 20),
        Text_Length = nchar(Label)
      )

    if (nrow(df_sub) == 0) next
    plot_width <- min(max(8, length(unique(df_sub$Level)) * max(df_sub$Text_Length) / 30), 50)

    p <- ggplot(df_sub, aes(x = Level, y = Outcome, fill = log10P)) +
      geom_tile(color = "white") +
      geom_text(aes(label = Label), size = 3.1, lineheight = 1.1) +
      scale_fill_gradient2(low = "blue", high = "orange", midpoint = -log10(0.05), name = "-log10(P)") +
      theme_minimal() +
      labs(title = paste("Subgroup:", subgroup), x = "Level", y = "Outcome") +
      theme(axis.text.x = element_text(angle = 45, hjust = 1))

    ggsave(file.path(output_dir, paste0("heatmap_", gsub("[^A-Za-z0-9]+", "_", subgroup), ".pdf")),
           plot = p, width = plot_width, height = 8, limitsize = FALSE)
  }
}

create_combined_disease_heatmap <- function(results_df, diseases_vars, output_file = "combined_diseases_heatmap.pdf", p_threshold = 0.05) {
  outcomes_all <- unique(results_df$Outcome)

  df_sub <- results_df %>%
    filter(Subgroup %in% diseases_vars, Adjusted_P < p_threshold) %>%
    mutate(
      Subgroup = factor(Subgroup, levels = diseases_vars),
      Outcome = factor(Outcome, levels = outcomes_all),
      log10P = -log10(Adjusted_P),
      Label = stringr::str_wrap(paste0(
        P_Stars, "\n",
        "Med: ", round(Median_NonMM, 1), "/", round(Median_MM, 1), "\n",
        "IQR: ", round(IQR_NonMM, 1), "/", round(IQR_MM, 1), "\n",
        "CI: [", round(CI_Lower, 1), ", ", round(CI_Upper, 1), "]"
      ), width = 20),
      Text_Length = nchar(Label)
    )

  if (nrow(df_sub) == 0) return(NULL)

  plot_width <- min(max(10, length(unique(df_sub$Subgroup)) * max(df_sub$Text_Length) / 30), 50)

  p <- ggplot(df_sub, aes(x = Subgroup, y = Outcome, fill = log10P)) +
    geom_tile(color = "white") +
    geom_text(aes(label = Label), size = 3.1, lineheight = 1.1) +
    scale_fill_gradient2(low = "blue", high = "orange", midpoint = -log10(0.05), name = "-log10(P)") +
    theme_minimal() +
    labs(title = "Combined Disease Heatmap", x = "Subgroup", y = "Outcome") +
    theme(axis.text.x = element_text(angle = 45, hjust = 1))

  ggsave(output_file, plot = p, width = plot_width, height = 8, limitsize = FALSE)
}

generate_summary_table_pdf <- function(results_df, output_file = "hybrid_results_summary_table.pdf") {
  library(gridExtra)
  library(grid)

  df_display <- results_df %>%
    select(Outcome, Subgroup, Level, Category, Control_n, Test_n,
           Test_Type, Median_NonMM, Median_MM, IQR_NonMM, IQR_MM,
           CI_Lower, CI_Upper, Difference, P_Value, Adjusted_P, P_Stars) %>%
    mutate(across(where(is.numeric), ~round(., 3)))

  table_grob <- tableGrob(df_display, rows = NULL, theme = ttheme_default(base_size = 8))
  page_height <- min(50, max(8, 0.25 * nrow(df_display)))

  pdf(output_file, width = 20, height = page_height)
  grid.newpage()
  grid.draw(table_grob)
  dev.off()

  message("Saved summary table to ", output_file)
}

# --- 1. Define variable groups and outcomes ---

demo_vars <- c("sex", "eth_corrected", "race_corrected", "age_groups")
diseases_vars <- c("AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF",
                   "CHRNLUNG", "COAG", "DEPRESS", "DM", "DMCX", "DRUG",
                   "HTN_C", "HYPOTHY", "LIVER", "LYMPH", "LYTES", "METS",
                   "NEURO", "OBESE", "PARA", "PERIVASC", "PSYCH", "PULMCIRC",
                   "RENLFAIL", "TUMOR", "ULCER", "VALVE", "WGHTLOSS", "ARRHYTH")
clinical_vars <- c("comorbidity_score", "van_index")
outcomes <- c("alive", "los_days", "tot_charge")

# Optional: custom group labels
group_labels <- c("Non-Multimorbidity", "Multimorbidity")

# --- 2. Run comparison analysis ---
hybrid_results <- compare_subgroups(
  CG_MW, TG_MW,
  outcomes = outcomes,
  demo_vars = demo_vars,
  diseases_vars = diseases_vars,
  clinical_vars = clinical_vars,
  score_vars = c("comorbidity_score", "van_index"),
  group_label = group_labels
)

# --- 3. Generate per-subgroup heatmaps ---
create_subgroup_heatmaps(hybrid_results)

# --- 4. Generate combined disease heatmap ---
create_combined_disease_heatmap(hybrid_results, diseases_vars)

# --- 5. Generate summary table PDF ---
generate_summary_table_pdf(hybrid_results)

```


```{r}
CGA_MW = CGA_MWU
TGA_MW = TGA_MWU

# ===============================
# Full Revised Subgroup Analysis Script (XY Line Plots Only, Discarding Extreme Outliers)
# ===============================

# Load required libraries
library(dplyr)
library(ggplot2)
library(grid)
library(gridExtra)

##########################
# Helper Functions
##########################

# Convert values to numeric safely.
# Also converts Y/N or Yes/No to 1/0.
safe_numeric <- function(x) {
  if (is.character(x) || is.factor(x)) {
    x_char <- as.character(x)
    uv <- unique(na.omit(x_char))
    if (all(uv %in% c("Y", "N"))) return(ifelse(x_char == "Y", 1, 0))
    if (all(uv %in% c("Yes", "No"))) return(ifelse(x_char == "Yes", 1, 0))
  }
  return(suppressWarnings(as.numeric(x)))
}

# Remove extreme outliers: values below the 1st percentile or above the 99th percentile are set to NA.
remove_outliers <- function(x, lower_quantile = 0.01, upper_quantile = 0.99) {
  q <- quantile(x, probs = c(lower_quantile, upper_quantile), na.rm = TRUE)
  x[x < q[1] | x > q[2]] <- NA
  return(x)
}

# Check if variable is binary (only two unique non-NA values that are 0/1)
is_binary <- function(x) {
  uv <- unique(na.omit(x))
  (length(uv) <= 2) && all(uv %in% c(0, 1, "0", "1"))
}

# Check if variable is categorical (if factor/character or numeric with ≤10 unique integer values)
is_categorical <- function(x) {
  if (is.factor(x) || is.character(x)) return(TRUE)
  if (is.numeric(x)) {
    uv <- unique(na.omit(x))
    return(length(uv) <= 10 && all(uv == round(uv)))
  }
  return(FALSE)
}

# Determine variable type: returns "binary", "categorical", or "continuous".
get_variable_type <- function(x) {
  if (is_binary(x)) return("binary")
  else if (is_categorical(x)) return("categorical")
  else return("continuous")
}

# Assign significance marker based on p-value.
add_significance_markers <- function(p_value) {
  if (is.na(p_value)) return("")
  else if (p_value < 0.001) return("***")
  else if (p_value < 0.01) return("**")
  else if (p_value < 0.05) return("*")
  else return("")
}

##########################
# Plotting Function (XY Line Plot Only)
##########################

create_xy_line_plot <- function(subgroup_var, outcome, CGA_MWU, TGA_MWU, output_dir = "./results_xy_lineplots") {
  # Convert outcome to numeric and remove extreme outliers.
  CGA_MWU[[outcome]] <- remove_outliers(as.numeric(CGA_MWU[[outcome]]))
  TGA_MWU[[outcome]] <- remove_outliers(as.numeric(TGA_MWU[[outcome]]))

  # Add group labels.
  # Control data (Non-Multimorbidity) and test data (Multimorbidity)
  CGA_MWU$group <- "Non-Multimorbidity"
  TGA_MWU$group <- "Multimorbidity"

  # Force column names to lower case so that "group" is consistent.
  names(CGA_MWU) <- tolower(names(CGA_MWU))
  names(TGA_MWU) <- tolower(names(TGA_MWU))

  # Combine datasets using only the common columns.
  common_cols <- intersect(names(CGA_MWU), names(TGA_MWU))
  if (!("group" %in% common_cols)) {
    common_cols <- c(common_cols, "group")
  }
  combined_df <- rbind(CGA_MWU[, common_cols, drop = FALSE],
                       TGA_MWU[, common_cols, drop = FALSE])

  # Ensure the subgroup variable is a factor.
  combined_df[[subgroup_var]] <- as.factor(combined_df[[subgroup_var]])

  # Aggregate data: compute mean outcome per subgroup level and group.
  aggregated <- combined_df %>%
    group_by(.data[[subgroup_var]], group) %>%
    summarise(MeanOutcome = mean(.data[[outcome]], na.rm = TRUE), .groups = "drop")

  # For each subgroup level, perform the appropriate significance test.
  subgroup_levels <- levels(combined_df[[subgroup_var]])
  sig_data <- data.frame(Level = subgroup_levels, p_value = NA, sig_marker = NA, stringsAsFactors = FALSE)

  for (l in subgroup_levels) {
    sub_data <- combined_df[combined_df[[subgroup_var]] == l, ]
    data_non_multi <- sub_data[sub_data$group == "non-multimorbidity", outcome]
    data_multi <- sub_data[sub_data$group == "multimorbidity", outcome]

    if (length(data_non_multi) >= 5 && length(data_multi) >= 5) {
      if (all(na.omit(c(data_non_multi, data_multi)) %in% c(0,1))) {
        tbl <- matrix(c(
          sum(data_non_multi == 1, na.rm = TRUE),
          sum(data_non_multi == 0, na.rm = TRUE),
          sum(data_multi == 1, na.rm = TRUE),
          sum(data_multi == 0, na.rm = TRUE)
        ), nrow = 2)
        if (min(tbl) > 0) {
          test_result <- try(fisher.test(tbl), silent = TRUE)
          p_val <- if (inherits(test_result, "try-error")) NA else test_result$p.value
        } else {
          p_val <- NA
        }
      } else {
        test_result <- try(wilcox.test(data_multi, data_non_multi, exact = FALSE, correct = TRUE), silent = TRUE)
        p_val <- if (inherits(test_result, "try-error")) NA else test_result$p.value
      }
      sig_data$p_value[sig_data$Level == l] <- p_val
      sig_data$sig_marker[sig_data$Level == l] <- add_significance_markers(p_val)
    }
  }

  # Determine annotation y positions: for each subgroup level, use the max mean outcome plus an offset.
  agg_by_level <- aggregated %>%
    group_by(.data[[subgroup_var]]) %>%
    summarise(max_mean = max(MeanOutcome, na.rm = TRUE), .groups = "drop")
  sig_data <- merge(sig_data, agg_by_level, by.x = "Level", by.y = subgroup_var, all.x = TRUE)

  outcome_range <- max(combined_df[[outcome]], na.rm = TRUE) - min(combined_df[[outcome]], na.rm = TRUE)
  y_offset <- 0.05 * outcome_range
  sig_data$annotation_y <- sig_data$max_mean + y_offset

  # Create the XY line plot.
  p <- ggplot(aggregated, aes_string(x = subgroup_var, y = "MeanOutcome", group = "group", color = "group")) +
    geom_line(size = 1) +
    geom_point(size = 3) +
    # Colors: Multimorbidity = blue; Non-Multimorbidity = orange.
    scale_color_manual(values = c("multimorbidity" = "blue", "non-multimorbidity" = "orange")) +
    theme_minimal() +
    labs(
      title = paste("Mean", gsub("_", " ", outcome), "by", subgroup_var),
      x = subgroup_var,
      y = paste("Mean", gsub("_", " ", outcome)),
      color = "group"
    ) +
    theme(
      plot.title = element_text(size = 14, face = "bold"),
      axis.title = element_text(size = 12)
    )

  p <- p + geom_text(data = sig_data, aes(x = Level, y = annotation_y, label = sig_marker),
                       vjust = 0, size = 5, color = "black")

  if (!dir.exists(output_dir)) {
    dir.create(output_dir, recursive = TRUE)
  }
  file_name <- paste0(output_dir, "/", outcome, "_xy_lineplot_", subgroup_var, ".png")
  ggsave(file_name, p, width = 8, height = 6, dpi = 300)

  return(p)
}

##########################
# Main Analysis Function (XY Plots Only)
##########################

analyze_subgroups <- function(CGA_MWU, TGA_MWU,
                              outcomes = c("alive", "los_days", "tot_charge"),
                              output_dir = "./results") {
  if (!dir.exists(output_dir)) dir.create(output_dir, recursive = TRUE)

  # Define subgroup variables.
  demo_vars <- c("sex", "eth_corrected", "race_corrected", "age_groups")
  clinical_vars <- c("AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF",
                     "CHRNLUNG", "COAG", "DEPRESS", "DM", "DMCX", "DRUG",
                     "HTN_C", "HYPOTHY", "LIVER", "LYMPH", "LYTES", "METS",
                     "NEURO", "OBESE", "PARA", "PERIVASC", "PSYCH", "PULMCIRC",
                     "RENLFAIL", "TUMOR", "ULCER", "VALVE", "WGHTLOSS", "ARRHYTH",
                     "comorbidity_score", "van_index")

  # Initialize master results table.
  all_results <- data.frame(
    Outcome = character(),
    Subgroup = character(),
    Level = character(),
    Control_n = integer(),
    Control_mean = numeric(),
    Control_median = numeric(),
    Test_n = integer(),
    Test_mean = numeric(),
    Test_median = numeric(),
    Difference = numeric(),
    Test_Type = character(),
    Statistic = numeric(),
    P_Value = numeric(),
    Adjusted_P_Value = numeric(),
    Z_Score = numeric(),
    CI_Lower = numeric(),
    CI_Upper = numeric(),
    Significance = character(),
    stringsAsFactors = FALSE
  )

  # Loop over each outcome.
  for (outcome in outcomes) {
    CGA_MWU[[outcome]] <- safe_numeric(CGA_MWU[[outcome]])
    TGA_MWU[[outcome]] <- safe_numeric(TGA_MWU[[outcome]])

    # Loop over subgroup variables from both demo_vars and clinical_vars.
    for (var in c(demo_vars, clinical_vars)) {
      if (!(var %in% names(CGA_MWU)) || !(var %in% names(TGA_MWU))) {
        cat("Warning: Variable", var, "not found in both datasets. Skipping.\n")
        next
      }

      # Determine variable type.
      var_type <- get_variable_type(c(CGA_MWU[[var]], TGA_MWU[[var]]))
      levels_var <- unique(na.omit(c(CGA_MWU[[var]], TGA_MWU[[var]])))

      # For binary subgroup variables, only analyze level == 1.
      for (lev in levels_var) {
        if (var_type == "binary" && as.numeric(as.character(lev)) != 1) next

        cg_subset <- subset(CGA_MWU, CGA_MWU[[var]] == lev)
        tg_subset <- subset(TGA_MWU, TGA_MWU[[var]] == lev)
        cg_outcome <- safe_numeric(cg_subset[[outcome]])
        tg_outcome <- safe_numeric(tg_subset[[outcome]])

        if (length(cg_outcome) >= 5 && length(tg_outcome) >= 5 &&
            !all(is.na(cg_outcome)) && !all(is.na(tg_outcome))) {
          cg_mean <- mean(cg_outcome, na.rm = TRUE)
          cg_median <- median(cg_outcome, na.rm = TRUE)
          tg_mean <- mean(tg_outcome, na.rm = TRUE)
          tg_median <- median(tg_outcome, na.rm = TRUE)
          mean_diff <- tg_mean - cg_mean

          # Initialize test result variables.
          statistic <- NA
          z_score <- NA
          ci_lower <- NA
          ci_upper <- NA

          is_binary_outcome <- all(na.omit(c(cg_outcome, tg_outcome)) %in% c(0,1))

          if (is_binary_outcome) {
            tbl <- matrix(c(
              sum(cg_outcome == 1, na.rm = TRUE),
              sum(cg_outcome == 0, na.rm = TRUE),
              sum(tg_outcome == 1, na.rm = TRUE),
              sum(tg_outcome == 0, na.rm = TRUE)
            ), nrow = 2)
            if (min(tbl) > 0) {
              test_result <- try(fisher.test(tbl), silent = TRUE)
              p_val <- if (inherits(test_result, "try-error")) NA else test_result$p.value
              test_type <- "Fisher's Exact"
            } else {
              p_val <- NA
              test_type <- "Fisher's Exact"
            }
          } else {
            test_result <- try(wilcox.test(tg_outcome, cg_outcome, exact = FALSE, correct = TRUE, conf.int = TRUE), silent = TRUE)
            p_val <- if (inherits(test_result, "try-error")) NA else test_result$p.value
            test_type <- "Mann-Whitney"
            if (!inherits(test_result, "try-error")) {
              n1 <- length(cg_outcome)
              n2 <- length(tg_outcome)
              U <- as.numeric(test_result$statistic)
              mu_U <- n1 * n2 / 2
              sigma_U <- sqrt((n1 * n2 * (n1 + n2 + 1)) / 12)
              z_score <- (U - mu_U) / sigma_U
              statistic <- U
              ci_lower <- test_result$conf.int[1]
              ci_upper <- test_result$conf.int[2]
            }
          }

          new_row <- data.frame(
            Outcome = outcome,
            Subgroup = var,
            Level = as.character(lev),
            Control_n = length(na.omit(cg_outcome)),
            Control_mean = cg_mean,
            Control_median = cg_median,
            Test_n = length(na.omit(tg_outcome)),
            Test_mean = tg_mean,
            Test_median = tg_median,
            Difference = mean_diff,
            Test_Type = test_type,
            Statistic = statistic,
            P_Value = p_val,
            Adjusted_P_Value = NA,
            Z_Score = z_score,
            CI_Lower = ci_lower,
            CI_Upper = ci_upper,
            Significance = "",
            stringsAsFactors = FALSE
          )
          new_row <- new_row[, names(all_results), drop = FALSE]
          all_results <- rbind(all_results, new_row)
        }
      }
    }
  }

  # Adjust p-values (Benjamini-Hochberg) within each outcome.
  if (nrow(all_results) > 0) {
    for (curr_outcome in unique(all_results$Outcome)) {
      outcome_indices <- all_results$Outcome == curr_outcome
      valid_p_indices <- outcome_indices & !is.na(all_results$P_Value)
      if (any(valid_p_indices)) {
        all_results$Adjusted_P_Value[valid_p_indices] <- p.adjust(all_results$P_Value[valid_p_indices], method = "BH")
      }
    }
    all_results$Significance <- sapply(all_results$Adjusted_P_Value, add_significance_markers)
  }

  all_results$CI_95 <- paste0("[", round(all_results$CI_Lower, 2), ", ", round(all_results$CI_Upper, 2), "]")
  all_results$CI_95[is.na(all_results$CI_Lower) | is.na(all_results$CI_Upper)] <- NA

  # Save the master results table.
  write.csv(all_results, paste0(output_dir, "/all_subgroup_analyses.csv"), row.names = FALSE)

  #############################
  # Create XY Line Plot Overviews with Significance Annotations
  #############################
  pdf(paste0(output_dir, "/xy_lineplot_overview.pdf"), width = 10, height = 8)
  for (curr_outcome in outcomes) {
    for (var in c(demo_vars, "comorbidity_score", "van_index")) {
      p <- create_xy_line_plot(var, curr_outcome, CGA_MWU, TGA_MWU, output_dir)
      print(p)
    }
  }
  dev.off()

  return(all_results)
}

#############################
# Usage:
#############################

result <- analyze_subgroups(
  CGA_MWU = CGA_MWU,
  TGA_MWU = TGA_MWU,
  outcomes = c("alive", "los_days", "tot_charge"),
  output_dir = "./results"
)

```

###OMIT: Mann-Whitney U Heat Map of Significant Results
```{r}
CGA_MW = CGA_MWU
TGA_MW = TGA_MWU

# Load necessary libraries
library(dplyr)
library(ggplot2)
library(reshape2)
library(viridis)
library(gridExtra)

# Comprehensive subgroup analysis function with focus on patients with "1" values
analyze_subgroups = function(CGA_MWU, TGA_MWU, outcomes = c("alive", "los_days", "tot_charge", "los_groups", "los_days_hilo", "tot_charge_hilo")) {
  # Demographic variables
  demo_vars = c("race", "ethnicity", "sex", "age", "eth_corrected", "race_corrected", "age_groups", "age_hilo")

  # Age-specific variables
  age_vars = c("age", "age_groups", "age_hilo")

  # Comorbidity and clinical variables
  clinical_vars = c("AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF",
                    "CHRNLUNG", "COAG", "DEPRESS", "DM", "DMCX", "DRUG",
                    "HTN_C", "HYPOTHY", "LIVER", "LYMPH", "LYTES", "METS",
                    "NEURO", "OBESE", "PARA", "PERIVASC", "PSYCH", "PULMCIRC",
                    "RENLFAIL", "TUMOR", "ULCER", "VALVE", "WGHTLOSS", "ARRHYTH")

  # Score variables
  score_vars = c("comorbidity_score", "van_index", "como_score_hilo", "van_index_hilo", "van_groups")

  # Master results table
  all_results = data.frame(
    Outcome = character(),
    Subgroup = character(),
    Level = character(),
    Category = character(),  # Add category for grouping
    Control_n = numeric(),
    Test_n = numeric(),
    Test_Type = character(),
    Statistic = numeric(),
    P_Value = numeric(),
    Z_Score = numeric(),
    CI_Lower = numeric(),
    CI_Upper = numeric(),
    stringsAsFactors = FALSE
  )

  # Helper function to convert values to numeric safely
  safe_numeric = function(x) {
    # First try direct conversion
    result = suppressWarnings(as.numeric(x))

    # Check for failed conversions
    if(all(is.na(result)) && !all(is.na(x))) {
      # Try to handle factors or strings
      if(is.factor(x) || is.character(x)) {
        # Check if we have values like "0", "1" that can be converted
        return(suppressWarnings(as.numeric(as.character(x))))
      }
    }

    return(result)
  }

  # Function to check if a variable is binary
  is_binary = function(x) {
    length(unique(na.omit(x))) <= 2
  }

  # Function to assign variable category
  assign_category = function(var) {
    if(var %in% age_vars) {
      return("Age")
    } else if(var %in% demo_vars) {
      return("Demographics")
    } else if(var %in% clinical_vars) {
      return("Clinical")
    } else if(var %in% score_vars) {
      return("Scores")
    } else {
      return("Other")
    }
  }

  # Process each outcome
  for(outcome in outcomes) {
    cat("\nAnalyzing outcome:", outcome, "\n")

    # Ensure outcome is numeric for appropriate tests
    CGA_MWU[[outcome]] = safe_numeric(CGA_MWU[[outcome]])
    TGA_MWU[[outcome]] = safe_numeric(TGA_MWU[[outcome]])

    # Process each variable
    all_vars = unique(c(demo_vars, clinical_vars, score_vars))
    for(var in all_vars) {
      cat("  Processing variable:", var, "\n")

      # Skip if variable doesn't exist in both datasets
      if(!(var %in% names(CGA_MWU)) || !(var %in% names(TGA_MWU))) {
        cat("    Warning: Variable", var, "not found in both datasets. Skipping.\n")
        next
      }

      # Assign category
      category = assign_category(var)

      # Ensure variable is in right format
      tryCatch({
        # Check if variable appears to be binary but is stored as character/factor
        if(is.character(CGA_MWU[[var]]) || is.factor(CGA_MWU[[var]])) {
          unique_vals = unique(na.omit(c(as.character(CGA_MWU[[var]]), as.character(TGA_MWU[[var]]))))
          if(all(unique_vals %in% c("0", "1"))) {
            # Convert to numeric
            CGA_MWU[[var]] = as.numeric(as.character(CGA_MWU[[var]]))
            TGA_MWU[[var]] = as.numeric(as.character(TGA_MWU[[var]]))
          }
        }

        # Determine variable type
        var_type = "unknown"
        if(is.factor(CGA_MWU[[var]]) || is.character(CGA_MWU[[var]])) {
          var_type = "categorical"
        } else if(is.numeric(CGA_MWU[[var]])) {
          unique_vals = unique(na.omit(c(CGA_MWU[[var]], TGA_MWU[[var]])))
          if(length(unique_vals) <= 2 && all(unique_vals %in% c(0,1))) {
            var_type = "binary"
          } else if(length(unique_vals) <= 10) {
            var_type = "discrete"
          } else {
            var_type = "continuous"
          }
        }

        # Handle according to variable type
        if(var_type == "binary") {
          # For binary variables, ONLY analyze patients with value = 1
          cg_subset = subset(CGA_MWU, CGA_MWU[[var]] == 1)
          tg_subset = subset(TGA_MWU, TGA_MWU[[var]] == 1)

          # Skip if either subset is too small
          if(nrow(cg_subset) < 5 || nrow(tg_subset) < 5) {
            cat("    Skipping", var, "- insufficient data for value 1\n")
            next
          }

          cg_outcome = safe_numeric(cg_subset[[outcome]])
          tg_outcome = safe_numeric(tg_subset[[outcome]])

          # Only analyze if we have sufficient data
          if(length(cg_outcome) >= 5 && length(tg_outcome) >= 5 &&
             !all(is.na(cg_outcome)) && !all(is.na(tg_outcome))) {

            # Determine if outcome is binary for this subgroup
            is_binary_outcome = all(na.omit(unique(c(cg_outcome, tg_outcome))) %in% c(0,1))

            if(is_binary_outcome) {
              # For binary outcomes like "alive", use Fisher's exact test
              tbl = matrix(c(
                sum(cg_outcome == 1, na.rm = TRUE),
                sum(cg_outcome == 0, na.rm = TRUE),
                sum(tg_outcome == 1, na.rm = TRUE),
                sum(tg_outcome == 0, na.rm = TRUE)
              ), nrow = 2)

              if(min(tbl) > 0) {  # Ensure no zero cells
                test_result = fisher.test(tbl)

                # Calculate proportions for CI
                prop_cg = sum(cg_outcome == 1, na.rm = TRUE) / length(na.omit(cg_outcome))
                prop_tg = sum(tg_outcome == 1, na.rm = TRUE) / length(na.omit(tg_outcome))

                # Calculate CI for difference in proportions
                n_cg = length(na.omit(cg_outcome))
                n_tg = length(na.omit(tg_outcome))

                # Standard error for difference in proportions
                se_diff = sqrt(prop_cg*(1-prop_cg)/n_cg + prop_tg*(1-prop_tg)/n_tg)

                # 95% CI for difference in proportions
                diff = prop_tg - prop_cg
                ci_lower = diff - 1.96*se_diff
                ci_upper = diff + 1.96*se_diff

                # Add to results
                all_results = rbind(all_results, data.frame(
                  Outcome = outcome,
                  Subgroup = var,
                  Level = "Present", # Indicating analysis of patients with condition present (value=1)
                  Category = category,
                  Control_n = length(na.omit(cg_outcome)),
                  Test_n = length(na.omit(tg_outcome)),
                  Test_Type = "Fisher's Exact",
                  Statistic = NA,  # Fisher's test doesn't have a standard test statistic
                  P_Value = test_result$p.value,
                  Z_Score = NA,
                  CI_Lower = ci_lower,
                  CI_Upper = ci_upper,
                  Difference = diff*100,
                  stringsAsFactors = FALSE
                ))
              }
            } else {
              # For continuous outcomes, use Mann-Whitney for patients with value = 1
              tryCatch({
                # Remove any NAs
                cg_outcome = na.omit(cg_outcome)
                tg_outcome = na.omit(tg_outcome)

                # Mann-Whitney U test with confidence interval
                mw_test = wilcox.test(tg_outcome, cg_outcome,
                                     exact = FALSE, correct = TRUE,
                                     conf.int = TRUE)

                # Calculate Z score (approximation)
                n1 = length(cg_outcome)
                n2 = length(tg_outcome)
                U = as.numeric(mw_test$statistic)

                # Z-score formula for Mann-Whitney U test
                mu_U = n1 * n2 / 2
                sigma_U = sqrt((n1 * n2 * (n1 + n2 + 1)) / 12)
                z_score = (U - mu_U) / sigma_U

                # Get confidence interval
                ci_lower = mw_test$conf.int[1]
                ci_upper = mw_test$conf.int[2]

                # Calculate difference in medians
                diff = median(tg_outcome, na.rm = TRUE) - median(cg_outcome, na.rm = TRUE)

                # Add to results
                all_results = rbind(all_results, data.frame(
                  Outcome = outcome,
                  Subgroup = var,
                  Level = "Present", # Indicating analysis of patients with condition present (value=1)
                  Category = category,
                  Control_n = n1,
                  Test_n = n2,
                  Test_Type = "Mann-Whitney",
                  Statistic = U,
                  P_Value = mw_test$p.value,
                  Z_Score = z_score,
                  CI_Lower = ci_lower,
                  CI_Upper = ci_upper,
                  Difference = diff,
                  stringsAsFactors = FALSE
                ))
              }, error = function(e) {
                cat("    Error in Mann-Whitney for", var, ":", e$message, "\n")
              })
            }
          }
        } else if(var_type %in% c("categorical", "discrete")) {
          # For categorical variables, analyze each level
          all_levels = unique(na.omit(c(CGA_MWU[[var]], TGA_MWU[[var]])))

          # Analyze each level
          for(level in all_levels) {
            # For categorical variables using exact matching
            cg_subset = subset(CGA_MWU, CGA_MWU[[var]] == level)
            tg_subset = subset(TGA_MWU, TGA_MWU[[var]] == level)

            cg_outcome = safe_numeric(cg_subset[[outcome]])
            tg_outcome = safe_numeric(tg_subset[[outcome]])

            # Only analyze if we have sufficient data
            if(length(cg_outcome) >= 5 && length(tg_outcome) >= 5 &&
               !all(is.na(cg_outcome)) && !all(is.na(tg_outcome))) {

              # Determine if outcome is binary for this subgroup
              is_binary_outcome = all(na.omit(unique(c(cg_outcome, tg_outcome))) %in% c(0,1))

              if(is_binary_outcome) {
                # For binary outcomes like "alive", use Fisher's exact test
                tbl = matrix(c(
                  sum(cg_outcome == 1, na.rm = TRUE),
                  sum(cg_outcome == 0, na.rm = TRUE),
                  sum(tg_outcome == 1, na.rm = TRUE),
                  sum(tg_outcome == 0, na.rm = TRUE)
                ), nrow = 2)

                if(min(tbl) > 0) {  # Ensure no zero cells
                  test_result = fisher.test(tbl)

                  # Calculate proportions for CI
                  prop_cg = sum(cg_outcome == 1, na.rm = TRUE) / length(na.omit(cg_outcome))
                  prop_tg = sum(tg_outcome == 1, na.rm = TRUE) / length(na.omit(tg_outcome))

                  # Calculate CI for difference in proportions
                  n_cg = length(na.omit(cg_outcome))
                  n_tg = length(na.omit(tg_outcome))

                  # Standard error for difference in proportions
                  se_diff = sqrt(prop_cg*(1-prop_cg)/n_cg + prop_tg*(1-prop_tg)/n_tg)

                  # 95% CI for difference in proportions
                  diff = prop_tg - prop_cg
                  ci_lower = diff - 1.96*se_diff
                  ci_upper = diff + 1.96*se_diff

                  # Add to results
                  all_results = rbind(all_results, data.frame(
                    Outcome = outcome,
                    Subgroup = var,
                    Level = as.character(level),
                    Category = category,
                    Control_n = length(na.omit(cg_outcome)),
                    Test_n = length(na.omit(tg_outcome)),
                    Test_Type = "Fisher's Exact",
                    Statistic = NA,
                    P_Value = test_result$p.value,
                    Z_Score = NA,
                    CI_Lower = ci_lower,
                    CI_Upper = ci_upper,
                    Difference = diff*100,
                    stringsAsFactors = FALSE
                  ))
                }
              } else {
                # For continuous outcomes, use Mann-Whitney
                tryCatch({
                  # Remove any NAs
                  cg_outcome = na.omit(cg_outcome)
                  tg_outcome = na.omit(tg_outcome)

                  # Mann-Whitney U test with confidence interval
                  mw_test = wilcox.test(tg_outcome, cg_outcome,
                                       exact = FALSE, correct = TRUE,
                                       conf.int = TRUE)

                  # Calculate Z score (approximation)
                  n1 = length(cg_outcome)
                  n2 = length(tg_outcome)
                  U = as.numeric(mw_test$statistic)

                  # Z-score formula for Mann-Whitney U test
                  mu_U = n1 * n2 / 2
                  sigma_U = sqrt((n1 * n2 * (n1 + n2 + 1)) / 12)
                  z_score = (U - mu_U) / sigma_U

                  # Get confidence interval
                  ci_lower = mw_test$conf.int[1]
                  ci_upper = mw_test$conf.int[2]

                  # Calculate difference in medians
                  diff = median(tg_outcome, na.rm = TRUE) - median(cg_outcome, na.rm = TRUE)

                  # Add to results
                  all_results = rbind(all_results, data.frame(
                    Outcome = outcome,
                    Subgroup = var,
                    Level = as.character(level),
                    Category = category,
                    Control_n = n1,
                    Test_n = n2,
                    Test_Type = "Mann-Whitney",
                    Statistic = U,
                    P_Value = mw_test$p.value,
                    Z_Score = z_score,
                    CI_Lower = ci_lower,
                    CI_Upper = ci_upper,
                    Difference = diff,
                    stringsAsFactors = FALSE
                  ))
                }, error = function(e) {
                  cat("    Error in Mann-Whitney for", var, "level", level, ":", e$message, "\n")
                })
              }
            }
          }
        } else if(var_type == "continuous") {
          # For continuous predictors, create quartiles
          all_vals = c(CGA_MWU[[var]], TGA_MWU[[var]])

          # Skip if all values are NA
          if(all(is.na(all_vals))) {
            cat("    Warning: All values are NA for", var, ". Skipping.\n")
            next
          }

          quart = suppressWarnings(quantile(all_vals, probs = c(0.25, 0.5, 0.75), na.rm = TRUE))

          # Create quartile labels
          quartile_labels = c(
            paste("Q1 (<", round(quart[1], 2), ")", sep=""),
            paste("Q2 (", round(quart[1], 2), "-", round(quart[2], 2), ")", sep=""),
            paste("Q3 (", round(quart[2], 2), "-", round(quart[3], 2), ")", sep=""),
            paste("Q4 (>", round(quart[3], 2), ")", sep="")
          )

          # Analyze by quartiles
          for(i in 1:4) {
            if(i == 1) {
              min_val = min(all_vals, na.rm = TRUE)
              max_val = quart[1]
            } else if(i == 2) {
              min_val = quart[1]
              max_val = quart[2]
            } else if(i == 3) {
              min_val = quart[2]
              max_val = quart[3]
            } else {
              min_val = quart[3]
              max_val = max(all_vals, na.rm = TRUE)
            }

            # Use inclusive ranges to avoid missing data points
            if(i == 1) {
              cg_subset = subset(CGA_MWU, CGA_MWU[[var]] <= max_val)
              tg_subset = subset(TGA_MWU, TGA_MWU[[var]] <= max_val)
            } else if(i == 4) {
              cg_subset = subset(CGA_MWU, CGA_MWU[[var]] > min_val)
              tg_subset = subset(TGA_MWU, TGA_MWU[[var]] > min_val)
            } else {
              cg_subset = subset(CGA_MWU, CGA_MWU[[var]] > min_val & CGA_MWU[[var]] <= max_val)
              tg_subset = subset(TGA_MWU, TGA_MWU[[var]] > min_val & TGA_MWU[[var]] <= max_val)
            }

            cg_outcome = safe_numeric(cg_subset[[outcome]])
            tg_outcome = safe_numeric(tg_subset[[outcome]])

            # Only analyze if we have sufficient data
            if(length(cg_outcome) >= 5 && length(tg_outcome) >= 5 &&
               !all(is.na(cg_outcome)) && !all(is.na(tg_outcome))) {

              # Determine if outcome is binary for this subgroup
              is_binary_outcome = all(na.omit(unique(c(cg_outcome, tg_outcome))) %in% c(0,1))

              if(is_binary_outcome) {
                # Binary outcome analysis (Fisher's Exact)
                tbl = matrix(c(
                  sum(cg_outcome == 1, na.rm = TRUE),
                  sum(cg_outcome == 0, na.rm = TRUE),
                  sum(tg_outcome == 1, na.rm = TRUE),
                  sum(tg_outcome == 0, na.rm = TRUE)
                ), nrow = 2)

                if(min(tbl) > 0) {  # Ensure no zero cells
                  test_result = fisher.test(tbl)

                  # Calculate proportions for CI
                  prop_cg = sum(cg_outcome == 1, na.rm = TRUE) / length(na.omit(cg_outcome))
                  prop_tg = sum(tg_outcome == 1, na.rm = TRUE) / length(na.omit(tg_outcome))

                  # Calculate CI for difference in proportions
                  n_cg = length(na.omit(cg_outcome))
                  n_tg = length(na.omit(tg_outcome))

                  # Standard error for difference in proportions
                  se_diff = sqrt(prop_cg*(1-prop_cg)/n_cg + prop_tg*(1-prop_tg)/n_tg)

                  # 95% CI for difference in proportions
                  diff = prop_tg - prop_cg
                  ci_lower = diff - 1.96*se_diff
                  ci_upper = diff + 1.96*se_diff

                  all_results = rbind(all_results, data.frame(
                    Outcome = outcome,
                    Subgroup = var,
                    Level = quartile_labels[i],
                    Category = category,
                    Control_n = length(na.omit(cg_outcome)),
                    Test_n = length(na.omit(tg_outcome)),
                    Test_Type = "Fisher's Exact",
                    Statistic = NA,
                    P_Value = test_result$p.value,
                    Z_Score = NA,
                    CI_Lower = ci_lower,
                    CI_Upper = ci_upper,
                    Difference = diff*100,
                    stringsAsFactors = FALSE
                  ))
                }
              } else {
                # Continuous outcome analysis (Mann-Whitney)
                tryCatch({
                  # Mann-Whitney U test
                  mw_test = wilcox.test(tg_outcome, cg_outcome,
                                       exact = FALSE, correct = TRUE,
                                       conf.int = TRUE)

                  # Calculate Z score
                  n1 = length(cg_outcome)
                  n2 = length(tg_outcome)
                  U = as.numeric(mw_test$statistic)

                  mu_U = n1 * n2 / 2
                  sigma_U = sqrt((n1 * n2 * (n1 + n2 + 1)) / 12)
                  z_score = (U - mu_U) / sigma_U

                  # Get confidence interval
                  ci_lower = mw_test$conf.int[1]
                  ci_upper = mw_test$conf.int[2]

                  # Calculate difference in medians
                  diff = median(tg_outcome, na.rm = TRUE) - median(cg_outcome, na.rm = TRUE)

                  all_results = rbind(all_results, data.frame(
                    Outcome = outcome,
                    Subgroup = var,
                    Level = quartile_labels[i],
                    Category = category,
                    Control_n = n1,
                    Test_n = n2,
                    Test_Type = "Mann-Whitney",
                    Statistic = U,
                    P_Value = mw_test$p.value,
                    Z_Score = z_score,
                    CI_Lower = ci_lower,
                    CI_Upper = ci_upper,
                    Difference = diff,
                    stringsAsFactors = FALSE
                  ))
                }, error = function(e) {
                  cat("    Error in Mann-Whitney for", var, "quartile", i, ":", e$message, "\n")
                })
              }
            }
          }
        }
      }, error = function(e) {
        cat("    Error processing variable", var, ":", e$message, "\n")
      })
    }
  }

  # Apply Benjamini-Hochberg correction for multiple testing
  if(nrow(all_results) > 0) {
    valid_indices = !is.na(all_results$P_Value)
    all_results$Adjusted_P_Value = NA
    all_results$Adjusted_P_Value[valid_indices] = p.adjust(all_results$P_Value[valid_indices], method = "BH")
  }

  # Sort by adjusted p-value (or original p-value if no adjustment)
  if("Adjusted_P_Value" %in% names(all_results)) {
    all_results = all_results[order(all_results$Adjusted_P_Value), ]
  } else {
    all_results = all_results[order(all_results$P_Value), ]
  }

  # Add formatted CI column
  all_results$CI_95 = paste0("[",
                           round(all_results$CI_Lower, 2), ", ",
                           round(all_results$CI_Upper, 2), "]")
  all_results$CI_95[is.na(all_results$CI_Lower) | is.na(all_results$CI_Upper)] = NA

  return(all_results)
}

# Create category-specific heatmap
create_category_heatmap = function(results, category, p_threshold = 0.05,
                                  output_file = NULL, max_items = 20) {
  # Filter for the specific category
  cat_results = results[results$Category == category, ]

  # Filter for significant results using adjusted p-values if available
  if("Adjusted_P_Value" %in% names(cat_results)) {
    sig_results = cat_results[!is.na(cat_results$Adjusted_P_Value) & cat_results$Adjusted_P_Value < p_threshold, ]
    p_value_col = "Adjusted_P_Value"
    p_label = "Adjusted P-Value"
  } else {
    sig_results = cat_results[!is.na(cat_results$P_Value) & cat_results$P_Value < p_threshold, ]
    p_value_col = "P_Value"
    p_label = "P-Value"
  }

  # If no significant results, use top results
  if(nrow(sig_results) == 0) {
    cat("No significant differences found for", category, "(p <", p_threshold, ").\n")

    # Use all category results instead, showing top ones by p-value
    sig_results = cat_results[order(cat_results$P_Value), ]
    if(nrow(sig_results) > max_items) {
      sig_results = head(sig_results, max_items)
    }
    p_value_col = "P_Value"
    p_label = "P-Value"
  } else {
    # Select top results for readability
    sig_results = sig_results[order(sig_results[,p_value_col]), ]
    if(nrow(sig_results) > max_items) {
      sig_results = head(sig_results, max_items)
    }
  }

  # If no results at all for this category, return NULL
  if(nrow(sig_results) == 0) {
    cat("No results available for category:", category, "\n")
    return(NULL)
  }

  # Create more readable labels
  sig_results$GroupLabel = paste(sig_results$Subgroup, sig_results$Level, sep=": ")

  # Shorten labels if too long (>40 chars)
  sig_results$GroupLabel = ifelse(nchar(sig_results$GroupLabel) > 40,
                                 paste0(substr(sig_results$GroupLabel, 1, 38), "..."),
                                 sig_results$GroupLabel)

  # Add direction of effect column
  sig_results$Direction = ifelse(is.na(sig_results$Z_Score), "Unknown",
                               ifelse(sig_results$Z_Score > 0, "NonMultimorbidity > Multimorbidity",
                                    "Multimorbidity > NonMultimorbidity"))

  # Add significance stars
  sig_results$Stars = ""
  sig_results$Stars[sig_results[,p_value_col] < 0.05] = "*"
  sig_results$Stars[sig_results[,p_value_col] < 0.01] = "**"
  sig_results$Stars[sig_results[,p_value_col] < 0.001] = "***"

  # Format p-values for display
  sig_results$P_Display = sprintf("%.3f %s", sig_results[,p_value_col], sig_results$Stars)
  sig_results$P_Display[sig_results[,p_value_col] < 0.001] = paste0("<0.001 ", sig_results$Stars[sig_results[,p_value_col] < 0.001])

  # Create a heatmap with orange-blue color scheme
  p_value_map = ggplot(sig_results, aes(x = Outcome, y = reorder(GroupLabel, -get(p_value_col)))) +
    geom_tile(aes(fill = -log10(get(p_value_col))), color = "white", width = 0.9, height = 0.9) +
    geom_text(aes(label = P_Display), size = 3.5) +
    scale_fill_gradient2(name = "-log10(p-value)",
                        low = "blue", mid = "white", high = "orange",
                        midpoint = -log10(0.05)) +
   theme_minimal() +
    theme(
      axis.text.x = element_text(angle = 45, hjust = 1, size = 11, face = "bold"),
      axis.text.y = element_text(size = 10),
      axis.title = element_text(face = "bold", size = 12),
      plot.title = element_text(hjust = 0.5, face = "bold", size = 14),
      plot.subtitle = element_text(hjust = 0.5, size = 11),
      panel.grid.major = element_blank(),
      panel.grid.minor = element_blank(),
      legend.position = "right",
      plot.margin = margin(20, 20, 20, 20)
    ) +
    labs(title = paste0(category, " Variables: Differences Between Groups"),
         subtitle = paste0(p_label, " with significance (* p<0.05, ** p<0.01, *** p<0.001)"),
         x = "Outcome Measure", y = "Patient Subgroup")

  # Create a separate direction of effect map with blue-orange color scheme
  direction_map = ggplot(sig_results, aes(x = Outcome, y = reorder(GroupLabel, -get(p_value_col)))) +
    geom_tile(aes(fill = Direction), color = "white", width = 0.9, height = 0.9) +
    geom_text(aes(label = sprintf("n: %d vs %d", Control_n, Test_n)), size = 3.2) +
    scale_fill_manual(values = c(
      "NonMultimorbidity > Multimorbidity" = "blue",
      "Multimorbidity > NonMultimorbidity" = "orange",
      "Unknown" = "#AAAAAA"
    )) +
    theme_minimal() +
    theme(
      axis.text.x = element_text(angle = 45, hjust = 1, size = 11, face = "bold"),
      axis.text.y = element_text(size = 10),
      axis.title = element_text(face = "bold", size = 12),
      plot.title = element_text(hjust = 0.5, face = "bold", size = 14),
      panel.grid.major = element_blank(),
      panel.grid.minor = element_blank(),
      plot.margin = margin(20, 20, 20, 20)
    ) +
    labs(title = paste0(category, ": Direction & Sample Sizes"),
         x = "Outcome Measure", y = "Patient Subgroup")

  # If an output file is specified, save the plots
  if(!is.null(output_file)) {
    # Create a multi-panel figure
    pdf(output_file, width = 10, height = max(6, nrow(sig_results) * 0.4))
    print(p_value_map)
    print(direction_map)
    dev.off()

    # Also save as PNG for easy viewing
    png_file = gsub("\\.pdf$", ".png", output_file)
    png(png_file, width = 1200, height = max(800, nrow(sig_results) * 60), res = 100)
    grid.arrange(p_value_map, direction_map, ncol = 1, heights = c(1, 0.8))
    dev.off()
  }

  # Return the plot objects in a list
  return(list(p_value_map = p_value_map, direction_map = direction_map))
}

# Create all category heatmaps
create_all_category_heatmaps = function(results, p_threshold = 0.05, max_items = 15) {
  # Get all categories
  categories = unique(results$Category)

  # Create a directory for heatmaps if it doesn't exist
  dir.create("category_heatmaps", showWarnings = FALSE)

  # Store all plot objects
  all_plots = list()

  # Process each category
  for(category in categories) {
    cat("Creating heatmap for category:", category, "\n")

    # Create output filename
    output_file = paste0("category_heatmaps/", tolower(gsub(" ", "_", category)), "_heatmap.pdf")

    # Create the heatmap
    category_plots = create_category_heatmap(
      results = results,
      category = category,
      p_threshold = p_threshold,
      output_file = output_file,
      max_items = max_items
    )

    # Store the plots if created
    if(!is.null(category_plots)) {
      all_plots[[category]] = category_plots
    }
  }

  # Also create outcome-specific heatmaps for each category
  outcomes = unique(results$Outcome)
  for(outcome in outcomes) {
    outcome_results = results[results$Outcome == outcome, ]

    if(nrow(outcome_results) > 0) {
      # Create a directory for this outcome
      outcome_dir = paste0("category_heatmaps/", tolower(gsub(" ", "_", outcome)))
      dir.create(outcome_dir, showWarnings = FALSE)

      # Process each category for this outcome
      for(category in categories) {
        # Filter for this category and outcome
        cat_outcome_results = outcome_results[outcome_results$Category == category, ]

        if(nrow(cat_outcome_results) > 0) {
          cat("Creating heatmap for outcome:", outcome, "category:", category, "\n")

          # Create output filename
          output_file = paste0(outcome_dir, "/", tolower(gsub(" ", "_", category)), "_heatmap.pdf")

          # Create the heatmap
          create_category_heatmap(
            results = cat_outcome_results,
            category = category,
            p_threshold = p_threshold,
            output_file = output_file,
            max_items = max_items
          )
        }
      }
    }
  }

  return(all_plots)
}

# Main function to run everything
run_full_analysis = function(p_threshold = 0.05, max_items = 15) {
  # Check if the dataframes exist
  if(!exists("CGA_MWU") || !exists("TGA_MWU")) {
    stop("Error: CGA_MWU and TGA_MWU dataframes must exist in the environment.")
  }

  # Print dimensions of the dataframes
  cat("CGA_MWU dimensions:", nrow(CGA_MWU), "rows,", ncol(CGA_MWU), "columns\n")
  cat("TGA_MWU dimensions:", nrow(TGA_MWU), "rows,", ncol(TGA_MWU), "columns\n")

  # Run analysis
  cat("Running analysis with focus on 1's...\n")
  results = analyze_subgroups(CGA_MWU, TGA_MWU)

  # Save results
  cat("Saving results...\n")
  write.csv(results, "all_subgroup_outcomes_analysis.csv", row.names = FALSE)

  # Create separate heatmaps for each category
  cat("Creating category-specific heatmaps...\n")
  all_plots = create_all_category_heatmaps(results, p_threshold, max_items)

  # Create outcome-specific result files
  cat("Creating outcome-specific result files...\n")
  outcomes = unique(results$Outcome)
  for(outcome in outcomes) {
    outcome_results = results[results$Outcome == outcome, ]
    write.csv(outcome_results, paste0(outcome, "_analysis.csv"), row.names = FALSE)

    # Print significant results for this outcome (using adjusted p-values)
    cat("\nSignificant results for", outcome, "(adjusted p < 0.05):\n")
    sig_results = outcome_results[outcome_results$Adjusted_P_Value < 0.05, ]
    if(nrow(sig_results) > 0) {
      print(sig_results[, c("Subgroup", "Level", "Category", "Control_n", "Test_n", "Test_Type", "P_Value", "Adjusted_P_Value", "Z_Score", "CI_95")])
    } else {
      cat("No significant results found after adjustment for multiple testing.\n")

      # Show top results even if none are significant
      cat("\nTop results for", outcome, "(unadjusted):\n")
      top_results = head(outcome_results[order(outcome_results$P_Value), ], 5)
      print(top_results[, c("Subgroup", "Level", "Category", "Control_n", "Test_n", "Test_Type", "P_Value", "Adjusted_P_Value", "Z_Score", "CI_95")])
    }
  }

  # Print overall summary
  cat("\nAnalysis complete.\n")
  cat("Files created:\n")
  cat("- all_subgroup_outcomes_analysis.csv (complete results)\n")
  cat("- Outcome-specific CSV files\n")
  cat("- Category-specific heatmaps in the 'category_heatmaps' directory\n")
  cat("- Outcome and category-specific heatmaps in subdirectories\n")

  # Display the most significant category heatmap
  most_sig_category = unique(results$Category)[1]  # Default to first category
  if("Age" %in% unique(results$Category)) {
    most_sig_category = "Age"  # Prioritize Age if available
  } else if("Clinical" %in% unique(results$Category)) {
    most_sig_category = "Clinical"  # Or Clinical
  }

  if(most_sig_category %in% names(all_plots)) {
    cat("\nDisplaying heatmap for", most_sig_category, "category...\n")
    print(all_plots[[most_sig_category]]$p_value_map)
  }

  return(results)
}

 results = run_full_analysis()
```

###OMIT: MWU: New Script 03272025 Part 1: Analysis Only
```{r}
# Script 1: Subgroup Statistical Tests and CSV Output

# Load required library
library(dplyr)

##########################
# Helper Functions
##########################

# Remove extreme outliers: values below the 1st or above the 99th percentile become NA.
remove_outliers <- function(x, lower_quantile = 0.01, upper_quantile = 0.99) {
  q <- quantile(x, probs = c(lower_quantile, upper_quantile), na.rm = TRUE)
  x[x < q[1] | x > q[2]] <- NA
  return(x)
}


# Check if a variable is categorical.
is_categorical <- function(x) {
  if (is.factor(x) || is.character(x)) return(TRUE)
  if (is.numeric(x)) {
    uv <- unique(na.omit(x))
    return(length(uv) <= 10 && all(uv == round(uv)))
  }
  return(FALSE)
}

# Determine variable type.
get_variable_type <- function(x) {
  if (is_binary(x)) return("binary")
  else if (is_categorical(x)) return("categorical")
  else return("continuous")
}

# Assign significance marker based on p-value.
add_significance_markers <- function(p_value) {
  if (is.na(p_value)) return("")
  else if (p_value < 0.001) return("***")
  else if (p_value < 0.01) return("**")
  else if (p_value < 0.05) return("*")
  else return("")
}

##########################
# Main Analysis Function: produce CSV output only
##########################

analyze_subgroups_tests <- function(CGA_MWU, TGA_MWU,
                              outcomes = c("alive", "los_days", "tot_charge"),
                              output_dir = "./results") {
  if (!dir.exists(output_dir)) dir.create(output_dir, recursive = TRUE)

  # Define subgroup variables.
  demo_vars <- c("sex", "eth_corrected", "race_corrected", "age_groups")
  clinical_vars <- c("AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF",
                     "CHRNLUNG", "COAG", "DEPRESS", "DM", "DMCX", "DRUG",
                     "HTN_C", "HYPOTHY", "LIVER", "LYMPH", "LYTES", "METS",
                     "NEURO", "OBESE", "PARA", "PERIVASC", "PSYCH", "PULMCIRC",
                     "RENLFAIL", "TUMOR", "ULCER", "VALVE", "WGHTLOSS", "ARRHYTH",
                     "comorbidity_score", "van_index")

  all_results <- data.frame(
    Outcome = character(),
    Subgroup = character(),
    Level = character(),
    Control_n = integer(),
    Control_mean = numeric(),
    Control_median = numeric(),
    Test_n = integer(),
    Test_mean = numeric(),
    Test_median = numeric(),
    Difference = numeric(),
    Test_Type = character(),
    Statistic = numeric(),
    P_Value = numeric(),
    Adjusted_P_Value = numeric(),
    Z_Score = numeric(),
    CI_Lower = numeric(),
    CI_Upper = numeric(),
    Significance = character(),
    stringsAsFactors = FALSE
  )

  for (outcome in outcomes) {
    # Convert outcome variables and remove outliers.
    CGA_MWU[[outcome]] <- remove_outliers(safe_numeric(CGA_MWU[[outcome]]))
    TGA_MWU[[outcome]] <- remove_outliers(safe_numeric(TGA_MWU[[outcome]]))

    for (var in c(demo_vars, clinical_vars)) {
      if (!(var %in% names(CGA_MWU)) || !(var %in% names(TGA_MWU))) {
        cat("Warning:", var, "not found in both datasets. Skipping.\n")
        next
      }
      var_type <- get_variable_type(c(CGA_MWU[[var]], TGA_MWU[[var]]))
      levels_var <- unique(na.omit(c(CGA_MWU[[var]], TGA_MWU[[var]])))

      # For binary subgroup variables, only use level == 1.
      for (lev in levels_var) {
        if (var_type == "binary" && as.numeric(as.character(lev)) != 1) next

        cg_subset <- subset(CGA_MWU, CGA_MWU[[var]] == lev)
        tg_subset <- subset(TGA_MWU, TGA_MWU[[var]] == lev)
        cg_outcome <- safe_numeric(cg_subset[[outcome]])
        tg_outcome <- safe_numeric(tg_subset[[outcome]])

        if (length(cg_outcome) >= 5 && length(tg_outcome) >= 5 &&
            !all(is.na(cg_outcome)) && !all(is.na(tg_outcome))) {
          cg_mean <- mean(cg_outcome, na.rm = TRUE)
          cg_median <- median(cg_outcome, na.rm = TRUE)
          tg_mean <- mean(tg_outcome, na.rm = TRUE)
          tg_median <- median(tg_outcome, na.rm = TRUE)
          mean_diff <- tg_mean - cg_mean

          statistic <- NA
          z_score <- NA
          ci_lower <- NA
          ci_upper <- NA

          is_binary_outcome <- all(na.omit(c(cg_outcome, tg_outcome)) %in% c(0,1))

          if (is_binary_outcome) {
            tbl <- matrix(c(
              sum(cg_outcome == 1, na.rm = TRUE),
              sum(cg_outcome == 0, na.rm = TRUE),
              sum(tg_outcome == 1, na.rm = TRUE),
              sum(tg_outcome == 0, na.rm = TRUE)
            ), nrow = 2)
            if (min(tbl) > 0) {
              test_result <- try(fisher.test(tbl), silent = TRUE)
              p_val <- if (inherits(test_result, "try-error")) NA else test_result$p.value
              test_type <- "Fisher's Exact"
            } else {
              p_val <- NA
              test_type <- "Fisher's Exact"
            }
          } else {
            test_result <- try(wilcox.test(tg_outcome, cg_outcome, exact = FALSE, correct = TRUE, conf.int = TRUE), silent = TRUE)
            p_val <- if (inherits(test_result, "try-error")) NA else test_result$p.value
            test_type <- "Mann-Whitney"
            if (!inherits(test_result, "try-error")) {
              n1 <- length(cg_outcome)
              n2 <- length(tg_outcome)
              U <- as.numeric(test_result$statistic)
              mu_U <- n1 * n2 / 2
              sigma_U <- sqrt((n1 * n2 * (n1 + n2 + 1)) / 12)
              z_score <- (U - mu_U) / sigma_U
              statistic <- U
              ci_lower <- test_result$conf.int[1]
              ci_upper <- test_result$conf.int[2]
            }
          }

          new_row <- data.frame(
            Outcome = outcome,
            Subgroup = var,
            Level = as.character(lev),
            Control_n = length(na.omit(cg_outcome)),
            Control_mean = cg_mean,
            Control_median = cg_median,
            Test_n = length(na.omit(tg_outcome)),
            Test_mean = tg_mean,
            Test_median = tg_median,
            Difference = mean_diff,
            Test_Type = test_type,
            Statistic = statistic,
            P_Value = p_val,
            Adjusted_P_Value = NA,
            Z_Score = z_score,
            CI_Lower = ci_lower,
            CI_Upper = ci_upper,
            Significance = "",
            stringsAsFactors = FALSE
          )
          new_row <- new_row[, names(all_results), drop = FALSE]
          all_results <- rbind(all_results, new_row)
        }
      }
    }
  }

  # Adjust p-values (Benjamini-Hochberg) within each outcome.
  if (nrow(all_results) > 0) {
    for (curr_outcome in unique(all_results$Outcome)) {
      outcome_indices <- all_results$Outcome == curr_outcome
      valid_p_indices <- outcome_indices & !is.na(all_results$P_Value)
      if (any(valid_p_indices)) {
        all_results$Adjusted_P_Value[valid_p_indices] <-
          p.adjust(all_results$P_Value[valid_p_indices], method = "BH")
      }
    }
    all_results$Significance <- sapply(all_results$Adjusted_P_Value, add_significance_markers)
  }

  all_results$CI_95 <- paste0("[", round(all_results$CI_Lower, 2), ", ", round(all_results$CI_Upper, 2), "]")
  all_results$CI_95[is.na(all_results$CI_Lower) | is.na(all_results$CI_Upper)] <- NA

  # Save the master results table.
  write.csv(all_results, paste0(output_dir, "/all_subgroup_analyses.csv"), row.names = FALSE)

  return(all_results)
}


CGA_MWU <- CGA_MWU
TGA_MWU <- TGA_MWU

# Here we assume they are already loaded.
result_tests <- analyze_subgroups_tests(
  CGA_MWU = CGA_MWU,
  TGA_MWU = TGA_MWU,
  outcomes = c("alive", "los_days", "tot_charge"),
  output_dir = "./results"
)

```

###OMIT: MWU Plot Code
```{r}
# ===============================
# Script: Statistical Analysis and Visualization for Multimorbidity Study
# ===============================

# Load required libraries
library(dplyr)
library(ggplot2)

##########################
# Helper Functions
##########################

# Function to perform Benjamini-Hochberg correction
apply_bh_correction <- function(p_values) {
  n <- length(p_values)
  p_order <- order(p_values)
  ranks <- rank(p_values)
  bh_adjusted <- p_values[p_order] * n / ranks
  bh_adjusted <- pmin(bh_adjusted, 1)  # Ensure values are capped at 1
  return(bh_adjusted[order(p_order)])
}

# Function to add significance markers
add_significance_markers <- function(p_value) {
  if (is.na(p_value)) return("")
  else if (p_value < 0.001) return("***")
  else if (p_value < 0.01) return("**")
  else if (p_value < 0.05) return("*")
  else return("")
}

# Safe min and max functions to handle empty or all-NA inputs
safe_min <- function(x) {
  if (all(is.na(x))) {
    warning("safe_min: All values are NA; returning Inf.")
    return(Inf)
  }
  return(min(x, na.rm = TRUE))
}

safe_max <- function(x) {
  if (all(is.na(x))) {
    warning("safe_max: All values are NA; returning -Inf.")
    return(-Inf)
  }
  return(max(x, na.rm = TRUE))
}

##########################
# Main Analysis and Visualization Function
##########################

analyze_and_visualize <- function(CGA_MWU, TGA_MWU, outcomes, demo_vars, clinical_vars, output_dir = "./results") {
  # Combine data from both groups
  CGA_MWU$group <- "Non-Multimorbidity"
  TGA_MWU$group <- "Multimorbidity"
  combined_df <- rbind(CGA_MWU, TGA_MWU)

  # Ensure output directory exists
  if (!dir.exists(output_dir)) dir.create(output_dir, recursive = TRUE)

  # Initialize results storage
  results <- data.frame(Variable = character(), RawPValue = numeric(), AdjustedPValue = numeric(), Significance = character(), stringsAsFactors = FALSE)

  # Define all variables to analyze
  all_vars <- c(demo_vars, clinical_vars)

  # Iterate through each variable
  for (var in all_vars) {
    message("Processing variable: ", var)

    # Skip variable if it has only NA values
    if (all(is.na(combined_df[[var]]))) {
      warning("Variable '", var, "' contains only NA values. Skipping.")
      next
    }

    # Skip variable if it has insufficient unique values
    if (is.numeric(combined_df[[var]]) && length(unique(na.omit(combined_df[[var]]))) < 2) {
      warning("Variable '", var, "' has fewer than two unique values. Skipping.")
      next
    }
    if (all(na.omit(combined_df[[var]]) %in% c(0, 1)) && length(unique(na.omit(combined_df[[var]]))) < 2) {
      warning("Variable '", var, "' does not have both binary levels (0 and 1). Skipping.")
      next
    }

    # Remove rows with NA values for group or variable
    valid_rows <- !is.na(combined_df$group) & !is.na(combined_df[[var]])
    sub_df <- combined_df[valid_rows, ]

    # Check length consistency
    if (length(sub_df$group) != length(sub_df[[var]])) {
      warning("Mismatch in lengths for 'group' and variable '", var, "'. Skipping.")
      next
    }

    # Determine statistical test
    if (is.numeric(sub_df[[var]])) {
      # Mann-Whitney U Test for numeric variables
      group1 <- sub_df[sub_df$group == "Non-Multimorbidity", var]
      group2 <- sub_df[sub_df$group == "Multimorbidity", var]
      test_result <- try(wilcox.test(group1, group2, exact = FALSE, correct = TRUE), silent = TRUE)
      p_value <- if (inherits(test_result, "try-error")) NA else test_result$p.value
    } else if (all(na.omit(sub_df[[var]]) %in% c(0, 1))) {
      # Fisher's Exact Test for binary variables
      tbl <- table(sub_df$group, sub_df[[var]])
      if (min(tbl) > 0) {  # Ensure no zero-count cells
        test_result <- try(fisher.test(tbl), silent = TRUE)
        p_value <- if (inherits(test_result, "try-error")) NA else test_result$p.value
      } else {
        p_value <- NA
      }
    } else {
      warning("Variable '", var, "' is neither numeric nor binary. Skipping.")
      next
    }

    # Save raw p-value
    results <- rbind(results, data.frame(Variable = var, RawPValue = p_value, AdjustedPValue = NA, Significance = "", stringsAsFactors = FALSE))
  }

  # Apply Benjamini-Hochberg correction
  results$AdjustedPValue <- apply_bh_correction(results$RawPValue)
  results$Significance <- sapply(results$AdjustedPValue, add_significance_markers)

  # Save results to CSV
  write.csv(results, file = paste0(output_dir, "/statistical_results.csv"), row.names = FALSE)

  # Create plots for outcomes
  for (outcome in outcomes) {
    if (is.numeric(combined_df[[outcome]])) {
      p <- ggplot(combined_df, aes_string(x = "group", y = outcome, fill = "group")) +
        geom_boxplot() +
        scale_fill_manual(values = c("Non-Multimorbidity" = "orange", "Multimorbidity" = "blue")) +
        theme_minimal() +
        labs(
          title = paste("Outcome:", outcome),
          x = "Group",
          y = paste("Mean", outcome)
        )

      # Save plot
      ggsave(filename = paste0(output_dir, "/", outcome, "_plot.png"), plot = p, width = 8, height = 6, dpi = 300)
    }
  }

  return(results)
}

##########################
# Example Execution
##########################

# Define input datasets and variables
# Replace these with your actual data
CGA_MWU <- data.frame(sex = sample(c(0, 1), 50, replace = TRUE), alive = rnorm(50), los_days = rnorm(50), tot_charge = rnorm(50))
TGA_MWU <- data.frame(sex = sample(c(0, 1), 50, replace = TRUE), alive = rnorm(50), los_days = rnorm(50), tot_charge = rnorm(50))

outcomes <- c("alive", "los_days", "tot_charge")
demo_vars <- c("sex", "eth_corrected", "race_corrected", "age_groups")
clinical_vars <- c("AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF",
                   "CHRNLUNG", "COAG", "DEPRESS", "DM", "DMCX", "DRUG",
                   "HTN_C", "HYPOTHY", "LIVER", "LYMPH", "LYTES", "METS",
                   "NEURO", "OBESE", "PARA", "PERIVASC", "PSYCH", "PULMCIRC",
                   "RENLFAIL", "TUMOR", "ULCER", "VALVE", "WGHTLOSS", "ARRHYTH",
                   "comorbidity_score", "van_index")

# Run the analysis
results <- analyze_and_visualize(CGA_MWU, TGA_MWU, outcomes, demo_vars, clinical_vars)

```


### Explanation of Libraries Used

This notebook utilizes several R libraries for data manipulation, statistical analysis, and visualization:

- **`readr`**: For efficient reading of various data formats, like CSV files.
- **`dplyr`**: A powerful package for data manipulation, providing functions for filtering, selecting, arranging, and summarizing data.
- **`stats`**: R's built-in package for statistical computations, including functions for statistical tests.
- **`tidyr`**: Used for tidying data, making it easier to reshape and organize.
- **`bnlearn`**: A package for Bayesian network learning and inference.
- **`viridisLite`**: Provides colorblind-friendly color maps.
- **`knitr`**: Used for dynamic report generation in R.
- **`stringr`**: For easy string manipulation.
- **`ggplot2`**: A widely used package for creating informative and aesthetically pleasing statistical graphics.
- **`reshape2`**: Used for reshaping data between wide and long formats.
- **`gridExtra`**: For arranging multiple ggplot2 plots on a single page.
- **`psych`**: Provides functions for personality, psychometrics, and experimental psychology, including functions for correlation analysis like `tetrachoric`.

### Part 1: Descriptives

This section focuses on characterizing the chronic diseases by comparing the control (non-multimorbidity) and test (multimorbidity) groups using descriptive statistics and visualizations.

**Count of Chronic Diseases in the Combinations Data**

This code snippet filters a dataframe (`TGA_back`) based on patient IDs present in another dataframe (`second_dataframe`) to create a subset (`TGA_SCD`).

**Descriptive Statistics (Control: Nonmultimorbidity (CG2) & Test: Multimorbidity (TG2))**

This code defines functions to calculate descriptive statistics (mean, median, mode, etc.) for numeric and categorical variables and then applies these functions to the `CGA` and `TGA` dataframes. The results are printed and saved to CSV files.

**Bar Plot: Age Groups**

This code generates a bar plot comparing the distribution of age groups between the non-multimorbidity and multimorbidity groups.

**Race Barplot**

This code generates a bar plot comparing the distribution of race between the non-multimorbidity and multimorbidity groups.

**Ethnicity Barplot**

This code generates a bar plot comparing the distribution of ethnicity between the non-multimorbidity and multimorbidity groups.

**Sex Barplot**

This code generates a bar plot comparing the distribution of sex between the non-multimorbidity and multimorbidity groups.

**Comorbidity Score & Van Index Barplots**

This code generates bar plots comparing the distribution of comorbidity scores and Van indexes between the non-multimorbidity and multimorbidity groups.

### Part 2: Correlations

This section explores the correlations and relationships between chronic diseases and combinations within the multimorbidity group, as well as correlations with demographic subgroups.

**HeatMap of Chronic Disease Proportions between Groups**

This code generates a heatmap visualizing the proportions of different chronic diseases in the non-multimorbidity and multimorbidity groups.

**Histograms of the Combinations of Chronic Diseases Present in Multimorbidity Group by type (dyad, etc.)**

This code categorizes chronic disease combinations by the number of diseases (dyad, triad, tetrad, pentad) and generates histograms to visualize the most frequent combinations within the multimorbidity group for each category.

**HOT! Correlations based on subgroups**

This section contains R scripts to analyze chronic disease combinations by demographic subgroups, generating heatmaps and bar plots to visualize the frequency and distribution of combinations across different ethnic, racial, age, and sex groups. It also includes a function to compare a specific combination across all demographic variables.

**Correlations based on subgroups**

This R script generates heatmaps showing the frequency of chronic disease combinations (dyads, triads, and combined tetrads/pentads) by demographic subgroups.

### Part 2b: Tetrachoric Correlation Matrix for Each Chronic Disease against others (singularly)

This section focuses on calculating and visualizing the tetrachoric correlations between individual chronic diseases within both the non-multimorbidity and multimorbidity groups.

This code calculates the tetrachoric correlation matrix for chronic diseases in both the `CGA` and `TGA` dataframes and generates heatmaps to visualize these correlations.

### Part 2c: Correlation Matrix for Combinations vs Chronic Diseases

This section examines the correlations between specific chronic disease combinations and individual chronic diseases within the multimorbidity group.

This code identifies the top chronic disease combinations, binarizes them, and then calculates the tetrachoric correlations between these top combinations and the individual chronic diseases within the `TGA` dataframe. Heatmaps are generated to visualize these correlations, highlighting strong relationships.

### Part 3: Mann-Whitney U

This section performs Mann-Whitney U tests to compare hospitalization outcomes (e.g., length of stay, total charges, mortality) between the non-multimorbidity and multimorbidity groups, stratified by various demographic and clinical subgroups.

**Mann Whitney U Test for nonparametric data (differences in hospitalization outcomes by subgroups, age, race, sex, ethnicity, chronic diseases)**

This code performs Mann-Whitney U tests or Fisher's Exact tests (depending on the outcome type) to compare outcomes between the two groups across various subgroups. It generates heatmaps and a summary table PDF to visualize the results.

### Explanation of Libraries Used

This notebook utilizes several R libraries for data manipulation, statistical analysis, and visualization:

- **`readr`**: For efficient reading of various data formats, like CSV files.
- **`dplyr`**: A powerful package for data manipulation, providing functions for filtering, selecting, arranging, and summarizing data.
- **`stats`**: R's built-in package for statistical computations, including functions for statistical tests.
- **`tidyr`**: Used for tidying data, making it easier to reshape and organize.
- **`bnlearn`**: A package for Bayesian network learning and inference.
- **`viridisLite`**: Provides colorblind-friendly color maps.
- **`knitr`**: Used for dynamic report generation in R.
- **`stringr`**: For easy string manipulation.
- **`ggplot2`**: A widely used package for creating informative and aesthetically pleasing statistical graphics.
- **`reshape2`**: Used for reshaping data between wide and long formats.
- **`gridExtra`**: For arranging multiple ggplot2 plots on a single page.
- **`psych`**: Provides functions for personality, psychometrics, and experimental psychology, including functions for correlation analysis like `tetrachoric`.

In [None]:
#Call the packages and load the data
library(readr)
library(dplyr)
library(stats)
library(tidyr)
library(bnlearn)
library(viridisLite)
library(knitr)
library(stringr)
library(ggplot2)

datapath2 = normalizePath("C:/Users/taiqu/Box/01-TaiR-Dissertation-FALL2024/08-0325-RScripts")

### Part 1: Descriptives

This section focuses on characterizing the chronic diseases by comparing the control (non-multimorbidity) and test (multimorbidity) groups using descriptive statistics and visualizations.

**Count of Chronic Diseases in the Combinations Data**

This code snippet filters a dataframe (`TGA_back`) based on patient IDs present in another dataframe (`second_dataframe`) to create a subset (`TGA_SCD`).

In [None]:
# Load necessary library
library(dplyr)

# Subset the second dataframe based on values in the reference dataframe
TGA_SCD <- TGA_back %>%
  filter(de_id_mrn %in% second_dataframe$de_id_mrn)

# View the new dataframe
head(TGA_SCD)

**Descriptive Statistics (Control: Nonmultimorbidity (CG2) & Test: Multimorbidity (TG2))**

This code defines functions to calculate descriptive statistics (mean, median, mode, etc.) for numeric and categorical variables and then applies these functions to the `CGA` and `TGA` dataframes. The results are printed and saved to CSV files.

In [None]:
# Function to calculate mode
calculate_mode = function(x) {
  uniq = unique(na.omit(x))  # Remove NAs and get unique values
  uniq[which.max(tabulate(match(x, uniq)))]  # Return the most frequent value
}

# Function to generate descriptive statistics
generate_descriptive_stats = function(df) {
  # Exclude "encdates" from analyses
  df_analysis = df[, !colnames(df) %in% "encdates"]

  # Separate numeric and categorical variables
  numeric_vars = df_analysis[sapply(df_analysis, is.numeric)]
  categorical_vars = df_analysis[sapply(df_analysis, is.factor) | sapply(df_analysis, is.character)]

  # Identify binary numeric variables and include only "Yes" (1s) in categorical stats
  binary_stats = lapply(names(numeric_vars), function(var) {
    x = numeric_vars[[var]]
    if (all(na.omit(x) %in% c(0, 1))) {
      # Calculate stats for "Yes" (1)
      count_1s = sum(x == 1, na.rm = TRUE)
      proportion_1s = mean(x == 1, na.rm = TRUE)
      data.frame(
        Variable = var,
        Category = "Yes",
        Count = count_1s,
        Proportion = round(proportion_1s, 3),
        stringsAsFactors = FALSE
      )
    } else {
      NULL  # Exclude non-binary variables
    }
  })
  binary_stats = do.call(rbind, binary_stats)  # Combine all binary stats into one data frame if any

  # Categorical statistics (excluding binary variables)
  categorical_stats = lapply(names(categorical_vars), function(var) {
    x = categorical_vars[[var]]
    if (!is.null(x)) {
      freq_table = table(x, useNA = "ifany")
      data.frame(
        Variable = var,
        Category = names(freq_table),
        Count = as.numeric(freq_table),
        Proportion = round(as.numeric(prop.table(freq_table)), 3),
        stringsAsFactors = FALSE
      )
    } else {
      NULL
    }
  })
  categorical_stats = do.call(rbind, categorical_stats)  # Combine all categorical stats into one data frame

  # Combine binary stats with categorical stats if binary stats exist
  combined_categorical_stats = if (!is.null(binary_stats)) {
    rbind(categorical_stats, binary_stats)
  } else {
    categorical_stats
  }

  # Numeric statistics
  numeric_stats = data.frame(
    Variable = names(numeric_vars),
    Missing = sapply(numeric_vars, function(x) sum(is.na(x))),
    Mean = sapply(numeric_vars, mean, na.rm = TRUE),
    Median = sapply(numeric_vars, median, na.rm = TRUE),
    Std_Dev = sapply(numeric_vars, sd, na.rm = TRUE),
    Mode = sapply(numeric_vars, calculate_mode),
    Min = sapply(numeric_vars, min, na.rm = TRUE),
    Max = sapply(numeric_vars, max, na.rm = TRUE),
    IQR = sapply(numeric_vars, function(x) {
      Q1 = quantile(x, 0.25, na.rm = TRUE)
      Q3 = quantile(x, 0.75, na.rm = TRUE)
      paste0(round(Q1, 2), "-", round(Q3, 2))  # Return the range as "Q1-Q3"
    })
  )

  # Return combined results
  list(numeric = numeric_stats, categorical = combined_categorical_stats)
}

# Format and print tables
print_stats = function(stats, dataset_name) {
  cat("\nDescriptive Statistics for", dataset_name, "Numeric Variables:\n")
  print(stats$numeric)

  cat("\nFrequency Distribution (Including Yes for Binary Variables) for", dataset_name, "Categorical Variables:\n")
  print(stats$categorical)
}

# Generate stats for CGA and TGA
CGA_stats = generate_descriptive_stats(CGA)
TGA_stats = generate_descriptive_stats(TGA)

# Print and save results
print_stats(CGA_stats, "CGA")
print_stats(TGA_stats, "TGA")

write.csv(CGA_stats$numeric, "CGA_numeric_stats.csv", row.names = FALSE)
write.csv(CGA_stats$categorical, "CGA_categorical_stats.csv", row.names = FALSE)

write.csv(TGA_stats$numeric, "TGA_numeric_stats.csv", row.names = FALSE)
write.csv(TGA_stats$categorical, "TGA_categorical_stats.csv", row.names = FALSE)

**Bar Plot: Age Groups**

This code generates a bar plot comparing the distribution of age groups between the non-multimorbidity and multimorbidity groups.

In [None]:
# Load necessary library
library(ggplot2)

# Add a `group` column to differentiate the datasets
CGA$group = "Non-Multimorbidity"
TGA$group = "Multimorbidity"

# Combine the datasets into one for plotting
combined_data = rbind(
  data.frame(age_groups = CGA$age_groups, group = CGA$group),
  data.frame(age_groups = TGA$age_groups, group = TGA$group)
)

# Preprocess the data to calculate proportions
combined_data_summary = aggregate(
  count ~ group + age_groups,
  data = transform(combined_data, count = 1),
  FUN = length
)
combined_data_summary = transform(
  combined_data_summary,
  proportion = ave(count, group, FUN = function(x) x / sum(x))
)

# Overlay plot with y-axis ticks as raw proportions
plot = ggplot(combined_data_summary, aes(x = age_groups, y = proportion, fill = group)) +
  geom_bar(stat = "identity", position = "dodge", alpha = 1, color = NA) +
  geom_text(
  aes(label = round(proportion, 2)), # Display raw proportions rounded to 2 decimal places
  position = position_dodge(width = 1.2),
  vjust = -0.5,
  size = 3.5
  ) +
  scale_fill_manual(values = c("blue", "orange")) +
  labs(
    title = "Age Group Distribution: Patients With Multimorbidity vs Non-Multimorbidity",
    x = "Age Group (years)",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 50)
  ) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) # Default formatting for raw proportions

# Save the plot
ggsave("age_group_distribution.png", plot = plot, width = 8, height = 6, dpi = 300)

**Race Barplot**

This code generates a bar plot comparing the distribution of race between the non-multimorbidity and multimorbidity groups.

In [None]:
# Load necessary library
library(ggplot2)

# Add a `group` column to differentiate the datasets
CGA$group = "Non-Multimorbidity"
TGA$group = "Multimorbidity"

# Combine the datasets into one for plotting
combined_data = rbind(
  data.frame(race_corrected = CGA$race_corrected, group = CGA$group),
  data.frame(race_corrected = TGA$race_corrected, group = TGA$group)
)

# Preprocess the data to calculate proportions
combined_data_summary = aggregate(
  count ~ group + race_corrected,
  data = transform(combined_data, count = 1),
  FUN = length
)
combined_data_summary = transform(
  combined_data_summary,
  proportion = ave(count, group, FUN = function(x) x / sum(x))
)

# Overlay plot with y-axis ticks as raw proportions
plot = ggplot(combined_data_summary, aes(x = race_corrected, y = proportion, fill = group)) +
  geom_bar(stat = "identity", position = "dodge", alpha = 1, color = NA) +
  geom_text(
  aes(label = round(proportion, 2)), # Display raw proportions rounded to 2 decimal places
  position = position_dodge(width = 1.2),
  vjust = -0.5,
  size = 3.5
  ) +
  scale_fill_manual(values = c("blue", "orange")) +
  labs(
    title = "Distribution of Race for Patients With Multimorbidity vs Non-Multimorbidity",
    x = "Race",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 50)
  ) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) # Default formatting for raw proportions

# Save the plot
ggsave("race_distribution.png", plot = plot, width = 8, height = 6, dpi = 300)

**Ethnicity Barplot**

This code generates a bar plot comparing the distribution of ethnicity between the non-multimorbidity and multimorbidity groups.

In [None]:
# Add a `group` column to differentiate the datasets
CGA$group = "Non-Multimorbidity"
TGA$group = "Multimorbidity"

# Combine the datasets into one for plotting
combined_data = rbind(
  data.frame(eth_corrected = CGA$eth_corrected, group = CGA$group),
  data.frame(eth_corrected = TGA$eth_corrected, group = TGA$group)
)

# Preprocess the data to calculate proportions
combined_data_summary = aggregate(
  count ~ group + eth_corrected,
  data = transform(combined_data, count = 1),
  FUN = length
)
combined_data_summary = transform(
  combined_data_summary,
  proportion = ave(count, group, FUN = function(x) x / sum(x))
)

# Overlay plot with y-axis ticks as raw proportions
plot = ggplot(combined_data_summary, aes(x = eth_corrected, y = proportion, fill = group)) +
  geom_bar(stat = "identity", position = "dodge", alpha = 1, color = NA) +
  geom_text(
  aes(label = round(proportion, 2)), # Display raw proportions rounded to 2 decimal places
  position = position_dodge(width = 1.2),
  vjust = -0.5,
  size = 3.5
  ) +
  scale_fill_manual(values = c("blue", "orange")) +
  labs(
    title = "Distribution of Ethnicity for Patients With Multimorbidity vs Non-Multimorbidity",
    x = "Ethnicity",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 50)
  ) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) # Default formatting for raw proportions

# Save the plot
ggsave("ethnicity_distribution.png", plot = plot, width = 8, height = 6, dpi = 300)

**Sex Barplot**

This code generates a bar plot comparing the distribution of sex between the non-multimorbidity and multimorbidity groups.

In [None]:
# Combine the datasets into one for plotting
combined_data = rbind(
  data.frame(sex = CGA$sex, group = CGA$group),
  data.frame(sex = TGA$sex, group = TGA$group)
)

# Preprocess the data to calculate proportions
combined_data_summary = aggregate(
  count ~ group + sex,
  data = transform(combined_data, count = 1),
  FUN = length
)
combined_data_summary = transform(
  combined_data_summary,
  proportion = ave(count, group, FUN = function(x) x / sum(x))
)

# Overlay plot with y-axis ticks as raw proportions
plot = ggplot(combined_data_summary, aes(x = sex, y = proportion, fill = group)) +
  geom_bar(stat = "identity", position = "dodge", alpha = 1, color = NA) +
  geom_text(
  aes(label = round(proportion, 2)), # Display raw proportions rounded to 2 decimal places
  position = position_dodge(width = 1.2),
  vjust = -0.5,
  size = 3.5
  ) +
  scale_fill_manual(values = c("blue", "orange")) +
  labs(
    title = "Distribution of Sex Patients With Multimorbidity vs Non-Multimorbidity",
    x = "Sex",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text = element_text(size = 10),
    legend.title = element_text(size = 10),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 50)
  ) +
  scale_y_continuous(expand = expansion(mult = c(0, 0.05))) # Default formatting for raw proportions

# Save the plot
ggsave("sex_distribution.png", plot = plot, width = 8, height = 6, dpi = 300)

**Comorbidity Score & Van Index Barplots**

This code generates bar plots comparing the distribution of comorbidity scores and Van indexes between the non-multimorbidity and multimorbidity groups.

In [None]:
# Combine the datasets for analysis
combined_data = rbind(
  data.frame(variable = CGA$comorbidity_score, group = "Non-Multimorbidity", name = "comorbidity_score"),
  data.frame(variable = TGA$comorbidity_score, group = "Multimorbidity", name = "comorbidity_score"),
  data.frame(variable = CGA$van_index, group = "Non-Multimorbidity", name = "van_index"),
  data.frame(variable = TGA$van_index, group = "Multimorbidity", name = "van_index")
)

# Calculate proportions for each variable and group
combined_data_summary = combined_data %>%
  group_by(name, group, variable) %>%
  summarise(count = n(), .groups = "drop") %>% # Avoid grouped output
  mutate(proportion = count / sum(count))

# Define colors for groups
group_colors = c("Non-Multimorbidity" = "orange", "Multimorbidity" = "blue")

# Plot for comorbidity_score
plot_comorbidity = ggplot(
  filter(combined_data_summary, name == "comorbidity_score"),
  aes(x = as.factor(variable), y = proportion, fill = group)
) +
  geom_bar(stat = "identity", position = "dodge") +
  scale_fill_manual(values = group_colors) + # Apply custom colors
  labs(
    title = "Proportion of Comorbidity Scores: Multimorbidity vs Non-Multimorbidity",
    x = "Comorbidity Score",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal()

# Plot for van_index
plot_van_index = ggplot(
  filter(combined_data_summary, name == "van_index"),
  aes(x = as.factor(variable), y = proportion, fill = group)
) +
  geom_bar(stat = "identity", position = "dodge") +
  scale_fill_manual(values = group_colors) + # Apply custom colors
  labs(
    title = "Proportion of Van Indexes: Multimorbidity vs Non-Multimorbidity",
    x = "Van Index",
    y = "Proportion",
    fill = "Group"
  ) +
  theme_minimal()

# Save the plots
ggsave("comorbidity_score_proportion_plot.png", plot = plot_comorbidity, width = 8, height = 6, dpi = 300)
ggsave("van_index_proportion_plot.png", plot = plot_van_index, width = 8, height = 6, dpi = 300)

## Part 2: Correlations

This section explores the correlations and relationships between chronic diseases and combinations within the multimorbidity group, as well as correlations with demographic subgroups.

**HeatMap of Chronic Disease Proportions between Groups**

This code generates a heatmap visualizing the proportions of different chronic diseases in the non-multimorbidity and multimorbidity groups.

In [None]:
#Get list of header names
colnames(TGA)

#Generate Heat Map
# Load necessary libraries
library(ggplot2)
library(reshape2)  # For reshaping data

# List of variables representing chronic diseases
disease_variables = c("AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF",
                      "CHRNLUNG", "COAG", "DEPRESS", "DM", "DMCX", "DRUG",
                      "HTN_C", "HYPOTHY", "LIVER", "LYMPH", "LYTES", "METS",
                      "NEURO", "OBESE", "PARA", "PERIVASC", "PSYCH", "PULMCIRC",
                      "RENLFAIL", "TUMOR", "ULCER", "VALVE", "WGHTLOSS", "ARRHYTH")

# Calculate proportions for Multimorbidity (TGA) and Non-Multimorbidity (CGA)
TGA_proportions = sapply(disease_variables, function(var) mean(TGA[[var]], na.rm = TRUE))
CGA_proportions = sapply(disease_variables, function(var) mean(CGA[[var]], na.rm = TRUE))

# Combine data into a single dataframe
combined_data = data.frame(
  disease = disease_variables,
  Multimorbidity = TGA_proportions,
  Non_Multimorbidity = CGA_proportions
)

# Reshape data for heatmap (long format)
melted_data = melt(combined_data, id.vars = "disease", variable.name = "group", value.name = "proportion")

# Create heatmap
heatmap_plot = ggplot(melted_data, aes(x = group, y = disease, fill = proportion)) +
  geom_tile(color = "white") +  # Create tiles
  geom_text(aes(label = round(proportion, 2)), size = 3.5) +  # Add proportion numbers
  scale_fill_gradient(low = "blue", high = "orange", name = "Proportion") +  # Gradient fill
  labs(
    title = "Comparison of Chronic Disease Proportions: Multimorbidity vs Non-Multimorbidity",
    x = "Group",
    y = "Chronic Diseases"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text.x = element_text(size = 10),
    axis.text.y = element_text(size = 10),
    legend.title = element_text(size = 10),
    plot.margin = margin(t = 10, r = 10, b = 10, l = 10)
  )

# Save the heatmap
ggsave("chronic_disease_heatmap.png", plot = heatmap_plot, width = 10, height = 8, dpi = 300)

**Histograms of the Combinations of Chronic Diseases Present in Multimorbidity Group by type (dyad, etc.)**

This code categorizes chronic disease combinations by the number of diseases (dyad, triad, tetrad, pentad) and generates histograms to visualize the most frequent combinations within the multimorbidity group for each category.

In [None]:
# Load necessary libraries
library(ggplot2)

# Step 1: Categorize combinations by the number of "+" signs
TGA$combination_type = sapply(gregexpr("\\+", TGA$combinations), function(x) length(x[x > 0]))

# Label the categories
TGA$combination_category = ifelse(
  TGA$combination_type == 1, "Dyad",
  ifelse(TGA$combination_type == 2, "Triad",
  ifelse(TGA$combination_type == 3, "Tetrad",
  ifelse(TGA$combination_type >= 4, "Pentad", "Other")))
)

# Step 2: Calculate proportions for each unique combination
combination_counts = aggregate(TGA$combinations, by = list(category = TGA$combination_category, combination = TGA$combinations), FUN = length)
colnames(combination_counts) = c("combination_category", "combinations", "count")
combination_counts$proportion = combination_counts$count / nrow(TGA)

# Step 3: Split data into categories
dyads = subset(combination_counts, combination_category == "Dyad")
triads = subset(combination_counts, combination_category == "Triad")
tetrads_and_pentads = subset(combination_counts, combination_category %in% c("Tetrad", "Pentad"))

# Sort and select top 20 for dyads and triads
dyads_top20 = head(dyads[order(-dyads$proportion), ], 20)
triads_top20 = head(triads[order(-triads$proportion), ], 20)

# Combine tetrads and pentads
tetrads_and_pentads_top = tetrads_and_pentads[order(-tetrads_and_pentads$proportion), ]

# Step 4: Create Histograms

# Histogram for Dyads
histogram_dyads = ggplot(dyads_top20, aes(x = reorder(combinations, -proportion), y = proportion, fill = "Dyad")) +
  geom_bar(stat = "identity", color = "black", alpha = 0.8) +
  geom_text(aes(label = round(proportion, 2)), vjust = -0.5, size = 3.5) +  # Add values above bars
  scale_fill_manual(values = c("Dyad" = "blue")) +
  labs(
    title = "Top 20 Chronic Disease Dyad Combinations in Patients with Multimorbidity",
    x = "Chronic Disease Combinations",
    y = "Proportion"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text.x = element_text(size = 10, angle = 45, hjust = 1),
    legend.position = "none"
  )

# Histogram for Triads
histogram_triads = ggplot(triads_top20, aes(x = reorder(combinations, -proportion), y = proportion, fill = "Triad")) +
  geom_bar(stat = "identity", color = "black", alpha = 0.8) +
  geom_text(aes(label = round(proportion, 2)), vjust = -0.5, size = 3.5) +  # Add values above bars
  scale_fill_manual(values = c("Triad" = "orange")) +
  labs(
    title = "Top 20 Chronic Disease Triad Combinations in Patients with Multimorbidity",
    x = "Chronic Disease Combinations",
    y = "Proportion"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text.x = element_text(size = 10, angle = 45, hjust = 1),
    legend.position = "none"
  )

# Histogram for Tetrads and Pentads
histogram_tetrads_pentads = ggplot(tetrads_and_pentads_top, aes(x = reorder(combinations, -proportion), y = proportion, fill = combination_category)) +
  geom_bar(stat = "identity", color = "black", alpha = 0.8) +
  geom_text(aes(label = round(proportion, 2)), vjust = -0.5, size = 3.5) +  # Add values above bars
  scale_fill_manual(values = c("Tetrad" = "blue", "Pentad" = "orange")) +
  labs(
    title = "Top 20 Chronic Disease Tetrad & Pentad Combinations in Patients with Multimorbidity",
    x = "Chronic Disease Combinations",
    y = "Proportion",
    fill = "Combination Type"
  ) +
  theme_minimal() +
  theme(
    plot.title = element_text(hjust = 0.5, size = 14, face = "bold"),
    axis.text.x = element_text(size = 10, angle = 45, hjust = 1),
    legend.title = element_text(size = 10)
  )

# Save the histograms
ggsave("top20_dyads_histogram.png", plot = histogram_dyads, width = 14, height = 8, dpi = 300)
ggsave("top20_triads_histogram.png", plot = histogram_triads, width = 14, height = 8, dpi = 300)
ggsave("tetrads_pentads_histogram.png", plot = histogram_tetrads_pentads, width = 14, height = 8, dpi = 300)

**HOT! Correlations based on subgroups**

This section contains R scripts to analyze chronic disease combinations by demographic subgroups, generating heatmaps and bar plots to visualize the frequency and distribution of combinations across different ethnic, racial, age, and sex groups. It also includes a function to compare a specific combination across all demographic variables.

In [None]:
# R Script for Chronic Disease Combination Analysis by Demographic Subgroups
# Focuses on combination frequencies across demographic groups

library(ggplot2)
library(reshape2)
library(dplyr)
library(gridExtra)

# List of demographic subgroup columns
demographic_vars = c("eth_corrected", "race_corrected", "age_groups", "sex")

# Function to generate heatmap showing combination frequencies by subgroup
generate_subgroup_heatmap = function(data, category, subgroup_var,
                                    top_n = 20, title = NULL) {

  # Default title if none provided
  if (is.null(title)) {
    title = paste0("Top ", top_n, " ", category, " combinations by ", subgroup_var)
  }

  # Filter data for the specific combination category
  subset_data = data[data$combo_cat == category, ]

  # Skip if no data for this category
  if (nrow(subset_data) == 0) {
    message(paste("No data for", category))
    return(NULL)
  }

  # Group by subgroup and combination, then count occurrences
  combo_counts = subset_data %>%
    group_by(!!sym(subgroup_var), combinations) %>%
    summarise(count = n(), .groups = "drop") %>%
    arrange(!!sym(subgroup_var), desc(count))

  # Get top N combinations overall to ensure consistent combinations across subgroups
  overall_top_combos = subset_data %>%
    group_by(combinations) %>%
    summarise(count = n(), .groups = "drop") %>%
    arrange(desc(count)) %>%
    slice_head(n = top_n) %>%
    pull(combinations)

  # Filter to only include top combinations
  combo_counts = combo_counts %>%
    filter(combinations %in% overall_top_combos)

  # Create a complete grid of all subgroup-combination pairs
  subgroups = unique(subset_data[[subgroup_var]])

  # Create empty matrix for the heatmap
  result_matrix = matrix(0,
                        nrow = length(subgroups),
                        ncol = length(overall_top_combos))
  rownames(result_matrix) = subgroups
  colnames(result_matrix) = overall_top_combos

  # Fill matrix with counts
  for (i in 1:nrow(combo_counts)) {
    row = combo_counts[i, ]
    sg = row[[subgroup_var]]
    combo = row$combinations
    count = row$count
    result_matrix[sg, combo] = count
  }

  # Calculate proportions within each subgroup
  # First get total counts per subgroup
  subgroup_totals = subset_data %>%
    group_by(!!sym(subgroup_var)) %>%
    summarise(total = n(), .groups = "drop")

  # Create proportions matrix
  prop_matrix = result_matrix
  for (sg in subgroups) {
    total = subgroup_totals$total[subgroup_totals[[subgroup_var]] == sg]
    if (total > 0) {
      prop_matrix[sg, ] = result_matrix[sg, ] / total
    }
  }

  # Melt matrix for ggplot
  melted_data = melt(prop_matrix)
  names(melted_data) = c(subgroup_var, "Combination", "Proportion")

  # Create heatmap
  p = ggplot(melted_data, aes_string(x = "Combination", y = subgroup_var, fill = "Proportion")) +
    geom_tile(color = "white", linewidth = 0.2) +
    scale_fill_gradient(low = "white", high = "darkred",
                       name = "Proportion",
                       limits = c(0, max(melted_data$Proportion)),
                       guide = guide_colorbar(title.position = "top")) +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1, size = 8),
          axis.text.y = element_text(size = 10),
          axis.title = element_blank(),
          panel.grid = element_blank(),
          panel.border = element_blank(),
          panel.background = element_blank(),
          plot.title = element_text(hjust = 0.5, size = 12, face = "bold"),
          plot.margin = margin(10, 10, 10, 10)) +
    ggtitle(title)

  return(p)
}

# Function to generate frequency distribution of combinations by subgroup
generate_frequency_barplot = function(data, category, subgroup_var, top_n = 10,
                                     title = NULL) {

  # Default title if none provided
  if (is.null(title)) {
    title = paste0("Frequency of top ", top_n, " ", category, " combinations by ", subgroup_var)
  }

  # Filter data for the specific combination category
  subset_data = data[data$combo_cat == category, ]

  # Get top N combinations overall
  overall_top_combos = subset_data %>%
    group_by(combinations) %>%
    summarise(count = n(), .groups = "drop") %>%
    arrange(desc(count)) %>%
    slice_head(n = top_n) %>%
    pull(combinations)

  # Filter to top combinations and count by subgroup
  plot_data = subset_data %>%
    filter(combinations %in% overall_top_combos) %>%
    group_by(!!sym(subgroup_var), combinations) %>%
    summarise(count = n(), .groups = "drop")

  # Convert to percentages within each subgroup
  plot_data = plot_data %>%
    group_by(!!sym(subgroup_var)) %>%
    mutate(percentage = count / sum(count) * 100) %>%
    ungroup()

  # Create barplot
  p = ggplot(plot_data, aes_string(x = "combinations", y = "percentage", fill = subgroup_var)) +
    geom_bar(stat = "identity", position = "dodge") +
    scale_fill_brewer(palette = "Set1") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1, size = 8),
          legend.title = element_blank(),
          plot.title = element_text(hjust = 0.5, size = 12, face = "bold")) +
    labs(title = title,
         x = "Disease Combination",
         y = "Percentage within Subgroup")

  return(p)
}

# Function to compare a specific combination across all demographic variables
compare_combo_across_demographics = function(data, combination, title = NULL) {

  # Default title
  if (is.null(title)) {
    title = paste0("Distribution of '", combination, "' across demographic subgroups")
  }

  # Filter data for the specific combination
  subset_data = data[data$combinations == combination, ]

  # Create empty list to store plots
  plots = list()

  # Create a plot for each demographic variable
  for (demo_var in demographic_vars) {
    # Count occurrences by subgroup
    counts = subset_data %>%
      group_by(!!sym(demo_var)) %>%
      summarise(count = n(), .groups = "drop")

    # Calculate percentages
    counts$percentage = counts$count / sum(counts$count) * 100

    # Create plot
    p = ggplot(counts, aes_string(x = demo_var, y = "percentage", fill = demo_var)) +
      geom_bar(stat = "identity") +
      scale_fill_brewer(palette = "Set2") +
      theme_minimal() +
      theme(axis.text.x = element_text(angle = 45, vjust = 0.5, hjust = 1),
            legend.position = "none",
            plot.title = element_text(size = 10)) +
      labs(title = paste0("By ", demo_var),
           y = "Percentage",
           x = NULL)

    plots[[demo_var]] = p
  }

  # Arrange plots in a grid
  combined_plot = do.call(grid.arrange, c(plots, top = title, ncol = 2))

  return(combined_plot)
}

# Main function to analyze disease combinations by demographic subgroups
analyze_disease_combinations_by_demographics = function(df, output_dir = NULL) {

  # Create output directory if specified and doesn't exist
  if (!is.null(output_dir)) {
    if (!dir.exists(output_dir)) {
      dir.create(output_dir, recursive = TRUE)
    }
  }

  # Make sure combinations column is character type
  df$combinations = as.character(df$combinations)
  df$combo_cat = as.character(df$combo_cat)

  # Ensure demographic variables are factors
  for (var in demographic_vars) {
    if (var %in% names(df)) {
      df[[var]] = factor(df[[var]])
    } else {
      warning(paste("Variable", var, "not found in the dataset"))
    }
  }

  # Results storage
  all_plots = list()

  # Generate heatmaps for each demographic variable and combination category
  for (category in c("dyad", "triad", "tetrad", "pentad")) {
    for (demo_var in demographic_vars) {
      # Skip if demographic variable not in dataframe
      if (!(demo_var %in% names(df))) next

      # Generate heatmap
      plot_title = paste0("Top ", category, " combinations by ", demo_var)
      p = generate_subgroup_heatmap(df, category, demo_var, top_n = 15, title = plot_title)

      # Save plot if not NULL
      if (!is.null(p)) {
        plot_name = paste0(category, "_by_", demo_var)
        all_plots[[plot_name]] = p

        # Save to file if output directory specified
        if (!is.null(output_dir)) {
          ggsave(file.path(output_dir, paste0(plot_name, ".png")),
                p, width = 12, height = 8, dpi = 300)
        }
      }

      # Generate frequency barplot
      p_bar = generate_frequency_barplot(df, category, demo_var, top_n = 5)

      # Save plot if not NULL
      if (!is.null(p_bar)) {
        plot_name = paste0(category, "_freq_by_", demo_var)
        all_plots[[plot_name]] = p_bar

        # Save to file if output directory specified
        if (!is.isNull(output_dir)) {
          ggsave(file.path(output_dir, paste0(plot_name, ".png")),
                p_bar, width = 12, height = 8, dpi = 300)
        }
      }
    }
  }

  # Find top 5 most common combinations overall to analyze across demographics
  top_combos = df %>%
    group_by(combinations) %>%
    summarise(count = n(), .groups = "drop") %>%
    arrange(desc(count)) %>%
    slice_head(n = 5) %>%
    pull(combinations)

  # Compare top combinations across all demographic variables
  for (combo in top_combos) {
    p_demo = compare_combo_across_demographics(df, combo)

    # Save plot
    plot_name = paste0("combo_", gsub("\\+", "_", combo), "_demographics")
    all_plots[[plot_name]] = p_demo

    # Save to file if output directory specified
    if (!is.null(output_dir)) {
      ggsave(file.path(output_dir, paste0(plot_name, ".png")),
            p_demo, width = 10, height = 8, dpi = 300)
    }
  }

  # Return all plots
  return(all_plots)
}

plots = analyze_disease_combinations_by_demographics(TGA, output_dir = "output")

**Correlations based on subgroups**

This R script generates heatmaps showing the frequency of chronic disease combinations (dyads, triads, and combined tetrads/pentads) by demographic subgroups.

In [None]:
# R Script for Chronic Disease Combination Analysis by Demographic Subgroups
# Analyzes combinations by "eth_corrected", "race_corrected", "age_group", "sex"

library(ggplot2)
library(reshape2)

# List of demographic subgroups
demo_vars = c("eth_corrected", "race_corrected", "age_groups", "sex")

# Correct Age Group Order
cor_age_order = c("18-29", "30-39", "40-49", "50-59", "60-69", "70-79", "80-89")
# Convert Age Group to a factor so that it is ordered correctly
TGA$age_groups = factor(TGA$age_groups, levels = cor_age_order)

# Function to generate heatmap showing frequency of combinations by subgroup
generate_subgroup_heatmap = function(data, category, subgroup_var,
                                     top_n = NULL, title = "Disease Combinations Heatmap") {

  # Filter data for the specific combination category
  subset_data = data[data$combo_cat == category, ]

  # Calculate frequency of each combination
  combo_counts = table(subset_data$combinations)

  # If top_n is specified, get only top combinations
  if (!is.null(top_n) && length(combo_counts) > top_n) {
    # Get top N most frequent combinations
    top_combos = names(sort(combo_counts, decreasing = TRUE)[1:min(top_n, length(combo_counts))])
    # Filter data to keep only rows with top combinations
    subset_data = subset_data[subset_data$combinations %in% top_combos, ]
  }

  # Get unique combinations after potential filtering
  unique_combos = unique(subset_data$combinations)

  # Get unique subgroup values
  subgroup_values = unique(subset_data[[subgroup_var]])

  # Create matrix to store counts for each combination by subgroup
  combo_matrix = matrix(0, nrow = length(subgroup_values), ncol = length(unique_combos))
  rownames(combo_matrix) = subgroup_values
  colnames(combo_matrix) = unique_combos

  # Count combinations for each subgroup
  for (i in 1:length(subgroup_values)) {
    sg = subgroup_values[i]
    sg_data = subset(subset_data, subset_data[[subgroup_var]] == sg)

    for (j in 1:length(unique_combos)) {
      combo = unique_combos[j]
      combo_matrix[i, j] = sum(sg_data$combinations == combo)
    }
  }

  # Initialize prop_matrix as a copy of combo_matrix
  prop_matrix = combo_matrix

  # Calculate proportions based on total counts for each combination
  combo_total_counts = colSums(combo_matrix)  # Get total counts for each combination
  for (i in 1:nrow(combo_matrix)) {
    for (j in 1:ncol(combo_matrix)) {
      if (combo_total_counts[j] > 0) {
        prop_matrix[i, j] = combo_matrix[i, j] / combo_total_counts[j]  # Proportion by total
      }
    }
  }

  # Melt matrix for ggplot
  melted_data = melt(prop_matrix)
  names(melted_data) = c("Subgroup", "Combination", "Proportion")

  # Ensure melted_data preserves the age group factor levels
if (subgroup_var == "age_groups") {
  melted_data$Subgroup = factor(melted_data$Subgroup, levels = cor_age_order)
}

  # Create heatmap with values displayed
  p = ggplot(melted_data, aes(x = Combination, y = Subgroup, fill = Proportion)) +
    geom_tile(color = "black", linewidth = 0.1) +
    # Add text labels with proportion values
    geom_text(aes(label = sprintf("%.2f", Proportion)),
              color = ifelse(melted_data$Proportion > 0.5, "white", "black"),
              size = 3.0) +
    scale_fill_gradient(low = "blue", high = "orange",
                        name = "Proportion",
                        guide = guide_colorbar(title.position = "top")) +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1, size = 10),
          axis.text.y = element_text(size = 10),
          axis.title = element_blank(),
          panel.grid = element_blank(),
          panel.border = element_blank(),
          panel.background = element_blank(),
          plot.title = element_text(hjust = 0.5, size = 12, face = "bold"),
          plot.margin = margin(10, 10, 10, 10)) +
    ggtitle(title)

  return(p)
}

# Function to combine tetrads and pentads in one heatmap by subgroup
generate_combined_subgroup_heatmap = function(data, subgroup_var,
                                              title = "Tetrads and Pentads by Subgroup") {

  # Filter data for tetrads and pentads
  subset_data = data[data$combo_cat %in% c("tetrad", "pentad"), ]

  # Calculate frequency of each combination
  combo_counts = table(subset_data$combinations)

  # Get top 20 combinations overall
  if (length(combo_counts) > 20) {
    top_combos = names(sort(combo_counts, decreasing = TRUE)[1:20])
    # Filter data to keep only rows with top combinations
    subset_data = subset_data[subset_data$combinations %in% top_combos, ]
  }

  # Get unique combinations after filtering
  unique_combos = unique(subset_data$combinations)

  # Get unique subgroup values
  subgroup_values = unique(subset_data[[subgroup_var]])

  # Create matrix to store counts for each combination by subgroup
  combo_matrix = matrix(0, nrow = length(subgroup_values), ncol = length(unique_combos))
  rownames(combo_matrix) = subgroup_values
  colnames(combo_matrix) = unique_combos

  # Count combinations for each subgroup
  for (i in 1:length(subgroup_values)) {
    sg = subgroup_values[i]
    sg_data = subset(subset_data, subset_data[[subgroup_var]] == sg)

    for (j in 1:length(unique_combos)) {
      combo = unique_combos[j]
      combo_matrix[i, j] = sum(sg_data$combinations == combo)
    }
  }

  # Initialize prop_matrix as a copy of combo_matrix
  prop_matrix = combo_matrix

  # Calculate proportions based on total counts for each combination
  combo_total_counts = colSums(combo_matrix)  # Get total counts for each combination
  for (i in 1:nrow(combo_matrix)) {
    for (j in 1:ncol(combo_matrix)) {
      if (combo_total_counts[j] > 0) {
        prop_matrix[i, j] = combo_matrix[i, j] / combo_total_counts[j]  # Proportion by total
      }
    }
  }

  # Melt matrix for ggplot
  melted_data = melt(prop_matrix)
  names(melted_data) = c("Subgroup", "Combination", "Proportion")

  # Ensure melted_data preserves the age group factor levels
if (subgroup_var == "age_groups") {
  melted_data$Subgroup = factor(melted_data$Subgroup, levels = cor_age_order)
}

  # Create heatmap with values displayed
  p = ggplot(melted_data, aes(x = Combination, y = Subgroup, fill = Proportion)) +
    geom_tile(color = "black", linewidth = 0.1) +
    # Add text labels with proportion values
    geom_text(aes(label = sprintf("%.2f", Proportion)),
              color = ifelse(melted_data$Proportion > 0.5, "white", "black"),
              size = 2.5) +
    scale_fill_gradient(low = "blue", high = "orange",
                        name = "Proportion",
                        guide = guide_colorbar(title.position = "top")) +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1, size = 8),
          axis.text.y = element_text(size = 10),
          axis.title = element_blank(),
          panel.grid = element_blank(),
          panel.border = element_blank(),
          panel.background = element_blank(),
          plot.title = element_text(hjust = 0.5, size = 12, face = "bold"),
          plot.margin = margin(10, 10, 10, 10)) +
    ggtitle(title)

  return(p)
}

# Main function - assign directly to your dataframe
generate_disease_subgroup_heatmaps = function(df,
                                              dyad_title = "Top 20 Disease Dyads",
                                              triad_title = "Top 20 Disease Triads",
                                              combined_title = "Disease Tetrads and Pentads") {

  # Make sure combinations column is character type
  df$combinations = as.character(df$combinations)
  df$combo_cat = as.character(df$combo_cat)

  # Check that demographic variables exist in the dataframe
  available_demo_vars = demo_vars[demo_vars %in% names(df)]

  if (length(available_demo_vars) == 0) {
    stop("None of the demographic variables found in the dataframe")
  }

  # Store all plots in a list
  all_plots = list()

  # Generate heatmaps for each demographic variable
  for (demo_var in available_demo_vars) {

    # Dyad heatmap by subgroup
    dyad_title_sg = paste0(dyad_title, " by ", demo_var)
    dyad_plot = generate_subgroup_heatmap(df, "dyad", demo_var, top_n = 20, title = dyad_title_sg)
    ggsave(paste0("dyad_heatmap_by_", demo_var, ".png"), dyad_plot, width = 14, height = 8, dpi = 300)
    all_plots[[paste0("dyad_", demo_var)]] = dyad_plot

    # Triad heatmap by subgroup
    triad_title_sg = paste0(triad_title, " by ", demo_var)
    triad_plot = generate_subgroup_heatmap(df, "triad", demo_var, top_n = 20, title = triad_title_sg)
    ggsave(paste0("triad_heatmap_by_", demo_var, ".png"), triad_plot, width = 14, height = 8, dpi = 300)
    all_plots[[paste0("triad_", demo_var)]] = triad_plot

    # Combined tetrad and pentad heatmap by subgroup
    combined_title_sg = paste0(combined_title, " by ", demo_var)
    combined_plot = generate_combined_subgroup_heatmap(df, demo_var, title = combined_title_sg)
    ggsave(paste0("tetrad_pentad_heatmap_by_", demo_var, ".png"), combined_plot, width = 16, height = 8, dpi = 300)
    all_plots[[paste0("combined_", demo_var)]] = combined_plot
  }

  return(all_plots)
}

# Use this line with your dataframe
plots = generate_disease_subgroup_heatmaps(TGA)

## Part 2b: Tetrachoric Correlation Matrix for Each Chronic Disease against others (singularly)

This section focuses on calculating and visualizing the tetrachoric correlations between individual chronic diseases within both the non-multimorbidity and multimorbidity groups.

This code calculates the tetrachoric correlation matrix for chronic diseases in both the `CGA` and `TGA` dataframes and generates heatmaps to visualize these correlations.

In [None]:
# Load necessary libraries
library(psych)
library(ggplot2)
library(reshape2)
library(gridExtra) # For combining plots

# Function to calculate tetrachoric correlations and save heatmap as PNG
generate_and_save_heatmap <- function(data, title, file_name) {
  # Specify columns for correlation
  columns <- c("AIDS", "ALCOHOL", "ANEMDEF", "ARTH", "BLDLOSS", "CHF", "CHRNLUNG",
               "COAG", "DEPRESS", "DM", "DMCX", "DRUG", "HTN_C", "HYPOTHY", "LIVER",
               "LYMPH", "LYTES", "METS", "NEURO", "OBESE", "PARA", "PERIVASC", "PSYCH",
               "PULMCIRC", "RENLFAIL", "TUMOR", "ULCER", "VALVE", "WGHTLOSS", "ARRHYTH")

  # Alphabetize variables
  columns <- sort(columns)

  # Select specified columns and handle NAs
  data_subset <- data[columns]
  data_subset[is.na(data_subset)] <- NA

  # Compute tetrachoric correlations
  corr_matrix <- tetrachoric(data_subset)$rho

  # Melt correlation matrix for plotting
  melted_corr <- melt(corr_matrix)
  colnames(melted_corr) <- c("Var1", "Var2", "Correlation")

  # Create heatmap
  heatmap_plot <- ggplot(melted_corr, aes(x = Var1, y = Var2, fill = Correlation)) +
    geom_tile(color = "white") +
    scale_fill_gradient2(low = "blue", high = "orange", mid = "white", midpoint = 0,
                         limit = c(-1, 1), space = "Lab", name = "Correlation") +
    geom_text(aes(label = sprintf("%.2f", Correlation),
                  fontface = ifelse(abs(Correlation) > 0.5, "bold", "plain")), size = 3.5) + # Adjust text size
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1, size = 10),
          axis.text.y = element_text(size = 10),
          plot.title = element_text(size = 14, face = "bold"),
          plot.margin = margin(10, 10, 10, 10)) +
    labs(title = title, x = "", y = "") +
    coord_fixed(ratio = 0.8)  # Adjust tile aspect ratio

  # Save heatmap as PNG
  ggsave(file_name, plot = heatmap_plot, width = 14, height = 10)

  # Return the heatmap plot object (useful for combining plots later)
  return(heatmap_plot)
}

# Save individual heatmaps
heatmap_CGA <- generate_and_save_heatmap(CGA, "Correlation Heatmap Between Chronic Diseases in Patients with Non-Multimorbidity", "CGA_Heatmap_tetrachoric.png")
heatmap_TGA <- generate_and_save_heatmap(TGA, "Correlation Heatmap Between Chronic Diseases in Patients with Multimorbidity", "TGA_Heatmap_tetrachoric.png")

# Combine heatmaps for side-by-side comparison
combined_plot <- grid.arrange(heatmap_CGA, heatmap_TGA, ncol = 2)

# Save combined heatmap
ggsave("combined_heatmap_fixed_width.png", combined_plot, width = 26, height = 10)

## Part 2c: Correlation Matrix for Combinations vs Chronic Diseases

This section examines the correlations between specific chronic disease combinations and individual chronic diseases within the multimorbidity group.

This code identifies the top chronic disease combinations, binarizes them, and then calculates the tetrachoric correlations between these top combinations and the individual chronic diseases within the `TGA` dataframe. Heatmaps are generated to visualize these correlations, highlighting strong relationships.