In [47]:
import numpy as np
import pandas as pd
from scipy.cluster.hierarchy import linkage, fcluster, dendrogram
from scipy.spatial.distance import pdist
import matplotlib.pyplot as plt
from collections import defaultdict

# Set random seed for reproducibility
np.random.seed(42)

# Generate sample data - 10000 samples with 5 features
# In a real scenario, you would use your actual data
n_samples = 1000
n_features = 2
X = np.random.randn(n_samples, n_features)

print(f"Generated {n_samples} samples with {n_features} features")

# Calculate the distance matrix
print("Computing distance matrix...")
distances = pdist(X, metric='euclidean')

# Perform hierarchical clustering using Ward's method
print("Performing hierarchical clustering...")
Z = linkage(distances, method='ward')

# Function to extract cluster information at different levels
def extract_hierarchical_structure(Z, n_samples):
    """
    Extract the hierarchical structure from the linkage matrix Z.
    
    Parameters:
    - Z: scipy.cluster.hierarchy.linkage output matrix
    - n_samples: number of original samples
    
    Returns:
    - clusters: dictionary with cluster information
    """
    n_clusters = len(Z) + 1
    clusters = {}
    
    # Initialize with original samples (leaves)
    for i in range(n_samples):
        clusters[i] = {
            'parent': None,  # Will be updated when this cluster gets merged
            'children': [],  # Original samples have no children
            'samples': [i]  # Sample index
        }
    
    # Process each merge step from the linkage matrix
    for i, row in enumerate(Z):
        cluster_id = n_samples + i  # New cluster ID
        left_child = int(row[0])
        right_child = int(row[1])
        
        # Update parent references for children
        clusters[left_child]['parent'] = cluster_id
        clusters[right_child]['parent'] = cluster_id
        
        # Create new cluster
        new_samples = clusters[left_child]['samples'] + clusters[right_child]['samples']
        clusters[cluster_id] = {
            'parent': None,  # Will be updated if this cluster gets merged
            'children': [left_child, right_child],
            'samples': new_samples
        }
    
    return clusters

# Extract hierarchical structure
print("Extracting hierarchical structure...")
hierarchical_structure = extract_hierarchical_structure(Z, n_samples)

Generated 1000 samples with 2 features
Computing distance matrix...
Performing hierarchical clustering...
Extracting hierarchical structure...


In [4]:
# next step: implement to the ppi data
# create a map dict for easier storage
# clustering based on above
# for each disease, get pos nodes and mask nodes
# calculate and selec pseodo negatives
# also the hard negative implementation

In [48]:
import numpy as np
import random
from typing import Dict, List, Set, Tuple, Optional


def hierarchical_negative_sampling(
    clusters: Dict,
    positive_genes: Set[int],
    neg_sample_ratio: float = 5.0
) -> List[int]:
    """
    Hierarchical negative sampling algorithm that allocates negative samples based on cluster structure.
    
    Args:
        clusters: Dictionary representing hierarchical clustering results
        positive_genes: Set of positive gene indices
        neg_sample_ratio: Ratio of negative to positive samples (default: 5.0)
    
    Returns:
        List of selected negative sample indices
    """
    # Find the root cluster (the one with no parent)
    root_id = None
    for cluster_id, cluster_info in clusters.items():
        if cluster_info['parent'] is None:
            root_id = cluster_id
            break
    
    if root_id is None:
        raise ValueError("No root cluster found (cluster with parent=None)")
    
    # Calculate total number of positive genes
    num_positives = len(positive_genes)
    neg_sample_size = int(num_positives * neg_sample_ratio)
    
    # Store the selected negative samples
    selected_negatives = []
    
    # Begin recursive negative sampling
    def sample_negatives_recursive(
        cluster_id: int, 
        remaining_neg_samples: int
    ) -> List[int]:
        """
        Recursively sample negative genes based on cluster structure.
        
        Args:
            cluster_id: Current cluster ID
            remaining_neg_samples: Number of negative samples to allocate
        
        Returns:
            List of selected negative samples from this branch
        """
        cluster = clusters[cluster_id]
        
        # Check if this is a leaf node (original sample) or has <= 3 positives
        cluster_samples = set(cluster['samples'])
        cluster_positives = cluster_samples.intersection(positive_genes)
        num_cluster_positives = len(cluster_positives)
        
        # If cluster has 3 or fewer positives or is a leaf, sample from this cluster
        if num_cluster_positives <= 3:
            # Get negative samples in this cluster
            cluster_negatives = list(cluster_samples - cluster_positives)
            
            # If no negatives in this cluster, return empty list
            if not cluster_negatives:
                print('no negative in cluster', cluster_id)
                return []
            
            # Randomly sample from negative samples in this cluster
            if remaining_neg_samples > 0:
                # Make sure we're not sampling more elements than are available
                samples_to_take = min(remaining_neg_samples, len(cluster_negatives))
                print(f'sample {samples_to_take} from cluster f{cluster_id}, pos/size {len(cluster_samples.intersection(positive_genes))}/{len(cluster_samples)}')
                return random.sample(cluster_negatives, samples_to_take)
            else:
                print('cluster', cluster_id, 'remaining_neg_samples == 0')
            return []
        
        # If this cluster has more than 3 positives, continue recursion
        # Calculate distribution factors for each child
        pos_pro = []
        sizes = []
        # Count positives in each child
        for child_id in cluster['children']:
            child = clusters[child_id]
            child_samples = set(child['samples'])
            child_positives = child_samples.intersection(positive_genes)
            pos_pro.append(len(child_positives))
            sizes.append(len(child_samples))
        
        # Calculate negative sample allocation for each child using the formula
        # N_i = ΣN_i × A_i / (ΣA_i)
        # A_i = S_i × (1 - P_i / (ΣP_i))
        
        # First calculate A_i for each child
        child_factors = []
        
        for i, child_id in enumerate(cluster['children']):
            child_factors.append(sizes[i] * (1 - pos_pro[i]/sum(pos_pro)))
        
        # Allocate negative samples to children
        child_neg_samples = {}
        selected_samples = []
        
        for i, child_id in enumerate(cluster['children']):
            # N_i = ΣN_i × A_i / (ΣA_i)
            child_neg_samples[child_id] = int(remaining_neg_samples * child_factors[i]/sum(child_factors))
            
            # Recursive call to sample from this child
            child_selected = sample_negatives_recursive(
                child_id, child_neg_samples[child_id]
            )
            selected_samples.extend(child_selected)
        
        return selected_samples
    
    # Begin sampling from the root
    selected_negatives = sample_negatives_recursive(root_id, neg_sample_size)
    
    return selected_negatives


