# Malaria Geographic Origin Data Exploration

This notebook performs exploratory data analysis (EDA) on the genomic sequences of malaria parasites to understand patterns related to their geographic origins before model development.

In [None]:
# Import necessary libraries
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from Bio import SeqIO
import random
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import plotly.express as px
import plotly.graph_objects as go

# Add project root to path
sys.path.append('..')

# Set plotting style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('viridis')

# Load the genomic sequence dataset


In [None]:
from src.data.genomic_sequences import GenomicSequenceDataset

# Load all data to understand the full dataset
train_dataset = GenomicSequenceDataset(
    split_dir="../data/split",
    split_type="train",
    window_size=1000,
    stride=500,
    cache_size=128
)

val_dataset = GenomicSequenceDataset(
    split_dir="../data/split",
    split_type="val",
    window_size=1000,
    stride=500,
    cache_size=128
)

test_dataset = GenomicSequenceDataset(
    split_dir="../data/split",
    split_type="test",
    window_size=1000,
    stride=500,
    cache_size=128
)

# Get class names (geographic regions)
class_names = train_dataset.encoder.classes_
print(f"Number of classes: {len(class_names)}")
print(f"Class names (geographic regions): {class_names}")

# Display dataset sizes
print(f"\nTraining set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")
print(f"Total samples: {len(train_dataset) + len(val_dataset) + len(test_dataset)}")

# Dataset Class Distribution

Let's examine the distribution of samples across different geographic regions to check for class imbalance.

In [None]:
# Extract labels from each dataset
train_labels = [train_dataset[i]['label'] for i in range(len(train_dataset))]
val_labels = [val_dataset[i]['label'] for i in range(len(val_dataset))]
test_labels = [test_dataset[i]['label'] for i in range(len(test_dataset))]

# Convert numeric labels to geographic region names
def convert_to_class_names(labels, class_names):
    return [class_names[label] for label in labels]

train_regions = convert_to_class_names(train_labels, class_names)
val_regions = convert_to_class_names(val_labels, class_names)
test_regions = convert_to_class_names(test_labels, class_names)

# Count occurrences of each region
train_counts = Counter(train_regions)
val_counts = Counter(val_regions)
test_counts = Counter(test_regions)

# Combine all counts
all_counts = Counter()
for region in class_names:
    all_counts[region] = train_counts.get(region, 0) + val_counts.get(region, 0) + test_counts.get(region, 0)

# Create DataFrames for visualization
train_df = pd.DataFrame({'Region': train_regions})
val_df = pd.DataFrame({'Region': val_regions})
test_df = pd.DataFrame({'Region': test_regions})
all_df = pd.DataFrame({
    'Region': list(all_counts.keys()),
    'Count': list(all_counts.values())
})

# Sort by count for better visualization
all_df = all_df.sort_values('Count', ascending=False)

