In [None]:
%load_ext autoreload
%autoreload 2
import notebook_setup
from src.config import INTERIM_DATA_DIR, PROCESSED_DATA_DIR, REPORTS_DIR, EXTERNAL_DATA_DIR, MODELS_DIR
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"

In [None]:
import numpy as np
from scipy.spatial import KDTree
import matplotlib.pyplot as plt

# OCC
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.IFSelect import IFSelect_RetDone
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_IN
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
from OCC.Core.BRepClass import BRepClass_FaceClassifier
from OCC.Core.gp import gp_Pnt2d

def ensure_uv_2d(x):
    x = np.asarray(x, dtype=np.float32)
    if x.size == 0:
        return np.zeros((0, 2), dtype=np.float32)
    if x.ndim == 2 and x.shape[1] == 2:
        return x
    if x.ndim == 1 and x.shape[0] == 2:
        return x.reshape(1, 2)
    return x.reshape(-1, 2)

def sample_uv_extended(resolution, extend=0.1):
    u = np.linspace(-extend, 1 + extend, resolution)
    v = np.linspace(-extend, 1 + extend, resolution)
    uu, vv = np.meshgrid(u, v, indexing='ij')
    return np.stack([uu.flatten(), vv.flatten()], axis=-1)

def compute_sdf(inside_points, outside_points):
    inside = ensure_uv_2d(inside_points)
    outside = ensure_uv_2d(outside_points)

    if inside.shape[0] == 0 and outside.shape[0] == 0:
        return np.zeros((0, 2), dtype=np.float32), np.zeros((0,), dtype=np.float32)
    if inside.shape[0] == 0:
        return outside, -np.zeros((outside.shape[0],), dtype=np.float32)
    if outside.shape[0] == 0:
        return inside, np.zeros((inside.shape[0],), dtype=np.float32)

    inside_tree = KDTree(outside)
    outside_tree = KDTree(inside)
    d_inside, _ = inside_tree.query(inside)
    d_outside, _ = outside_tree.query(outside)

    sdf_inside = d_inside.astype(np.float32)
    sdf_outside = -d_outside.astype(np.float32)
    sdf_points = np.concatenate([inside, outside], axis=0)
    sdf_values = np.concatenate([sdf_inside, sdf_outside], axis=0)
    return ensure_uv_2d(sdf_points), sdf_values.astype(np.float32)

def bias_sample_sdf(sdf_points, sdf_values, n_samples, boundary_ratio=0.4):
    pts = ensure_uv_2d(sdf_points)
    vals = np.asarray(sdf_values, dtype=np.float32).reshape(-1)
    if pts.shape[0] == 0:
        return pts, vals
    idx = np.argsort(np.abs(vals))
    nb = int(n_samples * boundary_ratio)
    nb = max(0, min(nb, idx.size))
    i_boundary = idx[:nb]
    i_pool = idx[nb:]
    if i_pool.size:
        i_pool = np.random.permutation(i_pool)
    need_rand = max(0, n_samples - nb)
    i_sel = np.concatenate([i_boundary, i_pool[:min(need_rand, i_pool.size)]], axis=0)
    i_sel = i_sel.astype(int)
    return pts[i_sel], vals[i_sel]

# ---------- привязка к реальной грани ----------
def query_cad_kernel_face(face, uv_samples):
    """
    Делим UV на inside/outside относительно ТРИМОВ грани через CAD-ядро.
    """
    uv = ensure_uv_2d(uv_samples)
    surf = BRepAdaptor_Surface(face)
    u0, u1 = surf.FirstUParameter(), surf.LastUParameter()
    v0, v1 = surf.FirstVParameter(), surf.LastVParameter()
    clf = BRepClass_FaceClassifier()

    mask = []
    for uv_ in uv:
        uu = float(u0 + float(uv_[0]) * (u1 - u0))
        vv = float(v0 + float(uv_[1]) * (v1 - v0))
        p2d = gp_Pnt2d(uu, vv)
        clf.Perform(face, p2d, 1e-9)
        mask.append(clf.State() == TopAbs_IN)

    mask = np.array(mask, dtype=bool)
    inside  = uv[mask]
    outside = uv[~mask]
    return ensure_uv_2d(inside), ensure_uv_2d(outside)

