#  Visualize Data

* **Project:** ADRD Genetic Diversity in Biobanks
* **Version:** Python/3.10
* **Last Updated:** 24-FEB-2025

## Notebook Overview
Visualizing the results of protective/resilience variants using a heatmap, Visualizing the results of allele frequencies using a heatmap, Visualizing the results of APOE genotypes using barplot, Visualizing mutation sites on predicted protein structures, Visualizing Beta values using Upset plot, Generating a PCA plot

## Variables used 
`${COHORT}` = AD, Dementia, Control

`${Biobank}` = UKB, AoU, ADSP, AMP PD, 100KGP

`${Value}` = data specific to each biobank

### Using a Heatmap to Visualize Protective/Resilient Variants

In [None]:
df_z = pd.read_csv("Heatmap_protective_${COHORT}.csv")
df_z.head()

In [None]:
import pandas as pd

# Load the CSV file
df_z = pd.read_csv('Heatmap_protective_${COHORT}.csv')

# Filter out rows where the 'Variants' column is empty
df_z = df_z[df_z['Variants'].notna()]

# Find the maximum value in the numeric columns
max_value = df_z.iloc[:, 1:].max().max()

# Print the maximum value
print("The maximum value in the dataset is:", max_value)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize, LinearSegmentedColormap
import pandas as pd
import matplotlib.patches as patches

# Create a figure with three subplots (one row, three columns)
fig, axs = plt.subplots(1, 3, figsize=(30, 8))  

# Define file names, titles, and norms for each heatmap
file_names = ['HeatmapPlot_protective_AD.csv', 'HeatmapPlot_protective_Dementia.csv', 'HeatmapPlot_protective_control.csv']
titles = ['AD_e4 ', 'Related Dementias_e4', 'Controls_e4 ']
norms = [Normalize(vmin=, vmax=), Normalize(vmin=, vmax=), Normalize(vmin=, vmax=)]
text_positions = [0.8, 3.7, 1.9]  # Text positions for AD, Related Dementia, Controls

# Custom colormap to set 0 as white
colors = [(1, 1, 1), *sns.color_palette("Spectral_r", 256)]
custom_cmap = LinearSegmentedColormap.from_list("CustomSpectral", colors, N=256)

# Iterate over the files, titles, and subplots to create each heatmap
for i, (file_name, title, norm, text_position) in enumerate(zip(file_names, titles, norms, text_positions)):
    # Load the data
    df_z = pd.read_csv(file_name)
    df_z = df_z[df_z['Variants'].notna()]

    # List of ancestry columns (all columns except the first one)
    cols = list(df_z.columns)[1:]

    # Create heatmap with custom colormap and no internal borders
    sns.heatmap(
        df_z[cols],
        cmap=custom_cmap,
        norm=norm,
        cbar_kws={"shrink": 0.8},
        ax=axs[i],
        linewidths=0,  # No internal borders between cells
    )

    # Set axis labels
    axs[i].set_ylabel('SNPs', fontsize=13, fontweight='bold')
    axs[i].set_xlabel('Ancestry', fontsize=13, fontweight='bold')

    # Adjust y-ticks to be centered
    y_ticks = range(len(df_z))  
    axs[i].set_yticks([tick + 0.5 for tick in y_ticks])
    axs[i].set_yticklabels(df_z['Variants'], rotation=0, fontsize=13, ha='right', va='center')

    # Keep x-ticks as default (top-aligned)
    axs[i].set_xticklabels(axs[i].get_xticklabels(), rotation=45, ha='right', fontsize=11, va='top')

    # Draw a border around the outer edge of the heatmap (4 sides)
    border = patches.Rectangle((0, 0), 1, 1, transform=axs[i].transAxes, color='black', linewidth=2, fill=False)
    axs[i].add_patch(border)

    # Get the colorbar from the heatmap
    colorbar = axs[i].collections[0].colorbar

    # Adjust colorbar tick label size
    colorbar.ax.tick_params(labelsize=13)

    # Adjust text label position above the colorbar
    colorbar.ax.text(text_position, 1.05, title, ha='center', va='center', fontsize=13, fontweight='bold', transform=colorbar.ax.transAxes)

    # Add a label to the colorbar to indicate percentage
    colorbar.ax.set_ylabel('Percentage', fontsize=12)

