## Things you need to know about this script and how to run it
This script takes all the viral annotation raw data and writes them into usable form (put them in the same table, map them to usable names). </br>
Then it makes overview figures (heatmaps, barplots) how the overall expression looks like. </br>
Checks which cell lines are positive for which virus or engineered. </br>
Makes correlations between Crispr screens and TPM values between the cell lines to find potential targets. </br>
Takes potential targets proteins and checks difference in chronos score between virus positive and negative cell lines.

Download the transcript quantification data from the google bucket with this line (put it in terminal) to your local machine. </br>
gsutil -m cp -r -c gs://virus_expression_results/transcript_quant_results/* {add your DATA_DIR here}/virus_genes/transcripts_quant </br>
</br>
maybe you need to put your username (like philipp-trollmann) into the command </br>
gsutil -u {put google bucket username here} -m cp -r -c gs://virus_expression_results/transcript_quant_results/* {add your DATA_DIR here}/virus_genes/transcripts_quant </br>

You need a main directory (DATA_DIR) with according sub directories. </br>
DATA_DIR/transcripts_quant/ </br>
DATA_DIR/general_parts/ </br>
DATA_DIR/plots_with_decoy/ </br>
DATA_DIR/scripts/ </br>
DATA_DIR/human_genom/ </br>
DATA_DIR/viral_genes/ </br>

 </br>
These than also need certain files in them. </br>

Transcripts Quant: </br>
    download all the directories with the tpm values in here (gsutil -m cp -r -c gs://virus_expression_results/transcript_quant_results/* {add your DATA_DIR here}/transcripts_quant) </br>
</br>
Human Genes: </br>
    Homo_sapiens.GRCh38.112.gtf (download from https://ftp.ensembl.org/pub/release-112/gtf/homo_sapiens/)</br>
 </br>
Viral Genes: </br>
    viral_gene_metadata.csv (created in "research_viral_annotations.ipynb Python script)</br>
</br>
General Data: </br>
    internal-24q2_v87-crisprgeneeffect (download from taiga) </br>
    internal-24q2_v87-model.csv (download from taiga) </br>

Make sure you have installed all packages and run this command in your terminal: </br>
*pip install numpy pandas matplotlib seaborn natsort scipy joblib scanpy h5py pyarrow ipyparallel adjustText networkx*

In [None]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from natsort import natsort_keygen
from scipy.cluster.hierarchy import linkage, leaves_list
from matplotlib.backends.backend_pdf import PdfPages
from joblib import Parallel, delayed
import math
import scipy.stats
import scanpy as sc
import h5py
import csv
import pyarrow as pa
import pyarrow.csv as csv
import itertools
from scipy import stats
import ipyparallel as ipp
from adjustText import adjust_text
from matplotlib import gridspec
import networkx as nx
from scipy import stats

#DATA_DIR = "/Users/ptrollma/experiment_data/virus_genes/"  # adjust this to your directory path
DATA_DIR = "/home/ubuntu/experiments/virus_genes/"  # adjust this to your directory path
QUANT_DIR = DATA_DIR + "transcripts_quant/"
META_DATA_DIR = DATA_DIR + "viral_genes/"
PLOT_DIR = DATA_DIR + "plots_with_decoy/"
HUMAN_GENE_DIR = DATA_DIR + "human_genom/"
GENERAL_DATA = DATA_DIR + "general_parts/"
SCRIPTS_DATA = DATA_DIR + "scripts/"

In [None]:
# Check if directories exist
directories = {
    "QUANT_DIR": DATA_DIR + "transcripts_quant/",
    "META_DATA_DIR": DATA_DIR + "viral_genes/",
    "PLOT_DIR": DATA_DIR + "plots_with_decoy/",
    "HUMAN_GENE_DIR": DATA_DIR + "human_genom/",
    "GENERAL_DATA": DATA_DIR + "general_parts/",
    "SCRIPTS_DATA": DATA_DIR + "scripts/"
}

# Check if DATA_DIR exists
if os.path.exists(DATA_DIR):
    # Iterate over each directory and create it if it doesn't exist
    for dir_name, dir_path in directories.items():
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
            print(f"{dir_name} created: {dir_path}")
else:
    print(f"{DATA_DIR} does not exist.")

### Create a dictionary to map human gene ids to human gene names

In [None]:
# Define the function to extract gene ID to gene name mapping
def get_ens_dict(file_path):
    with open(file_path) as f:
        gtf = list(f)

    # Filter for relevant lines that contain gene_id and gene_name
    gtf = [x for x in gtf if not x.startswith('#')]
    gtf = [x for x in gtf if 'transcript_id "' in x and 'gene_name "' in x]
    
    if len(gtf) == 0:
        print('you need to change transcript_id " and gene_name " formats')
    
    # Create a dictionary mapping gene_id to gene_name
    gtf = list(map(lambda x: (x.split('transcript_id "')[1].split('"')[0], x.split('gene_name "')[1].split('"')[0]), gtf))
    gtf = dict(set(gtf))
    return gtf

# Load the human genome GTF file and create the dictionary for human gene ids to gene names
gtf_dict = get_ens_dict(f'{HUMAN_GENE_DIR}Homo_sapiens.GRCh38.112.gtf') 

### Perform the analysis of the raw salmon data, extract TPM values and put them all into a dataframe and map them to common names and create overview plots (takes some time, >1h)

In [None]:
import concurrent.futures

# Define the list of comparative human genes
comparative_human_genes = ['UBE3A', 'TP53']

# Load viral gene metadata file
metadata_df = pd.read_csv(f"{META_DATA_DIR}viral_gene_metadata.csv")

# Apply the compiled pattern to remove version numbers from Virus_id in metadata_df
metadata_df['Virus_id'] = metadata_df['Virus_id'].str.extract(r'([A-Za-z0-9_]+\d*)(?:\.\d*)?')

# Get all cell lines that start with 'ACH-' and contain a quant.sf file
cell_lines = [
    d for d in os.listdir(QUANT_DIR) 
    if d.startswith('ACH-') and os.path.isfile(os.path.join(QUANT_DIR, d, 'quant.sf'))
]

# Initialize an empty list to store data across cell lines
all_expressed_viral_genes_df = []
all_viral_genes_heatmap_df = []
all_viral_and_comparative_human_genes = []
human_avg_tmps = []

# Function to process a single cell line
def process_cell_line(cell_line):
    CELL_DIR = os.path.join(QUANT_DIR, cell_line)

    # Load the quant.sf file
    quant_df = pd.read_csv(f"{CELL_DIR}/quant.sf", sep='\t', usecols=['Name', 'TPM'], low_memory=False)     
   
    # Add a species column to distinguish between human and viral genes
    quant_df['species'] = quant_df['Name'].apply(lambda x: 'Human' if x.startswith('ENST') else 'Virus')

    # Compile the regular expressions for efficiency
    name_pattern = re.compile(r'([A-Za-z0-9_]+\d*)(?:\.\d*)?:(c?(?:\d+-\d+)(?:,c?\d+-\d+)*)')

    # Apply extraction to the 'Name' column for 'Virus' rows only
    extracted_values = quant_df.loc[quant_df['species'] == 'Virus', 'Name'].str.extract(name_pattern)

    # Now, assign the extracted 'Virus_id' and 'Genomic_range' back to the corresponding rows in quant_df
    quant_df.loc[quant_df['species'] == 'Virus', 'Virus_id'] = extracted_values[0]
    quant_df.loc[quant_df['species'] == 'Virus', 'Genomic_range'] = extracted_values[1]

    # Merge quant.sf with metadata on Virus_id and Genomic_range
    quant_df = quant_df.merge(metadata_df, how='left', on=['Virus_id', 'Genomic_range'])

    # Map human gene IDs to gene names using gtf_dict
    quant_df['Gene_name'] = quant_df.apply(
        lambda x: gtf_dict.get(next((id.split('.')[0] for id in x['Name'].split('|') if id.startswith('ENST')), None), x['Name'])
        if x['species'] == 'Human' else x['Gene_name'], 
        axis=1
    )

    # Calculate the median TPM for human genes where TPM > 0
    human_median_tpm = quant_df[(quant_df['species'] == 'Human') & (quant_df['TPM'] > 0)]['TPM'].median()
    human_mean_tpm = quant_df[(quant_df['species'] == 'Human') & (quant_df['TPM'] > 0)]['TPM'].mean()
    human_avg_tmps.append([cell_line, human_median_tpm, human_mean_tpm])
    

    # Filter for viral genes and human genes from the comparative list
    viral_genes_df = quant_df[quant_df['species'] == 'Virus']
    comparative_human_genes_df = quant_df[(quant_df['species'] == 'Human') & (quant_df['Gene_name'].isin(comparative_human_genes))].copy()
    comparative_human_genes_df['Gene_Virus'] = comparative_human_genes_df['Gene_name'] + " (" + comparative_human_genes_df['species'] + ")"

    # Sort viral genes for plotting
    viral_genes_df = viral_genes_df.sort_values(by='Virus_name', key=natsort_keygen())

    # Add a combined label for viral genes (Gene name + Virus name)
    viral_genes_df['Gene_Virus'] = viral_genes_df['Gene_name'] + " (" + viral_genes_df['Virus_name'] + ")"
    viral_genes_df['Cell_line'] = cell_line

    # Combine human and viral genes for plotting
    combined_genes_df = pd.concat([comparative_human_genes_df, viral_genes_df])
    expressed_genes = list(combined_genes_df[combined_genes_df['TPM'] > 0]['Gene_Virus'])
    all_viral_and_comparative_human_genes.append(combined_genes_df[combined_genes_df['Gene_Virus'].isin(expressed_genes)])

    # Append the viral genes data for this cell line to the list
    expressed_viral_genes = list(viral_genes_df[viral_genes_df['TPM'] > 0]['Gene_Virus'])
    all_expressed_viral_genes_df.append(viral_genes_df[viral_genes_df['Gene_Virus'].isin(expressed_viral_genes)])
    print(f'In {cell_line} these viral genes are expressed: {expressed_viral_genes}')

    # Sort the dataframe by 'Virus_name' using natsort
    sorted_cell_line_data = viral_genes_df.sort_values(by='Virus_name', key=natsort_keygen())
    gene_virus_order = sorted_cell_line_data['Gene_Virus'].unique() # get the order as the pivot table creates its own ordering
    tpm_per_viral_gene_df = sorted_cell_line_data.pivot_table(
        index='Cell_line', 
        columns='Gene_Virus', 
        values='TPM', 
        aggfunc='first'
    )
    tpm_per_viral_gene_df = tpm_per_viral_gene_df.reindex(columns=gene_virus_order)     # Reorder the columns in the pivoted DataFrame to match the sorted 'Gene_Virus' order
    #all_viral_genes_heatmap_df.append(tpm_per_viral_gene_df)

    # # Create a directory for the current cell line plots
    # cell_line_plot_dir = os.path.join(PLOT_DIR, cell_line)

    # if not os.path.exists(cell_line_plot_dir):
    #     os.makedirs(cell_line_plot_dir, exist_ok=True)

    #     # Plot TPM of all viral genes and certain human genes
    #     plt.figure(figsize=(40, 10))
    #     sns.barplot(x='Gene_Virus', y='TPM', hue='species', data=combined_genes_df, dodge=False)
    #     median_line = plt.axhline(human_median_tpm, color='red', linestyle='--', label=f'Human Median TPM: {human_median_tpm:.2f}', linewidth=2)
    #     mean_line = plt.axhline(human_mean_tpm, color='blue', linestyle='--', label=f'Human Mean TPM: {human_mean_tpm:.2f}', linewidth=2)
    #     plt.xticks(rotation=90, fontsize=5)  
    #     plt.title(f'TPM of Human and Viral Genes with Human TPM as Reference for cell line {cell_line}', fontsize=20)
    #     plt.xlabel('Gene (Virus)', fontsize=16)
    #     plt.ylabel('TPM', fontsize=16)
    #     handles, labels = plt.gca().get_legend_handles_labels()
    #     if f'Human Median TPM: {human_median_tpm:.2f}' in labels:
    #         median_idx = labels.index(f'Human Median TPM: {human_median_tpm:.2f}')
    #         del handles[median_idx]
    #         del labels[median_idx]        
    #     if f'Human Mean TPM: {human_mean_tpm:.2f}' in labels:
    #         mean_idx = labels.index(f'Human Mean TPM: {human_mean_tpm:.2f}')
    #         del handles[mean_idx]
    #         del labels[mean_idx]
    #     handles.append(median_line)
    #     labels.append(f'Human Median TPM: {human_median_tpm:.2f}')
    #     handles.append(mean_line)
    #     labels.append(f'Human Mean TPM: {human_mean_tpm:.2f}')
    #     plt.legend(handles, labels, fontsize=14, title='Legend', title_fontsize=16)
    #     plt.savefig(f'{cell_line_plot_dir}/{cell_line}_all-viral-genes_tpm-barplot.svg', format='svg', bbox_inches='tight')
    #     plt.close();


    # # Calculate median TPM per virus
    # virus_median_tpm_df = viral_genes_df.groupby('Virus_name')['TPM'].median().reset_index()
    # virus_median_tpm_df = virus_median_tpm_df.sort_values(by='Virus_name', key=natsort_keygen())

    # # Plot median TPM per virus
    # plt.figure(figsize=(12, 6))
    # sns.barplot(x='Virus_name', y='TPM', data=virus_median_tpm_df)
    # plt.axhline(human_median_tpm, color='red', linestyle='--', label=f'Human Median TPM: {human_median_tpm:.2f}')
    # plt.axhline(human_mean_tpm, color='blue', linestyle='--', label=f'Human Mean TPM: {human_mean_tpm:.2f}')
    # plt.xticks(rotation=90)
    # plt.title(f'Median TPM per Virus with Human TPM as Reference for cell line {cell_line}', fontsize=20)
    # plt.legend()
    # plt.savefig(f'{cell_line_plot_dir}/{cell_line}_virus-overview_median-tpm-barplot.svg', format='svg', bbox_inches='tight')
    # plt.close();



    # # Calculate mean TPM per virus
    # virus_mean_tpm_df = viral_genes_df.groupby('Virus_name')['TPM'].mean().reset_index()
    # virus_mean_tpm_df = virus_mean_tpm_df.sort_values(by='Virus_name', key=natsort_keygen())

    # # Plot mean TPM per virus
    # plt.figure(figsize=(12, 6))
    # sns.barplot(x='Virus_name', y='TPM', data=virus_mean_tpm_df)
    # plt.axhline(human_median_tpm, color='red', linestyle='--', label=f'Human Median TPM: {human_median_tpm:.2f}')
    # plt.axhline(human_mean_tpm, color='blue', linestyle='--', label=f'Human Mean TPM: {human_mean_tpm:.2f}')
    # plt.xticks(rotation=90)
    # plt.title(f'Mean TPM per Virus with Human TPM as Reference for cell line {cell_line}', fontsize=20)
    # plt.legend()
    # plt.savefig(f'{cell_line_plot_dir}/{cell_line}_virus-overview_mean-tpm-barplot.svg', format='svg', bbox_inches='tight')
    # plt.close();

    return(human_median_tpm, human_mean_tpm, expressed_viral_genes, viral_genes_df, tpm_per_viral_gene_df, combined_genes_df, cell_line)


# Use ThreadPoolExecutor or ProcessPoolExecutor
with concurrent.futures.ThreadPoolExecutor() as executor:
    # Map the process_cell_line function to the cell_lines list
    results = executor.map(process_cell_line, cell_lines)

    # Process the results
    for result in results:
        human_median_tpm, human_mean_tpm, expressed_viral_genes, viral_genes_df, tpm_per_viral_gene_df, combined_genes_df, cell_line = result

        # Append the results to the corresponding lists
        human_avg_tmps.append([cell_line, human_median_tpm, human_mean_tpm])
        all_viral_and_comparative_human_genes.append(combined_genes_df)
        all_expressed_viral_genes_df.append(viral_genes_df)
        all_viral_genes_heatmap_df.append(tpm_per_viral_gene_df)

# Concatenate the data from all cell lines into one DataFrame
all_viral_and_comparative_human_genes = pd.concat(all_viral_and_comparative_human_genes)
all_viral_and_comparative_human_genes = all_viral_and_comparative_human_genes.sort_values(by='Virus_name', key=natsort_keygen())
all_viral_and_comparative_human_genes.to_csv(f'{PLOT_DIR}all_expressed_viral_and_comparative_human_genes_data.csv', index=False)

all_expressed_viral_genes_df = pd.concat(all_expressed_viral_genes_df)
all_expressed_viral_genes_df = all_expressed_viral_genes_df.sort_values(by='Virus_name', key=natsort_keygen())
all_expressed_viral_genes_df.to_csv(f'{PLOT_DIR}all_expressed_viral_genes_data.csv', index=False)

all_viral_genes_heatmap_df = pd.concat(all_viral_genes_heatmap_df)
all_viral_genes_heatmap_df.to_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index=True)

human_avg_tmps_df = pd.DataFrame(human_avg_tmps, columns=['Cell_line', 'human_median_tpm', 'human_mean_tpm'])
human_avg_tmps_df.to_csv(f'{PLOT_DIR}human_avg_tpm_data.csv', index=True)


# Plot comparison of viral gene expression over cell lines
# Plot barplots of TPM of expressed iral genes across all cell lines
plt.figure(figsize=(30, 10))
sns.barplot(x='Gene_Virus', y='TPM', hue='Cell_line', data=all_expressed_viral_genes_df, dodge=False)
plt.xticks(rotation=90, fontsize=5)
plt.title('TPM of Expressed Viral Genes Across All Cell Lines', fontsize=20)
plt.xlabel('Gene (Virus)', fontsize=16)
plt.ylabel('TPM', fontsize=16)
plt.legend(title='Cell Line', fontsize=12, title_fontsize=14)
plt.savefig(f'{PLOT_DIR}compare_viral-expression-over-cell-lines-tpm-barplot.svg', format='svg', bbox_inches='tight')
plt.show()

In [None]:
# if the first column is unnamed 
if all_viral_genes_heatmap_df.columns[0].startswith('Unnamed:'):
    all_viral_genes_heatmap_df = all_viral_genes_heatmap_df.drop(all_viral_genes_heatmap_df.columns[0], axis=1)
    all_viral_genes_heatmap_df.to_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index=True)

### Load the viral gene heatmap data (needed for most things below)

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

### Plot the heatmap of all the data

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

# Perform hierarchical clustering on rows (Cell lines) for the heatmap
linkage_matrix = linkage(all_viral_genes_heatmap_df.values, method='average', metric='euclidean')
row_order = leaves_list(linkage_matrix)
clustered_df = all_viral_genes_heatmap_df.iloc[row_order, :] # Reorder rows based on clustering result

# Plot heatmap of viral expression over all cell lines
plt.figure(figsize=(120, 0.3*len(clustered_df)))
ax = sns.heatmap(np.log10(clustered_df + 1e-9), mask=(clustered_df == 0), cmap='plasma', cbar_kws={'label': 'log10(TPM)'}, linewidths=0.003, linecolor="lightgray")
cbar = ax.collections[0].colorbar # Access the color bar from the heatmap
# Rotate the label and move it below the color bar
cbar.set_label('log10(TPM)', rotation=0, fontsize=30)  # rotation=0 makes the label horizontal

# Move the label below the color bar, Adjust the label's position directly
cbar.ax.yaxis.set_label_coords(0.7, 1.2)  # x and y coordinates to adjust the label's position
# Adjust the position of the color bar
cbar.ax.set_position([0.75, 0.15, 0.02, 0.7])  # [left, bottom, width, height] 

ax.set(xlabel="Viral Genes", ylabel="Cell Lines")
ax.xaxis.label.set_size(20)  # X-axis label font size
ax.yaxis.label.set_size(20)  # Y-axis label font size
plt.title('TPM of Expressed Viral Genes Across All Cell Lines', fontsize=60)
plt.yticks(rotation=0, fontsize=15)  # Rotate y ticks and set fontsize
plt.savefig(f'{PLOT_DIR}compare_viral-expression-over-cell-lines-tpm-heatmap.svg', format='svg', bbox_inches='tight')
plt.show()



# Plot heatmap of viral expression over all cell lines (filter out TPM < 0.5, but with same clustering)
tpm_filter = 1

plt.figure(figsize=(120, 0.3*len(clustered_df)))
ax = sns.heatmap(np.log10(clustered_df + 1e-9), mask=(clustered_df < tpm_filter), cmap='plasma', cbar_kws={'label': 'log10(TPM)'}, linewidths=0.003, linecolor="lightgray")
cbar = ax.collections[0].colorbar # Access the color bar from the heatmap
# Rotate the label and move it below the color bar
cbar.set_label('log10(TPM)', rotation=0, fontsize=30)  # rotation=0 makes the label horizontal

# Move the label below the color bar, Adjust the label's position directly
cbar.ax.yaxis.set_label_coords(0.7, 1.2)  # x and y coordinates to adjust the label's position
# Adjust the position of the color bar
cbar.ax.set_position([0.75, 0.15, 0.02, 0.7])  # [left, bottom, width, height] 

ax.set(xlabel="Viral Genes", ylabel="Cell Lines")
ax.xaxis.label.set_size(20)  # X-axis label font size
ax.yaxis.label.set_size(20)  # Y-axis label font size
plt.title(f'TPM (> {tpm_filter}) of Expressed Viral Genes Across All Cell Lines', fontsize=60)
plt.yticks(rotation=0, fontsize=15)  # Rotate y ticks and set fontsize
plt.savefig(f'{PLOT_DIR}compare_viral-expression-over-cell-lines-tpm-filtered-heatmap.svg', format='svg', bbox_inches='tight')
plt.show()

### Plot the heatmap of the condensed data (only cell lines and viruses that have a TPM)

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

rows_with_greater_than_one = all_viral_genes_heatmap_df.gt(0).any(axis=1) # Identify rows that have at least one value > 0
cols_with_greater_than_one = all_viral_genes_heatmap_df.gt(0).any(axis=0) # Identify columns that have at least one value > 0
filtered_viral_genes_heatmap_df = all_viral_genes_heatmap_df.loc[rows_with_greater_than_one, cols_with_greater_than_one] # Filter the DataFrame by the identified rows and columns
print(filtered_viral_genes_heatmap_df.shape)

# Perform hierarchical clustering on rows (Cell lines) for the heatmap
linkage_matrix = linkage(filtered_viral_genes_heatmap_df.values, method='average', metric='euclidean')
row_order = leaves_list(linkage_matrix)
clustered_df = filtered_viral_genes_heatmap_df.iloc[row_order, :] # Reorder rows based on clustering result

# Plot heatmap of viral expression over all cell lines
plt.figure(figsize=(120, 0.3*len(clustered_df)))
ax = sns.heatmap(np.log10(clustered_df + 1e-9), mask=(clustered_df == 0), cmap='plasma', cbar_kws={'label': 'log10(TPM)'}, linewidths=0.003, linecolor="lightgray")
cbar = ax.collections[0].colorbar # Access the color bar from the heatmap
# Rotate the label and move it below the color bar
cbar.set_label('log10(TPM)', rotation=0, fontsize=30)  # rotation=0 makes the label horizontal

# Move the label below the color bar, Adjust the label's position directly
cbar.ax.yaxis.set_label_coords(0.7, 1.05)  # x and y coordinates to adjust the label's position
# Adjust the position of the color bar
cbar.ax.set_position([0.75, 0.15, 0.02, 0.7])  # [left, bottom, width, height] 

ax.set(xlabel="Viral Genes", ylabel="Cell Lines")
ax.xaxis.label.set_size(40)  # X-axis label font size
ax.yaxis.label.set_size(40)  # Y-axis label font size
plt.title('TPM of Expressed Viral Genes Across All Cell Lines', fontsize=60)
plt.xticks(rotation=90, fontsize=25)  # Rotate x ticks and set fontsize
plt.yticks(rotation=0, fontsize=15)  # Rotate y ticks and set fontsize
plt.savefig(f'{PLOT_DIR}compare_viral-expression-over-cell-lines-tpm-reduced-heatmap.svg', format='svg', bbox_inches='tight')
plt.show()



# Plot the reduces heatmap with subplots for each virus
# Create a dictionary mapping viruses to their corresponding genes
gene_to_virus = {}
for _, row in metadata_df.iterrows():
    gene_to_virus[row['Gene_name'] + " (" + row['Virus_name'] + ")"] = row['Virus_name']

virus_to_genes = {}
for gene, virus in gene_to_virus.items():
    if virus in virus_to_genes:
        virus_to_genes[virus].append(gene)
    else:
        virus_to_genes[virus] = [gene]

# Filter viruses that have valid genes in the DataFrame
valid_viruses = []
virus_gene_counts = []  # Store the number of genes for each valid virus
for virus, genes in virus_to_genes.items():
    # Strip any extra spaces from the gene names in virus_genes
    genes = [gene.strip() for gene in genes]
    
    # Filter only genes that are present in the DataFrame
    valid_genes = [gene for gene in genes if gene in clustered_df.columns]
    
    if valid_genes:
        valid_viruses.append(virus)
        virus_to_genes[virus] = valid_genes  # Keep only valid genes
        virus_gene_counts.append(len(valid_genes))  # Track the number of genes for this virus

num_viruses = len(valid_viruses)

# Set the figure width, scaling by the total number of genes across all viruses
total_genes = sum(virus_gene_counts)
fig_width = total_genes * 0.4  # Adjust width scaling factor as needed (0.4 is an example)
fig_width = 120
fig_height = 0.3 * len(clustered_df)  # Adjust height based on the number of cell lines
fig = plt.figure(figsize=(fig_width, fig_height))  # Dynamic figure size based on genes and cell lines
# Set up the GridSpec with widths proportional to the number of genes for each virus
gs = gridspec.GridSpec(1, num_viruses, width_ratios=virus_gene_counts)

# Add a title above the figure
plt.suptitle('Viral Expression Heatmap by Virus', fontsize=70, fontweight='bold', y=0.92)

plotted_heatmap = None  # Variable to store a plotted heatmap for the colorbar

# Loop through each valid virus and plot
for i, virus in enumerate(valid_viruses):
    ax = plt.subplot(gs[i])
    
    # Look up the valid viral genes for the current virus
    virus_genes = virus_to_genes[virus]
    
    # Subset the data for these valid viral genes, preserving the pre-clustered cell line order
    virus_data = clustered_df[virus_genes]
    
    # Skip if virus_data is empty or has fewer than 2 columns
    if virus_data.empty or len(virus_data.columns) < 2:
        continue  # Skip empty or insufficient data
    
    # Plot heatmap without clustering again, maintaining cell line order
    heatmap = sns.heatmap(np.log10(virus_data + 1e-9), mask=(virus_data == 0), cmap='plasma', 
                          cbar=False, linewidths=0.003, linecolor="lightgray", ax=ax)
    
    # Set x-axis labels for viral genes (gene names)
    ax.set_xlabel(f"Viral Genes", fontsize=50)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=25)
    
    # Set virus name as title above each subplot
    ax.set_title(virus, fontsize=60, fontweight='bold')
    
    # Only show y-axis cell line labels on the first subplot (leftmost)
    if i == 0:
        ax.set_ylabel("Cell Lines", fontsize=40)
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=12)
    else:
        ax.set_ylabel('')
        ax.set_yticklabels([])  # Hide y-axis labels for other subplots
    
    # If a heatmap is plotted, store the instance for colorbar later
    if heatmap:
        plotted_heatmap = heatmap

