In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt
from IPython.display import display



In [2]:
def generate_kmeans_data(N, K, output_dir=None, seed=42, visualize=True, save_files=True):
    # Set random seed if specified
    if seed is not None:
        np.random.seed(seed)
    
    # Create output directory if it doesn't exist
    if save_files:
        os.makedirs(output_dir, exist_ok=True)
    
    # Define centers for the clusters (spaced apart for visual verification)
    centers = np.linspace(0, 10*(K-1), K)
    
    # Generate points around each center with some standard deviation
    points = []
    points_per_cluster = N // K
    remaining = N % K
    
    # Track standard deviations for reporting
    std_devs = []
    
    for i in range(K):
        # Add some randomness to standard deviation for variety
        std_dev = 1.0 + 0.5 * np.random.rand()
        std_devs.append(std_dev)
        
        # Generate points for this cluster
        cluster_size = points_per_cluster + (1 if i < remaining else 0)
        cluster_points = np.random.normal(centers[i], std_dev, cluster_size)
        points.append(cluster_points)
    
    # Combine all points and shuffle
    all_points = np.concatenate(points)
    np.random.shuffle(all_points)
    
    # Generate initial centroids randomly within the data range
    min_val, max_val = all_points.min(), all_points.max()
    initial_centroids = np.random.uniform(min_val, max_val, K)
    initial_centroids.sort()  # Sort for better interpretability
    
    # Save to CSV files if requested
    if save_files:
        np.savetxt(f"{output_dir}/points_{N}.csv", all_points, fmt='%.6f')
        np.savetxt(f"{output_dir}/centers_{K}.csv", initial_centroids, fmt='%.6f')
        print(f"Files saved to {output_dir}/dados.csv and {output_dir}/centroides_iniciais.csv")
    
    print(f"Generated {N} data points and {K} initial centroids")
    print(f"Data range: [{min_val:.2f}, {max_val:.2f}]")
    print(f"True cluster centers: {centers}")
    print(f"Cluster standard deviations: {[f'{sd:.2f}' for sd in std_devs]}")
    
    # Create visualization if requested
    if visualize and N <= 100000:  # Limit visualization to reasonable sizes
        plt.figure(figsize=(12, 6))
        
        # Plot histogram of points
        plt.hist(all_points, bins=50, alpha=0.7)
        
        # Mark true centers
        for center in centers:
            plt.axvline(center, color='r', linestyle='--', alpha=0.5)
        
        # Mark initial centroids
        for centroid in initial_centroids:
            plt.axvline(centroid, color='g', linestyle='-', alpha=0.8)
        
        plt.title(f'K-means 1D Synthetic Data (N={N}, K={K})')
        plt.xlabel('Value')
        plt.ylabel('Frequency')
        plt.legend(['True Centers', 'Initial Centroids', 'Data Points'])
        
        # Save figure if saving files
        if save_files:
            plt.savefig(f"{output_dir}/data_visualization.png")
            print(f"Visualization saved to {output_dir}/data_visualization.png")
        
        # Display directly in notebook
        plt.show()
    
    return all_points, initial_centroids, centers

In [5]:
generate_kmeans_data(1000000, 16, "./data")

Files saved to ./data/dados.csv and ./data/centroides_iniciais.csv
Generated 1000000 data points and 16 initial centroids
Data range: [-4.97, 155.09]
True cluster centers: [  0.  10.  20.  30.  40.  50.  60.  70.  80.  90. 100. 110. 120. 130.
 140. 150.]
Cluster standard deviations: ['1.19', '1.49', '1.17', '1.28', '1.35', '1.08', '1.49', '1.08', '1.29', '1.06', '1.06', '1.31', '1.43', '1.39', '1.16', '1.23']


(array([ 58.21610916,  79.91939636, 118.92240982, ..., 101.39083308,
         99.85401875,   1.41031896]),
 array([ 22.27423596,  27.74828219,  29.74914949,  34.10879683,
         43.4789237 ,  49.20444817,  81.16410664,  86.3997171 ,
        109.38511023, 127.69826044, 134.19347059, 146.44657725,
        147.85982176, 148.81319643, 148.89002678, 150.57429602]),
 array([  0.,  10.,  20.,  30.,  40.,  50.,  60.,  70.,  80.,  90., 100.,
        110., 120., 130., 140., 150.]))