# Adjust layout to prevent overlap
plt.tight_layout()

# Save the combined plot
plt.savefig('combined_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

for grouping in ["all","e4e4","e4"]:
    df_control = pd.read_csv(f"Control_{grouping}.csv", index_col=0)
    fig, axes = plt.subplots(1, 2, figsize=(20,10))
    for i, pheno in enumerate(["AD","RD"]):
        df_case = pd.read_csv(f"{pheno}_{grouping}.csv", index_col=0)
        df_ratio = df_case / df_control
        arr_ratio = df_ratio.values

        mask_inf = arr_ratio == np.inf
        mask_zero = arr_ratio == 0
        mask_na = np.isnan(arr_ratio)
        mask_other = ~(mask_inf + mask_zero + mask_na)

        arr_log = np.log2(arr_ratio)
        max_abs = np.max(np.abs(arr_log[mask_other]))
        arr_log[mask_inf] = max_abs + 1
        arr_log[mask_zero] = (max_abs + 1) * -1
        arr_log[mask_na] = (max_abs + 1) * -1

        cmap = sns.color_palette("Spectral_r", as_cmap=True)

        ax = sns.heatmap(
            arr_log,
            cmap=cmap,
            cbar=True,
            xticklabels=df_ratio.columns.values,
            yticklabels=df_ratio.index.values,
            ax = axes[i],
            mask=mask_na, 
        )

        for _, spine in ax.spines.items():
            spine.set_visible(True)
            spine.set_color("black")
            spine.set_linewidth(1)

        cbar = ax.collections[0].colorbar

        min_val, max_val = np.ceil(arr_log.min()), np.floor(arr_log.max())
        ticks = np.arange(int(min_val), int(max_val) + 1)
        ticks = np.append(arr_log.min(), ticks)
        ticks = np.append(ticks, arr_log.max())
        cbar.set_ticks(ticks)
        ticklabels = ["Only controls"] + [f"{2**tick:.3f}" for tick in np.arange(int(min_val), int(max_val) + 1, dtype=float)] + ["Only cases"]
        cbar.set_ticklabels(ticklabels)

        ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
        ax.set_ylabel('SNPs', fontweight="bold")
        ax.set_xlabel('Ancestry', fontweight="bold")

        if pheno == "AD":
            ax.set_title(f"{pheno} vs Control Frequency Ratio")
        else:
            ax.set_title(f"Related Dementias vs Control Frequency Ratio")

    plt.tight_layout()
    plt.savefig(f"Merged_{grouping}.png")
    plt.close()


### Using a Heatmap to Visualize Allele Frequencies Across All Biobanks

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
import pandas as pd
import numpy as np

# Load your data into a DataFrame from the specified CSV file
df_z = pd.read_csv('${WORK_DIR}/${Biobank}.csv')

# Filter out rows where the 'SNP' column is empty
df_z = df_z[df_z['SNP'].notna()]

# List of ancestry columns (all columns except the first one)
cols = list(df_z.columns)[1:]

# Create a colormap with a slightly darker color for zero and a linear scale for non-zero values
zero_color = '#E0E0E0'  

# Generate a colormap for non-zero values using the viridis color palette
nonzero_colors = sns.color_palette("viridis", n_colors=100)
colors = [zero_color] + nonzero_colors
custom_cmap = ListedColormap(colors)

# Define boundaries for the colormap and normalization
min_value = ${Value}
max_value = ${Value}
boundaries = np.linspace(min_value, max_value, 100)
norm = BoundaryNorm(boundaries, custom_cmap.N, clip=True)

# Increase the figure size 
plt.figure(figsize=(3, 6))  

# Create a heatmap with custom color palette and normalization
ax = sns.heatmap(df_z[cols], cmap=custom_cmap, norm=norm, cbar_kws={"shrink": 0.8})

# Set axis labels with bold font
ax.set_ylabel('SNPs', fontsize=6, fontweight='bold') 
ax.set_xlabel('Ancestry', fontsize=6, fontweight='bold')  

# Add title 
plt.title('Allele frequencies of variants across all ancestries in ${Biobank}', fontsize=6, fontweight='bold')

# Set y-ticks and y-tick labels to match the number of variants
ax.set_yticks(np.arange(len(df_z)) + 0.5)  
ax.set_yticklabels(df_z['SNP'], rotation=0, fontsize=5, ha='right')  

# Adjust the y-axis label 
ax.yaxis.labelpad = 10  

# Rotate x-axis labels 
plt.xticks(rotation=45, ha='right', fontsize=5)  

# Adjust colorbar ticks and labels to handle zero properly and make the colorbar less cluttered
colorbar = ax.collections[0].colorbar
# Set fewer ticks to avoid clutter
colorbar.set_ticks(np.linspace(min_value, max_value, 6))
colorbar.set_ticklabels([f'{tick:.4f}' for tick in np.linspace(min_value, max_value, 6)])

# Adjust colorbar font size 
colorbar.ax.tick_params(labelsize=3)  
colorbar.ax.yaxis.set_tick_params(labelsize=2, width=0.1, color='black', labelcolor='black')
colorbar.ax.yaxis.set_tick_params(width=1) 
for label in colorbar.ax.get_yticklabels():
    label.set_fontsize(3)  
    label.set_fontweight('bold')  

# Save the plot 
plt.savefig('${Biobank}_Final.png', dpi=300, bbox_inches='tight')
plt.show()

### Using a Barplot to Visualize the Results of APOE Genotypes

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Define the data
data = {
    'Genotypes': ['AD_e4/e4', 'Related Dementias_e4/e4', 'Controls_e4/e4'],
    'EUR': [13.03,7.53,1.7],
    'AFR': [13.56,6.91,4.29],
    'AMR': [4.58,4.09,0.96],
    'EAS': [9.38,0,1.04],
    'SAS': [2.64,5.21,1.19],
    'MDE': [8.11,0,0.71],
    'AJ': [8.02,8.88,1],
    'FIN': [36.36,0,1.67],
    'AAC': [13.08,7.74,3.31],
    'CAS': [6.25,6.67,1.59],
    'CAH': [7.56,3.7,2.07]
}

# Create a DataFrame
df = pd.DataFrame(data)

# Define bar width and x locations for the groups
bar_width = 0.25
x = np.arange(len(df.columns[1:]))  

# Define the space between bars and data labels
label_offset = 0.2

# Plot
plt.figure(figsize=(10, 10))

# Plot bars for each genotype
for i, genotype in enumerate(df['Genotypes']):
    bars = plt.bar(x + i * bar_width, df.iloc[i, 1:], width=bar_width, label=genotype)
    for bar, value in zip(bars, df.iloc[i, 1:]):
        if value.is_integer():
            plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + label_offset, f'{int(value)}', ha='center', va='bottom', color='black', fontsize=7, fontweight='bold', rotation=90)
        else:
            plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + label_offset, f'{value:.2f}', ha='center', va='bottom', color='black', fontsize=7, fontweight='bold', rotation=90)