# Add a colorbar on the right side of the figure if a heatmap was plotted
if plotted_heatmap:
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # Create axes for colorbar
    cbar = plt.colorbar(plotted_heatmap.collections[0], cax=cbar_ax, label='log10(TPM)')

    # Set larger tick labels for the color bar
    cbar.ax.tick_params(labelsize=35)  # Change the tick label size
    cbar.set_label('log10(TPM)', rotation=0, fontsize=40)  # Rotation for the label
    cbar.ax.yaxis.set_label_coords(0.5, 1.04)  # Adjust the label's position

# Save the plot
plt.savefig(f'{PLOT_DIR}compare_viral-expression-over-cell-lines-per-virus-tpm-reduced-heatmap.svg', format='svg', bbox_inches='tight')
plt.show()



# Plot heatmap of viral expression over all cell lines (filter out TPM < 0.5, but with same clustering)
tpm_filter = 1

rows_with_greater_than_one = all_viral_genes_heatmap_df.gt(tpm_filter).any(axis=1) # Identify rows that have at least one value > tpm_filter
cols_with_greater_than_one = all_viral_genes_heatmap_df.gt(tpm_filter).any(axis=0) # Identify columns that have at least one value > tpm_filter
filtered_viral_genes_heatmap_df = all_viral_genes_heatmap_df.loc[rows_with_greater_than_one, cols_with_greater_than_one] # Filter the DataFrame by the identified rows and columns
print(filtered_viral_genes_heatmap_df.shape)

# Perform hierarchical clustering on rows (Cell lines) for the heatmap
linkage_matrix = linkage(filtered_viral_genes_heatmap_df.values, method='average', metric='euclidean')
row_order = leaves_list(linkage_matrix)
clustered_df = filtered_viral_genes_heatmap_df.iloc[row_order, :] # Reorder rows based on clustering result

plt.figure(figsize=(120, 0.3*len(clustered_df)))
ax = sns.heatmap(np.log10(clustered_df + 1e-9), mask=(clustered_df < tpm_filter), cmap='plasma', cbar_kws={'label': 'log10(TPM)'}, linewidths=0.003, linecolor="lightgray")
cbar = ax.collections[0].colorbar # Access the color bar from the heatmap
# Rotate the label and move it below the color bar
cbar.set_label('log10(TPM)', rotation=0, fontsize=30)  # rotation=0 makes the label horizontal

# Move the label below the color bar, Adjust the label's position directly
cbar.ax.yaxis.set_label_coords(0.7, 1.05)  # x and y coordinates to adjust the label's position
# Adjust the position of the color bar
cbar.ax.set_position([0.75, 0.15, 0.02, 0.7])  # [left, bottom, width, height] 

ax.set(xlabel="Viral Genes", ylabel="Cell Lines")
ax.xaxis.label.set_size(30)  # X-axis label font size
ax.yaxis.label.set_size(30)  # Y-axis label font size
plt.title(f'TPM (> {tpm_filter}) of Expressed Viral Genes Across All Cell Lines', fontsize=60)
plt.xticks(rotation=90, fontsize=25)  # Rotate x ticks and set fontsize
plt.yticks(rotation=0, fontsize=15)  # Rotate y ticks and set fontsize
plt.savefig(f'{PLOT_DIR}compare_viral-expression-over-cell-lines-tpm-filtered-reduced-heatmap.svg', format='svg', bbox_inches='tight')
plt.show()