def compute_xyz_from_uv_face(face, uv_coords):
    uv = ensure_uv_2d(uv_coords)
    if uv.shape[0] == 0:
        return np.zeros((0, 3), dtype=np.float32)
    surf = BRepAdaptor_Surface(face)
    u0, u1 = surf.FirstUParameter(), surf.LastUParameter()
    v0, v1 = surf.FirstVParameter(), surf.LastVParameter()
    uu = u0 + uv[:, 0] * (u1 - u0)
    vv = v0 + uv[:, 1] * (v1 - v0)
    out = np.zeros((uv.shape[0], 3), dtype=np.float32)
    for i in range(uv.shape[0]):
        p = surf.Value(float(uu[i]), float(vv[i]))
        out[i, 0] = p.X(); out[i, 1] = p.Y(); out[i, 2] = p.Z()
    return out

# ---------- загрузка STEP и обход граней ----------
def load_shape(step_path: str):
    r = STEPControl_Reader()
    assert r.ReadFile(step_path) == IFSelect_RetDone, "STEP read failed"
    r.TransferRoots()
    return r.OneShape()
from OCC.Extend import TopologyUtils
def iter_faces(shape):
    top_exp = TopologyUtils.TopologyExplorer(shape, ignore_orientation=True)
    for face in top_exp.faces():
        yield face

# ---------- быстрая проверка на 1-й грани ----------
def quick_sdf_check_all(step_path: str, res=128, extend=0.1, n_samples=500, show_first_k=2):
    shape = load_shape(step_path)
    faces = list(iter_faces(shape))
    assert len(faces) > 0, "Нет граней в модели"

    all_sdf_uv   = []   # список [ [n_i,2], ... ]
    all_sdf_vals = []   # список [ [n_i],   ... ]
    all_samp_uv  = []   # список [ [n_samples,2], ... ]
    all_samp_sdf = []   # список [ [n_samples],   ... ]
    all_targ_xyz = []   # список [ [n_samples,3], ... ]

    for i, face in enumerate(faces):
        if i >= 3:
            break
        uv = sample_uv_extended(resolution=res, extend=extend)            # [M,2] torch
        inside, outside = query_cad_kernel_face(face, uv)                 # [*,2]
        sdf_pts, sdf_vals = compute_sdf(inside, outside)                  # [M',2],[M']
        samp_uv, samp_sdf = bias_sample_sdf(sdf_pts, sdf_vals,
                                            n_samples=n_samples, boundary_ratio=0.4)
        targ_xyz = compute_xyz_from_uv_face(face, samp_uv)                # [n_samples,3]

        all_sdf_uv.append(sdf_pts)
        all_sdf_vals.append(sdf_vals)
        all_samp_uv.append(samp_uv)
        all_samp_sdf.append(samp_sdf)
        all_targ_xyz.append(targ_xyz)

        # необязательная визуализация для первых k граней
        # if i < show_first_k:
        fig, ax = plt.subplots(1, 2, figsize=(10,4))
        sc = ax[0].scatter(sdf_pts[:,0], sdf_pts[:,1],
                            c=sdf_vals, s=4, cmap="coolwarm")
        fig.colorbar(sc, ax=ax[0], label="SDF (UV)")
        ax[0].set_title(f"Face {i}: SDF на UV-сетке")
        ax[0].set_xlabel("u"); ax[0].set_ylabel("v")

        sc2 = ax[1].scatter(samp_uv[:,0], samp_uv[:,1],
                                c=samp_sdf, s=8, cmap="coolwarm")
        fig.colorbar(sc2, ax=ax[1], label="SDF (выборка у границы)")
        ax[1].set_title(f"Face {i}: выборка у границы")
        ax[1].set_xlabel("u"); ax[1].set_ylabel("v")
        plt.tight_layout(); plt.show()

    # сводка
    total_points = sum(v for v in all_sdf_vals)
    print(f"Граней: {len(faces)}; всего SDF-точек на сетках: {total_points}; "
          f"выборок у границы на грань: {n_samples}")

    return {
        "faces_count": len(faces),
        "sdf_grid_uv_list": all_sdf_uv,        # список torch тензоров [n_i,2]
        "sdf_grid_vals_list": all_sdf_vals,    # список torch тензоров [n_i]
        "sampled_uv_list": all_samp_uv,        # список torch тензоров [n_samples,2]
        "sampled_sdf_list": all_samp_sdf,      # список torch тензоров [n_samples]
        "target_xyz_list": all_targ_xyz        # список torch тензоров [n_samples,3]
    }

STEPS_DIR = PROCESSED_DATA_DIR / "dataset_129" / "stp"

step_files = stems = {p for p in STEPS_DIR.glob("*.stp")}