# Set labels and title
plt.xlabel('Ancestries', fontweight='bold')
plt.ylabel('Proportion (%)', fontweight='bold')
plt.title('Proportions of APOE e4/e4 Across Different Ancestries')

# Add x-ticks and labels
plt.xticks(x + bar_width, df.columns[1:], rotation=45)

# Add legend
plt.legend(title='Genotype')

# Show plot
plt.tight_layout()
plt.savefig('proportions_e4e4.png', dpi=300)
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data for the first plot
genotypes_first_plot = ['e2/e2', 'e2/e3', 'e2/e4 or e1/e3', 'e3/e3', 'e3/e4', 'e4/e4']
populations = ['EUR', 'AFR', 'AMR', 'EAS', 'SAS', 'MDE', 'AJ', 'FIN', 'AAC', 'CAS', 'CAH']
data_first_plot = {
    'EUR': [0.17, 4.04, 2.29, 35.25, 45.22, 13.03],
    'AFR': [0.67, 7.25, 5.23, 31.41, 41.88, 13.56],
    'AMR': [0, 3.63, 0.86, 63.27, 27.66, 4.58],
    'EAS': [0, 7.81, 1.56, 48.44, 32.81, 9.38],
    'SAS': [0, 7.05, 1.76, 60.79, 27.75, 2.64],
    'MDE': [2.7, 2.7, 0, 62.16, 24.32, 8.11],
    'AJ': [0.33, 3.76, 1.96, 45.34, 40.59, 8.02],
    'FIN': [0, 0, 0, 27.27, 36.36, 36.36],
    'AAC': [0.43, 9.91, 3.68, 34.02, 38.89, 13.08],
    'CAS': [0, 6.25, 3.13, 50, 34.38, 6.25],
    'CAH': [0.28, 6.54, 2.15, 51.73, 31.75, 7.56]
}