# Plot the reduces heatmap with subplots for each virus
# Create a dictionary mapping viruses to their corresponding genes
gene_to_virus = {}
for _, row in metadata_df.iterrows():
    gene_to_virus[row['Gene_name'] + " (" + row['Virus_name'] + ")"] = row['Virus_name']

virus_to_genes = {}
for gene, virus in gene_to_virus.items():
    if virus in virus_to_genes:
        virus_to_genes[virus].append(gene)
    else:
        virus_to_genes[virus] = [gene]

# Filter viruses that have valid genes in the DataFrame
valid_viruses = []
virus_gene_counts = []  # Store the number of genes for each valid virus
for virus, genes in virus_to_genes.items():
    # Strip any extra spaces from the gene names in virus_genes
    genes = [gene.strip() for gene in genes]
    
    # Filter only genes that are present in the DataFrame
    valid_genes = [gene for gene in genes if gene in clustered_df.columns]
    
    if valid_genes:
        valid_viruses.append(virus)
        virus_to_genes[virus] = valid_genes  # Keep only valid genes
        virus_gene_counts.append(len(valid_genes))  # Track the number of genes for this virus

num_viruses = len(valid_viruses)

# Set the figure width, scaling by the total number of genes across all viruses
total_genes = sum(virus_gene_counts)
fig_width = total_genes * 0.4  # Adjust width scaling factor as needed (0.4 is an example)
fig_width = 120
fig_height = 0.3 * len(clustered_df)  # Adjust height based on the number of cell lines
fig = plt.figure(figsize=(fig_width, fig_height))  # Dynamic figure size based on genes and cell lines
# Set up the GridSpec with widths proportional to the number of genes for each virus
gs = gridspec.GridSpec(1, num_viruses, width_ratios=virus_gene_counts)

# Add a title above the figure
plt.suptitle(f'Viral Expression Heatmap by Virus (TPM > {tpm_filter}) ', fontsize=70, fontweight='bold', y=0.92)

plotted_heatmap = None  # Variable to store a plotted heatmap for the colorbar

# Loop through each valid virus and plot
for i, virus in enumerate(valid_viruses):
    ax = plt.subplot(gs[i])
    
    # Look up the valid viral genes for the current virus
    virus_genes = virus_to_genes[virus]
    
    # Subset the data for these valid viral genes, preserving the pre-clustered cell line order
    virus_data = clustered_df[virus_genes]
    
    # Skip if virus_data is empty or has fewer than 2 columns
    if virus_data.empty or len(virus_data.columns) < 2:
        continue  # Skip empty or insufficient data
    
    # Plot heatmap without clustering again, maintaining cell line order
    heatmap = sns.heatmap(np.log10(virus_data + 1e-9), mask=(virus_data == 0), cmap='plasma', 
                          cbar=False, linewidths=0.003, linecolor="lightgray", ax=ax)
    
    # Set x-axis labels for viral genes (gene names)
    ax.set_xlabel(f"Viral Genes", fontsize=50)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=25)
    
    # Set virus name as title above each subplot
    ax.set_title(virus, fontsize=60, fontweight='bold')
    
    # Only show y-axis cell line labels on the first subplot (leftmost)
    if i == 0:
        ax.set_ylabel("Cell Lines", fontsize=40)
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=12)
    else:
        ax.set_ylabel('')
        ax.set_yticklabels([])  # Hide y-axis labels for other subplots
    
    # If a heatmap is plotted, store the instance for colorbar later
    if heatmap:
        plotted_heatmap = heatmap

# Add a colorbar on the right side of the figure if a heatmap was plotted
if plotted_heatmap:
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # Create axes for colorbar
    cbar = plt.colorbar(plotted_heatmap.collections[0], cax=cbar_ax, label='log10(TPM)')

    # Set larger tick labels for the color bar
    cbar.ax.tick_params(labelsize=35)  # Change the tick label size
    cbar.set_label('log10(TPM)', rotation=0, fontsize=40)  # Rotation for the label
    cbar.ax.yaxis.set_label_coords(0.5, 1.04)  # Adjust the label's position

# Save the plot
plt.savefig(f'{PLOT_DIR}compare_viral-expression-over-cell-lines-per-virus-tpm-filtered-reduced-heatmap.svg', format='svg', bbox_inches='tight')
plt.show()

### Get an overview how often certain genes are found over all cell lines, make a barplot

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

# Make a histogram for the occurance of each gene over all cell lines
gene_counts = (all_viral_genes_heatmap_df > 0).sum(axis=0)
plt.figure(figsize=(120, 10)) 
bar_plot = sns.barplot(x=gene_counts.index, y=gene_counts.values)
plt.xlabel('Viral Genes', fontsize=30)
plt.ylabel('# Cell Lines with Gene Expression', fontsize=30)
plt.title('Distribution of Cell Lines Expressing Viral Genes', fontsize=60)
plt.xticks(rotation=90, ha='center', fontsize=6)
for p in bar_plot.patches:
    bar_plot.annotate(f'{int(p.get_height())}', 
                      (p.get_x() + p.get_width() / 2, p.get_height() + 1.5), 
                      ha='center', 
                      va='bottom', 
                      fontsize=12, 
                      rotation=90)
plt.savefig(f'{PLOT_DIR}viral_gene_expression_histogram.svg', format='svg', bbox_inches='tight')
plt.show()


gene_counts = (all_viral_genes_heatmap_df > 0.5).sum(axis=0)
plt.figure(figsize=(120, 10)) 
bar_plot = sns.barplot(x=gene_counts.index, y=gene_counts.values)
plt.xlabel('Viral Genes', fontsize=30)
plt.ylabel('# Cell Lines with Gene Expression', fontsize=30)
plt.title('Distribution of Cell Lines Expressing Viral Genes (TPM > 0.5)', fontsize=60)
plt.xticks(rotation=90, ha='center', fontsize=6)
for p in bar_plot.patches:
    bar_plot.annotate(f'{int(p.get_height())}', 
                      (p.get_x() + p.get_width() / 2, p.get_height() + 1.5), 
                      ha='center', 
                      va='bottom', 
                      fontsize=12, 
                      rotation=90)
plt.savefig(f'{PLOT_DIR}viral_gene_expression_histogram_tpm_filtered.svg', format='svg', bbox_inches='tight')
plt.show()

### Get a list of all Cell lines that are positive for each Virus

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

# Load the metadata for viral genes
metadata_df = pd.read_csv(f"{META_DATA_DIR}viral_gene_metadata.csv")

# Map each gene to its corresponding virus
gene_to_virus = {}
for _, row in metadata_df.iterrows():
    gene_to_virus[row['Gene_name'] + " (" + row['Virus_name'] + ")" ] = row['Virus_name']

# Initialize a dictionary to hold positive cell lines for each virus
positive_cell_lines = {}

# Iterate through each cell line in the heatmap DataFrame
for cell_line in all_viral_genes_heatmap_df.index:
    # Initialize a dictionary to count the number of genes expressed > 0 and > 0.5 for each virus
    virus_gene_counts = {}
    high_expression_counts = {}

    for gene in all_viral_genes_heatmap_df.columns:
        # Get the expression level of the current gene for the current cell line
        expression_level = all_viral_genes_heatmap_df.loc[cell_line, gene]

        # Map the gene to its corresponding virus
        virus_name = gene_to_virus.get(gene)

        if virus_name:
            if expression_level > 0.01:
                # Increment the count for this virus if it is expressed > 0.01
                if virus_name not in virus_gene_counts:
                    virus_gene_counts[virus_name] = 0
                virus_gene_counts[virus_name] += 1                  
            
            if expression_level > 1:
                    # Increment the high expression count for this virus if it is expressed > 0.3
                    if virus_name not in high_expression_counts:
                        high_expression_counts[virus_name] = 0
                    high_expression_counts[virus_name] += 1
            

    # Now check which viruses have 2 or more expressed genes > 0
    for virus_name in virus_gene_counts:
        if virus_gene_counts[virus_name] >= 2 and high_expression_counts.get(virus_name, 0) >= 1:
            if virus_name not in positive_cell_lines:
                positive_cell_lines[virus_name] = []
            positive_cell_lines[virus_name].append(cell_line)

# Convert the positive_cell_lines dictionary to a DataFrame for easier readability and output
positive_cell_lines_df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in positive_cell_lines.items()]))
positive_cell_lines_df = positive_cell_lines_df.transpose()  # transpose for better readability

# Save the results
positive_cell_lines_df.to_csv(f'{PLOT_DIR}positive_cell_lines_per_virus.csv', index=True)

### Get a list of all Cell lines that are negative for each Virus

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

# Initialize a dictionary to hold negative cell lines for each virus
negative_cell_lines = {}

# Get a list of all cell lines
all_cell_lines = all_viral_genes_heatmap_df.index.tolist()

# Iterate through each virus in the positive_cell_lines dictionary
for virus_name in positive_cell_lines:
    # Get the positive cell lines for the current virus
    positive_lines = positive_cell_lines[virus_name]
    
    # Identify negative cell lines by excluding positive ones from the list of all cell lines
    negative_lines = [line for line in all_cell_lines if line not in positive_lines]
    
    # Store the negative cell lines for the virus
    negative_cell_lines[virus_name] = negative_lines

# Convert the negative_cell_lines dictionary to a DataFrame for easier readability and output
negative_cell_lines_df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in negative_cell_lines.items()]))
negative_cell_lines_df = negative_cell_lines_df.transpose()  # transpose for better readability

# Save the results
negative_cell_lines_df.to_csv(f'{PLOT_DIR}negative_cell_lines_per_virus.csv', index=True)

### Get a list of enginieered cell lines

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

# Load the metadata for viral genes
metadata_df = pd.read_csv(f"{META_DATA_DIR}viral_gene_metadata.csv")

# Map each gene to its corresponding virus
gene_to_virus = {}
for _, row in metadata_df.iterrows():
    gene_to_virus[row['Gene_name'] + " (" + row['Virus_name'] + ")" ] = row['Virus_name']

# Initialize a dictionary to hold positive cell lines for each virus
engineered_cell_lines = {}

# Iterate through each cell line in the heatmap DataFrame
for cell_line in all_viral_genes_heatmap_df.index:
    # Initialize a dictionary to count the number of genes expressed > 0 and > 0.5 for each virus
    virus_gene_counts = {}
    high_expression_counts = {}

    # List to track the gene that meets the condition
    selected_gene = None

    for gene in all_viral_genes_heatmap_df.columns:
        # Get the expression level of the current gene for the current cell line
        expression_level = all_viral_genes_heatmap_df.loc[cell_line, gene]

        # Map the gene to its corresponding virus
        virus_name = gene_to_virus.get(gene)

        if virus_name:
            if expression_level > 0.01:
                # Increment the count for this virus if it is expressed > 0
                if virus_name not in virus_gene_counts:
                    virus_gene_counts[virus_name] = []
                virus_gene_counts[virus_name].append(gene) # Add the gene to the virus' gene list
            
            if expression_level > 100:
                # Increment the high expression count for this virus if it is expressed > 100
                if virus_name not in high_expression_counts:
                    high_expression_counts[virus_name] = 0
                high_expression_counts[virus_name] += 1

    # Now check which viruses have 1 gene > 100 and no other genes expressed
    for virus_name, genes in virus_gene_counts.items():
        if len(genes) == 1 and high_expression_counts.get(virus_name, 0) == 1:
            selected_gene = genes[0]  # There's exactly one gene for this virus with expression > 100
            if selected_gene not in engineered_cell_lines:
                engineered_cell_lines[selected_gene] = []
            engineered_cell_lines[selected_gene].append(cell_line)

# Convert the engineered_cell_lines dictionary to a DataFrame for easier readability and output
engineered_cell_lines_df = pd.DataFrame(dict([(k, pd.Series(v)) for k, v in engineered_cell_lines.items()]))
engineered_cell_lines_df = engineered_cell_lines_df.transpose()  # transpose for better readability

# Save the results
engineered_cell_lines_df.to_csv(f'{PLOT_DIR}engineered_cell-lines.csv', index=True)

### Get an overview how many positive cell lines there are for each virus

In [None]:
positive_cell_lines_count = {virus: len(cell_lines) for virus, cell_lines in positive_cell_lines.items()}

# Create a DataFrame from the counts
positive_cell_lines_count_df = pd.DataFrame(list(positive_cell_lines_count.items()), columns=['Virus_name', 'Positive_count'])

# Ensure all viruses from metadata_df are included with default value 0 if missing
all_viruses = metadata_df['Virus_name'].unique()
# Merge positive_df with all viruses, filling missing counts with 0
full_df = pd.DataFrame({'Virus_name': all_viruses}).merge(positive_cell_lines_count_df, on='Virus_name', how='left').fillna(0)

# Plot as a barplot
plt.figure(figsize=(24, 6))
bars = plt.bar(full_df['Virus_name'], full_df['Positive_count'], color='skyblue')

# Add text above the bars
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval, int(yval), ha='center', va='bottom')

# Adjust x-axis limits to make the plot tighter
plt.xlim(-0.5, len(full_df['Virus_name']) - 0.5)  # Remove excess space around the bars

plt.xlabel('Virus Name', fontsize=12)
plt.ylabel('Positive Cell Line Count', fontsize=12)
plt.title('Virus Positive Cell Line Counts', fontsize=18)
plt.xticks(rotation=90, ha='center')
plt.tight_layout()
plt.savefig(f'{PLOT_DIR}viral_positive_cell_line_counts.svg', format='svg', bbox_inches='tight')
plt.show()

### Get an overview how many engineered cell lines there are for each virus

In [None]:
engineered_cell_lines_count = {virus: len(cell_lines) for virus, cell_lines in engineered_cell_lines.items()}

# Create a DataFrame from the counts
engineered_cell_lines_count_df = pd.DataFrame(list(engineered_cell_lines_count.items()), columns=['Virus_name', 'Positive_count'])

# Ensure all viruses from metadata_df are included with default value 0 if missing
all_viruses = metadata_df['Virus_name'].unique()

# Plot as a barplot
plt.figure(figsize=(4, 6))
bars = plt.bar(engineered_cell_lines_count_df['Virus_name'], engineered_cell_lines_count_df['Positive_count'], color='skyblue')

# Add text above the bars
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, yval, int(yval), ha='center', va='bottom')

# Adjust x-axis limits to make the plot tighter
plt.xlim(-0.5, len(engineered_cell_lines_count_df['Virus_name']) - 0.5)  # Remove excess space around the bars

plt.xlabel('Gene Name', fontsize=12)
plt.ylabel('Engineered Cell Line Count', fontsize=12)
plt.title('Virus Engineered Cell Line Counts', fontsize=18)
plt.xticks(rotation=90, ha='center')
plt.tight_layout()
plt.savefig(f'{PLOT_DIR}viral_engineered_cell_line_counts.svg', format='svg', bbox_inches='tight')
plt.show()

### Get an overview how the positive viral Cell lines are distributed over the different tissues

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)
modeldata_df = pd.read_csv(f"{GENERAL_DATA}/internal-24q2_v87-model.csv") 

# Initialize a dictionary to hold counts of OncotreeLineage types for each virus
virus_type_counts = {}

