In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import DynamicEdgeConv, global_max_pool, knn_graph
import torch.optim as optim
from torch_geometric.loader import DataLoader
from tqdm import tqdm


from torch_geometric.data import Data

import os
import shutil
import numpy as np

import random

# Set seed
# SEED=42
# random.seed(SEED)
# torch.manual_seed(SEED)
# np.random.seed(SEED)


In [2]:

K = 4

In [3]:
class PointNetInstanceSeg(nn.Module):
    def __init__(self):
        super(PointNetInstanceSeg, self).__init__()
        self.edge_conv1 = DynamicEdgeConv(nn.Sequential(
            nn.Linear(14, 64),
            nn.SiLU(),
            nn.Linear(64, 128),
            nn.SiLU()
        ), k=K)
        self.fc = nn.Linear(128, 21)  # Predicting instance mask for each point

    def forward(self, data):
        x, edge_index = data.pos, data.edge_index
        x = self.edge_conv1(x, edge_index)
        x = self.fc(x)
        return x

# Calc number of trainable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(PointNetInstanceSeg())

11989

In [4]:
# Assuming you have point cloud data stored as a list of positions
data = "/home/group10/deephalo_gnn/Imbalance_Resampled_for_mulltilabel/train"
files = os.listdir(data)
point_cloud_data = [(np.load(data+"/"+f)) for f in files if f.endswith(".npy")] # List of point cloud data, each element is a list of point coordinates

# Convert each point cloud data into a Data object
data_list = []
for point_cloud in point_cloud_data:
    # Create a Data object with the positions of points
    pos = torch.tensor(point_cloud[:,:-1], dtype=torch.float)
    # Recentering positions per halo
    pos[:3] = pos[:3] - pos[:3].mean(dim=1, keepdim=True)
    data = Data(
        pos=pos,
        y = torch.eye(21)[torch.tensor(point_cloud[:,-1]+1, dtype=torch.long)],
        # edge_index=knn_graph(pos, k=K)
    )
    # data.y = torch.tensor(point_cloud[:,-1], dtype=torch.long)
    # Dynamically generate edge_index based on the positions of points
    # You can use a method like k-NN to construct the edges
    # For example, using knn_graph from torch_geometric.transforms:
    # from torch_geometric.transforms import knn_graph
    # data.edge_index = knn_graph(data.pos, k=K)  # Construct edges based on 6 nearest neighbors
    # Add other necessary attributes to the Data object if needed
    # For example, data.y for ground truth labels
    data_list.append(data)

# Now you can use DataLoader with this list of Data objects
loader = DataLoader(data_list, batch_size=1, shuffle=True)


In [5]:
data.y
print(data.y)
class_labels = torch.argmax(data.y, dim=1)
print(class_labels)

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
tensor([0, 0, 0,  ..., 5, 5, 5])


In [6]:
labels = []
for point_clouds in point_cloud_data:
    label = point_clouds[:,-1]
    labels.append(label)
labels = torch.tensor(np.concatenate(labels))
# Convert list of tensors to a single tensor
# labels = torch.cat(labels)

# Calculate unique labels and counts
unique_labels, counts = torch.unique(labels, return_counts=True)

# Calculate frequencies
frequencies = counts.float() / labels.numel()

# Calculate weights
weight_vec = 1.0 / torch.log(torch.tensor(1.2) + frequencies)

In [7]:
weight_vec[0]=0.01
weight_vec
weight_vec.shape

torch.Size([21])

In [8]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1., gamma=2.):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce_logits = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        BCE_loss = self.bce_logits(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return F_loss.mean()

In [9]:
DEVICE = torch.device('cuda:0')

In [10]:
# Define your dataset and DataLoader
# dataloader = DataLoader(SHDataSet("train"), batch_size=1, shuffle=True)
# Initialize the model
model = PointNetInstanceSeg().to(DEVICE)
weights = torch.FloatTensor(weight_vec).to(DEVICE)
# Define your loss function and optimizer
# Assuming instance masks are represented as class labels

criterion = FocalLoss(alpha=weights)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=5e-4)