# Data for the second plot
genotypes_second_plot = ['e2/e2', 'e2/e3', 'e2/e4 or e1/e3', 'e3/e3', 'e3/e4', 'e4/e4']
data_second_plot = {
    'EUR': [0.42,8.22,2.94,46.87,34.03,7.53],
    'AFR': [0.84,12.60,6.50,41.06,32.11,6.91],
    'AMR': [0.92,8.64,0.45,60.91,25,4.09],
    'EAS': [0,20,3.33,56.67,20,0],
    'SAS': [0,8.33,1.04,54.17,31.25,5.21],
    'MDE': [0,8,0,76,16,0],
    'AJ': [0.39,8.49,3.47,51.74,27.03,8.88],
    'FIN': [0,0,0,62.5,37.5,0],
    'AAC': [0.6,11.31,5.36,44.05,30.95,7.74],
    'CAS': [0,3.33,0,60,30,6.67],
    'CAH': [0,0,7.41,48.15,40.74,3.70]
}

# Data for the third plot
genotypes_third_plot = ['e1/e4', 'e2/e2', 'e2/e3', 'e2/e4 or e1/e3', 'e3/e3', 'e3/e4', 'e4/e4']
data_third_plot = {
    'EUR': [0,0.63,12.82,2.14,61.36,21.35,1.7],
    'AFR': [0.03,1.25,15.38,5.11,45.44,28.5,4.29],
    'AMR': [0,0.18,5.78,0.76,74.53,17.79,0.96],
    'EAS': [0,0.69,15.45,0.87,66.67,15.28,1.04],
    'SAS': [0,0.15,7.75,1.13,73.05,16.72,1.19],
    'MDE': [0,0,13.48,1.42,75.18,9.22,0.71],
    'AJ': [0,0.5,11.5,1.61,64.01,21.39,1],
    'FIN': [0,0,15,0,66.67,16.67,1.67],
    'AAC': [0,0.91,14.3,4.03,49.45,27.1,3.31],
    'CAS': [0,0.8,10.34,1.33,71.09,14.85,1.59],
    'CAH': [0,0.33,10,1.79,64.4,21.41,2.07]
}

# Define plot parameters
barWidth = 0.13
new_color = '#333333'

# Plot
fig, axs = plt.subplots(3, 1, figsize=(16, 30))

# Plot the first plot
r = np.arange(len(populations))
bars_list = []

for i, genotype in enumerate(genotypes_first_plot):
    bars = [data_first_plot[pop][i] for pop in populations]
    bars_list.append(axs[0].bar(r + i * barWidth, bars, width=barWidth, edgecolor='grey', label=genotype))
    for j, bar in enumerate(bars):
        value = int(bar) if isinstance(bar, int) or float(bar).is_integer() else round(bar, 2)
        axs[0].annotate(str(value), xy=(r[j] + i * barWidth, bar), xytext=(0, 3), textcoords='offset points', ha='center', va='bottom', color='black', fontsize=8, fontweight='bold', rotation=90)
axs[0].set_xlabel('Ancestries', fontweight='bold')
axs[0].set_ylabel('Proportion (%)', fontweight='bold')
axs[0].set_title('Proportions of APOE Genotypes Across Different Ancestries in AD')
axs[0].set_xticks(r + barWidth * (len(genotypes_first_plot) - 1) / 2)
axs[0].set_xticklabels(populations, rotation=45)
axs[0].legend()

# Extract colors used in the first plot
colors_first_plot = [bar[0].get_facecolor() for bar in bars_list]

# Plot the second plot
for i, genotype in enumerate(genotypes_second_plot):
    bars = [data_second_plot[pop][i] for pop in populations]
    axs[1].bar(r + i * barWidth, bars, width=barWidth, edgecolor='grey', color=colors_first_plot[i], label=genotype)
    for j, bar in enumerate(bars):
        value = int(bar) if isinstance(bar, int) or float(bar).is_integer() else round(bar, 2)
        axs[1].annotate(str(value), xy=(r[j] + i * barWidth, bar), xytext=(0, 3), textcoords='offset points', ha='center', va='bottom', color='black', fontsize=8, fontweight='bold',rotation=90)