# Iterate through each virus in the positive_cell_lines dictionary
for virus_name in positive_cell_lines:
    # Initialize a dictionary to count types for the current virus
    type_counts = {}
    
    # Get the positive cell lines for the current virus
    positive_lines = positive_cell_lines[virus_name]
    
    # Check each positive cell line for its type in the oncotree DataFrame
    for cell_line in positive_lines:
        # Find the type for the current cell line in the oncotree DataFrame
        oncotree_lineage = modeldata_df.loc[modeldata_df['ModelID'] == cell_line, 'OncotreeLineage']
        
        if not oncotree_lineage.empty:
            lineage_type = oncotree_lineage.iloc[0]  # Get the first matching lineage type
            # Increment the count for this lineage type
            if lineage_type not in type_counts:
                type_counts[lineage_type] = 0
            type_counts[lineage_type] += 1
            
    # Store the type counts for the virus
    virus_type_counts[virus_name] = type_counts

# Convert the virus_type_counts dictionary to a DataFrame for easier readability
virus_type_counts_df = pd.DataFrame.from_dict(virus_type_counts, orient='index').fillna(0)  # Fill NaN with 0
virus_type_counts_df = virus_type_counts_df.astype(int)  # Convert counts to integers

# Save the results to a CSV file
virus_type_counts_df.to_csv(f'{PLOT_DIR}virus_type_counts.csv', index=True)



# Ensure modeldata_df has only the cell lines present in the viral genes heatmap data
filtered_modeldata_df = modeldata_df[modeldata_df['ModelID'].isin(all_viral_genes_heatmap_df.index)]

# Calculate the total number of cell lines per tissue
tissue_totals = filtered_modeldata_df.groupby('OncotreeLineage')['ModelID'].count()

# Calculate the percentage of positive cell lines per tissue for each virus
virus_percentage_df = virus_type_counts_df.copy()

# Initialize a dictionary to hold absolute counts (for labels)
tissue_positive_counts = {}

for tissue in virus_percentage_df.columns:
    if tissue in tissue_totals:
        total_cell_lines = tissue_totals[tissue]
        virus_percentage_df[tissue] = (virus_percentage_df[tissue] / total_cell_lines) * 100  # Convert to percentage

        # Sum of positive cell lines for this tissue
        positive_sum = virus_type_counts_df[tissue].sum()
        tissue_positive_counts[tissue] = f"({positive_sum}/{total_cell_lines})"

# Plot the stacked bar chart
fig, ax = plt.subplots(figsize=(12, 8))

# Create a stacked bar plot using the percentage DataFrame
virus_percentage_df.T.plot(kind='bar', stacked=True, ax=ax, colormap='tab20')

# Customize the x-axis labels with the absolute counts
tissue_labels = [f"{tissue} {tissue_positive_counts.get(tissue, '(0/0)')}" for tissue in virus_percentage_df.columns]
ax.set_xticklabels(tissue_labels)

# Add labels and title
ax.set_xlabel('Tissues', fontsize=14)
ax.set_ylabel('Percentage of Positive Cell Lines', fontsize=14)
ax.set_title('Distribution of Virus Positive Cell Lines Across Tissues', fontsize=16)

# Rotate x-axis labels for better readability
plt.xticks(rotation=45, ha='right')

# Add a legend with virus names
ax.legend(title="Viruses", bbox_to_anchor=(1.05, 1), loc='upper left')

# Save the plot
plt.tight_layout()
plt.savefig(f'{PLOT_DIR}virus_positive_distribution_by_tissue.svg', format='svg')
plt.show()


In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)
# Load model data
modeldata_df = pd.read_csv(f"{GENERAL_DATA}/internal-24q2_v87-model.csv") 

# Initialize a dictionary to hold counts of OncotreeLineage types for each virus
virus_type_counts = {}

# Iterate through each virus in the positive_cell_lines dictionary
for virus_name in positive_cell_lines:
    # Initialize a dictionary to count types for the current virus
    type_counts = {}
    
    # Get the positive cell lines for the current virus
    positive_lines = positive_cell_lines[virus_name]
    
    # Check each positive cell line for its type in the oncotree DataFrame
    for cell_line in positive_lines:
        # Find the type for the current cell line in the oncotree DataFrame
        oncotree_lineage = modeldata_df.loc[modeldata_df['ModelID'] == cell_line, 'OncotreeLineage']
        
        if not oncotree_lineage.empty:
            lineage_type = oncotree_lineage.iloc[0]  # Get the first matching lineage type
            # Increment the count for this lineage type
            if lineage_type not in type_counts:
                type_counts[lineage_type] = 0
            type_counts[lineage_type] += 1
            
    # Only store counts if the virus has positive cell lines
    if type_counts:
        virus_type_counts[virus_name] = type_counts

# Convert the virus_type_counts dictionary to a DataFrame for easier readability
virus_type_counts_df = pd.DataFrame.from_dict(virus_type_counts, orient='index').fillna(0)  # Fill NaN with 0
virus_type_counts_df = virus_type_counts_df.astype(int)  # Convert counts to integers

# Filter model data to include only the cell lines present in the viral genes heatmap data
filtered_modeldata_df = modeldata_df[modeldata_df['ModelID'].isin(all_viral_genes_heatmap_df.index)]

# Calculate the total number of cell lines per tissue
tissue_totals = filtered_modeldata_df.groupby('OncotreeLineage')['ModelID'].count()

# Specify the desired order of viruses
desired_order = [
    "SV40", "pLenti", "MCV",
    "HPV16", "HPV18", "HBV",
    "EBV (HHV-4)", "HHV-8, KSHV", "HTLV-1"
]

# Create a new DataFrame ordered by the desired order
ordered_virus_type_counts_df = virus_type_counts_df.loc[desired_order]

# Number of plots per row
plots_per_row = 3

# Calculate the total number of rows based on the number of viruses
num_viruses = len(virus_type_counts_df.index)
num_rows = math.ceil(num_viruses / plots_per_row)

# Create the subplot grid
fig, axes = plt.subplots(num_rows, plots_per_row, figsize=(plots_per_row * 4.5, num_rows * 5), sharey=True)

# Ensure the x-axis labels (tissues) are strings before plotting
tissue_labels = virus_type_counts_df.columns.astype(str)
# Use the colormap object from matplotlib
cmap = plt.get_cmap('tab20c')

# Flatten the axes array for easier indexing when plotting
axes = axes.flatten()

# Iterate through viruses and plot
for i, virus in enumerate(ordered_virus_type_counts_df.index):
    ax = axes[i]
    tissue_values = ordered_virus_type_counts_df.loc[virus]  # Use absolute values

    # Horizontal bar chart
    bars = ax.barh(tissue_labels, tissue_values, color=cmap.colors[:len(tissue_values)])

    # Set plot title and labels
    ax.set_title(virus, fontsize=12)
    ax.set_xlabel('# pos. Cell Lines', fontsize=14)

    # Display y-tick labels only on the leftmost and rightmost plots
    if i % plots_per_row == 0:  # First plot in the row (left side)
        ax.set_ylabel('Tissues', fontsize=14)
    else:
        plt.setp(ax.get_yticklabels(), visible=False)  # Hide y-tick labels for middle plots  

    ax.set_yticklabels(tissue_labels, fontsize=12, fontweight='bold')

    # Customize y-tick labels with colors
    for j, label in enumerate(ax.get_yticklabels()):
        label.set_color(bars[j].get_facecolor())  # Set color to match the bar
        label.set_bbox(dict(facecolor='none', edgecolor=bars[j].get_edgecolor(), pad=0.1))  # Adjust the padding here   

# Hide any unused subplots (if the number of viruses is not a perfect multiple of plots_per_row)
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

# Add a shared title for the entire figure
fig.suptitle('Tissue Distribution for Each Virus (Absolute Counts)', fontsize=16)

# Adjust layout for better readability
plt.tight_layout(rect=[0, 0, 1, 0.95])

# Save the plot
plt.savefig(f'{PLOT_DIR}virus_tissue_distribution_grid.svg', format='svg')
plt.show()

### Get an overview which cell lines are positive for which virus and show which tissue they are

In [None]:
# Load model data
modeldata_df = pd.read_csv(f"{GENERAL_DATA}/internal-24q2_v87-model.csv") 

# Initialize a dictionary to hold the overview of positive cell lines
virus_cell_line_info = {}

# Iterate through each virus in the positive_cell_lines dictionary
for virus_name in positive_cell_lines:
    # Get the positive cell lines for the current virus
    positive_lines = positive_cell_lines[virus_name]
    
    for cell_line in positive_lines:
        # Find the type for the current cell line in the oncotree DataFrame
        oncotree_lineage = modeldata_df.loc[modeldata_df['ModelID'] == cell_line, 'OncotreeLineage']
        
        if not oncotree_lineage.empty:
            lineage_type = oncotree_lineage.iloc[0]  # Get the first matching lineage type
            
            # Store information in the dictionary
            if cell_line not in virus_cell_line_info:
                virus_cell_line_info[cell_line] = {'Viruses': [], 'Tissue': lineage_type}
            
            # Append the virus name to the list for this cell line
            virus_cell_line_info[cell_line]['Viruses'].append(virus_name)

# Convert the virus_cell_line_info dictionary to a DataFrame
overview_df = pd.DataFrame.from_dict(
    {
        cell_line: {
            'Viruses': ', '.join(info['Viruses']),  # Join multiple virus names
            'Tissue': info['Tissue']
        }
        for cell_line, info in virus_cell_line_info.items()
    },
    orient='index'
).reset_index().rename(columns={'index': 'Cell Line'})

# Save the DataFrame to a CSV file
overview_df.to_csv(f'{PLOT_DIR}virus_cell_line_overview.csv', index=False)

### Get an overview which cell lines are negative for which virus and show which tissue they are

In [None]:
# Load model data
modeldata_df = pd.read_csv(f"{GENERAL_DATA}/internal-24q2_v87-model.csv") 

# Initialize a dictionary to hold the overview of positive cell lines
virus_cell_line_info = {}

# Iterate through each virus in the positive_cell_lines dictionary
for virus_name in negative_cell_lines:
    # Get the positive cell lines for the current virus
    negative_lines = negative_cell_lines[virus_name]
    
    for cell_line in negative_lines:
        # Find the type for the current cell line in the oncotree DataFrame
        oncotree_lineage = modeldata_df.loc[modeldata_df['ModelID'] == cell_line, 'OncotreeLineage']
        
        if not oncotree_lineage.empty:
            lineage_type = oncotree_lineage.iloc[0]  # Get the first matching lineage type
            
            # Store information in the dictionary
            if cell_line not in virus_cell_line_info:
                virus_cell_line_info[cell_line] = {'Viruses': [], 'Tissue': lineage_type}
            
            # Append the virus name to the list for this cell line
            virus_cell_line_info[cell_line]['Viruses'].append(virus_name)

# Convert the virus_cell_line_info dictionary to a DataFrame
overview_df_ = pd.DataFrame.from_dict(
    {
        cell_line: {
            'Viruses negative for': ', '.join(info['Viruses']),  # Join multiple virus names
            'Tissue': info['Tissue']
        }
        for cell_line, info in virus_cell_line_info.items()
    },
    orient='index'
).reset_index().rename(columns={'index': 'Cell Line'})

# Save the DataFrame to a CSV file
overview_df_.to_csv(f'{PLOT_DIR}virus_negative_cell_line_overview.csv', index=False)

### Get a list of positive cell lines that are missing CRISPR screens

In [None]:
# Import crisper cancer cell line data
f = h5py.File(GENERAL_DATA + "internal-24q2_v87-crisprgeneeffect")
data = sc.read_hdf(GENERAL_DATA + "internal-24q2_v87-crisprgeneeffect", key="data")
data.obs_names = [cl.decode("utf-8") for cl in f["dim_0"][()]]
data.var_names = [gene.decode("utf-8").split(" ")[0] for gene in f["dim_1"][()]]
human_df = data.to_df()

modeldata_df = pd.read_csv(f"{GENERAL_DATA}internal-24q2_v87-model.csv")
cell_line_id_to_name = modeldata_df.set_index('ModelID')['CellLineName'].to_dict()

overview_df['CellLineName'] = overview_df['Cell Line'].map(cell_line_id_to_name)
overview_df.to_csv(f'{PLOT_DIR}virus_positive_cell_line_overview.csv', index=False)

overview_df_filtered = overview_df[~overview_df['Cell Line'].isin(human_df.index)]
overview_df_filtered['CellLineName'] = overview_df_filtered['Cell Line'].map(cell_line_id_to_name)

overview_df_filtered.to_csv(f'{PLOT_DIR}virus_cell_lines_missing_crispr_overview.csv', index=False)

In [None]:
overview_df_filtered_ = overview_df_filtered[overview_df_filtered['Viruses'].astype(str).str.contains('EBV')]
overview_df_filtered_.to_csv(f'{PLOT_DIR}virus_cell_lines_missing_EBV_crispr_overview.csv', index=False)

overview_df_filtered_ = overview_df_filtered[overview_df_filtered['Viruses'].astype(str).str.contains('HPV')]
overview_df_filtered_.to_csv(f'{PLOT_DIR}virus_cell_lines_missing_HPV_crispr_overview.csv', index=False)

### Make a Heatmap to cluster the tissues

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)
modeldata_df = pd.read_csv(f"{GENERAL_DATA}internal-24q2_v87-model.csv")

# Create a dictionary mapping tissues to cell lines
tissue_to_cell_lines = modeldata_df.groupby('OncotreeLineage')['ModelID'].apply(list).to_dict()

def plot_heatmap_by_tissue(viral_genes_heatmap_df, tpm_filter=0.5):
    # Ensure values below TPM filter are zero
    viral_genes_heatmap_df[viral_genes_heatmap_df < tpm_filter] = 0
    
    # Get the unique tissue types
    unique_tissues = list(tissue_to_cell_lines.keys())
    
    # Create subplots for each tissue
    num_tissues = len(unique_tissues)
    fig = plt.figure(figsize=(120, 5 * num_tissues))  # Adjust figure size as needed
    gs = gridspec.GridSpec(num_tissues, 1, height_ratios=[len(tissue_to_cell_lines[tissue]) for tissue in unique_tissues])
    
    # Add a title above the first subplot
    plt.suptitle('Viral Expression Heatmap by Tissue', fontsize=30, fontweight='bold', y=1.02)
    
    # Loop through each tissue and plot
    for i, tissue in enumerate(unique_tissues):
        ax = plt.subplot(gs[i])
        
        # Look up the cell lines for the current tissue from the dictionary
        tissue_cell_lines = tissue_to_cell_lines[tissue]
        
        # Filter only cell lines that are present in viral_genes_heatmap_df
        valid_cell_lines = [cell_line for cell_line in tissue_cell_lines if cell_line in viral_genes_heatmap_df.index]
        
        if not valid_cell_lines:
            continue  # Skip tissues without matching cell lines
        
        # Subset the data for these valid cell lines
        tissue_data = viral_genes_heatmap_df.loc[valid_cell_lines].dropna()
        
        # Skip if tissue_data is empty or has fewer than 2 rows (required for clustering)
        if tissue_data.empty or len(tissue_data) < 2:
            continue  # Skip empty or insufficient data
        
        # Perform hierarchical clustering on rows (Cell lines)
        linkage_matrix = linkage(tissue_data.values, method='average', metric='euclidean')
        row_order = leaves_list(linkage_matrix)
        clustered_df = tissue_data.iloc[row_order, :]
        
        # Plot heatmap
        sns.heatmap(np.log10(clustered_df + 1e-9), mask=(clustered_df < tpm_filter), cmap='plasma', 
                    cbar=False, linewidths=0.003, linecolor="lightgray", ax=ax)
        
        # Set ylabel as tissue name
        ax.set_ylabel(tissue, fontsize=30, rotation=0, labelpad=100, va="center")
        
        # Only set x-axis label for the last subplot
        if i == num_tissues - 1:
            ax.set_xlabel("Viral Genes", fontsize=20)
        else:
            ax.set_xticks([])
        
        # Set y-axis label font size
        ax.yaxis.label.set_size(20)
        ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=15)
    
    # Add a colorbar on the right side of the figure
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    plt.colorbar(ax.collections[0], cax=cbar_ax, label='log10(TPM)')
    
    # Save the plot
    plt.savefig(f'{PLOT_DIR}viral_expression_by_tissue_heatmap.svg', format='svg', bbox_inches='tight')
    plt.show()