stp = list(step_files)[15]
print(f"Using STEP file: {stp}")
out = quick_sdf_check_all(str(stp), res=128, extend=0.1, n_samples=500)



In [None]:
from src.modeling.SSL_BrepNet import extract_features

FEATURES_LIST_PATH = EXTERNAL_DATA_DIR / "feature_lists" / "all.json"
STEPS_DIR = PROCESSED_DATA_DIR / "dataset_129" / "stp"
extract_features.run(
    step_path_dir=STEPS_DIR,
    feature_list_path=FEATURES_LIST_PATH,
    num_workers=0,
    force_regeneration=True
)

In [None]:
from src.modeling.SSL_BrepNet import build_dataset_file
BREPNET_NPZ_DIR = PROCESSED_DATA_DIR / "dataset_129" / "features" / "brep"
STATS_BREPNET = PROCESSED_DATA_DIR / "dataset_129" / "dataset_brepnet_stats.json"
os.makedirs(BREPNET_NPZ_DIR, exist_ok=True)
build_dataset_file.run(
    brepnet_dir=BREPNET_NPZ_DIR,
    output_file=STATS_BREPNET,
    validation_split=0.1,
    test_split=0.1,
    random_seed=42,
)

In [None]:
BREPNET_NPZ_DIR = PROCESSED_DATA_DIR / "dataset_129" / "features" / "brep"
import numpy as np
npz_files = list(BREPNET_NPZ_DIR.glob("*.npz"))
print(f"Всего .npz файлов: {len(npz_files)}")

with np.load(npz_files[0]) as data:
    for k, v in data.items():
        print(f"{k}: {v.shape}, dtype={v.dtype}, min={v.min() if v.size>0 else 'N/A'}, max={v.max() if v.size>0 else 'N/A'}")

# BREPNET_NPZ_DIR = PROCESSED_DATA_DIR / "dataset_129" / "features" / "all_sdf_with_normals"
# npz_files = list(BREPNET_NPZ_DIR.glob("*.npz"))
# print(f"Всего .npz файлов: {len(npz_files)}")

# with np.load(npz_files[0]) as data:
#     for k, v in data.items():
#         print(f"{k}: {v.shape}, dtype={v.dtype}, min={v.min() if v.size>0 else 'N/A'}, max={v.max() if v.size>0 else 'N/A'}")

# os.makedirs(BREPNET_NPZ_DIR, exist_ok=True)

# for npz_path in npz_files:
#     with np.load(npz_path) as data:
#         uv_faces = data['uv_faces']          # [n,2]
#         sdf_faces = data['sdf_faces']    # [n]
#     np.savez_compressed(BREPNET_NPZ_DIR / npz_path.name, uv_faces=uv_faces, sdf_faces=sdf_faces)
#     print(f"Обработан файл {npz_path}, добавлены нормали.")

In [None]:
import numpy as np
import torch
from types import SimpleNamespace
from src.modeling.SSL_BrepNet.model.encoder import CustomBRepEncoder
from src.modeling.SSL_BrepNet.model.decoder import ConditionalDecoder

BREPNET_NPZ_DIR = PROCESSED_DATA_DIR / "dataset_129" / "features" / "brep"
SDF_NPZ_DIR = PROCESSED_DATA_DIR / "dataset_129" / "features" / "all_sdf_with_normals"
npz_brep_files = list(BREPNET_NPZ_DIR.glob("*.npz"))
npz_sdf_files = list(SDF_NPZ_DIR.glob("*.npz"))



D = np.load(npz_files[0])
SDF =np.load(npz_files[0])

data = SimpleNamespace(
    vertices=torch.from_numpy(D["vertex"].astype(np.float32)),         # [n_v, 3]
    edges=torch.from_numpy(D["edge_features"].astype(np.float32)),     # [n_e, edge_feat_dim]
    faces=torch.from_numpy(D["face_features"].astype(np.float32)),     # [n_f, face_feat_dim]
    edge_to_vertex=torch.from_numpy(D["edge_to_vertex"].astype(np.int64)), # [2, n_e]
    face_to_edge=torch.from_numpy(D["face_to_edge"][::-1].astype(np.int64)),   # [2, n_f]
    face_to_face=torch.from_numpy(D["face_to_face"].astype(np.int64)),     # [2, M]
)


encoder = CustomBRepEncoder(
    v_in_width=data.vertices.shape[1],             
    e_in_width=data.edges.shape[1], 
    f_in_width=data.faces.shape[1], 
    out_width=64,
    num_layers=2,
    use_attention=True
).eval()

