In [1]:
import numpy as np
import pandas as pd


def load_data(file_path='../data/FIREbox_z=0.txt'):
    """
    Load FIREbox data from text file and process it.
    
    Parameters:
    -----------
    file_path : str
        Path to the FIREbox_z=0.txt file
        
    Returns:
    --------
    pd.DataFrame
        Processed dataframe with:
        - Only rows where hostHaloID = -1
        - Xc, Yc, Zc columns renamed to pos_x, pos_y, pos_z
        - Position values divided by 10^2 to convert from dividing by H0/100 to H0
    """
    # Read the data file, skipping comment lines (starting with #)
    df = pd.read_csv(file_path, sep='\s+', comment='#')
    
    # Filter for rows where hostHaloID = -1
    df_filtered = df[df['hostHaloID'] == -1].copy()
    
    # Rename position and velocity columns 
    df_filtered = df_filtered.rename(columns={
        'Xc': 'pos_x',
        'Yc': 'pos_y', 
        'Zc': 'pos_z', 
        'VXc': 'vel_x',
        'VYc': 'vel_y',
        'VZc': 'vel_z'
    })
    
    # Divide position values by 10^2
    df_filtered['pos_x'] = df_filtered['pos_x'] / 100
    df_filtered['pos_y'] = df_filtered['pos_y'] / 100
    df_filtered['pos_z'] = df_filtered['pos_z'] / 100
    
    return df_filtered

  df = pd.read_csv(file_path, sep='\s+', comment='#')


In [2]:
df = load_data()
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 160502 entries, 0 to 199491
Data columns (total 13 columns):
 #   Column              Non-Null Count   Dtype  
---  ------              --------------   -----  
 0   haloID              160502 non-null  int64  
 1   hostHaloID          160502 non-null  int64  
 2   lg_Mhalo            160502 non-null  float64
 3   Rhalo               160502 non-null  float64
 4   pos_x               160502 non-null  float64
 5   pos_y               160502 non-null  float64
 6   pos_z               160502 non-null  float64
 7   vel_x               160502 non-null  float64
 8   vel_y               160502 non-null  float64
 9   vel_z               160502 non-null  float64
 10  lg_Mstar_<Rhalo     160502 non-null  float64
 11  lg_Mstar_<0.1Rhalo  160502 non-null  float64
 12  lg_Mstar_<10kpc     160502 non-null  float64
dtypes: float64(11), int64(2)
memory usage: 17.1 MB


In [3]:
# summary statistics
df.describe()

  sqr = _ensure_numeric((avg - values) ** 2)
  diff_b_a = subtract(b, a)
  sqr = _ensure_numeric((avg - values) ** 2)
  diff_b_a = subtract(b, a)
  sqr = _ensure_numeric((avg - values) ** 2)
  diff_b_a = subtract(b, a)


Unnamed: 0,haloID,hostHaloID,lg_Mhalo,Rhalo,pos_x,pos_y,pos_z,vel_x,vel_y,vel_z,lg_Mstar_<Rhalo,lg_Mstar_<0.1Rhalo,lg_Mstar_<10kpc
count,160502.0,160502.0,160502.0,160502.0,160502.0,160502.0,160502.0,160502.0,160502.0,160502.0,160502.0,160502.0,160502.0
mean,101313.178577,-1.0,7.855367,9.38059,77.907889,87.94908,79.482393,2.27518,-0.378612,-0.453479,-inf,-inf,-inf
std,57122.66908,0.0,0.498449,7.205899,43.084758,47.718787,43.581347,75.868411,99.528694,72.376967,,,
min,0.0,-1.0,7.1255,0.29,0.00114,0.00458,0.00229,-468.37,-429.62,-499.67,-inf,-inf,-inf
25%,52257.25,-1.0,7.5023,6.51,43.26668,44.2131,42.57889,-37.18,-69.14,-40.31,,,
50%,101804.5,-1.0,7.7044,7.58,80.690575,101.89133,82.88956,3.125,-4.53,-0.61,,,
75%,150768.75,-1.0,8.049,9.82,112.936875,130.50156,119.58847,44.04,70.98,37.74,,,
max,199491.0,-1.0,13.078,464.03,149.99886,149.99542,149.99771,524.89,533.28,543.27,11.809,11.565,11.325


In [10]:
df.head()

Unnamed: 0,haloID,hostHaloID,lg_Mhalo,Rhalo,pos_x,pos_y,pos_z,vel_x,vel_y,vel_z,lg_Mstar_<Rhalo,lg_Mstar_<0.1Rhalo,lg_Mstar_<10kpc
0,0,-1,13.078,464.03,147.11058,129.86948,76.14448,-20.5,18.82,4.31,11.779,11.474,11.245
1,1,-1,12.963,425.03,79.08847,118.16418,61.88743,-19.79,39.51,33.28,11.738,11.565,11.325
2,2,-1,12.924,412.25,148.07982,146.23561,3.6222,81.54,-1.85,51.89,11.809,11.554,11.196
3,3,-1,12.685,343.34,8.73864,147.24276,31.67646,23.02,-48.64,-29.93,11.622,11.518,11.272
4,4,-1,12.619,326.36,106.70664,149.3047,87.7744,15.68,-28.43,-22.17,11.445,11.231,11.08


In [25]:
# number of galaxies with lg_Mstar_<Rhalo > 0
df[df['lg_Mstar_<Rhalo'] > 0].count()

