In [4]:
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import KFold, LeaveOneOut,GridSearchCV
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import cross_val_score, KFold
from sklearn.metrics import mean_squared_error
from joblib import dump
import matplotlib.pyplot as plt
import seaborn as sns
import xgboost as xgb
import os
from statistics import mean, stdev
import re
from sklearn.pipeline import make_pipeline
import warnings
warnings.filterwarnings('ignore')
from sklearn.preprocessing import LabelEncoder

  from .autonotebook import tqdm as notebook_tqdm


In [None]:

def mean_df(resdat):
    dat = pd.DataFrame(columns=['Cancer' , 'actual_mean',  'actual_std',  'predicted_mean',  'predicted_std'])
    print(resdat)
    drugs = resdat['Drug'].unique()
    for drug in drugs:
        # Filter data for the current drug
        drug_data = resdat[resdat['Drug'] == drug]
        # Group by 'Cancer' and calculate mean and std
        summary = drug_data.groupby('Cancer').agg(
            actual_mean=('Actual', 'mean'),
            actual_std=('Actual', 'std'),
            predicted_mean=('Predicted', 'mean'),
            predicted_std=('Predicted', 'std')
        ).reset_index()
        dat = pd.concat([dat,summary])
    return dat
        

In [18]:
def plot_mean_std(data):
    # Ensure 'Cancer', 'Actual', 'Predicted', and 'Drug' columns are available
    required_columns = ['Cancer', 'Actual', 'Predicted', 'Drug']
    if not all(col in data.columns for col in required_columns):
        raise ValueError(f"Data must contain columns: {', '.join(required_columns)}")
    
    # Check if the 'results/plots/' directory exists, create it if it doesn't
    import os
    output_dir = 'results/plots/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Get the unique drugs
    drugs = data['Drug'].unique()
    
    for drug in drugs:
        # Filter data for the current drug
        drug_data = data[data['Drug'] == drug]
        
        if drug_data.empty:
            print(f"No data for drug: {drug}")
            continue
        
        # Group by 'Cancer' and calculate mean and std
        summary = drug_data.groupby('Cancer').agg(
            actual_mean=('Actual', 'mean'),
            actual_std=('Actual', 'std'),
            predicted_mean=('Predicted', 'mean'),
            predicted_std=('Predicted', 'std')
        ).reset_index()
        
        summary['mean_diff'] = abs(summary['actual_mean'] - summary['predicted_mean'])
        avg_mean_diff = summary['mean_diff'].mean()
        
        # Plot
        plt.figure(figsize=(14, 8))
        
        # Plot means with error bars
        plt.errorbar(summary['Cancer'], summary['actual_mean'], yerr=summary['actual_std'], fmt='o', label='Actual Mean ± Std', color='blue', capsize=5)
        plt.errorbar(summary['Cancer'], summary['predicted_mean'], yerr=summary['predicted_std'], fmt='o', label='Predicted Mean ± Std', color='red', capsize=5)
        
        # Customize plot
        plt.title(f'Drug: {drug}; Mean difference: {avg_mean_diff:.2f}')
        plt.xlabel('Cancer Type')
        plt.ylabel('Drug Sensitivity')
        plt.xticks(rotation=45, ha='right')
        plt.legend()
        
        # Save plot
        plt_path = os.path.join(output_dir, f'{drug}_res_plot.png')
        plt.savefig(plt_path, bbox_inches='tight')
        plt.close()

In [33]:
def plot_cor(resdat):
    # Calculate the correlation between 'Actual' and 'Predicted'
    correlation = resdat[['Actual', 'Predicted']].corr().iloc[0, 1]
    data = mean_df(resdat)
    data.columns = ['Cancer', 'Actual', 'Actual_std', 'Predicted', 'Predicted_std']

    # Create the scatter plot
    plt.figure(figsize=(14, 8))
    
    # Get unique cancer types and assign colors
    cancer_types = data['Cancer'].unique()
    palette = sns.color_palette('Set1', n_colors=len(cancer_types))
    color_mapping = {cancer: palette[i] for i, cancer in enumerate(cancer_types)}
    
    # Create a scatter plot with different colors for cancer types
    scatter = sns.scatterplot(
        x='Actual', 
        y='Predicted', 
        data=data, 
        hue='Cancer',        # Color by cancer type
        palette=color_mapping, # Use custom color mapping
        s=100, 
        edgecolor='w', 
        alpha=0.7
    )

    # Add a line of equality
    max_val = max(data['Actual'].max(), data['Predicted'].max())
    min_val = min(data['Actual'].min(), data['Predicted'].min())
    plt.plot([min_val, max_val], [min_val, max_val], color='grey', linestyle='--', label='Perfect Prediction')

    # Customize plot
    plt.title(f'Drug Sensitivity by Cancer Type (Correlation: {correlation:.4f})')
    plt.xlabel('Actual Drug Sensitivity')
    plt.ylabel('Predicted Drug Sensitivity')

    # Adjust legend to show only cancer types
    plt.legend(title='Cancer Type', loc='center left', bbox_to_anchor=(1, 0.5), ncol=1, frameon=False)
    plt.show()
    # Save plot
    plt_path = 'results/plots/resdat_plot.png'
    plt.savefig(plt_path, bbox_inches='tight')
   

In [37]:
resdat = pd.read_csv("/home/surabhi/Documents/PatStrat-Personalized-Health-Technologies-Conference-2024/results/logs/all_results.csv", index_col=0)
resdat

Unnamed: 0,Actual,Predicted,Cancer,Drug
0,5.064569,4.603163,ESCA,WZ4003
1,4.809691,3.642198,STAD,WZ4003
2,4.331533,4.993625,PAAD,WZ4003
3,4.792402,4.464000,HNSC,WZ4003
4,4.891858,4.264158,LUAD,WZ4003
...,...,...,...,...
332,4.383515,4.140968,COREAD,XMD11_85h
333,4.113227,3.694937,SKCM,XMD11_85h
334,4.532307,3.681673,SKCM,XMD11_85h
335,2.525772,2.653384,ALL,XMD11_85h


In [5]:
plot_mean_std(resdat)

In [13]:
dat

Unnamed: 0,Cancer,actual_mean,actual_std,predicted_mean,predicted_std
0,ALL,2.872412,0.980282,3.035061,0.381278
1,BLCA,4.268171,1.377704,4.067990,0.226710
2,BRCA,4.795914,1.085439,4.858865,0.171860
3,COREAD,4.721900,1.219109,4.666740,0.195051
4,DLBC,3.592602,1.182986,3.629573,0.217063
...,...,...,...,...,...
13,OV,3.840038,0.828701,3.782971,0.066907
14,PAAD,4.125338,0.805505,3.865344,0.036610
15,SCLC,4.154941,1.059941,4.146363,0.112943
16,SKCM,3.774774,0.897972,3.716462,0.057214


In [36]:
plot_cor(dat)

KeyError: 'Drug'