# проброс признаков
with torch.no_grad():
    emb = encoder(data)

print("OK. encoder output shape:", tuple(emb.shape) if isinstance(emb, torch.Tensor) else type(emb))

decoder = ConditionalDecoder(latent_size=64, hidden_dims=[1024, 1024, 1024, 1024])

data_sdf = np.load(npz_sdf_files[0])
# uv_faces: (9, 500, 2), dtype=float32, min=-0.10000000149011612, max=1.100000023841858
# sdf_faces: (9, 500), dtype=float32, min=-0.39380156993865967, max=0.4913385510444641

sdf_uv = torch.from_numpy(data_sdf["uv_faces"].astype(np.float32))        # [n_faces, n_samples, 2]
sdf_vals = torch.from_numpy(data_sdf["sdf_faces"].astype(np.float32))      # [n_faces, n_samples]
print(f"emb shape: {emb.shape}, sdf_uv shape: {sdf_uv.shape}, sdf_vals shape: {sdf_vals.shape}")
with torch.no_grad():
    pred_sdf = decoder(sdf_uv[0], emb[0])   # [n_faces, n_samples]

print("OK. decoder output shape:", tuple(pred_sdf.shape) if isinstance(pred_sdf, torch.Tensor) else type(pred_sdf))


In [None]:
import numpy as np
import torch

BREPNET_NPZ_DIR = PROCESSED_DATA_DIR / "dataset_129" / "features" / "brep"
SDF_NPZ_DIR = PROCESSED_DATA_DIR / "dataset_129" / "features" / "all_sdf_with_normals"
DT_PATH = PROCESSED_DATA_DIR / "dataset_129" / "preprocessed_data.pt"

brep_files = {p.name: p for p in BREPNET_NPZ_DIR.glob("*.npz")}
sdf_files = {p.name: p for p in SDF_NPZ_DIR.glob("*.npz")}

common_names = sorted(set(brep_files.keys()) & set(sdf_files.keys()))

preprocessed_data = []
for fname in common_names:
    D = np.load(brep_files[fname])
    S = np.load(sdf_files[fname])

    # Признаки для encoder
    data = SimpleNamespace(
        vertices=torch.from_numpy(D["vertex"].astype(np.float32)),
        edges=torch.from_numpy(D["edge_features"].astype(np.float32)),
        faces=torch.from_numpy(D["face_features"].astype(np.float32)),
        edge_to_vertex=torch.from_numpy(D["edge_to_vertex"].astype(np.int64)),
        face_to_edge=torch.from_numpy(D["face_to_edge"][::-1].astype(np.int64)),
        face_to_face=torch.from_numpy(D["face_to_face"].astype(np.int64)),
    )
    # Получаем эмбеддинг
    with torch.no_grad():
        emb = encoder(data)
        if emb.ndim == 2:
            emb = emb.mean(dim=0)

    # SDF-выборки
    sampled_points = torch.from_numpy(S["uv_faces"].astype(np.float32))  # [n, 2]
    sampled_sdf = torch.from_numpy(S["sdf_faces"].astype(np.float32))    # [n]

    preprocessed_data.append((fname, emb, sampled_points, sampled_sdf))

# Сохраняем датасет для обучения
torch.save(preprocessed_data, DT_PATH)
print(f"Датасет собран: {len(preprocessed_data)} моделей")

In [None]:
from src.modeling.SSL_BrepNet.dataset import BrepNetDataset
STATS_BREPNET = PROCESSED_DATA_DIR / "dataset_129" / "dataset_brepnet_stats.json"
SDF_NPZ_DIR = PROCESSED_DATA_DIR / "dataset_129" / "features" / "all_sdf_with_normals"
BREP_NPZ_DIR = PROCESSED_DATA_DIR / "dataset_129" / "features" / "brep"

train_dataset = BrepNetDataset(STATS_BREPNET, BREP_NPZ_DIR, SDF_NPZ_DIR, split="training_set")
val_dataset = BrepNetDataset(STATS_BREPNET, BREP_NPZ_DIR, SDF_NPZ_DIR, split="validation_set")
test_dataset = BrepNetDataset(STATS_BREPNET, BREP_NPZ_DIR, SDF_NPZ_DIR, split="test_set")

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}, Test samples: {len(test_dataset)}")
print(f"Example data keys: {list(vars(train_dataset[0]).keys())}")


In [None]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from src.modeling.SSL_BrepNet.sdf_compute import SDFComputer


