In [2]:
import os
import random
import torch
import open3d as o3d
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import trimesh
import numpy as np
from itertools import product
from torch.utils.data import DataLoader 

def read_classification_file(filename):
    with open(filename, "r") as f:
        lines = f.readlines()

    # Skip first two lines
    lines = lines[2:]

    modelclass = []  # List of (model_id, class_name) pairs
    N = []  # List of (class_name, num_models)

    i = 0
    while i < len(lines):
        parts = lines[i].strip().split()
        if len(parts) < 3:
            i += 1
            continue
        
        class_name, _, num_models = parts
        num_models = int(num_models)

        model_ids = [lines[i + j + 1].strip() for j in range(num_models)]
        # print(model_ids)

        # Store class name and number of models
        N.append((class_name, num_models))

        # Store model-class pairs
        for model_id in model_ids:
            modelclass.append((model_id, class_name))

        i += num_models + 1  # Move to next class

    return modelclass, N




'''the sketch ids and the model ids arent the same'''

'''
temp1 = []
temp2 = []
for i, (sketch_id, sketch_class) in enumerate(m_s):
    for model_id, model_class in m:
        if model_class == sketch_class:
            temp1.append(model_id)
    break  # Breaks after first sketch class is processed

# Loop through 3D model classes
for i, (model_id, model_class) in enumerate(m):
    for sketch_id, sketch_class in m_s:
        if sketch_class == model_class:
            temp2.append(sketch_id)
    break  # Breaks after first model class is processed

print(temp1, temp2)

# for sketch_id in temp2:
#     sketch_id_padded = str(sketch_id).zfill(6)  # Convert "1" → "00001"
#     print(sketch_id_padded)
#     if sketch_id_padded not in temp1:
#         print("yooooooyouoyoyoy")
#         print(sketch_id)

'''

class ShapeData(Dataset):
    def __init__(self, sketch_dir, model_dir, sketch_file, model_file, label = "train",transform=None):
        self.sketch_dir = sketch_dir
        self.model_dir = model_dir
        self.transform = transform
        self.sketch_models, self.sketch_N = sketch_file
        self.models_3d, self.N_3d = model_file
        self.label = label
        # print(self.models_3d)

        self.pairs = []
        all_classes = set(self.sketch_models.keys()) & set(self.models_3d.keys())
        for class_name in all_classes:
            #positive pairs (target = 0)
            sketch_ids = self.sketch_models[class_name]
            model_ids = self.models_3d[class_name]
            if len(model_ids) == 0:
                    continue  
            # print("model_ids: ", model_ids)
            
            for i in sketch_ids:
                pos_ind = random.choice(model_ids)
                # print("pos_ind: ", pos_ind)
                self.pairs.append((i, pos_ind, class_name, 0))

            #negative pairs (target = 1)
            neg_classes = all_classes - {class_name}
            for i in sketch_ids:
                neg_cls = random.choice(list(neg_classes))
                # print("neg_cls: ", neg_cls)
                # print("model_ids neg: ", self.models_3d[neg_cls])   
                if len(self.models_3d[neg_cls]) == 0:
                    continue             
                neg_ind = random.choice(self.models_3d[neg_cls])
                # print("neg_ind: ", neg_ind)
                self.pairs.append((i, neg_ind, class_name, 1))         
        

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, index):
        sketch_id, model_id, class_name, target = self.pairs[index]
        # print("skt_id: ", sketch_id, "model_id: ", model_id, "class_name: ", class_name, "target: ", target)
        sketch_path = os.path.join(self.sketch_dir, f"{class_name}/{self.label}/{sketch_id}.png")
        model_path = os.path.join(self.model_dir, f"M{model_id}.off")

        sketch = Image.open(sketch_path).convert("RGB")
        mesh = o3d.io.read_triangle_mesh(model_path)
        if len(mesh.vertices) == 0:
            return None, None, None
        pcd = o3d.geometry.PointCloud()
        pcd = mesh.sample_points_poisson_disk(number_of_points=500, init_factor=5)
        # vertices_np = np.asarray(mesh.vertices)
        # pcd.points = o3d.utility.Vector3dVector(vertices_np)


        if self.transform:
            sketch = self.transform(sketch)

        return sketch, pcd, target
    

if __name__ == "__main__":
    file = "/nlsasfs/home/neol/rushar/scripts/img_to_pcd/shrec_data/sketches/SHREC14LSSTB_SKETCHES/SHREC14_SBR_models_train.cla"
    m, n = read_classification_file(file)
    # print(m)
    # print(n)

    file = "/nlsasfs/home/neol/rushar/scripts/img_to_pcd/shrec_data/sketches/SHREC14LSSTB_SKETCHES/SHREC14_SBR_Sketch_Train.cla"
    m_s, n_s = read_classification_file(file)
    print(len(m_s))
    print(len(n_s)) 
    print(len(m))
    print(len(n))

    m_temp = {}
    m_s_temp = {}
    for (model_id, model_class) in m:
        m_temp.setdefault(model_class, []).append(model_id)

    for (sketch_id, sketch_class) in m_s:
        m_s_temp.setdefault(sketch_class, []).append(sketch_id)

    n_dict= dict(n)
    m_tr = {}
    m_te = {}
    for ind,i in enumerate(m_temp):
        m_tr.setdefault(i, []).extend(m_temp[i][:n_dict[i]//2])
        m_te.setdefault(i, []).extend(m_temp[i][n_dict[i]//2:])

    print("lol")
    # print(len(m_tr))
    # print(len(m_te))
    # print(m_tr)
    # print(m_te)
    # print(len(m_s_temp))


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
8550
171
8987
171
lol


In [3]:
dataset = ShapeData(
    sketch_dir="/nlsasfs/home/neol/rushar/scripts/img_to_pcd/shrec_data/sketches/sketches_unzp/SHREC14LSSTB_SKETCHES/",
    model_dir="/nlsasfs/home/neol/rushar/scripts/img_to_pcd/shrec_data/target3d/SHREC14LSSTB_TARGET_MODELS/",
    sketch_file=(m_s_temp, n_s),
    model_file=(m_tr, n),
    label='train',
    transform=None  # You can add image transformations here
)
 
for i in dataset:
    print(i)
    break


(<PIL.Image.Image image mode=RGB size=1111x1111 at 0x7F634F668D00>, PointCloud with 500 points., 0)