# Call the function to plot
plot_heatmap_by_tissue(all_viral_genes_heatmap_df)

### Make a heatmap for a specified virus (like EBV) and cluster the cell lines according to their expression, to find tissue specific viral gene expression

In [None]:
def plot_viral_gene_clustering(viral_names):
    # Combine the genes from all the selected viruses
    selected_genes = []
    for virus in viral_names:
        selected_genes.extend(virus_to_genes.get(virus, []))
    
    all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

    # Subset the matrix for the selected viral genes
    selected_virus_matrix = all_viral_genes_heatmap_df[selected_genes]

    # Get the positive cell lines for all the selected viruses
    selected_cell_lines = []
    for virus in viral_names:
        selected_cell_lines.extend(positive_cell_lines.get(virus, []))
    
    # Filter the matrix for the positive cell lines
    selected_virus_matrix = selected_virus_matrix.loc[selected_cell_lines]

    # Load the model data (cell line metadata) and map tissues to the cell lines
    filtered_modeldata_df = modeldata_df[modeldata_df['ModelID'].isin(selected_virus_matrix.index)]

    # Create a mapping from 'ModelID' to 'OncotreeLineage' (tissue)
    cell_line_to_tissue = filtered_modeldata_df.set_index('ModelID')['OncotreeLineage'].to_dict()

    # Add a 'Tissue' column to the virus matrix based on the 'ModelID' -> 'OncotreeLineage' mapping
    selected_virus_matrix['Tissue'] = selected_virus_matrix.index.map(cell_line_to_tissue)

    # Perform hierarchical clustering on rows (Cell lines)
    linkage_matrix = linkage(selected_virus_matrix.drop(columns='Tissue').values, method='average', metric='euclidean')
    row_order = leaves_list(linkage_matrix)

    # Reorder rows based on clustering result
    clustered_virus_matrix = selected_virus_matrix.iloc[row_order, :]

    # Get the number of genes and cell lines
    num_genes = len(clustered_virus_matrix.columns) - 1  # Exclude the 'Tissue' column
    num_cell_lines = len(clustered_virus_matrix)

    # Set the figure size dynamically based on the number of genes and cell lines
    fig_width = 0.3 * num_genes  # Adjust the scaling factor as needed
    fig_height = 0.2 * num_cell_lines  # Adjust the scaling factor as needed
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    # Plot the heatmap
    sns.heatmap(np.log10(clustered_virus_matrix.drop(columns='Tissue')),
                mask=(clustered_virus_matrix.drop(columns='Tissue') == 0), 
                cmap='plasma', 
                cbar_kws={'label': 'log10(TPM)'}, 
                linewidths=0.003, 
                linecolor="lightgray",
                ax=ax)

    # Add a default color for undefined (NaN) tissues
    default_tissue_color = 'gray'

    # Add tissue color mapping to the rows
    tissue_color_mapping = {tissue: sns.color_palette("Set2", len(set(cell_line_to_tissue.values())))[i] 
                            for i, tissue in enumerate(set(cell_line_to_tissue.values()))}

    # Add tissue-colored labels to y-axis
    for ytick, tissue in zip(plt.gca().get_yticklabels(), clustered_virus_matrix['Tissue']):
        if pd.isna(tissue):
            ytick.set_color(default_tissue_color)  # Assign default color for undefined tissues
        else:
            ytick.set_color(tissue_color_mapping.get(tissue, default_tissue_color))  # Use mapped color or default if tissue is missing

    # Add a legend for the tissue colors
    handles = [plt.Line2D([0], [0], color=color, lw=4) for color in tissue_color_mapping.values()]
    labels = tissue_color_mapping.keys()
    plt.legend(handles, labels, title="Tissue", bbox_to_anchor=(1.5, 1), loc='upper left', borderaxespad=0.)

    # Set title and labels
    viral_names_str = ", ".join(viral_names)  # Join viral names for title and filename
    plt.title(f'Viral Gene Expression for {viral_names_str}', fontsize=20)
    ax.set(xlabel="Viral Genes", ylabel="Cell Lines")
    ax.xaxis.label.set_size(14)
    ax.yaxis.label.set_size(14)
    plt.yticks(rotation=0)

    # Save the plot to a file
    clustered_virus_matrix.to_csv(f'{PLOT_DIR}/cluster-viral-gene-expression-{viral_names_str}.csv')
    plt.savefig(f'{PLOT_DIR}/cluster-viral-gene-expression-{viral_names_str}.svg', format='svg', bbox_inches='tight')
    plt.show()


all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

plot_viral_gene_clustering(['HPV18'])
plot_viral_gene_clustering(['EBV (HHV-4)'])
plot_viral_gene_clustering(['pLenti'])

### Check for possible overlaps of positive cell lines between the different viruses

In [None]:
from upsetplot import UpSet, generate_counts
import warnings

# Suppress FutureWarnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Convert the data into a form suitable for UpSet plot
# Create a set representation for all unique cell lines across viruses
all_cell_lines = set().union(*positive_cell_lines.values())

# Create a binary membership matrix (each row is a cell line, each column a virus)
data = {}
for cell_line in all_cell_lines:
    data[cell_line] = {virus: int(cell_line in cell_lines) for virus, cell_lines in positive_cell_lines.items()}

# Convert the dictionary to a DataFrame
df = pd.DataFrame.from_dict(data, orient='index')

# Generate counts for UpSet plot
# This step compresses the boolean matrix to show overlaps (the number of cell lines in each combination of viruses)
counts = df.groupby(list(df.columns)).size()

# Create the UpSet plot
upset = UpSet(counts, subset_size='count', show_counts='%d')

# Plot the UpSet plot
upset.plot()
plt.suptitle("UpSet Plot of Cell Line Overlap Across 10 Viruses")
plt.savefig(f'{PLOT_DIR}overlap_different_viruses_positive_celllines_upset-plot.png', format='png', bbox_inches='tight')
plt.show()

In [None]:
# Function to create the graph with all viruses as nodes
def create_graph_with_all_viruses(cell_lines_dict):
    G = nx.Graph()  # Create an empty graph
    for virus, lines in cell_lines_dict.items():
        lines_set = set(lines)  # Ensure lines are a set
        G.add_node(virus, weight=len(lines_set), pos_cell_lines=len(lines_set))  # Add node with cell line count
        
        for other_virus, other_lines in cell_lines_dict.items():
            if virus != other_virus:
                other_lines_set = set(other_lines)  # Ensure other_lines is a set
                overlap_count = len(lines_set & other_lines_set)
                if overlap_count > 0:  # Only add edges for overlapping viruses
                    G.add_edge(virus, other_virus, weight=overlap_count)
    return G

# Create the graph
G = create_graph_with_all_viruses(positive_cell_lines)

# Draw the graph
plt.figure(figsize=(14, 12))
pos = nx.spring_layout(G, seed=50, k=1.8)  # Position nodes using spring layout with larger distance

# Set node sizes based on the count of positive cell lines (scaled for visibility)
node_sizes = [G.nodes[virus]['weight'] * 200 for virus in G.nodes()]  # Scale the size by a factor

# Generate a unique color for each node
colors = plt.cm.get_cmap('tab10', len(G.nodes()))  # Using colormap 'tab10'
node_colors = [colors(i) for i in range(len(G.nodes()))]

# Draw nodes and edges
nx.draw(G, pos, with_labels=False, node_color=node_colors, node_size=node_sizes, font_weight='bold', edge_color='gray')

# Draw edge labels (weights)
edge_labels = nx.get_edge_attributes(G, 'weight')
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)

# Draw node labels with cell line counts, adjusting position
node_labels = {virus: f"{virus}\nn={data['pos_cell_lines']}" for virus, data in G.nodes(data=True)}
# Positioning labels slightly above the nodes
label_pos = {k: (v[0], v[1]) for k, v in pos.items()}
nx.draw_networkx_labels(G, label_pos, labels=node_labels, font_color='black', font_size=14, font_weight='bold')

# Title
plt.title("Graph of Virus Overlaps with Cell Line Counts", fontsize=16)
plt.show()

### Check a list of Cell lines if they were tested on viral gene expression and what viruses were found positive and negative

In [None]:
def check_cell_line_status(file_name):
    modeldata_df = pd.read_csv(f"{GENERAL_DATA}internal-24q2_v87-model.csv") 

    new_cell_lines_df = pd.read_csv(f"{GENERAL_DATA}{file_name}.csv", header=None)
    new_cell_lines_df.columns = ['ModelID']  # Rename the column for easier access

    # Initialize new columns in the DataFrame
    new_cell_lines_df['Positive'] = ''
    new_cell_lines_df['Engineered'] = ''
    new_cell_lines_df['Negative'] = ''
    new_cell_lines_df['Tissue'] = ''
    new_cell_lines_df['Status'] = 'not in list'

    # Iterate through each cell line in the new DataFrame
    for index, row in new_cell_lines_df.iterrows():
        cell_line = row['ModelID']
        
        # Check if the cell line is in positive cell lines
        positive_viruses = [virus for virus, lines in positive_cell_lines.items() if cell_line in lines]
        negative_viruses = [virus for virus, lines in negative_cell_lines.items() if cell_line in lines]
        engineered_genes = [gene for gene, lines in engineered_cell_lines.items() if cell_line in lines]
        
        # Update the DataFrame based on the checks
        if positive_viruses:
            new_cell_lines_df.at[index, 'Positive'] = ', '.join(positive_viruses)            
            new_cell_lines_df.at[index, 'Status'] = 'positive'        
        if engineered_genes:
            new_cell_lines_df.at[index, 'Engineered'] = ', '.join(engineered_genes)
            new_cell_lines_df.at[index, 'Status'] = 'engineered'
        if negative_viruses:
            new_cell_lines_df.at[index, 'Negative'] = ', '.join(negative_viruses)
            if(new_cell_lines_df.at[index, 'Status'] == 'not in list'):
                new_cell_lines_df.at[index, 'Status'] = 'negative'

        new_cell_lines_df.at[index, 'Tissue'] = modeldata_df.loc[modeldata_df['ModelID'] == cell_line, 'OncotreeLineage'].values[0]
    
    # Save the updated DataFrame to a new CSV file
    new_cell_lines_df.to_csv(f'{PLOT_DIR}{file_name}.csv', index=False)

In [None]:
check_cell_line_status("priority_cell_lines")

In [None]:
check_cell_line_status("HPV_positive")

In [None]:
check_cell_line_status("EBV_potential")

### Check which cell lines were not annotated yet

In [None]:
CL_public = pd.read_csv(SCRIPTS_DATA+'cell_line_google_bucket_index.csv', index_col=0)
CL_private = pd.read_csv(SCRIPTS_DATA+'cell_line_google_bucket_index2.csv', index_col=0)

CL_all = pd.concat([CL_public, CL_private])

CL_annotated = [d for d in os.listdir(QUANT_DIR) if os.path.isdir(os.path.join(QUANT_DIR, d))]

CL_not_annotated = CL_all[~CL_all.index.isin(CL_annotated)]
CL_not_annotated.to_csv(SCRIPTS_DATA+'cell_line_google_bucket_index_missing.csv')

### Compare mine and Sowmya's mapping data to check how they compare (looks good)

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

mapped_cell_lines = pd.read_csv(f"{GENERAL_DATA}HPV_mapped_transposed.csv")
mapped_cell_lines.columns = ['Cell_line', 'E6 (HPV18)', 'E7 (HPV18)', 'E1 (HPV18)', 'E2 (HPV18)', 'HPV18.5', 'HPV18.6', 'HPV18.7', 'L1 (HPV18)', 
                             'E6 (HPV16)', 'E7 (HPV16)', 'E1 (HPV16)', 'E1^E4 (HPV16)', 'HPV16.5', 'HPV16.6', 'HPV16.7', 'HPV16.8', 'HPV16.9']
compare_viral_hits = pd.DataFrame(columns= ['Cell_line', 'viral_genes_salmon', 'viral_genes_sowmya', 'comment'])
compare_viral_hits['Cell_line'] = mapped_cell_lines['Cell_line']

hpv_columns = all_viral_genes_heatmap_df.columns[all_viral_genes_heatmap_df.columns.str.contains('HPV16|HPV18')]
numeric_cols = mapped_cell_lines.columns[1:]

# Fill viral_genes_sowmya for mapped_cell_lines
for index, row in mapped_cell_lines.iterrows():
    viral_genes = numeric_cols[row[numeric_cols] > 0].tolist()  # Get column names where values > 0
    compare_viral_hits.loc[compare_viral_hits['Cell_line'] == row['Cell_line'], 'viral_genes_sowmya'] = ', '.join(viral_genes)

# Fill viral_genes_salmon for all_viral_genes_heatmap_df
for index, row in compare_viral_hits.iterrows():
    cell_line = row['Cell_line']
    
    if cell_line in all_viral_genes_heatmap_df.index:
        # Get column names where values > 0 in HPV columns
        viral_genes = hpv_columns[all_viral_genes_heatmap_df.loc[cell_line, hpv_columns] > 0].tolist()
        compare_viral_hits.loc[index, 'viral_genes_salmon'] = ', '.join(viral_genes)
    else:
        # If no match found, update the comment column
        compare_viral_hits.loc[index, 'comment'] = 'not in list'

compare_viral_hits.to_csv(f'{PLOT_DIR}compare_viral_gene_finds.csv', index=False)

# Function to compare unordered viral gene lists
def compare_unordered_entries(row):
    viral_salmon_set = set(row['viral_genes_salmon'].split(', ')) if pd.notna(row['viral_genes_salmon']) else set()
    viral_sowmya_set = set(row['viral_genes_sowmya'].split(', ')) if pd.notna(row['viral_genes_sowmya']) else set()
    
    # Compare the sets instead of strings
    return viral_salmon_set == viral_sowmya_set

# Filter out rows where 'comment' is 'not in list'
filtered_hits = compare_viral_hits[compare_viral_hits['comment'] != 'not in list']

# Keep only rows with entries in either 'viral_genes_salmon' or 'viral_genes_sowmya'
filtered_hits = filtered_hits[(filtered_hits['viral_genes_salmon'].notna()) | (filtered_hits['viral_genes_sowmya'].notna())]

# Remove the 'comment' column
filtered_hits = filtered_hits.drop(columns=['comment'])

# Attach raw data columns from both dataframes (map on Cell_line)
# Merging raw data columns from both mapped_cell_lines and all_viral_genes_heatmap_df
filtered_hits = filtered_hits.merge(mapped_cell_lines, on='Cell_line', how='left')
filtered_hits = filtered_hits.merge(all_viral_genes_heatmap_df[hpv_columns], left_on='Cell_line', right_index=True, how='left')

# Remove duplicate rows based on 'Cell_line'
filtered_hits = filtered_hits.drop_duplicates(subset='Cell_line')