axs[1].set_xlabel('Ancestries', fontweight='bold')
axs[1].set_ylabel('Proportion (%)', fontweight='bold')
axs[1].set_title('Proportions of APOE Genotypes Across Different Ancestries in Related Dementias')
axs[1].set_xticks(r + barWidth * (len(genotypes_second_plot) - 1) / 2)
axs[1].set_xticklabels(populations, rotation=45)
axs[1].legend()

# Plot the third plot
colors_third_plot = [new_color] + colors_first_plot
for i, genotype in enumerate(genotypes_third_plot):
    bars = [data_third_plot[pop][i] for pop in populations]
    axs[2].bar(r + i * barWidth, bars, width=barWidth, edgecolor='grey', color=colors_third_plot[i], label=genotype)
    for j, bar in enumerate(bars):
        value = int(bar) if isinstance(bar, int) or float(bar).is_integer() else round(bar, 2)
        axs[2].annotate(str(value), xy=(r[j] + i * barWidth, bar), xytext=(0, 3), textcoords='offset points', ha='center', va='bottom', color='black', fontsize=8, fontweight='bold', rotation=90)
axs[2].set_xlabel('Ancestries', fontweight='bold')
axs[2].set_ylabel('Proportion (%)', fontweight='bold')
axs[2].set_title('Proportions of APOE Genotypes Across Different Ancestries in Controls')
axs[2].set_xticks(r + barWidth * (len(genotypes_third_plot) - 1) / 2)
axs[2].set_xticklabels(populations, rotation=45)
axs[2].legend()

plt.tight_layout()
plt.savefig('APOE_genotypes.png', dpi=300)
plt.show()


### Visualizing mutation sites on predicted protein structures

In [None]:
# The code used for the APP protein was presented. This code was also used for data on PSEN1, PSEN2, TREM2, MAPT, GRN, SNCA, GBA1, TBK1, TARDBP, and APOE.
# Load structure
load af-app.cif

# Modify the representation of the system
hide all
show cartoon

# Customize the cartoon representation
set cartoon_fancy_helices
set cartoon_highlight_color, grey75
color red, ss h
color yellow, ss s
color green, ss l+''

# Select the residues of interest
select P132S, resi 132 and name CA
select A209T, resi 209 and name CA
select V227L, resi 227 and name CA
select I248V, resi 248 and name CA
select P251S, resi 251 and name CA
select S271G, resi 271 and name CA
select L364F, resi 364 and name CA
select T371A, resi 371 and name CA
select V375I, resi 375 and name CA
select S407T, resi 407 and name CA
select R409C, resi 409 and name CA
select G458R, resi 458 and name CA
select D460N, resi 460 and name CA
select D516N, resi 516 and name CA
select R517G, resi 517 and name CA
select Y538H, resi 538 and name CA
select L597W, resi 597 and name CA
select I600L, resi 600 and name CA
select E693Q, resi 693 and name CA
select A713T, resi 713 and name CA
select V717L, resi 717 and name CA
select A740V, resi 740 and name CA

# Color the residues of interest
color cyan, P132S
color cyan, A209T
color cyan, V227L
color cyan, I248V
color cyan, P251S
color cyan, S271G
color cyan, L364F
color cyan, T371A
color cyan, V375I
color cyan, S407T
color cyan, R409C
color cyan, G458R
color cyan, D460N
color cyan, D516N
color cyan, R517G
color cyan, Y538H
color cyan, L597W
color cyan, I600L
color cyan, E693Q
color cyan, A713T
color cyan, V717L
color cyan, A740V

# Change the representation of the residues of interest
show spheres, P132S
show spheres, A209T
show spheres, V227L
show spheres, I248V
show spheres, P251S
show spheres, S271G
show spheres, L364F
show spheres, T371A
show spheres, V375I
show spheres, S407T
show spheres, R409C
show spheres, G458R
show spheres, D460N
show spheres, D516N
show spheres, R517G
show spheres, Y538H
show spheres, L597W
show spheres, I600L
show spheres, E693Q
show spheres, A713T
show spheres, V717L
show spheres, A740V