# Example usage
if __name__ == "__main__":
    
    # Example positive genes (for demonstration)
    positive_genes = random.sample(range(1, 1000), 20)
    
    # Get negative samples using hierarchical sampling
    negative_samples = hierarchical_negative_sampling(hierarchical_structure, positive_genes)
    
    # print(f"Positive genes: {positive_genes}")
    print(f"Selected negative samples ({len(negative_samples)})")
    
    # Verify no overlap between positives and selected negatives
    assert len(set(negative_samples).intersection(positive_genes)) == 0, "Overlap between positives and negatives!"

sample 2 from cluster f1963, pos/size 1/41
cluster 1938 remaining_neg_samples == 0
cluster 1944 remaining_neg_samples == 0
sample 6 from cluster f1989, pos/size 3/114
sample 23 from cluster f1993, pos/size 3/192
sample 25 from cluster f1990, pos/size 3/154
sample 39 from cluster f1984, pos/size 0/60
cluster 1983 remaining_neg_samples == 0
cluster 1982 remaining_neg_samples == 0
cluster 1985 remaining_neg_samples == 0
Selected negative samples (95)


In [52]:
list(reversed(hierarchical_structure.keys()))[:10]

[1998, 1997, 1996, 1995, 1994, 1993, 1992, 1991, 1990, 1989]

In [63]:
hierarchical_structure[1990]

{'parent': 1997,
 'children': [1976, 1987],
 'samples': [95,
  257,
  863,
  650,
  688,
  816,
  146,
  119,
  991,
  606,
  707,
  426,
  634,
  350,
  112,
  23,
  441,
  762,
  345,
  809,
  33,
  225,
  522,
  930,
  211,
  567,
  823,
  878,
  155,
  984,
  439,
  528,
  357,
  477,
  199,
  451,
  790,
  874,
  519,
  559,
  45,
  871,
  314,
  409,
  87,
  625,
  715,
  481,
  72,
  288,
  408,
  368,
  742,
  29,
  997,
  517,
  977,
  410,
  412,
  685,
  320,
  718,
  948,
  813,
  570,
  607,
  333,
  666,
  482,
  800,
  894,
  590,
  599,
  70,
  828,
  725,
  456,
  35,
  175,
  36,
  754,
  820,
  88,
  245,
  437,
  383,
  770,
  803,
  285,
  885,
  465,
  554,
  483,
  925,
  15,
  384,
  689,
  377,
  578,
  161,
  609,
  104,
  807,
  89,
  745,
  541,
  511,
  847,
  726,
  423,
  645,
  56,
  616,
  720,
  774,
  671,
  392,
  710,
  851,
  105,
  444,
  1,
  180,
  619,
  212,
  480,
  271,
  877,
  869,
  940,
  339,
  32,
  301,
  302,
  193,
  864,
  280,
  7

In [None]:
def show_tree_info(i, dict):
    print(f"cluster {i}: child {dict[i]['children']}, size {len(dict[i]['samples'])}")

show_tree_info(1990,hierarchical_structure)

cluster 1990: child [1976, 1987], size 154


In [36]:
len(set(hierarchical_structure[19852]['samples']).intersection(positive_genes)), len(hierarchical_structure[19852]['samples'])

(0, 163)

In [43]:
len(set(hierarchical_structure[19945]['samples']).intersection(positive_genes)), len(hierarchical_structure[19945]['samples'])

(4, 462)

In [40]:
list(hierarchical_structure.keys())[-1]

19998

In [46]:
pos = [0,5]
size = [10,10]
neg = 25
a = []
for i, pos_i in enumerate(pos):
    a.append(size[i]*(1-pos_i/sum(pos)))
for i, a_i in enumerate(a):
    print(neg*a_i/sum(a))

25.0
0.0