# Compare unordered viral_genes_salmon and viral_genes_sowmya
filtered_hits['same_entries'] = filtered_hits.apply(compare_unordered_entries, axis=1)

# Create separate dataframes for same and different entries
same_entries = filtered_hits[filtered_hits['same_entries']]
different_entries = filtered_hits[~filtered_hits['same_entries']]

# Drop the 'same_entries' column as it's not needed anymore
same_entries = same_entries.drop(columns=['same_entries'])
different_entries = different_entries.drop(columns=['same_entries'])

# Save the results
same_entries.to_csv(f'{PLOT_DIR}same_viral_gene_finds.csv', index=False)
different_entries.to_csv(f'{PLOT_DIR}different_viral_gene_finds.csv', index=False)

# Function to zero out matching viral genes and raw data columns with _x and _y suffixes
def zero_out_raw_data(row, raw_data_columns_x, raw_data_columns_y):
    viral_salmon = set(row['viral_genes_salmon'].split(', '))
    viral_sowmya = set(row['viral_genes_sowmya'].split(', '))

    # Find the matching genes between the two columns
    common_genes = viral_salmon.intersection(viral_sowmya)

    # Remove matching genes from both sets
    viral_salmon -= common_genes
    viral_sowmya -= common_genes

    # If no viral genes left after removing matches, assign '0'
    row['viral_genes_salmon'] = ', '.join(viral_salmon) if viral_salmon else '0'
    row['viral_genes_sowmya'] = ', '.join(viral_sowmya) if viral_sowmya else '0'

    # Zero out raw data columns for matching genes
    for gene in common_genes:
        if gene + '_x' in raw_data_columns_x:
            row[gene + '_x'] = 0  # Zero out corresponding raw data value in _x columns
        if gene + '_y' in raw_data_columns_y:
            row[gene + '_y'] = 0  # Zero out corresponding raw data value in _y columns
    return row

# Identify raw data columns with _x and _y suffixes
raw_data_columns_x = [col for col in different_entries.columns if col.endswith('_x')]
raw_data_columns_y = [col for col in different_entries.columns if col.endswith('_y')]

# Apply the function to remove matching viral genes and zero out raw data columns
different_entries = different_entries.apply(lambda row: zero_out_raw_data(row, raw_data_columns_x, raw_data_columns_y), axis=1)

# Drop columns from raw data that are fully zero
raw_data_to_drop = [col for col in raw_data_columns_x + raw_data_columns_y if (different_entries[col] == 0).all()]
different_entries.drop(columns=raw_data_to_drop, inplace=True)

# Save the final result
different_entries.to_csv(f'{PLOT_DIR}filtered_different_viral_gene_finds.csv', index=False)

### Compare the tpm distribution between different genes for each virus as 4D histograms.

In [None]:
all_viral_genes_heatmap_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

def load_data(meta_data_dir, general_data_dir):
    """ Load metadata and model data CSV files """
    metadata_df = pd.read_csv(f"{meta_data_dir}viral_gene_metadata.csv")
    modeldata_df = pd.read_csv(f"{general_data_dir}internal-24q2_v87-model.csv") 
    modeldata_df['OncotreeLineage'] = modeldata_df['OncotreeLineage'].astype(str).fillna('Unknown')
    return metadata_df, modeldata_df

def filter_zero_tpm_rows(heatmap_df):
    """ Filter rows with all-zero TPM values and subset columns for specific analysis """
    filtered_df = heatmap_df.loc[heatmap_df.sum(axis=1) > 0]
    return filtered_df 

def map_genes_to_viruses(metadata_df):
    """ Create a mapping of each gene to its corresponding virus """
    gene_to_virus = {row['Gene_name'] + " (" + row['Virus_name'] + ")": row['Virus_name'] for _, row in metadata_df.iterrows()}
    return gene_to_virus

def group_genes_by_virus(genes, gene_to_virus):
    """ Group viral genes by their corresponding virus """
    virus_to_genes = {}
    for gene in genes:
        virus_name = gene_to_virus.get(gene)
        if virus_name:
            if virus_name not in virus_to_genes:
                virus_to_genes[virus_name] = []
            virus_to_genes[virus_name].append(gene)
    return virus_to_genes

def prepare_gene_data(gene, filtered_gene_data, modeldata_df, tpm_bins):
    """ Prepare the data for a specific gene and create a DataFrame for TPM distribution """
    gene_data = filtered_gene_data[gene]
    gene_data = gene_data.loc[gene_data > 0]  # Remove rows with TPM of 0
    data = []
    for cell_line, tpm_value in gene_data.items():
        # Check if the cell line is present in modeldata_df, if not return 'NA'
        tissue_info = modeldata_df.loc[modeldata_df['ModelID'] == cell_line, 'OncotreeLineage']
        tissue = tissue_info.iloc[0] if not tissue_info.empty else 'not in list'
        
        data.append((cell_line, tpm_value, tissue, gene))
    gene_df = pd.DataFrame(data, columns=['CellLine', 'TPM', 'Tissue', 'Gene'])
    gene_df['TPM_Bin'] = pd.cut(gene_df['TPM'], bins=tpm_bins, include_lowest=True)
    return gene_df

def generate_bins(filtered_gene_data):
    """ Generate custom bins based on the data """
    bins_0_to_1 = np.arange(0, 1, 0.1)
    bins_1_to_10 = np.arange(1, 10, 1)
    max_tpm_value = filtered_gene_data.max().max()
    bins_10_onwards = np.arange(10, max_tpm_value + 10, 10) if max_tpm_value > 10 else np.array([])
    tpm_bins = np.concatenate([bins_0_to_1, bins_1_to_10, bins_10_onwards])
    tpm_bins_dim = np.concatenate([np.full(len(bins_0_to_1) - 1, 0.1), np.full(len(bins_1_to_10), 1), np.full(len(bins_10_onwards), 10)])
    return tpm_bins, tpm_bins_dim

def plot_3d_distribution(ax, tissue_binned_counts_per_gene, tpm_bins, tpm_bins_dim, genes, tissue_color_map, 
                         xlim=None, title="", show_gene_labels=True, y_stretch=1.5):
    """ General 3D plotting function for TPM distributions. Can handle both main and zoomed plots. """
    ax.set_title(title, fontsize=28)
    bin_centers = (tpm_bins[:-1] + tpm_bins[1:]) / 2

    tissues_in_plot = set() # To keep track of which tissues are actually plotted

    for gene_index, tissue_binned_counts in enumerate(tissue_binned_counts_per_gene):
        bottom_bars = np.zeros(len(bin_centers))
        for tissue, counts in tissue_binned_counts.items():
            valid_indices = (counts > 0)
            if xlim:
                valid_indices = valid_indices & (tpm_bins_dim >= xlim[0]) & (tpm_bins_dim <= xlim[1] - 0.1)
            if np.any(valid_indices):
                ax.bar3d(bin_centers[valid_indices], gene_index, bottom_bars[valid_indices], tpm_bins_dim[valid_indices], 0.3, counts[valid_indices], color=tissue_color_map[tissue])
                tissues_in_plot.add(tissue)  # Record the tissue for the legend
            bottom_bars += counts
    
    if xlim:
        ax.set_xlim(xlim)
    
    ax.set_xlabel('TPM', fontsize=20)
    ax.set_ylabel('Gene', fontsize=20)
    ax.set_zlabel('Number of Cell Lines', fontsize=20)
    ax.xaxis.labelpad = 10
    ax.yaxis.labelpad = -3
    ax.zaxis.labelpad = 10
    #ax.set_yticks(range(len(genes)))
    ax.set_yticks(range(len(genes)))
    ax.set_yticklabels([]) # Clear default y-tick labels
    
    if show_gene_labels:
        for i in range(min(len(ax.get_yticks()), len(genes))):
            ax.text(1.35 * (ax.get_xlim()[1] - ax.get_xlim()[0]), i, 0, genes[i], ha='right', va='center')

    return tissues_in_plot  # Return the set of tissues actually plotted


def main(meta_data_dir, general_data_dir, all_viral_genes_heatmap_df, plot_dir):
    metadata_df, modeldata_df = load_data(meta_data_dir, general_data_dir)
    filtered_gene_data = filter_zero_tpm_rows(all_viral_genes_heatmap_df)
    gene_to_virus = map_genes_to_viruses(metadata_df)
    virus_to_genes = group_genes_by_virus(filtered_gene_data.columns, gene_to_virus)

    pdf_filename = f"{PLOT_DIR}virus_tpm_histograms_3D.pdf"
    with PdfPages(pdf_filename) as pdf:
        for virus, genes in virus_to_genes.items():

            filtered_gene_data_subset = filtered_gene_data[genes]
            tpm_bins, tpm_bins_dim = generate_bins(filtered_gene_data_subset)

            tissue_types = np.unique(np.append(modeldata_df['OncotreeLineage'].values, 'not in list'))
            tissue_color_map = dict(zip(tissue_types, sns.color_palette("Set2", n_colors=len(tissue_types))))

            all_gene_dfs = Parallel(n_jobs=8)(delayed(prepare_gene_data)(gene, filtered_gene_data_subset, modeldata_df, tpm_bins) for gene in genes)            
            combined_df = pd.concat(all_gene_dfs, ignore_index=True)

            bin_centers = (tpm_bins[:-1] + tpm_bins[1:]) / 2
            tissue_binned_counts_per_gene = []
            for gene in genes:
                gene_df = combined_df[combined_df['Gene'] == gene]
                tissue_binned_counts = {tissue: np.zeros(len(bin_centers)) for tissue in tissue_types}
                for _, row in gene_df.iterrows():
                    bin_idx = np.digitize(row['TPM'], tpm_bins) - 1
                    if 0 <= bin_idx < len(bin_centers):
                        tissue_binned_counts[row['Tissue']][bin_idx] += 1
                tissue_binned_counts_per_gene.append(tissue_binned_counts)

            fig = plt.figure(figsize=(60, 16))
            fig.suptitle(f"3D TPM Distribution for {virus}", fontsize=40)
            gs = fig.add_gridspec(1, 4, width_ratios=[6, 6, 6, 1])

            # Main plot
            ax_main = fig.add_subplot(gs[0], projection='3d')
            tissues_in_main_plot = plot_3d_distribution(ax_main, tissue_binned_counts_per_gene, tpm_bins, tpm_bins_dim, genes, tissue_color_map, xlim=(0, np.max(tpm_bins)+10), title="Full TPM Distribution")

            # Zoomed plot 1 (0-1 TPM)
            ax_zoom1 = fig.add_subplot(gs[1], projection='3d')
            tissues_in_zoom1 = plot_3d_distribution(ax_zoom1, tissue_binned_counts_per_gene, tpm_bins, tpm_bins_dim, genes, tissue_color_map, xlim=(0, 1), title="Zoomed TPM Distribution (0-1)")

            # Zoomed plot 2 (>0.2 TPM)
            ax_zoom2 = fig.add_subplot(gs[2], projection='3d')
            tissues_in_zoom2 = plot_3d_distribution(ax_zoom2, tissue_binned_counts_per_gene, tpm_bins, tpm_bins_dim, genes, tissue_color_map, xlim=(0.2, np.max(tpm_bins)+10), title="TPM Distribution (>0.2)")

            # Combine tissues from all subplots
            all_tissues_in_plots = tissues_in_main_plot.union(tissues_in_zoom1).union(tissues_in_zoom2)

            # Legend
            ax_legend = fig.add_subplot(gs[3])
            ax_legend.axis('off')
            ax_legend.legend(handles=[plt.Rectangle((0, 0), 1, 1, color=tissue_color_map[tissue]) for tissue in all_tissues_in_plots], 
                             labels=all_tissues_in_plots, title="Tissues in Plot", fontsize=14, title_fontsize=20)

            #plt.show()
            pdf.savefig(fig)
            plt.close(fig)


main(META_DATA_DIR, GENERAL_DATA, all_viral_genes_heatmap_df, PLOT_DIR)

### Code to make a correlation between 2 matrices and "pairing" the cell lines (human genes matrix with chronos scores and viral gene matrix with TPM values over cancer cell lines)

In [None]:
# Import crisper cancer cell line data
f = h5py.File(GENERAL_DATA + "internal-24q2_v87-crisprgeneeffect")
data = sc.read_hdf(GENERAL_DATA + "internal-24q2_v87-crisprgeneeffect", key="data")
data.obs_names = [cl.decode("utf-8") for cl in f["dim_0"][()]]
data.var_names = [gene.decode("utf-8").split(" ")[0] for gene in f["dim_1"][()]]
human_df = data.to_df()

# Import viral annotation cancer cell line data
viral_df = pd.read_csv(f'{PLOT_DIR}all_viral_genes_heatmap_data.csv', index_col=0)

# Remove low TPMs as they are probably false positive counts by setting them to 0
viral_df[viral_df < 0.1] = 0

# Function to compute Pearson correlations between two different matrices
def pearsonr_with_nan_two_matrices(df_arr_viral, df_arr_human):
    """
    Compute the Pearson correlation between two matrices, ignoring NaNs.
    df_arr_viral: Viral genes matrix (cell lines x viral genes)
    df_arr_human: Human genes matrix (cell lines x human genes)
    Returns: Correlation matrix (viral genes x human genes), p-value matrix
    """
    num_viral_genes = df_arr_viral.shape[1]
    num_human_genes = df_arr_human.shape[1]
    
    # Initialize correlation and p-value matrices
    corrs_arr = np.empty((num_viral_genes, num_human_genes))
    corrs_arr[:] = np.nan
        
    # Loop through each pair of viral gene (from viral matrix) and human gene (from human matrix)
    for i in range(num_viral_genes):
        for j in range(num_human_genes):
            # Mask NaN values in either of the two columns (only discard where NaNs are found)
            # but then remove the same rows from both matrices to get "paired" correlation 
            # with always the same cell lines in the 2 matrices
            mask = ~np.isnan(df_arr_viral[:, i]) & ~np.isnan(df_arr_human[:, j])
            
            # Only compute if there are at least 2 non-NaN values
            if mask.sum() > 1:
                # Calculate Pearson correlation for non-NaN values
                corrs_arr[i, j] = np.corrcoef(df_arr_viral[mask, i], df_arr_human[mask, j])[0, 1]
            else:
                print(f'NA, {i}, {j}')
    
    # Calculate p-values using the beta distribution method (method taken from: https://stackoverflow.com/a/77628581/27306413)
    N = df_arr_viral.shape[0]  # N is the number of samples (cell lines)
    dist = scipy.stats.beta(N/2 - 1, N/2 - 1, loc=-1, scale=2)  # Initialize a beta distribution with shape parameters derived from the input matrix
    ab_m = -np.abs(corrs_arr)   # minus of absolute value of the correlation coefficients is used to ensures we take the tail probability
    P = 2 * dist.cdf(ab_m)      # Compute p-values for each correlation coefficient
                                # dist.cdf calculates the cumulative distribution function of the beta distribution
                                # Multiplying by 2 accounts for the two-tailed nature of the test    

    return corrs_arr, P

# Function to align matrices by common cell lines
def align_cell_lines(df1, df2):
    """
    Align two DataFrames by their row indices (cell lines).
    Returns: Two DataFrames with only the common cell lines in the same order.
    """
    common_cell_lines = df1.index.intersection(df2.index)
    df1_aligned = df1.loc[common_cell_lines]
    df2_aligned = df2.loc[common_cell_lines]
    
    return df1_aligned, df2_aligned


