# Simulation-based (likelihood-free) inference

## Galaxy clustering starter problem

**Siddharth Mishra-Sharma** ([smsharma@mit.edu](mailto:smsharma@mit.edu))

In [52]:
import torch
import numpy as np
import pandas as pd

In [53]:
# ! pip install torch_geometric torch_cluster

In [57]:
# Load galaxy point cloud data
x_train = np.load("../data/set_diffuser_data/train_halos.npy")

x_train.shape  # (n_data, n_points, n_features); features are (x, y, z, vx, vy, vz, mass)

(1800, 5000, 7)

In [58]:
theta_train = pd.read_csv("../data/set_diffuser_data/train_cosmology.csv")

theta_train.head()

Unnamed: 0,Omega_m,Omega_b,h,n_s,sigma_8
0,0.1755,0.06681,0.7737,0.8849,0.6641
1,0.2139,0.05557,0.8599,0.9785,0.8619
2,0.1867,0.04503,0.6189,0.8307,0.7187
3,0.3271,0.06875,0.6313,0.8135,0.8939
4,0.1433,0.06347,0.6127,1.1501,0.7699


In [59]:
import torch_geometric
from torch_geometric.data import Data
from torch_cluster import radius_graph

x_pos = torch.tensor(x_train[:4, :, :3], dtype=torch.float)      

# Build radius graph x_pos is (n_data, n_points, 3), so flatten to (n_data * n_points, 3) and then build graph
batch = torch.repeat_interleave(torch.arange(x_pos.shape[0]), x_pos.shape[1])
edge_index = radius_graph(x_pos.view(-1, 3), r=100, batch=batch)

In [62]:
from torch_geometric.nn import GCNConv  # Graph convolutional layer from PyG

class GNN(torch.nn.Module):
    def __init__(self, in_channels=3, hidden_channels=32, out_features=64, num_layers=3):
        super(GNN, self).__init__()

        # Graph convolutional layers
        self.graph_layers = torch.nn.ModuleList()
        for i in range(num_layers):
            if i == 0:
                self.graph_layers.append(GCNConv(in_channels, hidden_channels))
            else:
                self.graph_layers.append(GCNConv(hidden_channels, hidden_channels))
        
        # Readout layer
        self.readout_layer = torch.nn.Linear(hidden_channels, out_features)

    def forward(self, x, edge_index, batch):
        for layer in self.graph_layers:
            x = torch.nn.functional.relu(layer(x, edge_index))
        
        # Mean-aggregate over nodes in each graph
        x = torch_geometric.nn.global_mean_pool(x, batch)
        x = self.readout_layer(x)
        return x

In [64]:
# Test instantiating GNN and running it on a small batch of data
gnn = GNN(in_channels=3, hidden_channels=32, out_features=64, num_layers=3)
features = gnn(x_pos.view(-1, 3), edge_index, batch)

In [65]:
features.shape

torch.Size([4, 64])