haloID                6213
hostHaloID            6213
lg_Mhalo              6213
Rhalo                 6213
pos_x                 6213
pos_y                 6213
pos_z                 6213
vel_x                 6213
vel_y                 6213
vel_z                 6213
lg_Mstar_<Rhalo       6213
lg_Mstar_<0.1Rhalo    6213
lg_Mstar_<10kpc       6213
dtype: int64

In [4]:
import torch
import numpy as np
import torch_geometric
from torch_geometric.data import Data
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

In [18]:
def create_graph_data(df, k=5, test_size=0.2, random_state=42, include_lg_Mstar=False):
    """
    Create a PyTorch Geometric graph from the FIREbox data.
    
    Parameters:
    -----------
    df : pd.DataFrame
        The processed FIREbox dataframe
    k : int
        Number of nearest neighbors for graph connectivity
    test_size : float
        Fraction of data to use for testing
    random_state : int
        Random seed for reproducibility
    include_lg_Mstar : bool
        Whether to include lg_Mstar_<Rhalo in the feature set

    Returns:
    --------
    tuple
        (train_graph, test_graph, scaler) where:
        - train_graph: PyTorch Geometric Data object for training
        - test_graph: PyTorch Geometric Data object for testing  
        - scaler: Fitted StandardScaler for feature normalization
    """
    
    # Define feature columns and target
    if include_lg_Mstar:
        feature_cols = ['Rhalo', 'pos_x', 'pos_y', 'pos_z', 'vel_x', 'vel_y', 'vel_z', 'lg_Mstar_<Rhalo']
    else:
        feature_cols = ['Rhalo', 'pos_x', 'pos_y', 'pos_z', 'vel_x', 'vel_y', 'vel_z']
    target_col = 'lg_Mhalo'
    
    # Extract features and target
    X = df[feature_cols].values
    y = df[target_col].values
    
    # Normalize features
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X_scaled, y, test_size=test_size, random_state=random_state
    )
    
    # Create k-NN graph connectivity based on spatial coordinates
    # Use pos_x, pos_y, pos_z (columns 1, 2, 3 in feature array)    
    
    # For training graph
    nbrs_train = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(X_train[:, 1:4])  # spatial coords
    distances_train, indices_train = nbrs_train.kneighbors(X_train[:, 1:4])
    pos_train = torch.tensor(X_train[:, 1:4], dtype=torch.float)

    # Remove self-loops (first neighbor is always the point itself)
    edge_index_train = []
    for i, neighbors in enumerate(indices_train):
        for j in neighbors[1:]:  # Skip first neighbor (self)
            edge_index_train.append([i, j])
    
    edge_index_train = torch.tensor(edge_index_train, dtype=torch.long).t().contiguous()
    
    # For test graph
    nbrs_test = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(X_test[:, 1:4])
    distances_test, indices_test = nbrs_test.kneighbors(X_test[:, 1:4])
    pos_test = torch.tensor(X_test[:, 1:4], dtype=torch.float)

    edge_index_test = []
    for i, neighbors in enumerate(indices_test):
        for j in neighbors[1:]:  # Skip first neighbor (self)
            edge_index_test.append([i, j])
    
    edge_index_test = torch.tensor(edge_index_test, dtype=torch.long).t().contiguous()
    
    # Create PyTorch Geometric Data objects
    train_graph = Data(
        x=torch.tensor(X_train, dtype=torch.float),
        edge_index=edge_index_train,
        y=torch.tensor(y_train, dtype=torch.float).unsqueeze(1),
        pos=pos_train
    )
    
    test_graph = Data(
        x=torch.tensor(X_test, dtype=torch.float),
        edge_index=edge_index_test,
        y=torch.tensor(y_test, dtype=torch.float).unsqueeze(1), 
        pos=pos_test
    )
    
    return train_graph, test_graph, scaler


In [26]:
# Test the graph creation function
train_graph, test_graph, scaler = create_graph_data(df, k=6)

print("Graph Creation Results:")
print(f"Training graph:")
print(f"  - Number of nodes: {train_graph.num_nodes}")
print(f"  - Number of edges: {train_graph.num_edges}")
print(f"  - Node features shape: {train_graph.x.shape}")
print(f"  - Target shape: {train_graph.y.shape}")


print(f"\nTest graph:")
print(f"  - Number of nodes: {test_graph.num_nodes}")
print(f"  - Number of edges: {test_graph.num_edges}")
print(f"  - Node features shape: {test_graph.x.shape}")
print(f"  - Target shape: {test_graph.y.shape}")

print(f"\nTarget statistics:")
print(f"  - Training target range: {train_graph.y.min():.3f} to {train_graph.y.max():.3f}")
print(f"  - Test target range: {test_graph.y.min():.3f} to {test_graph.y.max():.3f}")


Graph Creation Results:
Training graph:
  - Number of nodes: 128401
  - Number of edges: 770406
  - Node features shape: torch.Size([128401, 7])
  - Target shape: torch.Size([128401, 1])

Test graph:
  - Number of nodes: 32101
  - Number of edges: 192606
  - Node features shape: torch.Size([32101, 7])
  - Target shape: torch.Size([32101, 1])

Target statistics:
  - Training target range: 7.126 to 12.963
  - Test target range: 7.356 to 13.078