# Align both matrices by common cell lines
# to have only cell lines that are in both matrices and in the same order ("pairing")
viral_df_aligned, human_df_aligned = align_cell_lines(viral_df, human_df)

# Convert DataFrames to NumPy arrays
viral_matrix = viral_df_aligned.values
human_matrix = human_df_aligned.values

# Compute correlation and p-value matrices
corr_matrix, p_val_matrix = pearsonr_with_nan_two_matrices(viral_matrix, human_matrix)


In [None]:
print(viral_df.shape) # full viral matrix (cell lines, viral genes)
print(human_df.shape) # full human matrix (cell lines, human genes)
print(viral_df_aligned.shape) # aligned viral matrix (cell lines, viral genes) have only cell lines that are also in the human matrix (in the same cell line order)
print(human_df_aligned.shape) # aligned viral matrix (cell lines, viral genes) have only cell lines that are also in the viral matrix (in the same cell line order)

In [None]:
# Code to structure the correlation and write it to a file

def pyarrow_write_csv(csv_filename, data_matrix):
    """
    Use PyArrow to write large matrix data to a CSV file.
    """
    # Create a PyArrow Table from the list of lists
    table = pa.Table.from_pandas(data_matrix)

    # Write the table to a CSV file using PyArrow
    with pa.OSFile(csv_filename, 'wb') as f:
        csv.write_csv(table, f)


# Create a mask to filter out values where p_val is NaN
valid_mask = ~np.isnan(p_val_matrix.flatten())

# Apply false discovery rate control
p_val_adj_bh = scipy.stats.false_discovery_control(p_val_matrix.flatten()[valid_mask], method="bh")

# Compute medians, min, max for each human gene
human_gene_medians = human_df_aligned.median(axis=0)  
human_gene_min = human_df_aligned.min(axis=0)  
human_gene_max = human_df_aligned.max(axis=0)  

# Compute statistics for viral genes (non-zero)
non_zero_viral_gene_df = viral_df_aligned[viral_df_aligned > 0]
viral_gene_medians = non_zero_viral_gene_df.median(axis=0)
viral_gene_min = non_zero_viral_gene_df.min(axis=0)
viral_gene_max = non_zero_viral_gene_df.max(axis=0)


# Count how many cell lines have a expressed gene, only count it if we have a value for the chronos as well for this
# Get the valid masks for human, viral, and p-value data
valid_human_mask = ~np.isnan(human_df_aligned)
valid_viral_mask = viral_df_aligned > 1
valid_pval_mask = ~np.isnan(p_val_matrix)

# Reshape masks to be broadcastable
human_mask_expanded = np.expand_dims(valid_human_mask, axis=1)  # shape (cell_lines, 1, human_genes)
viral_mask_expanded = np.expand_dims(valid_viral_mask, axis=2)  # shape (cell_lines, viral_genes, 1)
p_val_expanded = np.expand_dims(valid_pval_mask, axis=0)        # shape (1, viral_genes, human_genes)

# Combine all masks using broadcasting
valid_combination_mask = human_mask_expanded & viral_mask_expanded & p_val_expanded  # shape (cell_lines, viral_genes, human_genes)

# Count valid combinations (i.e., valid cell lines per human-viral gene pair)
counts_per_gene_pair = valid_combination_mask.sum(axis=0)  # shape (viral_genes, human_genes)

# Convert the list to an array
viral_gene_counts = counts_per_gene_pair.flatten()[valid_mask]


# Mapping human gene statistics (using human_df_aligned.columns)
human_gene_median_values = np.array([
    human_gene_medians[human_df_aligned.columns.get_loc(gene)] 
    for gene in np.tile(human_df_aligned.columns, len(viral_df_aligned.columns))[valid_mask]
])

human_gene_min_values = np.array([
    human_gene_min[human_df_aligned.columns.get_loc(gene)] 
    for gene in np.tile(human_df_aligned.columns, len(viral_df_aligned.columns))[valid_mask]
])

human_gene_max_values = np.array([
    human_gene_max[human_df_aligned.columns.get_loc(gene)] 
    for gene in np.tile(human_df_aligned.columns, len(viral_df_aligned.columns))[valid_mask]
])

# Mapping viral gene statistics (using viral_df_aligned.columns)
viral_gene_median_values = np.array([
    viral_gene_medians[viral_df_aligned.columns.get_loc(gene)] 
    for gene in np.repeat(viral_df_aligned.columns, len(human_df_aligned.columns))[valid_mask]
])

viral_gene_min_values = np.array([
    viral_gene_min[viral_df_aligned.columns.get_loc(gene)] 
    for gene in np.repeat(viral_df_aligned.columns, len(human_df_aligned.columns))[valid_mask]
])

viral_gene_max_values = np.array([
    viral_gene_max[viral_df_aligned.columns.get_loc(gene)] 
    for gene in np.repeat(viral_df_aligned.columns, len(human_df_aligned.columns))[valid_mask]
])


# Prepare long format data for p_val and corr
data_long = [
    np.tile(human_df_aligned.columns, len(viral_df_aligned.columns))[valid_mask],  # gene_1 column repeated for each gene_2
    np.repeat(viral_df_aligned.columns, len(human_df_aligned.columns))[valid_mask],  # gene_2 column tiled for each gene_1
    corr_matrix.flatten()[valid_mask],  # Formatted corr column
    p_val_matrix.flatten()[valid_mask],  # Formatted p_val column
    p_val_adj_bh,  # Formatted adjusted p-value column
    human_gene_median_values,
    human_gene_min_values,
    human_gene_max_values,
    viral_gene_median_values,
    viral_gene_min_values,
    viral_gene_max_values,
    viral_gene_counts
]

# Define column names for the long format dataset
long_column_names = [
    'human_gene', 'viral_gene', 
    'corr', 'p_val', 'p_val_adj_bh',
    'chronos_median', 'chronos_min', 'chronos_max', 
    'TPM_median', 'TPM_min', 'TPM_max', 'count_celllines_with_gene'
]

# Create a DataFrame for the long format data
df_long = pd.DataFrame(data={name: column for name, column in zip(long_column_names, data_long)})
pyarrow_write_csv(f'{PLOT_DIR}human-gene-to-viral-gene-correlation_long_df_low_TPM_removed.csv', df_long)

# Filter the dataframe to only get significant values and make it easier to handel later on
df_filtered = df_long[df_long['p_val_adj_bh'] < 0.05]
pyarrow_write_csv(f"{PLOT_DIR}human-gene-to-viral-gene-correlation_long_df_adj-pval-smaller-0.05_low_TPM_removed.csv", df_filtered)
    

In [None]:
df_long = pd.read_csv(f"{PLOT_DIR}human-gene-to-viral-gene-correlation_long_df_low_TPM_removed.csv")
df_filtered = pd.read_csv(f"{PLOT_DIR}human-gene-to-viral-gene-correlation_long_df_adj-pval-smaller-0.05_low_TPM_removed.csv")

In [None]:
# Code to count the amount of gene to virus occurences

# Load the metadata for viral genes
metadata_df = pd.read_csv(f"{META_DATA_DIR}viral_gene_metadata.csv")

# Map each gene to its corresponding virus
gene_to_virus = {}
for _, row in metadata_df.iterrows():
    gene_to_virus[row['Gene_name'] + " (" + row['Virus_name'] + ")" ] = row['Virus_name']

df_long_tophits = df_filtered.loc[df_filtered['p_val_adj_bh'] < 1e-10] # Filter the top hits
df_long_tophits['virus'] = df_long_tophits['viral_gene'].map(gene_to_virus) # Map the gene to its corresponding virus

# Count how often certain human genes appear in each virus
counts = df_long_tophits.groupby(['human_gene', 'virus']).size().reset_index(name='count')
counts.to_csv(f'{PLOT_DIR}gene_counts_by_virus_low_TPM_removed.csv', index=False)

In [None]:
long_format_df = df_long[(df_long['p_val_adj_bh'] < 0.05) &
                         (df_long['chronos_min'] < -0.5)] # & 
                         #(df_long['TPM_median'] > 0.1)]

pyarrow_write_csv(f"{PLOT_DIR}human-gene-to-viral-gene-correlation_long_df_adj-pval-smaller-0.05_min-chronos-smaller-neg-0.5.csv", long_format_df)

# Sort the data to get the top 10 points based on significance (-log10 p-value)
top_10 = long_format_df.nsmallest(25, "p_val_adj_bh")

plt.figure(figsize=(15,15))
plt.scatter(long_format_df['corr'], -np.log10(long_format_df['p_val_adj_bh']), alpha=0.7)
plt.axhline(y=-np.log10(0.05), color='r', linestyle='--', label="Significance threshold (adj. p-value = 0.05)")

texts = []
# Mark and highlight a specific point where col1 == 'a' and col2 == 'b'
specific_point = long_format_df[(long_format_df['human_gene'] == 'UBE3A') & 
                                (long_format_df['viral_gene'] == 'E6 (HPV16)')]
if not specific_point.empty:
    plt.scatter(specific_point['corr'], -np.log10(specific_point['p_val_adj_bh']), color='red', s=100)
    # Add a label for the specific point to the texts list
    specific_label = f"{specific_point['human_gene'].values[0]}, {specific_point['viral_gene'].values[0]}"
    texts = [plt.text(specific_point['corr'], -np.log10(specific_point['p_val_adj_bh']), specific_label, fontsize=8, color='red')]

# Prepare labels for the top 10 points
for i, row in top_10.iterrows():
    label = f"{row['human_gene']}, {row['viral_gene']}"  # Assuming col1 and col2 are the columns to combine
    texts.append(plt.text(row['corr'], -np.log10(row['p_val_adj_bh']), label, fontsize=8))

# Adjust text positions to prevent overlaps
adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))

plt.title('Correlation of CRIPSR screens to viral gene annotations', fontsize=14)
plt.xlabel('Correlation', fontsize=12)
plt.ylabel('-log10(adj. p-value)', fontsize=12)
plt.legend()
plt.savefig(f'{PLOT_DIR}chronos_to_tpm_correlation_over_2k_celllines_low_TPM_removed.svg', format='svg', bbox_inches='tight')
plt.show()

In [None]:
tophits_df = df_filtered.nsmallest(100, 'p_val_adj_bh')

pdf_filename = f"{PLOT_DIR}tophits-human-gene-to-viral-gene-correlations.pdf"
with PdfPages(pdf_filename) as pdf:
    for index, row in tophits_df.iterrows():
        viral_gene = row['viral_gene']
        human_gene = row['human_gene']
        
        fig = plt.figure(figsize=(10, 10))
        plt.scatter(human_df_aligned[human_gene], viral_df_aligned[viral_gene], alpha=0.7)
        plt.title(f'Correlation of {human_gene} Chronos Scores to {viral_gene} TPM', fontsize=14)
        plt.xlabel('Chronos Score', fontsize=12)
        plt.ylabel('TPM', fontsize=12)        
        plt.axvline(0, c="black")
        plt.axhline(0, c="black")
        
        pdf.savefig(fig)
        plt.close(fig)

In [None]:
long_format_df = df_long[(df_long['p_val_adj_bh'] < 0.05) &
                         (df_long['chronos_min'] < -0.5) & 
                         (df_long['count_celllines_with_gene'] > 7)]

pyarrow_write_csv(f"{PLOT_DIR}human-gene-to-viral-gene-correlation_long_df_adj-pval-smaller-0.05_2-celllines-TPM-bigger-0.csv", long_format_df)

# Sort the data to get the top 10 points based on significance (-log10 p-value)
top_10 = long_format_df.nsmallest(25, "p_val_adj_bh")

plt.figure(figsize=(15,15))
plt.scatter(long_format_df['corr'], -np.log10(long_format_df['p_val_adj_bh']), alpha=0.7)
plt.axhline(y=-np.log10(0.05), color='r', linestyle='--', label="Significance threshold (adj. p-value = 0.05)")

texts = []
# Mark and highlight a specific point where col1 == 'a' and col2 == 'b'
specific_point = long_format_df[(long_format_df['human_gene'] == 'UBE3A') & 
                                (long_format_df['viral_gene'] == 'E6 (HPV16)')]
if not specific_point.empty:
    plt.scatter(specific_point['corr'], -np.log10(specific_point['p_val_adj_bh']), color='red', s=100)
    # Add a label for the specific point to the texts list
    specific_label = f"{specific_point['human_gene'].values[0]}, {specific_point['viral_gene'].values[0]}"
    texts = [plt.text(specific_point['corr'], -np.log10(specific_point['p_val_adj_bh']), specific_label, fontsize=8, color='red')]

# Prepare labels for the top 10 points
for i, row in top_10.iterrows():
    label = f"{row['human_gene']}, {row['viral_gene']}"  # Assuming col1 and col2 are the columns to combine
    texts.append(plt.text(row['corr'], -np.log10(row['p_val_adj_bh']), label, fontsize=8))

# Adjust text positions to prevent overlaps
adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))

plt.title('Correlation of CRIPSR screens to viral gene annotations', fontsize=14)
plt.xlabel('Correlation', fontsize=12)
plt.ylabel('-log10(adj. p-value)', fontsize=12)
plt.legend()
plt.savefig(f'{PLOT_DIR}chronos_to_tpm_correlation_over_2k_celllines_low_TPM_removed_sparse_celllines_removed.svg', format='svg', bbox_inches='tight')
plt.show()

In [None]:
# Code to plot the Correlation of the top 100 combinations to a pdf

tophits_df = long_format_df.nsmallest(100, 'p_val_adj_bh')

pdf_filename = f"{PLOT_DIR}tophits-human-gene-to-viral-gene-correlations_sparse_celllines_removed.pdf"
with PdfPages(pdf_filename) as pdf:
    for index, row in tophits_df.iterrows():
        viral_gene = row['viral_gene']
        human_gene = row['human_gene']
        
        fig = plt.figure(figsize=(10, 10))
        plt.scatter(human_df_aligned[human_gene], viral_df_aligned[viral_gene], alpha=0.7)
        plt.title(f'Correlation of {human_gene} Chronos Scores to {viral_gene} TPM', fontsize=14)
        plt.xlabel('Chronos Score', fontsize=12)
        plt.ylabel('TPM', fontsize=12)        
        plt.axvline(0, c="black")
        plt.axhline(0, c="black")
        
        pdf.savefig(fig)
        plt.close(fig)

### Try to find tissue specific clustering gene to gene combinations, check all possible combinations that are mentioned in the volcanos above

In [None]:
long_format_df_filtered = df_long[(df_long['p_val_adj_bh'] < 0.005) &
                         (df_long['chronos_min'] < -0.5) & 
                         (df_long['count_celllines_with_gene'] > 7) & 
                         (df_long['corr'] < 0)]

# Start or connect to the ipyparallel cluster
rc = ipp.Cluster(n=8).start_and_connect_sync()
dview = rc[:]

# Load dataframes
metadata_df = pd.read_csv(f"{META_DATA_DIR}viral_gene_metadata.csv")
modeldata_df = pd.read_csv(f"{GENERAL_DATA}internal-24q2_v87-model.csv")

# Map genes to viruses
gene_to_virus = {}
for _, row in metadata_df.iterrows():
    gene_to_virus[row['Gene_name'] + " (" + row['Virus_name'] + ")" ] = row['Virus_name']

# Initialize a dictionary to hold unique human genes for each virus
virus_to_human_genes = {}