# Label the residues of interest
label P132S, "P132S" 
label A209T, "A209T"
label V227L, "V227L"
label I248V, "I248V"
label P251S, "P251S"
label S271G, "S271G"
label L364F, "L364F"
label T371A, "T371A"
label V375I, "V375I"
label S407T, "S407T"
label R409C, "R409C"
label G458R, "G458R"
label D460N, "D460N"
label D516N, "D516N"
label R517G, "R517G"
label Y538H, "Y538H"
label L597W, "L597W"
label I600L, "I600L"
label E693Q, "E693Q"
label A713T, "A713T"
label V717L, "V717L"
label A740V, "A740V"

# Adjust the size of the label
set label_size=12
set label_position=[7.0, 0.0, 30.0]
set label_font_id=7

# Change the background color
bg_color white
set cartoon_transparency=0.4

# Render image and save
set depth_cue=0
orient
zoom complete=1
ray 3000, 3000
#png app_labels.png, dpi=500



### Using an Upset plot to Visualize the Beta Values

In [None]:
import os
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
#import warnings
from matplotlib import rcParams
from matplotlib.ticker import FormatStrFormatter
from matplotlib.collections import PatchCollection
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.gridspec as gridspec
#warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
Protective = pd.read_csv('Protective_model_beta.csv')
Protective.head()

In [None]:
for i,row in Protective.iterrows():
    SNP = row['SNP']
    for ANC in ['EUR','AFR','AMR', 'AJ','AAC', 'CAH']:
        beta = row[ANC]
        tempDF = pd.DataFrame({
            'ANC':ANC,
            'SNP':SNP,
            'BETA':beta
        }, index=[0])
        if (i==0) & (ANC=='EUR'):
            inputDF = tempDF
        else:
            inputDF = pd.concat([inputDF, tempDF]).reset_index(drop=True)

In [None]:
import pandas as pd
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

# Convert 'BETA' to numeric, setting errors='coerce' will turn non-numeric values into NaN
inputDF['BETA'] = pd.to_numeric(inputDF['BETA'], errors='coerce')

# Drop rows with NaN in 'BETA' if necessary
inputDF = inputDF.dropna(subset=['BETA'])

# Define palette
pal = ["#6B1414", "#C14E4E", "#CF9FFF", "#1E90FF", "#000000"]

# Define NormalizeData function
def NormalizeData(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

# Calculate quantiles
quantiles = inputDF["BETA"].describe().loc[['min', '25%', '50%', '75%', 'max']].tolist()
quantiles[2] = 0  # Set median (50%) to zero for color normalization
quantiles_norm = NormalizeData(quantiles)

# Create colormap
cmp = LinearSegmentedColormap.from_list("", list(zip(quantiles_norm, pal)))


In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from matplotlib.cm import ScalarMappable

# Extract only the rs part of SNP identifiers
inputDF['SNP_rs'] = inputDF['SNP'].str.extract(r'(rs\d+)', expand=False)

# Set up the figure and axes
fig, ax = plt.subplots(figsize=(2, 3), sharex='col')
fig.clf()
gs = gridspec.GridSpec(1, 2, width_ratios=(8, 0.7))

ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])

sns.set_style({'font_scale': 1, 'axes.facecolor': 'white', 'grid.color': '.8'})
sns.set_theme(style="white", palette=None)

# Set margins and spacing
plt.margins(x=0.01, y=0.05)
plt.subplots_adjust(wspace=0.1, hspace=0)

# Create scatter plot using SNP_rs for x-axis labels
plot_bottom = sns.scatterplot(data=inputDF, x='SNP_rs', y='ANC', hue='BETA', s=70,
                              edgecolor='black', palette=cmp, legend='brief', ax=ax1)

# Set title and remove x-axis label
plot_bottom.set_title('Protective model', fontsize=8 ,fontweight='bold')
plot_bottom.set_xlabel('')  

# Customize plot appearance
plot_bottom.spines["top"].set_visible(False) 
plot_bottom.spines["bottom"].set_visible(False) 
plot_bottom.spines["left"].set_visible(False) 
plot_bottom.spines["right"].set_visible(False) 
plot_bottom.set_xticklabels(inputDF.drop_duplicates(subset='SNP_rs')['SNP_rs'], rotation=45, ha='right')
plot_bottom.get_legend().remove()

# Adjust the font size of tick labels on x and y axes
plot_bottom.tick_params(axis='x', labelsize=8)  
plot_bottom.tick_params(axis='y', labelsize=8)  