In [11]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for data in tqdm(loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        #try:
        # print(data.pos.shape, data.pos.size(0), data.pos.numel())
        # assert 2 == 3
        # x.size(0) == batch_x.numel()
        data.to(DEVICE)
        # print(data.y.shape)
        optimizer.zero_grad()
        outputs = model(data)
        # Assuming instance masks are represented as class labels and provided in data.y
        # target = torch.argmax(data.y, dim=1)  # Convert one-hot encoded target to class labels
        # Resize or reshape tensors to match sizes
        loss = criterion(outputs, data.y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * data.num_graphs
        

    epoch_loss = running_loss / len(loader.dataset)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

# Create the "ckpts" directory if it doesn't exist
import os, time

os.makedirs("ckpts", exist_ok=True)
curr_time = time.strftime("%Y%m%d-%H%M%S")
model_name = f"{curr_time}_pointnet_instance_seg"

# Save the model
torch.save(model.state_dict(), f"./ckpts/{model_name}_model.pth")

Epoch 1/10: 100%|██████████| 102/102 [00:27<00:00,  3.70it/s]


Epoch 1/10, Loss: 248.6879


Epoch 2/10: 100%|██████████| 102/102 [00:27<00:00,  3.70it/s]


Epoch 2/10, Loss: 86.4936


Epoch 3/10: 100%|██████████| 102/102 [00:27<00:00,  3.68it/s]


Epoch 3/10, Loss: 64.9624


Epoch 4/10: 100%|██████████| 102/102 [00:27<00:00,  3.67it/s]


Epoch 4/10, Loss: 47.9186


Epoch 5/10: 100%|██████████| 102/102 [00:27<00:00,  3.68it/s]


Epoch 5/10, Loss: 50.4105


Epoch 6/10: 100%|██████████| 102/102 [00:27<00:00,  3.64it/s]


Epoch 6/10, Loss: 39.7103


Epoch 7/10: 100%|██████████| 102/102 [00:27<00:00,  3.65it/s]


Epoch 7/10, Loss: 30.3453


Epoch 8/10: 100%|██████████| 102/102 [00:27<00:00,  3.65it/s]


Epoch 8/10, Loss: 23.6815


Epoch 9/10: 100%|██████████| 102/102 [00:27<00:00,  3.67it/s]


Epoch 9/10, Loss: 21.3767


Epoch 10/10: 100%|██████████| 102/102 [00:27<00:00,  3.65it/s]

Epoch 10/10, Loss: 15.9767





In [12]:
data_test = "/home/group10/deephalo_gnn/New test"
files = os.listdir(data_test)
point_cloud_data = [(np.load(data_test+"/"+f)) for f in files if f.endswith(".npy") and (int(f[:-4])>50)] # List of point cloud data, each element is a list of point coordinates

# Convert each point cloud data into a Data object
data_test_list = []
for point_cloud in point_cloud_data:
    # Create a Data object with the positions of points
    data_test = Data(pos=torch.tensor(point_cloud[:,:-1], dtype=torch.float), y = torch.eye(21)[torch.tensor(point_cloud[:,-1]+1, dtype=torch.long)])
    # data.y = torch.tensor(point_cloud[:,-1], dtype=torch.long)
    # Dynamically generate edge_index based on the positions of points
    # You can use a method like k-NN to construct the edges
    # For example, using knn_graph from torch_geometric.transforms:
    # from torch_geometric.transforms import knn_graph
    # data.edge_index = knn_graph(data.pos, k=6)  # Construct edges based on 6 nearest neighbors
    # Add other necessary attributes to the Data object if needed
    # For example, data.y for ground truth labels
    data_test_list.append(data_test)


In [13]:
model.load_state_dict(torch.load(f"./ckpts/{model_name}_model.pth"))

<All keys matched successfully>

In [14]:
# Assuming you have a DataLoader `test_loader` for your test data
test_loader = DataLoader(data_test_list, batch_size=1, shuffle=False)

# Put the model in evaluation mode

# if not model_name:

model.eval()

# Initialize a list to store the predictions
ground_truth_labels = []
predictions = []
pos_list = []

# Loop over the test data
with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing'):
        # Move data to the device
        data = data.to(DEVICE)

        # Pass the data through the model
        outputs = model(data)
        pos = data.pos.cpu().numpy()

        pos_list.append(pos)
        # Get the predicted labels
        _, predicted_labels = torch.max(outputs, 1)
        _, ground_truth = torch.max(data.y, 1)
        # Store the predictions
        ground_truth_labels.append(ground_truth.cpu().numpy())
        predictions.append(predicted_labels.cpu().numpy())


# At this point, `predictions` is a list of numpy arrays with the predicted labels for each point cloud in the test set
# You can now compare these predictions to the actual labels to compute your test metrics

Testing: 100%|██████████| 23/23 [00:00<00:00, 100.47it/s]


In [15]:
def multi_label_iou(pred, target):
    # Ensure the tensors are float type
    pred = pred.float()
    target = target.float()

    # Reshape the tensors to a 2D format
    pred = pred.view(pred.shape[0], -1)
    target = target.view(target.shape[0], -1)

    # Calculate intersection and union for each sample
    intersection = (pred * target).sum(dim=1)
    union = (pred + target).clamp(0, 1).sum(dim=1)

    # Calculate IoU and avoid division by zero
    iou = intersection / (union + 1e-8)

    return iou.mean()

# Usage:
# Assume `output` is the output of your model and `target` is your ground truth
# Both `output` and `target` should be one-hot encoded and have the same shape


In [16]:
iou_score = []
for idx, (gt, pred) in enumerate(zip(ground_truth_labels, predictions)):
    # Checking if the model predicts different labels for different points in the same point cloud
    # if np.unique(pred).shape[0] != 1:
    print(idx, "\t", gt.shape, "\t", np.unique(pred), "\t", np.unique(gt))
    print()
    
    iou_score.append(multi_label_iou(torch.tensor(pred), torch.tensor(gt)))
    # accs.append(accuracy_score(gt, pred))
    # f1s.append(f1_score(gt, pred, average='weighted'))

print(f"Mean iou score: {np.median(iou_score):.4f} \pm {np.std(iou_score):.4f}")
#print(f"Mean F!: {np.mean(f1s):.4f} \pm {np.std(f1s):.4f}")

0 	 (10065,) 	 [0 3 4 5] 	 [0 1 2 3]

1 	 (5704,) 	 [0 3 4] 	 [0 1 2]

2 	 (14113,) 	 [0] 	 [ 0  1  2  3  4  5  6  7  8  9 10 11 12]

3 	 (7083,) 	 [0 3 4 6] 	 [0 1 2 3 4 5]

4 	 (10423,) 	 [ 0  4  5  7 15] 	 [0 1 2]

5 	 (6830,) 	 [0] 	 [0 1 2 3 4 5 6]

6 	 (6355,) 	 [0] 	 [0]

7 	 (13648,) 	 [0] 	 [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13]

8 	 (12394,) 	 [0] 	 [0 1 2 3 4 5 6]

9 	 (7559,) 	 [0 2 3 5 6 7] 	 [0 1 2 3 4 5]

10 	 (16444,) 	 [0 6] 	 [ 0  1  2  3  4  5  6  7  8  9 10]

11 	 (6586,) 	 [0 3 4 5 7] 	 [0 1 2 3]

12 	 (9898,) 	 [0 6] 	 [0 1 2 3]

13 	 (8716,) 	 [0 3] 	 [0 1 2 3 4 5 6 7 8 9]

14 	 (6429,) 	 [ 0  3  6 11] 	 [0 1 2 3 4 5 6 7 8]

15 	 (15265,) 	 [0 4 5 6 7] 	 [0 1 2 3 4 5 6 7 8 9]

16 	 (18205,) 	 [0 3 6] 	 [ 0  1  2  3  4  5  6  7  8  9 10 11]

17 	 (13278,) 	 [0 3 4 5] 	 [0 1 2 3 4 5 6 7 8 9]

18 	 (15318,) 	 [0] 	 [0 1 2 3 4 5 6 7]

19 	 (6788,) 	 [ 0  3  4  5  7 11] 	 [0 1 2 3 4 5]

20 	 (9654,) 	 [0] 	 [0 1 2]

21 	 (5874,) 	 [ 0  1  2  3  4  5  6  7 13] 	 [

In [17]:
iou_score

[tensor(0.4688),
 tensor(0.),
 tensor(0.),
 tensor(0.6599),
 tensor(0.0941),
 tensor(0.),
 tensor(0.),
 tensor(0.),
 tensor(0.),
 tensor(0.4552),
 tensor(0.),
 tensor(0.3246),
 tensor(0.),
 tensor(0.1332),
 tensor(2.2552),
 tensor(0.6352),
 tensor(0.0522),
 tensor(0.3759),
 tensor(0.),
 tensor(1.6120),
 tensor(0.),
 tensor(0.1180),
 tensor(1.2191)]

In [18]:
pos_list[14].shape

(6429, 7)

In [19]:
ground_truth_labels[14]

array([0, 0, 0, ..., 0, 0, 0])

In [20]:
import plotly.graph_objects as go

i=13
points = pos_list[i][ground_truth_labels[i] != 0]
labels = ground_truth_labels[i][ground_truth_labels[i] != 0]

In [21]:
# s = label != -1
# s0 = label == -1

# x = halopartinfo[:, 0]
# y = halopartinfo[:, 1]
# z = halopartinfo[:, 2]
# label = halopartinfo[:, -1]
# # print(x.shape)

# fig = go.Figure(data=[go.Scatter3d(
#     x=x[s],
#     y=y[s],
#     z=z[s],
#     mode='markers',
#     marker=dict(
#         size=4,
#         color=label[s],
#         colorscale='Rainbow',
#         opacity=0.8
#     )
# ),
#                       go.Scatter3d(
#     x=x[s0],
#     y=y[s0],
#     z=z[s0],
#     mode='markers',
#     marker=dict(
#         size=1,
#         color=label[s0],
#         colorscale='Viridis',
#         opacity=0.5
#     )
# )
# ])

# fig.update_layout(margin=dict(l=0, r=0, b=0, t=0), width=1200, height=1000)
# # fig.show()

In [None]:
fig = go.Figure(data=[
    go.Scatter3d(
        x=points[:,0],
        y=points[:,1],
        z=points[:,2],
        mode='markers',
        marker=dict(
            size=1, # Larger than surrounding data-points
            color=labels,
            opacity=0.75,
            showscale=True,
        ))
])
fig.update_layout(
    title=f"9898 points | BG: TODO: points", title_x=0.5,
)

In [None]:
# import plotly.graph_objects as go
fig = go.Figure(data=[
    go.Scatter3d(
        x=pos_list[13][:,0],
        y=pos_list[13][:,1],
        z=pos_list[13][:,2],
        mode='markers',
        marker=dict(
            size=1, # Larger than surrounding data-points
            color=predictions[13],
            opacity=0.75,
            showscale=True,
        ))
])
fig.update_layout(
    title=f"{9898} points | BG: TODO: points", title_x=0.5,
)

In [29]:
# np.unique(ground_truth_labels[0], return_counts=True)
np.unique(ground_truth_labels[0]).shape
raise NotImplementedError("stop here")

NotImplementedError: stop here

In [27]:
# np.unique(predictions[0], return_counts=True)
np.unique(predictions[0]).shape

(4,)

In [None]:
this is meant to error out if left uncommented

In [28]:
class SHDataSet(torch.utils.data.Dataset):
    def __init__(self, set):
        if set == "train":
            data = "/home/group10/ml/Labeled subhalo matrices of haloes/train"
        elif set == "val":
            data = "/home/group10/ml/Labeled subhalo matrices of haloes/train/val"
        
        self.length = len(data)

        self.set = set
        
        files = os.listdir(data)
        files = [torch.tensor(np.load(data+"/"+f), dtype=torch.float64) for f in files if f.endswith(".npy")]
        files1=[f[:,:-1] for f in files]
        self.files = files1
        
        labels = [f[:,-1] for f in files]
        
        labels = [torch.nn.functional.one_hot(j, 14) for j in labels]
        
        self.labels = labels
        
        
    def __getitem__(self,index):
        return torch.tensor(self.files[index], dtype = torch.float64), torch.tensor(self.labels[index], dtype = torch.long)
        
        
    def __len__(self):
        return self.length

def train_accuracy(
    model,
    data_generator,
    GPU = torch.device("cuda:2"),
):
    model.eval()
    with torch.no_grad():
        accs = []
        for batch_x, batch_y in data_generator:
            batch_x, batch_y = batch_x.to(GPU), batch_y
            y_true = batch_y.argmax(1).numpy()
            y_pred = model(batch_x).argmax(1).cpu().numpy()
            acc = accuracy_score(y_true, y_pred)
            accs.append(acc*100)
    model.train()
    return np.array(accs).mean()

In [None]:
def __getitem__(self, idx):
    data = np.load(os.path.join(self.directory, self.files[idx]))
    points = torch.tensor(data[:, :3], dtype=torch.float32)  # Assuming the first 3 columns are the point coordinates
    labels = torch.tensor(data[:, 3:], dtype=torch.long)  # Assuming the rest of the columns are the labels
    return points, labels

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def visualize_point_cloud(points, labels):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    scatter = ax.scatter(points[:, 0], points[:, 1], points[:, 2], c=labels, cmap='jet')
    plt.show()

# Get a batch of data
data = next(iter(loader))

# Run the model and get the outputs
model.eval()
with torch.no_grad():
    outputs = model(data)

# Get the predicted labels
_, predicted_labels = torch.max(outputs, 1)

# Visualize the point cloud with the predicted labels
visualize_point_cloud(data.pos.cpu().numpy(), predicted_labels.cpu().numpy())

In [None]:
import time

In [None]:
for data in loader:
    print(data)
    print(data.y.shape)
    print(data.pos.shape)
    break

In [None]:
# Define your loss function and optimizer
criterion = nn.CrossEntropyLoss()  # Loss function expects dynamic number of classes
# optimizer = optim.Adam(model.parameters(), lr=0.001)

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
max_class_index = 0  # Initialize maximum class index
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for data in tqdm(loader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        optimizer.zero_grad()
        outputs = model(data)
        
        # Determine the maximum class index encountered
        max_class_index = max(max_class_index, torch.max(data.y).item())
        
        # Adjust the fully connected layer dynamically based on the maximum class index
        model.adjust_fc(max_class_index + 1)  # Add 1 to account for zero-based indexing
        
        loss = criterion(outputs, data.y)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * data.num_graphs
    
    epoch_loss = running_loss / len(loader.dataset)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

# After training, you can use the model for inference

In [None]:
(outputs.shape, data.y.shape)