# Iterate through long_format_df to map viral genes to their corresponding viruses
for _, row in long_format_df_filtered.iterrows():
    viral_gene = row['viral_gene']    

    # Map the gene to its corresponding virus
    virus_name = gene_to_virus[viral_gene]
    
    if virus_name is not None:
        if virus_name not in virus_to_human_genes:
            virus_to_human_genes[virus_name] = set()
        virus_to_human_genes[virus_name].add(row['human_gene'])

# Push variables that workers need to all engines
dview.push({
    'human_df_aligned': human_df_aligned,
    'modeldata_df': modeldata_df,
    'positive_cell_lines': positive_cell_lines,
    'gene_to_virus': gene_to_virus,
    'virus_to_human_genes': virus_to_human_genes
})

# Function to perform statistical analysis for a single combination of genes and tissue
def analyze_genes(combo):
    import pandas as pd
    from scipy import stats

    virus_name, gene1, gene2, tissue = combo
    results = {}

    # Check if genes are present in the DataFrame
    if gene1 in human_df_aligned.columns and gene2 in human_df_aligned.columns:
        # Get data for the current combination
        gene_data = human_df_aligned[[gene1, gene2]].copy()

        # Get tissue information for positive and negative cell lines
        tissue_info = {}
        positive_lines = positive_cell_lines.get(virus_name, [])
        for cell_line in positive_lines:
            oncotree_lineage = modeldata_df.loc[modeldata_df['ModelID'] == cell_line, 'OncotreeLineage']
            tissue_type = oncotree_lineage.values[0] if not oncotree_lineage.empty else None
            tissue_info[cell_line] = tissue_type

        all_cell_lines = modeldata_df['ModelID'].unique()
        negative_lines = set(all_cell_lines) - set(positive_lines)
        tissue_negative_info = {}
        for cell_line in negative_lines:
            oncotree_lineage = modeldata_df.loc[modeldata_df['ModelID'] == cell_line, 'OncotreeLineage']
            tissue_type = oncotree_lineage.values[0] if not oncotree_lineage.empty else None
            tissue_negative_info[cell_line] = tissue_type

        # Combine tissue information
        tissue_info_series = pd.Series(tissue_info)
        tissue_negative_info_series = pd.Series(tissue_negative_info)
        gene_data['tissue_type'] = tissue_info_series.combine_first(tissue_negative_info_series)

        # Filter for the current tissue type
        tissue_positive_df = gene_data[(gene_data['tissue_type'] == tissue) & (gene_data.index.isin(positive_lines))]
        tissue_negative_df = gene_data[(gene_data['tissue_type'] == tissue) & (gene_data.index.isin(negative_lines))]

        # Perform statistical analysis only if both groups have at least 3 cell lines
        if tissue_positive_df.shape[0] > 2 and tissue_negative_df.shape[0] > 2:
            # Perform t-tests for both genes
            t_stat_gene1, p_value_gene1 = stats.ttest_ind(tissue_positive_df[gene1], tissue_negative_df[gene1], nan_policy='omit')
            t_stat_gene2, p_value_gene2 = stats.ttest_ind(tissue_positive_df[gene2], tissue_negative_df[gene2], nan_policy='omit')

            # Calculate means
            mean_positive = tissue_positive_df[[gene1, gene2]].mean()
            mean_negative = tissue_negative_df[[gene1, gene2]].mean()

            # Store results
            results = {
                'tissue': tissue,
                'virus': virus_name,
                'gene1': gene1,
                'gene2': gene2,
                't_stat_gene1': t_stat_gene1,
                'p_value_gene1': p_value_gene1,
                't_stat_gene2': t_stat_gene2,
                'p_value_gene2': p_value_gene2,
                'mean_positive_gene1': mean_positive[gene1],
                'mean_negative_gene1': mean_negative[gene1],
                'mean_positive_gene2': mean_positive[gene2],
                'mean_negative_gene2': mean_negative[gene2],
            }
    
    return results

# Prepare combinations for parallel processing
all_results = []

# Iterate through each virus and generate combinations of unique human genes
for virus_name, human_genes in virus_to_human_genes.items():
    combinations = itertools.combinations(human_genes, 2)
    for gene1, gene2 in combinations:
        # Get tissue types from modeldata_df
        for tissue in virus_type_counts_df.columns:
            if (virus_type_counts_df.loc[virus_name, tissue] > 2):
                all_results.append((virus_name, gene1, gene2, tissue))

# Use parallelization to analyze genes
results = dview.map_sync(analyze_genes, all_results)

# List of results with empty dictionaries filtered out
filtered_results = [res for res in results if res]  # Remove empty dictionaries

# Convert the filtered list to a DataFrame
results_df = pd.DataFrame(filtered_results)

# Fill NaN values (if you want to handle missing values this way)
results_df.fillna(value=pd.NA, inplace=True)

# Save to a CSV file 
results_df.to_csv(f'{PLOT_DIR}all_results_statistical_difference_test.csv', index=False)


# Filter results based on significance (e.g., p-value < 0.05)
significant_results_df = results_df[(results_df['p_value_gene1'] < 0.05) | (results_df['p_value_gene2'] < 0.05)]
significant_results_df.to_csv(f'{PLOT_DIR}significant_results_statistical_difference_one_gene.csv')

more_significant_results_df = results_df[(results_df['p_value_gene1'] < 0.05) & (results_df['p_value_gene2'] < 0.05)]
more_significant_results_df.to_csv(f'{PLOT_DIR}significant_results_statistical_difference_both_gene.csv')

really_significant_results_df = results_df[(results_df['p_value_gene1'] < 1E-5) & (results_df['p_value_gene2'] < 1E-5)]
really_significant_results_df.to_csv(f'{PLOT_DIR}really_significant_results_statistical_difference_both_gene.csv')

rc.shutdown() # shutdown cluster

In [None]:
# Sort the data to get the top 10 points based on significance (-log10 p-value)
significant_results_df['p_value_multipy'] = significant_results_df['p_value_gene1'] * significant_results_df['p_value_gene2']
top_10 = significant_results_df.nsmallest(10, "p_value_multipy")

unique_tissues = significant_results_df['tissue'].unique()
tissue_colors = dict(zip(unique_tissues, plt.cm.get_cmap('tab20', len(unique_tissues)).colors))  # Use 'tab10' colormap
colors = significant_results_df['tissue'].map(tissue_colors)

plt.figure(figsize=(15,15))
plt.scatter(-np.log10(significant_results_df['p_value_gene1']), -np.log10(significant_results_df['p_value_gene2']), alpha=0.7, c=colors)
plt.axhline(y=-np.log10(0.05), color='r', linestyle='--', label="Significance threshold (adj. p-value = 0.05)")
plt.axvline(x=-np.log10(0.05), color='r', linestyle='--')

texts = []

# Prepare labels for the top 10 points
for i, row in top_10.iterrows():
    label = f"{row['tissue']}, {row['virus']}, {row['gene1']}-{row['gene2']}"  # Assuming col1 and col2 are the columns to combine
    texts.append(plt.text(-np.log10(row['p_value_gene1']), -np.log10(row['p_value_gene2']), label, fontsize=8))

# Adjust text positions to prevent overlaps
adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))

# Create a legend for the tissue colors
handles = [plt.Line2D([0], [0], marker='o', color='w', label=tissue, markerfacecolor=color, markersize=10) 
           for tissue, color in tissue_colors.items()]
plt.legend(handles=handles, title="Tissues")

plt.title('Diff. of Chronos Scores of viral pos. and neg. cell lines per tissue', fontsize=14)
plt.xlabel('-log10(p-value gene1)', fontsize=12)
plt.ylabel('-log10(p-value gene2)', fontsize=12)
plt.legend()
plt.savefig(f'{PLOT_DIR}chronos_diff_pos-neg-tissue.svg', format='svg', bbox_inches='tight')
plt.show()

### Try to find virus specific gene that show a difference between viral positive and negative cell lines

In [None]:
long_format_df_filtered = df_long[(df_long['p_val_adj_bh'] < 0.005) &
                         (df_long['chronos_min'] < -0.5) & 
                         (df_long['count_celllines_with_gene'] > 7) & 
                         (df_long['corr'] < 0)]

# Start or connect to the ipyparallel cluster
rc = ipp.Cluster(n=8).start_and_connect_sync()
dview = rc[:]

# Load dataframes
metadata_df = pd.read_csv(f"{META_DATA_DIR}viral_gene_metadata.csv")
modeldata_df = pd.read_csv(f"{GENERAL_DATA}internal-24q2_v87-model.csv")

# Map genes to viruses
gene_to_virus = {}
for _, row in metadata_df.iterrows():
    gene_to_virus[row['Gene_name'] + " (" + row['Virus_name'] + ")" ] = row['Virus_name']

# Initialize a dictionary to hold unique human genes for each virus
virus_to_human_genes = {}

# Iterate through long_format_df to map viral genes to their corresponding viruses
for _, row in long_format_df_filtered.iterrows():
    viral_gene = row['viral_gene']    

    # Map the gene to its corresponding virus
    virus_name = gene_to_virus[viral_gene]
    
    if virus_name is not None:
        if virus_name not in virus_to_human_genes:
            virus_to_human_genes[virus_name] = set()
        virus_to_human_genes[virus_name].add(row['human_gene'])

# Push variables that workers need to all engines
dview.push({
    'human_df_aligned': human_df_aligned,
    'modeldata_df': modeldata_df,
    'positive_cell_lines': positive_cell_lines,
    'gene_to_virus': gene_to_virus,
    'virus_to_human_genes': virus_to_human_genes
})

# Function to perform statistical analysis for each gene and tissue
def analyze_gene(combo):
    import pandas as pd
    from scipy import stats

    virus_name, gene, tissue = combo
    results = {}

    if gene in human_df_aligned.columns:
        # Get gene expression data
        gene_data = human_df_aligned[gene].copy()

        # Get tissue information for positive and negative cell lines
        positive_lines = positive_cell_lines.get(virus_name, [])
        negative_lines = list(set(modeldata_df['ModelID'].unique()) - set(positive_lines))

        # Filter out cell lines that are not in gene_data
        positive_lines_filtered = [line for line in positive_lines if line in gene_data.index]
        negative_lines_filtered = [line for line in negative_lines if line in gene_data.index]

        tissue_info = modeldata_df.set_index('ModelID')['OncotreeLineage'].to_dict()
        
        # General analysis over all cells
        if len(positive_lines_filtered) >= 3 and len(negative_lines_filtered) >= 3:
            t_stat_general, p_value_general = stats.ttest_ind(
                gene_data[positive_lines_filtered], 
                gene_data[negative_lines_filtered], 
                nan_policy='omit'
            )
            
            mean_positive_general = gene_data[positive_lines_filtered].mean()
            mean_negative_general = gene_data[negative_lines_filtered].mean()

        else:
            return None

        # Tissue-specific analysis
        tissue_positive_df = gene_data.loc[[cell for cell in positive_lines_filtered if tissue_info.get(cell) == tissue]]
        tissue_negative_df = gene_data.loc[[cell for cell in negative_lines_filtered if tissue_info.get(cell) == tissue]]

        if len(tissue_positive_df) >= 3 and len(tissue_negative_df) >= 3:
            t_stat_tissue_spec, p_value_tissue_spec = stats.ttest_ind(
                tissue_positive_df, tissue_negative_df, nan_policy='omit'
            )
            mean_positive = tissue_positive_df.mean()
            mean_negative = tissue_negative_df.mean()
        else:
            t_stat_tissue_spec, p_value_tissue_spec, mean_positive, mean_negative = None, None, None, None

        # Store results
        results = {
            'virus': virus_name,
            'gene': gene,
            'tissue': tissue,
            't_stat_general': t_stat_general,
            'p_value_general': p_value_general,
            't_stat_tissue_spec': t_stat_tissue_spec,
            'p_value_tissue_spec': p_value_tissue_spec,
            'mean_positive': mean_positive_general,
            'mean_negative': mean_negative_general,
            'mean_positive_tissue': mean_positive,
            'mean_negative_tissue': mean_negative
        }
    
    return results

# Prepare combinations for parallel processing
all_results = []

# Iterate through each virus and generate combinations of virus, gene, and tissue
for virus_name, human_genes in virus_to_human_genes.items():
    for gene in human_genes:
        for tissue in modeldata_df['OncotreeLineage'].unique():
            all_results.append((virus_name, gene, tissue))

# Use parallelization to analyze genes
results = dview.map_sync(analyze_gene, all_results)
        
# Filter results
filtered_results = [res for res in results if res]

# Convert to DataFrame
results_df = pd.DataFrame(filtered_results)

# Save results to CSV
results_df.to_csv(f'{PLOT_DIR}/virus_specific_human_gene_change_analysis.csv', index=False)

# Filter significant results
significant_results_df = results_df[results_df['p_value_general'] < 0.05]
significant_results_df.to_csv(f'{PLOT_DIR}/significant_virus_specific_human_gene_change_analysis.csv', index=False)

more_significant_results_df = results_df[(results_df['p_value_general'] < 0.05) & (results_df['p_value_tissue_spec'] < 0.05)]
more_significant_results_df.to_csv(f'{PLOT_DIR}significant_virus_and_tissue_specific_human_gene_change_analysis.csv')

really_significant_results_df = results_df[(results_df['p_value_general'] < 1E-5) & (results_df['p_value_tissue_spec'] < 1E-5)]
really_significant_results_df.to_csv(f'{PLOT_DIR}really_significant_virus_and_tissue_specific_human_gene_change_analysis.csv')

rc.shutdown() # shutdown cluster

In [None]:
# Sort the data to get the top 10 points based on significance (-log10 p-value)
significant_results_df['p_value_multipy'] = significant_results_df['p_value_general'] * significant_results_df['p_value_tissue_spec']
top_10 = significant_results_df.nsmallest(30, "p_value_multipy")

unique_tissues = significant_results_df['tissue'].unique()
tissue_colors = dict(zip(unique_tissues, plt.cm.get_cmap('tab20', len(unique_tissues)).colors))  # Use 'tab10' colormap
colors = significant_results_df['tissue'].map(tissue_colors)

plt.figure(figsize=(15,15))
plt.scatter(-np.log10(significant_results_df['p_value_general']), -np.log10(significant_results_df['p_value_tissue_spec']), alpha=0.7, c=colors)
plt.axhline(y=-np.log10(0.05), color='r', linestyle='--', label="Significance threshold (adj. p-value = 0.05)")
plt.axvline(x=-np.log10(0.05), color='r', linestyle='--')

texts = []

# Prepare labels for the top 10 points
for i, row in top_10.iterrows():
    label = f"{row['tissue']}, {row['virus']}, {row['gene']}"  # Assuming col1 and col2 are the columns to combine
    texts.append(plt.text(-np.log10(row['p_value_general']), -np.log10(row['p_value_tissue_spec']), label, fontsize=8))

# Adjust text positions to prevent overlaps
adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))

# Filter for tissues represented in the plot
tissues_in_plot = significant_results_df['tissue'].unique()
handles = [plt.Line2D([0], [0], marker='o', color='w', label=tissue, markerfacecolor=tissue_colors[tissue], markersize=10) 
           for tissue in tissues_in_plot]

# Place the legend outside of the plot
plt.legend(handles=handles, title="Tissues", bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

plt.title('Diff. of Chronos Scores of viral pos. and neg. cell lines per tissue', fontsize=14)
plt.xlabel('-log10(p-value gene general)', fontsize=12)
plt.ylabel('-log10(p-value gene tissue spec.)', fontsize=12)
plt.savefig(f'{PLOT_DIR}chronos_diff_pos-neg-tissue.svg', format='svg', bbox_inches='tight')
plt.show()