# Shade specific y-ranges
plot_bottom.axhspan(ymin=0.5, ymax=1.5, color='steelblue', alpha=0.1, lw=0)
plot_bottom.axhspan(ymin=2.5, ymax=3.5, color='steelblue', alpha=0.1, lw=0)
plot_bottom.axhspan(ymin=4.5, ymax=5.5, color='steelblue', alpha=0.1, lw=0)
plot_bottom.set_ylabel('')

# Set up color bar with ScalarMappable
sm = ScalarMappable(cmap=cmp)
sm.set_array([])
sm.set_clim(inputDF["BETA"].min(), inputDF["BETA"].max())

min_beta = inputDF['BETA'].min()
max_beta = inputDF['BETA'].max()

cb_ticks = [inputDF['BETA'].min(), 0, inputDF['BETA'].max()]
cb = plt.colorbar(sm, cax=ax2, ticks=[min_beta, 0, max_beta])
cb.ax.set_yticklabels(cb_ticks)   
cb.outline.set_color('black')
cb.outline.set_linewidth(1)
cb.ax.tick_params(labelsize=6, pad=1, width=0)
cb.ax.set_title('Beta', size=8, pad=6)

# Save the figure
fig.savefig('Protective.png', facecolor='white', dpi=300, bbox_inches="tight")


### Generating a PCA plot

In [None]:
def plot_3d(labeled_df, color, symbol=None, x='PC1', y='PC2', z='PC3', title=None, x_range=None, y_range=None, z_range=None, plot_out=None):
    '''
    Parameters: 
    labeled_df (Pandas dataframe): labeled ancestry dataframe
    color (string): color of ancestry label. column name containing labels for ancestry in labeled_pcs_df
    symbol (string): symbol of secondary label (for example, predicted vs reference ancestry). default: None
    x (string): column name of x-dimension
    y (string): column name of y-dimension
    z (string): column name of z-dimension
    title (string, optional): title of output scatterplot
    x_range (list of floats [min, max], optional): range for x-axis
    y_range (list of floats [min, max], optional): range for y-axis
    z_range (list of floats [min, max], optional): range for z-axis

    Returns:
    3-D scatterplot (plotly.express.scatter_3d). If plot_out included, will write .png static image and .html interactive to plot_out filename
                
    '''

    fig = px.scatter_3d(
                labeled_df,
                x=x,
                y=y,
                z=z,
                color=color,
                symbol=symbol,
                title=title,
                color_discrete_sequence=px.colors.qualitative.Bold,
                range_x=x_range,
                range_y=y_range,
                range_z=z_range,
                hover_name="IID",
                color_discrete_map={'AFR': "#88CCEE",
                                    'SAS': "#CC6677",
                                    'EAS': "#DDCC77",
                                    'EUR':"#117733",
                                    'AMR':"#332288",
                                    'AJ': "#D55E00",
                                    'AAC':"#999933",
                                    'CAS':"#882255",
                                    'MDE':"#661100",
                                    'FIN':"#F0E442",
                                    'CAH':"#40B0A6",
                                    'new':"#ababab"}
            )

    fig.update_traces(marker={'size': 3})
    
    fig.show()
    
    if plot_out:
        fig.write_html(f'{plot_out}.html')

In [None]:
import pandas as pd
import os
import plotly.express as px

new_pcs = pd.read_csv('${WORK_DIR}/FILTERED.merged_biallelic_ancestry_projected_new_pca.txt', sep='\t') # Refer to (https://github.com/dvitale199/GenoTools) for generating this file
new_pcs.head()

labels = pd.read_csv('${WORK_DIR}/FILTERED.merged_biallelic_ancestry_umap_linearsvc_predicted_labels.txt',sep='\t') # Refer to (https://github.com/dvitale199/GenoTools) for generating this file
labels.head()

merged_plot = pd.merge(new_pcs, labels, on=['FID', 'IID'])
merged_plot.head()

merged_plot_final = merged_plot[['FID', 'IID', 'PC1', 'PC2', 'PC3', 'label_y']]
merged_plot_final.columns = ['FID', 'IID', 'PC1', 'PC2', 'PC3', 'label']
merged_plot_final.head()

plot_3d(merged_plot_final, color='label', plot_out='3d_plot_output_ADSP')