class BRepAutoEncoderModule(pl.LightningModule):
    def __init__(self, encoder, decoder, lr=1e-3):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.lr = lr
        self.save_hyperparameters(ignore=['encoder', 'decoder'])

    def compute_loss(self, predicted, target_xyz, target_sdf):
        pred_xyz, pred_sdf = predicted[:, :3], predicted[:, 3]
        xyz_loss = torch.nn.functional.mse_loss(pred_xyz, target_xyz)
        sdf_loss = torch.nn.functional.mse_loss(pred_sdf, target_sdf)
        return xyz_loss + sdf_loss

    def training_step(self, batch, batch_idx):
        device = self.device if hasattr(self, "device") else self.encoder.parameters().__next__().device
        for key, value in vars(batch).items():
            if isinstance(value, torch.Tensor):
                setattr(batch, key, value.to(device))

        data = batch
        embedding = self.encoder(data)
        sampled_points = data.sdf_uv.squeeze(0)
        sampled_sdf = data.sdf_vals.squeeze(0)

        total_loss = 0.0
        total_sdf_loss = 0.0
        total_xyz_loss = 0.0

        for i in range(sampled_points.shape[0]):
            uv = sampled_points[i]
            sdf = sampled_sdf[i]
            emb_face = embedding[i]
            # Исправление: добавляем batch-ось если нужно
            if uv.ndim == 1:
                uv = uv.unsqueeze(0)
            predicted = self.decoder(uv, emb_face)
            target_xyz = self.compute_xyz_from_uv(uv).float()
            pred_xyz, pred_sdf = predicted[:, :3], predicted[:, 3]

            print("uv:", uv)
            print("sdf:", sdf)
            print("emb_face:", emb_face)
            print("predicted:", predicted)
            print("target_xyz:", target_xyz)

            if torch.isnan(pred_xyz).any() or torch.isnan(pred_sdf).any() or torch.isnan(target_xyz).any() or torch.isnan(sdf).any():
                print("NaN detected in prediction or target!")
            if torch.isinf(pred_xyz).any() or torch.isinf(pred_sdf).any() or torch.isinf(target_xyz).any() or torch.isinf(sdf).any():
                print("Inf detected in prediction or target!")

            xyz_loss = torch.nn.functional.mse_loss(pred_xyz, target_xyz)
            sdf_loss = torch.nn.functional.mse_loss(pred_sdf, sdf)
            total_loss += xyz_loss + sdf_loss
            total_xyz_loss += xyz_loss
            total_sdf_loss += sdf_loss

        total_loss = total_loss / sampled_points.shape[0]
        total_xyz_loss = total_xyz_loss / sampled_points.shape[0]
        total_sdf_loss = total_sdf_loss / sampled_points.shape[0]

        self.log('train_loss', total_loss, batch_size=1)
        self.log('train_xyz_loss', total_xyz_loss, batch_size=1)
        self.log('train_sdf_loss', total_sdf_loss, batch_size=1)
        return total_loss
    

    def validation_step(self, batch, batch_idx):
        device = self.device if hasattr(self, "device") else self.encoder.parameters().__next__().device
        for key, value in vars(batch).items():
            if isinstance(value, torch.Tensor):
                setattr(batch, key, value.to(device))

        data = batch
        embedding = self.encoder(data)
        sampled_points = data.sdf_uv.squeeze(0)
        sampled_sdf = data.sdf_vals.squeeze(0)

        total_loss = 0.0
        total_sdf_loss = 0.0
        total_xyz_loss = 0.0

        for i in range(sampled_points.shape[0]):
            uv = sampled_points[i]
            sdf = sampled_sdf[i]
            emb_face = embedding[i]
            if uv.ndim == 1:
                uv = uv.unsqueeze(0)
            predicted = self.decoder(uv, emb_face)
            target_xyz = self.compute_xyz_from_uv(uv).float()
            pred_xyz, pred_sdf = predicted[:, :3], predicted[:, 3]
            xyz_loss = torch.nn.functional.mse_loss(pred_xyz, target_xyz)
            sdf_loss = torch.nn.functional.mse_loss(pred_sdf, sdf)
            total_loss += xyz_loss + sdf_loss
            total_xyz_loss += xyz_loss
            total_sdf_loss += sdf_loss

        total_loss = total_loss / sampled_points.shape[0]
        total_xyz_loss = total_xyz_loss / sampled_points.shape[0]
        total_sdf_loss = total_sdf_loss / sampled_points.shape[0]

        self.log('val_loss', total_loss, prog_bar=True, batch_size=1)
        self.log('val_xyz_loss', total_xyz_loss, prog_bar=True, batch_size=1)
        self.log('val_sdf_loss', total_sdf_loss, prog_bar=True, batch_size=1)
        return total_loss

    def compute_xyz_from_uv(self, uv_coords):
        """ Простейшая проекция UV в 3D пространство (z=0).
            спиздили у китайцев
        """
        x = uv_coords[:, 0]  # x координата
        y = uv_coords[:, 1]  # y координата
        z = torch.zeros_like(x)  # z координата

        return torch.stack([x, y, z], dim=-1)

    def configure_optimizers(self):
        return torch.optim.Adam(list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=self.lr)
    

