In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import DynamicEdgeConv, global_max_pool, knn_graph
import torch.optim as optim
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import plotly.graph_objects as go

from torch_geometric.data import Data

import os
import shutil
import numpy as np

import random

# Set seed
# SEED=42
# random.seed(SEED)
# torch.manual_seed(SEED)
# np.random.seed(SEED)


In [2]:

K = 4

In [None]:
class PointNetInstanceSeg(nn.Module):
    def __init__(self):
        super(PointNetInstanceSeg, self).__init__()
        self.edge_conv1 = DynamicEdgeConv(nn.Sequential(
            nn.Linear(14, 64),
            nn.SiLU(),
            nn.Linear(64, 128),
            nn.SiLU()
        ), k=K)
        self.edge_conv2 = DynamicEdgeConv(nn.Sequential(
             nn.Linear(256, 128),
             nn.SiLU(),
             nn.Linear(128, 64),
             nn.SiLU()
         ), k=K)
        # self.edge_conv3 = DynamicEdgeConv(nn.Sequential(
        #      nn.Linear(1024, 512),
        #      nn.SiLU(),
        #      nn.Linear(512, 256),
        #      nn.SiLU()
        #  ), k=8)
        # self.edge_conv4 = DynamicEdgeConv(nn.Sequential(
        #      nn.Linear(256, 256),
        #      nn.SiLU(),
        #      nn.Linear(256, 128),
        #      nn.SiLU()
        #  ), k=K)
        self.fc = nn.Linear(64, 21)  # Predicting instance mask for each point

    def forward(self, data):
        x, edge_index = data.pos, data.edge_index
        # print(data.edge_index)
        # return
        x = self.edge_conv1(x, edge_index)
        x = self.edge_conv2(x, edge_index)
        # x = self.edge_conv3(x, edge_index)
        # x = self.edge_conv4(x, edge_index)
        x = self.fc(x)
        return x

# Calc number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(PointNetInstanceSeg())

In [4]:
data = "/home/group10/deephalo_gnn/Imbalance_Resampled_for_mulltilabel/train"
files = os.listdir(data)
point_cloud_data = [(np.load(data+"/"+f)) for f in files if f.endswith(".npy")] # List of point cloud data, each element is a list of point coordinates

# Convert each point cloud data into a Data object
data_list = []
for point_cloud in point_cloud_data:
    # Create a Data object with the features
    pos = torch.tensor(point_cloud[:,:-1], dtype=torch.float)
    # Recentering positions per halo
    pos[:3] = pos[:3] - pos[:3].mean(dim=1, keepdim=True)
    data = Data(
        pos=pos,
        y = torch.eye(21)[torch.tensor(point_cloud[:,-1]+1, dtype=torch.long)],
        
    )
    
    data_list.append(data)

# Now you can use DataLoader with this list of Data objects
loader = DataLoader(data_list, batch_size=1, shuffle=True)


In [None]:
data.y
print(data.y)
class_labels = torch.argmax(data.y, dim=1)
print(class_labels)

In [6]:
labels = []
for point_clouds in point_cloud_data:
    label = point_clouds[:,-1]
    labels.append(label)
labels = torch.tensor(np.concatenate(labels))

# Calculate unique labels and counts
unique_labels, counts = torch.unique(labels, return_counts=True)

# Calculate frequencies
frequencies = counts.float() / labels.numel()

# Calculate weights
weight_vec = 1.0 / torch.log(torch.tensor(1.2) + frequencies)

In [None]:
weight_vec[0]=0.01
weight_vec
weight_vec.shape

In [8]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1., gamma=2.):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce_logits = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        BCE_loss = self.bce_logits(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

In [9]:
DEVICE = torch.device('cuda:0')

In [None]:
model = PointNetInstanceSeg().to(DEVICE)
weights = torch.FloatTensor(weight_vec).to(DEVICE)


criterion = FocalLoss(alpha=weights)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=5e-4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for data in tqdm(loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        data.to(DEVICE)
        # print(data.y.shape)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, data.y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * data.num_graphs
        

    epoch_loss = running_loss / len(loader.dataset)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

# Create the "ckpts" directory if it doesn't exist
import os, time

os.makedirs("ckpts", exist_ok=True)
curr_time = time.strftime("%Y%m%d-%H%M%S")
model_name = f"{curr_time}_pointnet_instance_seg"

# Save the model
torch.save(model.state_dict(), f"./ckpts/{model_name}_model.pth")

In [None]:
data_test = "/home/group10/deephalo_gnn/Labeled subhalo matrices of haloes/test"
files = os.listdir(data_test)
point_cloud_data = [(np.load(data_test+"/"+f)) for f in files if f.endswith(".npy") and (int(f[:-4])>50)] # List of point cloud data, each element is a list of point coordinates

# Convert each point cloud data into a Data object
data_test_list = []
for point_cloud in point_cloud_data:
    # Create a Data object with the features
    data_test = Data(pos=torch.tensor(point_cloud[:,:-1], dtype=torch.float), y = torch.eye(21)[torch.tensor(point_cloud[:,-1]+1, dtype=torch.long)])
    data_test_list.append(data_test)


In [None]:

test_loader = DataLoader(data_test_list, batch_size=1, shuffle=False)
    

model.eval()

# Initialize a list to store the predictions
ground_truth_labels = []
predictions = []
pos_list = []

# Loop over the test data
with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
        # Move data to the device
        data = data.to(DEVICE)
        
        # Pass the data through the model
        outputs = model(data)
        
        # Get the predicted labels
        _, predicted_labels = torch.max(outputs, 1)
        _, ground_truth = torch.max(data.y, 1)
        pos = data.pos.cpu().numpy()
        pos = pos[:,0:3]
        # Store the predictions
        ground_truth_labels.append(ground_truth.cpu().numpy())
        predictions.append(predicted_labels.cpu().numpy())
        pos_list.append(pos)
    

# At this point, `predictions` is a list of numpy arrays with the predicted labels for each point cloud in the test set
# You can now compare these predictions to the actual labels to compute your test metrics

In [28]:
def multi_label_iou(pred, target):
    pred = pred.float()
    target = target.float()

    # Reshape the tensors to a 2D format
    pred = pred.view(pred.shape[0], -1)
    target = target.view(target.shape[0], -1)

    # Calculate intersection and union for each sample
    intersection = (pred * target).sum(dim=1)
    union = (pred + target).clamp(0, 1).sum(dim=1)

    # Calculate IoU and avoid division by zero
    iou = intersection / (union + 1e-8)

    return iou.mean()

In [None]:
hid = 14

# tar, masked = self.prep_tar(hid)
# num_masked = sum(masked)

fig = go.Figure(data=[
    go.Scatter3d(
        x=pos_list[hid][:,0],
        y=pos_list[hid][:,1],
        z=pos_list[hid][:,2],
        mode='markers',
        marker=dict(
            size=1, # Larger than surrounding data-points
            color=ground_truth_labels[hid],
            opacity=0.75,
            showscale=True,
        ))
])
fig.update_layout(
    title=f'{predictions[hid].shape} particles', title_x=0.5,
)

In [None]:
iou_score = []
for idx, (gt, pred) in enumerate(zip(ground_truth_labels, predictions)):
    # Checking if the model predicts different labels for different points in the same point cloud
    # if np.unique(pred).shape[0] != 1:
    print(idx, "\t", gt.shape, "\t", np.unique(pred), "\t", np.unique(gt))
    print()
    
    iou_score.append(multi_label_iou(torch.tensor(pred), torch.tensor(gt)))

print(f"Mean acc: {np.mean(iou_score):.4f} \pm {np.std(iou_score):.4f}")