# Plot class distribution
plt.figure(figsize=(14, 8))
ax = sns.barplot(x='Region', y='Count', data=all_df)
plt.title('Distribution of Samples Across Geographic Regions', fontsize=15)
plt.xlabel('Geographic Region', fontsize=12)
plt.ylabel('Number of Samples', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()

# Add count labels on top of bars
for i, p in enumerate(ax.patches):
    ax.annotate(f'{p.get_height():.0f}', 
                (p.get_x() + p.get_width() / 2., p.get_height()), 
                ha = 'center', va = 'bottom', 
                fontsize=10)

plt.show()

# Check for class imbalance
min_count = min(all_counts.values())
max_count = max(all_counts.values())
imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')
print(f"Class imbalance ratio (max/min): {imbalance_ratio:.2f}")
print(f"Most common region: {all_counts.most_common(1)[0][0]} with {all_counts.most_common(1)[0][1]} samples")
print(f"Least common region: {all_counts.most_common()[-1][0]} with {all_counts.most_common()[-1][1]} samples")

# Sequence Characteristics Analysis

Let's analyze the characteristics of our genomic sequences.

In [None]:
# Randomly sample some sequences for analysis
def sample_sequences(dataset, n=100):
    indices = random.sample(range(len(dataset)), min(n, len(dataset)))
    sequences = []
    labels = []
    
    for idx in indices:
        sample = dataset[idx]
        seq = sample['sequence'].cpu().numpy()
        sequences.append(seq)
        labels.append(sample['label'])
    
    return sequences, labels

# Sample sequences from training set
sampled_sequences, sampled_labels = sample_sequences(train_dataset, n=500)

# Analyze sequence lengths
seq_lengths = [seq.flatten().shape[0] for seq in sampled_sequences]
seq_lengths_df = pd.DataFrame({'Length': seq_lengths})

# Plot sequence length distribution
plt.figure(figsize=(10, 6))
sns.histplot(data=seq_lengths_df, x='Length', kde=True)
plt.title('Distribution of Sequence Lengths', fontsize=15)
plt.xlabel('Sequence Length (nucleotides)', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()

print(f"Average sequence length: {np.mean(seq_lengths):.2f} nucleotides")
print(f"Min sequence length: {min(seq_lengths)} nucleotides")
print(f"Max sequence length: {max(seq_lengths)} nucleotides")

# Nucleotide Composition Analysis

Let's look at the distribution of nucleotides (A, C, G, T) across sequences.

In [None]:
# Convert one-hot encoded sequences back to nucleotide sequences
def one_hot_to_nucleotides(one_hot_seq):
    # Assumes one-hot encoding order is [A, C, G, T]
    nucleotide_map = ['A', 'C', 'G', 'T']
    nucleotides = []
    
    # Determine if it's a 1D or 2D array
    if len(one_hot_seq.shape) == 1:
        # Handle flattened array (assuming groups of 4)
        for i in range(0, len(one_hot_seq), 4):
            chunk = one_hot_seq[i:i+4]
            if 1 in chunk:
                idx = np.argmax(chunk)
                nucleotides.append(nucleotide_map[idx])
            else:
                nucleotides.append('N')  # Unknown nucleotide
    else:
        # Handle 2D array
        for i in range(one_hot_seq.shape[0]):
            idx = np.argmax(one_hot_seq[i])
            nucleotides.append(nucleotide_map[idx])
    
    return ''.join(nucleotides)

# Convert sampled sequences to nucleotides
nucleotide_sequences = [one_hot_to_nucleotides(seq) for seq in sampled_sequences]

# Count nucleotides across all sequences
nucleotide_counts = Counter()
for seq in nucleotide_sequences:
    nucleotide_counts.update(seq)

# Calculate GC content for each sequence
gc_contents = [(seq.count('G') + seq.count('C')) / len(seq) for seq in nucleotide_sequences]

# Prepare data for visualization
nucleotide_df = pd.DataFrame({
    'Nucleotide': list(nucleotide_counts.keys()),
    'Count': list(nucleotide_counts.values())
})

# Plot nucleotide distribution
plt.figure(figsize=(10, 6))
sns.barplot(x='Nucleotide', y='Count', data=nucleotide_df)
plt.title('Nucleotide Distribution Across Sampled Sequences', fontsize=15)
plt.xlabel('Nucleotide', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()

# Plot GC content distribution
plt.figure(figsize=(10, 6))
sns.histplot(gc_contents, bins=30, kde=True)
plt.title('GC Content Distribution', fontsize=15)
plt.xlabel('GC Content', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.axvline(x=np.mean(gc_contents), color='red', linestyle='--', 
            label=f'Mean GC: {np.mean(gc_contents):.2f}')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"Average GC content: {np.mean(gc_contents):.2f}")
print(f"Nucleotide distribution: {dict(nucleotide_counts)}")

# Summary Statistics by Geographic Region


In [None]:
# Convert labels to region names
region_names = [class_names[label] for label in sampled_labels]

# Create DataFrame with sequence characteristics and labels
seq_characteristics = pd.DataFrame({
    'Region': region_names,
    'GC_Content': gc_contents,
    'Length': seq_lengths
})

# Create a summary DataFrame for regions
region_summary = seq_characteristics.groupby('Region').agg({
    'GC_Content': ['mean', 'std', 'min', 'max'],
    'Length': ['mean', 'std', 'min', 'max', 'count']
})

# Flatten MultiIndex columns
region_summary.columns = ['_'.join(col).strip() for col in region_summary.columns.values]
region_summary = region_summary.reset_index()

# Display summary table
region_summary.sort_values('Length_count', ascending=False)

# Conclusion and Key Findings


In [None]:
# Summarize key findings
print("## Key Findings from Exploratory Data Analysis")
print("\n### Dataset Composition")
print(f"- Total samples: {len(train_dataset) + len(val_dataset) + len(test_dataset)}")
print(f"- Number of geographic regions: {len(class_names)}")
print(f"- Class imbalance ratio: {imbalance_ratio:.2f}")

print("\n### Sequence Characteristics")
print(f"- Average sequence length: {np.mean(seq_lengths):.2f} nucleotides")
print(f"- Average GC content: {np.mean(gc_contents):.2f}")
print(f"- Most common nucleotide: {nucleotide_counts.most_common(1)[0][0]}")

# Save Key Findings

In [None]:
# Save key figures for later use
plt.figure(figsize=(14, 8))
ax = sns.barplot(x='Region', y='Count', data=all_df)
plt.title('Distribution of Samples Across Geographic Regions', fontsize=15)
plt.xlabel('Geographic Region', fontsize=12)
plt.ylabel('Number of Samples', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.savefig('../reports/figures/geographic_distribution.png', dpi=300)

# Save nucleotide distribution
plt.figure(figsize=(10, 6))
sns.barplot(x='Nucleotide', y='Count', data=nucleotide_df)
plt.title('Nucleotide Distribution Across Sampled Sequences', fontsize=15)
plt.xlabel('Nucleotide', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.grid(True, alpha=0.3)
plt.savefig('../reports/figures/nucleotide_distribution.png', dpi=300)

print("Figures saved to ../reports/figures/")