In [None]:
from src.modeling.SSL_BrepNet.model.encoder import CustomBRepEncoder
from src.modeling.SSL_BrepNet.model.decoder import ConditionalDecoder
from pytorch_lightning.loggers import CSVLogger

encoder = CustomBRepEncoder(
    v_in_width=train_dataset[0].vertices.shape[1],
    e_in_width=train_dataset[0].edges.shape[1],
    f_in_width=train_dataset[0].faces.shape[1],
    out_width=64,
    num_layers=2,
    use_attention=True
)
decoder = ConditionalDecoder(latent_size=64, hidden_dims=[1024, 1024, 1024, 1024])
module = BRepAutoEncoderModule(encoder, decoder)

def brepnet_collate(batch):
    out = {}
    for key in vars(batch[0]).keys():
        values = [getattr(b, key) for b in batch]
        if key in ["edge_to_vertex", "face_to_edge", "face_to_face"]:
            out[key] = values[0]
        elif isinstance(values[0], torch.Tensor):
            # Если размерности совпадают — stack, иначе — список
            shapes = [v.shape for v in values]
            if all(s == shapes[0] for s in shapes):
                stacked = torch.stack(values)
                if stacked.shape[0] == 1 and key in ["vertices", "edges", "faces"]:
                    out[key] = stacked.squeeze(0)
                else:
                    out[key] = stacked
            else:
                out[key] = values  # список тензоров разной длины
        elif isinstance(values[0], np.ndarray):
            shapes = [v.shape for v in values]
            if all(s == shapes[0] for s in shapes):
                out[key] = np.stack(values)
            else:
                out[key] = values
        else:
            out[key] = values
    return SimpleNamespace(**out)

val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=brepnet_collate)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=brepnet_collate)
csv_logger = CSVLogger(save_dir=REPORTS_DIR, name="ssl_autoencoder_logs")
    
trainer = pl.Trainer(max_epochs=1, logger=[csv_logger])
trainer.fit(module, train_loader, val_loader)

In [49]:
import torch
from src.modeling.SSL_BrepNet.model.encoder import CustomBRepEncoder
from src.modeling.SSL_BrepNet.model.decoder import ConditionalDecoder

# Пути к сохранённым весам
ENCODER_PATH = MODELS_DIR / "epoch=4-step=520.ckpt"

ckpt = torch.load(ENCODER_PATH, map_location="cpu")
encoder_state_dict = ckpt["state_dict"]
# Оставляем только ключи, относящиеся к encoder
encoder_state_dict = {k.replace("encoder.", ""): v for k, v in encoder_state_dict.items() if k.startswith("encoder.")}

encoder = CustomBRepEncoder(
    v_in_width=train_dataset[0].vertices.shape[1],
    e_in_width=train_dataset[0].edges.shape[1],
    f_in_width=train_dataset[0].faces.shape[1],
    out_width=64,
    num_layers=2,
    use_attention=True
)
encoder.load_state_dict(encoder_state_dict)
encoder.eval()

# decoder = ConditionalDecoder(latent_size=64, hidden_dims=[1024, 1024, 1024, 1024])
# decoder.load_state_dict(torch.load(DECODER_PATH, map_location="cpu"))
# decoder.eval()

# Поиск: прогон нового образца
def search(sample):
    with torch.no_grad():
        emb = encoder(sample)
        results = []
        for i in range(sample.sdf_uv.shape[0]):
            uv = sample.sdf_uv[i]
            emb_face = emb[i]
            pred = decoder(uv, emb_face)
            results.append(pred)
        return results

# Пример использования
sample = train_dataset[0]
results = search(sample)
print("Search results:", results)

Search results: [tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]]), tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]]), tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]]), tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]]), tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        ...,
        [nan, nan, nan, nan],
        [nan, nan, nan, nan],
        [nan, nan, nan, nan]])]
