In [2]:
# Modified imports with different arrangement
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import mysql.connector
import pymysql
import seaborn as sns
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.feature_extraction.text import TfidfVectorizer
import scipy.stats as stats
import warnings
import os
import logging
from dataclasses import dataclass
from typing import Dict, List, Set, Tuple, Optional, Any
from dotenv import load_dotenv

# Suppress warnings
warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('Data Pipelining, Analysis and Visualization')

In [3]:
import pymysql
import pandas as pd
from typing import List, Dict, Any, Optional, Union, Tuple
import logging

# Setup logging
logging.basicConfig(level=logging.INFO, 
                   format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('DataPipelineManager')

class DataPipelineManager:
    """
    A class to manage data pipelines, handle MySQL database connections,
    and provide methods for data manipulation and storage.
    """
    
    def __init__(self, host: str = "127.0.0.1", user: str = "root", 
                 password: str = "", database: str = "data_analytics",
                 port: int = 3306):
        """
        Initialize the DataPipelineManager with a MySQL database connection.
        
        Args:
            host: MySQL server hostname
            user: MySQL username
            password: MySQL password
            database: MySQL database name
            port: MySQL server port
        """
        self.connection_params = {
            'host': host,
            'user': user,
            'password': password,
            'database': database,
            'port': port
        }
        self.logger = logger
        self.conn = self._create_connection()
        self.logger.info(f"Initialized DataPipelineManager with MySQL database at {host}:{port}/{database}")
    
    def _create_connection(self) -> pymysql.connections.Connection:
        """Create and return a MySQL database connection."""
        try:
            # Try to connect to the database
            conn = pymysql.connect(**self.connection_params)
            return conn
        except pymysql.err.OperationalError as e:
            error_code = e.args[0]
            # Check if the error is due to missing database (error code 1049)
            if error_code == 1049:  # Unknown database
                try:
                    # Connect without specifying database
                    temp_params = self.connection_params.copy()
                    database_name = temp_params.pop('database')
                    temp_conn = pymysql.connect(**temp_params)
                    
                    # Create the database
                    cursor = temp_conn.cursor()
                    cursor.execute(f"CREATE DATABASE IF NOT EXISTS {database_name}")
                    cursor.close()
                    temp_conn.close()
                    
                    # Try connecting again
                    return pymysql.connect(**self.connection_params)
                except pymysql.Error as create_error:
                    raise ConnectionError(f"Failed to create database: {create_error}")
            else:
                raise ConnectionError(f"Database connection error: {e}")
    
    def _clean_data(self, data: List[Dict[str, Any]], columns: List[str]) -> List[Tuple]:
        """
        Clean and prepare data for database insertion.
        
        Args:
            data: List of dictionaries containing data
            columns: List of column names to extract from data
            
        Returns:
            List of tuples ready for database insertion
        """
        cleaned_data = []
        for row in data:
            # Extract only the specified columns
            cleaned_row = tuple(row.get(col, None) for col in columns)
            cleaned_data.append(cleaned_row)
        
        self.logger.info(f"Cleaned {len(cleaned_data)} rows of data")
        return cleaned_data
    
    def is_connected(self):
        """Check if the connection is active."""
        try:
            self.conn.ping(reconnect=False)
            return True
        except:
            return False
    
    def save_data(self, table_name: str, columns: List[str], data: List[Dict[str, Any]]) -> bool:
        """
        Save data to a MySQL database table. Create the table if it doesn't exist.
        
        Args:
            table_name: Name of the table to save data to
            columns: List of column names
            data: List of dictionaries containing data
            
        Returns:
            Boolean indicating success
        """
        try:
            # Reconnect if connection is closed
            if not self.is_connected():
                self.conn = self._create_connection()
            
            cursor = self.conn.cursor()
            
            # Create table if it doesn't exist
            columns_with_types = [f"`{col}` VARCHAR(255)" for col in columns]
            create_table_query = f"""
            CREATE TABLE IF NOT EXISTS `{table_name}` (
                `id` INT AUTO_INCREMENT PRIMARY KEY,
                {', '.join(columns_with_types)},
                `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
            """
            cursor.execute(create_table_query)
            
            # Clean and prepare data
            cleaned_data = self._clean_data(data, columns)
            
            # Insert data into table
            placeholders = ', '.join(['%s' for _ in columns])
            insert_query = f"""
            INSERT INTO `{table_name}` (`{'`, `'.join(columns)}`)
            VALUES ({placeholders})
            """
            
            cursor.executemany(insert_query, cleaned_data)
            self.conn.commit()
            
            self.logger.info(f"Successfully saved {len(cleaned_data)} rows to table '{table_name}'")
            cursor.close()
            return True
            
        except Exception as e:
            self.conn.rollback()
            self.logger.error(f"Error saving data to table '{table_name}': {e}")
            return False
    
    def get_data(self, 
                table_name: str, 
                columns: Optional[List[str]] = None, 
                limit: Optional[int] = None,
                where_clause: Optional[str] = None,
                order_by: Optional[str] = None) -> pd.DataFrame:
        """
        Query data from a specified table.
        
        Args:
            table_name: Name of the table to query
            columns: List of columns to retrieve (None for all columns)
            limit: Maximum number of rows to return
            where_clause: SQL WHERE condition
            order_by: SQL ORDER BY clause
            
        Returns:
            DataFrame containing the query results
        """
        try:
            # Reconnect if connection is closed
            if not self.is_connected():
                self.conn = self._create_connection()
            
            # Build the SELECT part of the query
            cols_str = "*" if columns is None else "`" + "`, `".join(columns) + "`"
            query = f"SELECT {cols_str} FROM `{table_name}`"
            
            # Add WHERE clause if specified
            if where_clause:
                query += f" WHERE {where_clause}"
                
            # Add ORDER BY clause if specified
            if order_by:
                query += f" ORDER BY {order_by}"
                
            # Add LIMIT clause if specified
            if limit is not None:
                query += f" LIMIT {limit}"
            
            # Execute query using pandas
            df = pd.read_sql(query, self.conn)
            self.logger.info(f"Retrieved {len(df)} rows from table '{table_name}'")
            return df
            
        except Exception as e:
            self.logger.error(f"Error retrieving data from table '{table_name}': {e}")
            return pd.DataFrame()
    
    def execute_query(self, query: str, params: tuple = ()) -> pd.DataFrame:
        """
        Execute a custom SQL query.
        
        Args:
            query: SQL query to execute
            params: Parameters for the query
            
        Returns:
            DataFrame containing the query results
        """
        try:
            # Reconnect if connection is closed
            if not self.is_connected():
                self.conn = self._create_connection()
                
            # Execute query using pandas
            df = pd.read_sql(query, self.conn, params=params)
            self.logger.info(f"Custom query executed successfully, returned {len(df)} rows")
            return df
        except Exception as e:
            self.logger.error(f"Error executing custom query: {e}")
            return pd.DataFrame()
    
    def get_table_schema(self, table_name: str) -> pd.DataFrame:
        """
        Get the schema information for a table.
        
        Args:
            table_name: Name of the table
            
        Returns:
            DataFrame containing the table schema
        """
        try:
            # Reconnect if connection is closed
            if not self.is_connected():
                self.conn = self._create_connection()
                
            query = f"DESCRIBE `{table_name}`"
            df = pd.read_sql(query, self.conn)
            return df
        except Exception as e:
            self.logger.error(f"Error retrieving schema for table '{table_name}': {e}")
            return pd.DataFrame()
    
    def list_tables(self) -> List[str]:
        """
        Get a list of all tables in the database.
        
        Returns:
            List of table names
        """
        try:
            # Reconnect if connection is closed
            if not self.is_connected():
                self.conn = self._create_connection()
                
            cursor = self.conn.cursor()
            cursor.execute("SHOW TABLES")
            tables = [row[0] for row in cursor.fetchall()]
            cursor.close()
            return tables
        except Exception as e:
            self.logger.error(f"Error listing tables: {e}")
            return []
    
    def create_index(self, table_name: str, columns: List[str], index_name: Optional[str] = None) -> bool:
        """
        Create an index on specified columns.
        
        Args:
            table_name: Name of the table
            columns: List of columns to include in the index
            index_name: Name of the index (optional)
            
        Returns:
            Boolean indicating success
        """
        if not index_name:
            index_name = f"idx_{table_name}_{'_'.join(columns)}"
            
        try:
            # Reconnect if connection is closed
            if not self.is_connected():
                self.conn = self._create_connection()
                
            cursor = self.conn.cursor()
            query = f"CREATE INDEX `{index_name}` ON `{table_name}` (`{'`, `'.join(columns)}`)"
            cursor.execute(query)
            self.conn.commit()
            cursor.close()
            self.logger.info(f"Created index '{index_name}' on table '{table_name}'")
            return True
        except Exception as e:
            self.conn.rollback()
            self.logger.error(f"Error creating index: {e}")
            return False
    
    def perform_data_transformation(self, df: pd.DataFrame, transformations: Dict[str, callable]) -> pd.DataFrame:
        """
        Apply a series of transformations to a DataFrame.
        
        Args:
            df: Input DataFrame
            transformations: Dictionary mapping column names to transformation functions
            
        Returns:
            Transformed DataFrame
        """
        transformed_df = df.copy()
        
        for column, transform_func in transformations.items():
            if column in transformed_df.columns:
                transformed_df[column] = transformed_df[column].apply(transform_func)
                
        return transformed_df
    
    def execute_transaction(self, queries: List[Tuple[str, tuple]]) -> bool:
        """
        Execute multiple queries as a single transaction.
        
        Args:
            queries: List of (query, params) tuples
            
        Returns:
            Boolean indicating success
        """
        try:
            # Reconnect if connection is closed
            if not self.is_connected():
                self.conn = self._create_connection()
                
            cursor = self.conn.cursor()
            for query, params in queries:
                cursor.execute(query, params)
            self.conn.commit()
            cursor.close()
            self.logger.info(f"Transaction with {len(queries)} queries executed successfully")
            return True
        except Exception as e:
            self.conn.rollback()
            self.logger.error(f"Transaction failed: {e}")
            return False
    
    def backup_table(self, table_name: str, backup_suffix: str = "_backup") -> bool:
        """
        Create a backup of a table.
        
        Args:
            table_name: Name of the table to backup
            backup_suffix: Suffix to append to the original table name
            
        Returns:
            Boolean indicating success
        """
        backup_table = f"{table_name}{backup_suffix}"
        try:
            # Reconnect if connection is closed
            if not self.is_connected():
                self.conn = self._create_connection()
                
            cursor = self.conn.cursor()
            
            # Drop the backup table if it exists
            cursor.execute(f"DROP TABLE IF EXISTS `{backup_table}`")
            
            # Create the backup table
            cursor.execute(f"CREATE TABLE `{backup_table}` LIKE `{table_name}`")
            cursor.execute(f"INSERT INTO `{backup_table}` SELECT * FROM `{table_name}`")
            self.conn.commit()
            cursor.close()
            self.logger.info(f"Created backup of table '{table_name}' as '{backup_table}'")
            return True
        except Exception as e:
            self.conn.rollback()
            self.logger.error(f"Error backing up table: {e}")
            return False
    
    def close(self):
        """Close the database connection."""
        if self.conn:
            self.conn.close()
            self.logger.info("Database connection closed")

In [8]:
# load database credentials
load_dotenv()

HOST=os.getenv("HOST")
USER=os.getenv("USER")
PASSWORD=os.getenv("PASSWORD")
DATABASE=os.getenv("DATABASE")
PORT=os.getenv("PORT")

db = DataPipelineManager(host=HOST, user=USER, password=PASSWORD, database=DATABASE, port=int(PORT))

2025-04-25 18:41:27,571 - DataPipelineManager - INFO - Initialized DataPipelineManager with MySQL database at 127.0.0.1:3306/data_analytics


In [5]:
def retrieve_csv_data() -> dict:
    """
    Locate and load CSV files from the 'datasets' directory
    Returns: Dictionary with dataset names as keys and DataFrames as values
    """
    dataset_collection = {}
    
    # Navigate directory structure
    current_dir = os.getcwd()
    # upper_dir = os.path.dirname(current_dir)
    data_folder = os.path.join(current_dir, 'datasets')
    print(data_folder)

    
    # Find all CSV files
    csv_file_paths = []
    for file in os.listdir(data_folder):
        if file.endswith('.csv'):
            csv_file_paths.append(f"datasets/{file}")
    
    # Process each dataset
    for idx, filepath in enumerate(csv_file_paths):
        file_data = pd.read_csv(os.path.join(os.path.dirname(os.getcwd()), filepath))
        file_name = filepath.split("/")[1].split(".")[0]
        print(f"{idx+1}: {file_name}")
        dataset_collection[file_name] = file_data
    
    return dataset_collection

# Load datasets into dictionary
data_dict = retrieve_csv_data()

# Assign individual datasets to variables
disease_dataset = data_dict[list(data_dict.keys())[0]]
aids_trials = data_dict[list(data_dict.keys())[1]]
symptom_data = data_dict[list(data_dict.keys())[2]]
disease_info = data_dict[list(data_dict.keys())[3]]

/home/onfon/Desktop/isaac/freelance/creekpen/datasets


IndexError: list index out of range

In [None]:
def prepare_dataset(data: pd.DataFrame, target_col: str, excluded_patterns: list=None) -> pd.DataFrame:
    """
    Prepare dataset for analysis by removing unwanted columns, duplicates,
    and handling missing values
    
    Args:
        data: Input DataFrame
        target_col: Target column name
        excluded_patterns: List of patterns to exclude in column names
        
    Returns: Cleaned DataFrame
    """
    if excluded_patterns is None:
        excluded_patterns = []
    
    # Remove columns matching exclusion patterns
    filtered_df = data.drop([col for col in data.columns if any(pattern in col for pattern in excluded_patterns)], axis=1)
    
    # Remove duplicate rows
    filtered_df = filtered_df.loc[~filtered_df.duplicated()]
    
    # Check for missing values
    missing_count = filtered_df.isnull().sum()
    if missing_count.sum() > 0:
        print(missing_count)
    else:
        print("No missing values detected")
    
    # Split features and target
    features = filtered_df.drop(target_col, axis=1)
    target = filtered_df[target_col]
    
    # Create train-test split
    features_train, features_test, target_train, target_test = train_test_split(
        features, target, test_size=0.2, random_state=42)
    
    # Scale features
    feature_scaler = StandardScaler()
    features_train_scaled = feature_scaler.fit_transform(features_train)
    features_test_scaled = feature_scaler.transform(features_test)
    
    # Restore column names after scaling
    features_train_scaled = pd.DataFrame(features_train_scaled, columns=features_train.columns)
    features_test_scaled = pd.DataFrame(features_test_scaled, columns=features_test.columns)
    
    return filtered_df

In [None]:
# Clean and prepare first dataset
clean_disease_data = prepare_dataset(dataset=disease_dataset, target_col='disease', excluded_patterns=['Unnamed'])

# Generate disease distribution visualization
plt.figure(figsize=(10, 6))
disease_distribution = clean_disease_data['disease'].value_counts()
sns.barplot(x=disease_distribution.values, y=disease_distribution.index)
plt.title('Distribution of Disease Cases')
plt.xlabel('Number of Cases')
plt.ylabel('Disease Type')
plt.tight_layout()
plt.show()

# Analyze symptom correlations
symptom_features = clean_disease_data.columns[:-1]
plt.figure(figsize=(15, 12))
sns.heatmap(clean_disease_data[symptom_features].corr(), cmap='viridis', center=0, annot=False)
plt.title('Symptom Correlation Analysis')
plt.tight_layout()
plt.show()

# Create symptom distribution boxplot by disease
plt.figure(figsize=(15, 6))
reshaped_data = clean_disease_data.melt(id_vars=['disease'], value_vars=symptom_features)
sns.boxplot(x='disease', y='value', data=reshaped_data)
plt.xticks(rotation=45)
plt.title('Symptom Intensity Across Disease Categories')
plt.tight_layout()
plt.show()

# Analyze top symptoms per disease
leading_symptoms = clean_disease_data[symptom_features].mean().sort_values(ascending=False)[:10].index
disease_symptom_profile = pd.DataFrame()
for disease_type in disease_distribution.index:
    disease_subset = clean_disease_data[clean_disease_data['disease'] == disease_type][leading_symptoms].mean()
    disease_symptom_profile[disease_type] = disease_subset

plt.figure(figsize=(12, 6))
disease_symptom_profile.T.plot(kind='bar', stacked=True)
plt.title('Main Symptom Distribution by Disease')
plt.xlabel('Disease Category')
plt.ylabel('Symptom Frequency')
plt.legend(title='Symptoms', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

In [None]:
# Clean and prepare symptom dataset
clean_symptom_data = prepare_dataset(dataset=symptom_data, target_col='prognosis')

# Visualize prognosis distribution
plt.figure(figsize=(12, 6))
prognosis_distribution = clean_symptom_data['prognosis'].value_counts()
sns.barplot(x=prognosis_distribution.index, y=prognosis_distribution.values)
plt.title('Prognosis Case Distribution')
plt.xlabel('Prognosis Category')
plt.ylabel('Case Frequency')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

# Symptom importance analysis using chi-square
symptom_variables = clean_symptom_data.columns[:-1]
importance_scores = {}

for symptom in symptom_variables:
    contingency = pd.crosstab(clean_symptom_data[symptom], clean_symptom_data['prognosis'])
    chi2_val, p_value, _, _ = stats.chi2_contingency(contingency)
    importance_scores[symptom] = chi2_val

# Prepare importance dataframe
symptom_importance = pd.DataFrame(list(importance_scores.items()),
                                  columns=['Symptom_Name', 'Statistical_Significance'])
symptom_importance = symptom_importance.sort_values('Statistical_Significance', ascending=False)

# Plot top significant symptoms
plt.figure(figsize=(12, 8))
sns.barplot(data=symptom_importance.head(20), x='Statistical_Significance', y='Symptom_Name')
plt.title('Most Significant Symptoms for Prognosis Assessment')
plt.tight_layout()
plt.show()

# Create symptom co-occurrence visualization
plt.figure(figsize=(15, 15))
correlation_data = clean_symptom_data[symptom_variables].corr()
upper_mask = np.triu(np.ones_like(correlation_data, dtype=bool))
sns.heatmap(correlation_data, mask=upper_mask, center=0, cmap='RdBu_r',
            square=True, linewidths=.5, cbar_kws={"shrink": .5})
plt.title('Symptom Co-occurrence Pattern Analysis')
plt.tight_layout()
plt.show()

# Prognosis-specific symptom analysis
top_prognoses = prognosis_distribution.head().index
fig, axes = plt.subplots(2, 3, figsize=(20, 12))
flattened_axes = axes.ravel()

for i, prognosis_type in enumerate(top_prognoses):
    prognosis_subset = clean_symptom_data[clean_symptom_data['prognosis'] == prognosis_type][symptom_variables].mean()
    primary_symptoms = prognosis_subset.sort_values(ascending=False)[:10]
    sns.barplot(x=primary_symptoms.values, y=primary_symptoms.index, ax=flattened_axes[i])
    flattened_axes[i].set_title(f'Primary Symptoms for {prognosis_type}')
    flattened_axes[i].set_xlabel('Prevalence')

plt.tight_layout()
plt.show()

# PCA analysis for symptom clustering
feature_normalizer = StandardScaler()
normalized_symptoms = feature_normalizer.fit_transform(clean_symptom_data[symptom_variables])
normalized_symptoms = pd.DataFrame(normalized_symptoms, columns=symptom_variables)
normalized_symptoms['prognosis'] = clean_symptom_data['prognosis']

from sklearn.decomposition import PCA
reducer = PCA(n_components=2)
reduced_data = reducer.fit_transform(clean_symptom_data[symptom_variables])

plt.figure(figsize=(12, 8))
scatter_plot = plt.scatter(reduced_data[:, 0], reduced_data[:, 1],
                          c=pd.factorize(clean_symptom_data['prognosis'])[0],
                          alpha=0.6)
plt.title('Symptom Pattern Clusters by Prognosis (Dimensionality Reduction)')
plt.xlabel('First Principal Component')
plt.ylabel('Second Principal Component')
plt.legend(handles=scatter_plot.legend_elements()[0],
          labels=list(clean_symptom_data['prognosis'].unique()),
          title='Prognosis',
          bbox_to_anchor=(1.05, 1),
          loc='upper left')
plt.tight_layout()
plt.show()

# Summary statistics
analysis_summary = {
    'total_prognosis_categories': len(prognosis_distribution),
    'predominant_prognosis': prognosis_distribution.index[0],
    'key_indicator_symptom': symptom_importance.iloc[0]['Symptom_Name'],
    'total_tracked_symptoms': len(symptom_variables)
}

print("\nKey Analysis Findings:")
for metric, value in analysis_summary.items():
    print(f"{metric}: {value}")

In [None]:
# Clean and analyze AIDS clinical trial data
processed_aids_data = prepare_dataset(dataset=aids_trials, target_col='label', excluded_patterns=['Unnamed'])

# AIDS case distribution visualization
plt.figure(figsize=(10, 6))
sns.countplot(data=processed_aids_data, x='label')
plt.title('AIDS Diagnosis Distribution')
plt.xlabel('AIDS Status (0: Negative, 1: Positive)')
plt.ylabel('Patient Count')
plt.tight_layout()
plt.show()

# Demographic analysis by AIDS status
fig, (axis1, axis2) = plt.subplots(1, 2, figsize=(15, 6))

# Age distribution
sns.boxplot(data=processed_aids_data, x='label', y='age', ax=axis1)
axis1.set_title('Patient Age Distribution by AIDS Status')
axis1.set_xlabel('AIDS Status')
axis1.set_ylabel('Age (years)')

# Weight distribution
sns.boxplot(data=processed_aids_data, x='label', y='wtkg', ax=axis2)
axis2.set_title('Patient Weight Distribution by AIDS Status')
axis2.set_xlabel('AIDS Status')
axis2.set_ylabel('Weight (kg)')

plt.tight_layout()
plt.show()

# CD4 count analysis
plt.figure(figsize=(12, 6))
sns.scatterplot(data=processed_aids_data, x='cd40', y='cd420', hue='label', alpha=0.6)
plt.title('CD4 Count Progression: Initial vs Week 20')
plt.xlabel('Baseline CD4 Count')
plt.ylabel('Week 20 CD4 Count')
plt.tight_layout()
plt.show()

# Risk factor analysis
risk_categories = ['homo', 'drugs']
fig, plot_axes = plt.subplots(1, len(risk_categories), figsize=(15, 6))

for idx, factor in enumerate(risk_categories):
    sns.heatmap(pd.crosstab(processed_aids_data[factor], processed_aids_data['label'], normalize='index'),
               annot=True, fmt='.2%', cmap='Oranges', ax=plot_axes[idx])
    plot_axes[idx].set_title(f'{factor.capitalize()} Status vs AIDS Diagnosis')

plt.tight_layout()
plt.show()

# Treatment effect analysis
plt.figure(figsize=(12, 6))
treatment_outcomes = processed_aids_data.groupby(['treat', 'label']).size().unstack()
treatment_outcomes.plot(kind='bar', stacked=True)
plt.title('AIDS Status Distribution Across Treatment Groups')
plt.xlabel('Treatment Group')
plt.ylabel('Patient Count')
plt.legend(title='AIDS Status')
plt.tight_layout()
plt.show()

# Correlation analysis of clinical parameters
numeric_variables = ['age', 'wtkg', 'cd40', 'cd420', 'cd80', 'cd820', 'karnof']
plt.figure(figsize=(12, 10))
parameter_correlations = processed_aids_data[numeric_variables + ['label']].corr()
sns.heatmap(parameter_correlations, annot=True, cmap='RdBu', center=0)
plt.title('Clinical Parameter Correlation Matrix')
plt.tight_layout()
plt.show()

# Symptom severity analysis
plt.figure(figsize=(10, 6))
sns.countplot(data=processed_aids_data, x='symptom', hue='label')
plt.title('Symptom Severity by AIDS Status')
plt.xlabel('Symptom Level')
plt.ylabel('Patient Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Immune system ratio analysis
processed_aids_data['baseline_cd4_cd8_ratio'] = processed_aids_data['cd40'] / processed_aids_data['cd80']
processed_aids_data['followup_cd4_cd8_ratio'] = processed_aids_data['cd420'] / processed_aids_data['cd820']

fig, (axis1, axis2) = plt.subplots(1, 2, figsize=(15, 6))
sns.boxplot(data=processed_aids_data, x='label', y='baseline_cd4_cd8_ratio', ax=axis1)
axis1.set_title('Baseline CD4/CD8 Ratio by AIDS Status')
axis1.set_xlabel('AIDS Status')
axis1.set_ylabel('CD4/CD8 Ratio')

sns.boxplot(data=processed_aids_data, x='label', y='followup_cd4_cd8_ratio', ax=axis2)
axis2.set_title('Week 20 CD4/CD8 Ratio by AIDS Status')
axis2.set_xlabel('AIDS Status')
axis2.set_ylabel('CD4/CD8 Ratio')

plt.tight_layout()
plt.show()

# Statistical significance testing
print("\nKey Statistical Results:")
print("-" * 50)

# Age difference test
age_tstat, age_pval = stats.ttest_ind(processed_aids_data[processed_aids_data['label']==1]['age'],
                                    processed_aids_data[processed_aids_data['label']==0]['age'])
print(f"Age difference significance (t-test p-value): {age_pval:.4f}")

# Treatment effect test
treat_chi2, treat_pval = stats.chi2_contingency(pd.crosstab(processed_aids_data['treat'], 
                                                          processed_aids_data['label']))[0:2]
print(f"Treatment effect significance (chi-square p-value): {treat_pval:.4f}")

# CD4 count change
processed_aids_data['cd4_delta'] = processed_aids_data['cd420'] - processed_aids_data['cd40']
cd4_tstat, cd4_pval = stats.ttest_ind(processed_aids_data[processed_aids_data['label']==1]['cd4_delta'],
                                    processed_aids_data[processed_aids_data['label']==0]['cd4_delta'])
print(f"CD4 count change significance (t-test p-value): {cd4_pval:.4f}")