# GNN Model

In [None]:
# Imports
import io
import sys
import toml
import pprint

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from jarvis.db.figshare import data


import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import importlib

## Get Model Configuration

In [None]:
# Configuration
CONFIG_PATH = "config.toml"
config = toml.load(CONFIG_PATH)

# Pretty print configuration
print("Project Configuration:")
pprint.pprint(config)

# Set up system path
SYS_PATH = config.get('system', {}).get('path', './')
sys.path.append(SYS_PATH) # .../code/jarvis/


Project Configuration:
{'data': {'dataset_name': 'dft_3d', 'store_dir': '/shared/data/jarvis'},
 'features': {'bag_of_elements': True,
              'derived': ['eps_mean', 'eps_std'],
              'use_columns': ['ehull',
                              'formation_energy_peratom',
                              'avg_elec_mass',
                              'avg_hole_mass',
                              'effective_masses_297K',
                              'epsx',
                              'epsy',
                              'epsz',
                              'natoms']},
 'filters': {'bandgap_column': 'optb88vdw_bandgap',
             'max_eps': 10.0,
             'min_eps': 1.0,
             'semiconductor_max': 4.0,
             'semiconductor_min': 0.5,
             'toxic_elements': ['Pb', 'Cd', 'Hg', 'As', 'Se'],
             'transparent_min': 3.0},
 'known': {'transparent_formulas': ['In2O3',
                                    'ZnO',
                                   

## Setup Logger and Import Data

In [None]:
# Custom Imports and Configurations
from jarvis_utils import load_or_fetch_dataset
from logger_utils import setup_logger, flush_logger
from filter_utils import apply_filters
from featurizer import Featurizer

# Setup logger
logger = setup_logger(config)

logger.info("Project configuration loaded.")
logger.info(f"Dataset: {config['data']['dataset_name']}")
logger.info(f"Store directory: {config['data']['store_dir']}")

# Load dataset
df = load_or_fetch_dataset(config["data"]["dataset_name"], data, config["data"]["store_dir"])
logger.info(f"Dataset shape: {df.shape}")

2025-12-03 21:16:28,184 - jarvis_project - INFO - Project configuration loaded.
2025-12-03 21:16:28,184 - jarvis_project - INFO - Dataset: dft_3d
2025-12-03 21:16:28,185 - jarvis_project - INFO - Store directory: /shared/data/jarvis
2025-12-03 21:16:29,391 - jarvis_project - INFO - Dataset shape: (75993, 64)


Dataset shape: (75993, 64)


In [None]:
features = df.columns.tolist()
logger.info(f"Features: {features}")

2025-12-03 21:16:29,400 - jarvis_project - INFO - Features: ['jid', 'spg_number', 'spg_symbol', 'formula', 'formation_energy_peratom', 'func', 'optb88vdw_bandgap', 'atoms', 'slme', 'magmom_oszicar', 'spillage', 'elastic_tensor', 'effective_masses_300K', 'kpoint_length_unit', 'maxdiff_mesh', 'maxdiff_bz', 'encut', 'optb88vdw_total_energy', 'epsx', 'epsy', 'epsz', 'mepsx', 'mepsy', 'mepsz', 'modes', 'magmom_outcar', 'max_efg', 'avg_elec_mass', 'avg_hole_mass', 'icsd', 'dfpt_piezo_max_eij', 'dfpt_piezo_max_dij', 'dfpt_piezo_max_dielectric', 'dfpt_piezo_max_dielectric_electronic', 'dfpt_piezo_max_dielectric_ionic', 'max_ir_mode', 'min_ir_mode', 'n-Seebeck', 'p-Seebeck', 'n-powerfact', 'p-powerfact', 'ncond', 'pcond', 'nkappa', 'pkappa', 'ehull', 'Tc_supercon', 'dimensionality', 'efg', 'xml_data_link', 'typ', 'exfoliation_energy', 'spg', 'crys', 'density', 'poisson', 'raw_files', 'nat', 'bulk_modulus_kv', 'shear_modulus_gv', 'mbj_bandgap', 'hse_gap', 'reference', 'search']


## Configuration and feature schemas

In [None]:
num_cols = [
        "formation_energy_peratom", "optb88vdw_bandgap",
        "slme", "magmom_oszicar", "spillage", "kpoint_length_unit",
        "optb88vdw_total_energy", "epsx", "epsy", "epsz", "density",
        "poisson", "nat", "bulk_modulus_kv", "shear_modulus_gv",
        "mbj_bandgap", "hse_gap", "ehull", "Tc_supercon"
        # add other numeric features as needed
    ]

# Configuration for feature dimensions and embeddings
config = {
    "devices": {"model": "cuda"},
    "threshold": 0.94,  # deployment cutoff, if you want to keep it here

    "numeric_cols": num_cols,  # list of numeric feature column names

    # Categorical vocab sizes (from our dataset preprocessing)
    "categorical": {
        "spg_number": {"vocab_size": 214, "embed_dim": 50},
        "func": {"vocab_size": 2, "embed_dim": 1},
        "dimensionality": {"vocab_size": 8, "embed_dim": 4},
        "typ": {"vocab_size": 2, "embed_dim": 1},
        "crys": {"vocab_size": 8, "embed_dim": 4},
    },

    # Formula embedding
    "formula": {"num_elements": 89, "embed_dim": 32},

    # MLP head sizes
    "mlp": {"hidden1": 128, "hidden2": 64, "dropout": 0.3},

    # Graph settings
    "graph": {
        "atom_embed_dim": 64,   # node embedding size
        "edge_feat_dim": 4,     # e.g., distance, 1/r, angle or dummy padding
        "radius": 5.0,          # neighbor cutoff (Å)
        "max_neighbors": 24     # cap neighbors per atom (optional)
    }
}


## Add Candidate

In [None]:
def add_candidate_column(df: pd.DataFrame, config: dict) -> pd.DataFrame:
    df = df.copy()

    bandgap_col   = config["filters"]["bandgap_column"]
    sem_min       = config["filters"]["semiconductor_min"]
    sem_max       = config["filters"]["semiconductor_max"]
    trans_min     = config["filters"]["transparent_min"]
    toxic_elements = config["filters"]["toxic_elements"]

    df[bandgap_col] = pd.to_numeric(df[bandgap_col], errors="coerce")
    in_semiconductor_range = df[bandgap_col].between(sem_min, sem_max)
    is_transparent = df[bandgap_col] > trans_min

    if "ehull" in df.columns:
        df["ehull"] = pd.to_numeric(df["ehull"], errors="coerce")
        is_stable = df["ehull"] < 0.1
    else:
        is_stable = True

    if "formula" in df.columns:
        tokens = df["formula"].fillna("").astype(str).str.findall(r"[A-Z][a-z]?")
        has_toxic = tokens.apply(lambda t: any(el in t for el in toxic_elements))
    else:
        has_toxic = False

    df["is_candidate"] = (
        in_semiconductor_range &
        is_transparent &
        is_stable &
        (has_toxic == False)
    ).astype(int)

    return df

config_cand = {
    "filters": {
        "bandgap_column": "optb88vdw_bandgap",
        "semiconductor_min": 0.5,
        "semiconductor_max": 5.0,
        "transparent_min": 2.5,
        "toxic_elements": ["Pb", "Cd", "As", "Hg"]
    }
}

df = add_candidate_column(df, config_cand)
df = df.rename(columns={"is_candidate": "target"})

## Build atomic graphs from atoms_obj

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FormulaEmbedder(nn.Module):
    def __init__(self, num_elements, embed_dim):
        super().__init__()
        self.emb = nn.Embedding(num_elements, embed_dim)

    def forward(self, element_counts):  # dict-like or tensor of counts
        # element_counts: tensor [num_elements] with counts for the formula
        # Create indices for elements present
        idx = torch.arange(element_counts.shape[-1], device=element_counts.device)
        weights = element_counts.float().unsqueeze(-1)       # [E,1]
        vecs = self.emb(idx)                                 # [E, D]
        # Weighted sum of element embeddings (composition-level representation)
        return (vecs * weights).sum(dim=0)                   # [D]


class SimpleGraphEncoder(nn.Module):
    """
    Lightweight message-passing encoder:
    - node embedding from element indices
    - edge conditioning from distances/unit vectors
    - 2-layer message passing + global pooling
    """
    def __init__(self, num_elements, atom_embed_dim, edge_feat_dim):
        super().__init__()
        self.atom_emb = nn.Embedding(num_elements, atom_embed_dim)

        self.msg1 = nn.Linear(atom_embed_dim + edge_feat_dim, atom_embed_dim)
        self.msg2 = nn.Linear(atom_embed_dim + edge_feat_dim, atom_embed_dim)
        self.node_up1 = nn.Linear(atom_embed_dim, atom_embed_dim)
        self.node_up2 = nn.Linear(atom_embed_dim, atom_embed_dim)

    def message_pass(self, x, edge_index, edge_attr):
        if edge_index.numel() == 0:
            return x
        src, dst = edge_index
        x_src = x[src]                             # [E, D]
        m = torch.cat([x_src, edge_attr], dim=-1)  # [E, D+F]
        m = F.relu(self.msg1(m))
        # Aggregate messages per destination node (sum)
        agg = torch.zeros_like(x)
        agg.index_add_(0, dst, m)
        x = F.relu(self.node_up1(x + agg))
        # second round
        x_src = x[src]
        m = torch.cat([x_src, edge_attr], dim=-1)
        m = F.relu(self.msg2(m))
        agg = torch.zeros_like(x)
        agg.index_add_(0, dst, m)
        x = F.relu(self.node_up2(x + agg))
        return x

    def forward(self, graph):
        # graph: dict with keys x (indices), edge_index, edge_attr
        x = self.atom_emb(graph["x"])                 # [N, D]
        x = self.message_pass(x, graph["edge_index"], graph["edge_attr"])
        # global pooling (mean)
        if x.shape[0] == 0:
            return torch.zeros(self.atom_emb.embedding_dim, device=x.device)
        return x.mean(dim=0)                          # [D]


class CategoricalEmbeddings(nn.Module):
    def __init__(self, cat_cfg):
        super().__init__()
        self.embeddings = nn.ModuleDict({
            name: nn.Embedding(params["vocab_size"], params["embed_dim"])
            for name, params in cat_cfg.items()
        })

    def forward(self, X_cat_dict):
        # X_cat_dict: dict of tensors (Long) per categorical feature
        embs = [self.embeddings[name](X_cat_dict[name]) for name in self.embeddings]
        # If inputs are batch-sized, concatenate along last dim
        return torch.cat([e if e.dim() > 1 else e.unsqueeze(0) for e in embs], dim=-1)


class CandidateNetMultimodal(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.cfg = config

        # Numeric projection (optional normalization outside)
        self.num_proj = nn.Identity()

        # Categorical embeddings
        self.cat_emb = CategoricalEmbeddings(config["categorical"])
        cat_total_dim = sum([p["embed_dim"] for p in config["categorical"].values()])

        # Formula embedding
        self.formula_emb = FormulaEmbedder(
            num_elements=config["formula"]["num_elements"],
            embed_dim=config["formula"]["embed_dim"]
        )

        # Graph encoder
        self.graph_enc = SimpleGraphEncoder(
            num_elements=config["formula"]["num_elements"],
            atom_embed_dim=config["graph"]["atom_embed_dim"],
            edge_feat_dim=config["graph"]["edge_feat_dim"]
        )

        # Fusion MLP
        in_dim = len(config["numeric_cols"]) + cat_total_dim + config["formula"]["embed_dim"] + config["graph"]["atom_embed_dim"]
        self.fc = nn.Sequential(
            nn.Linear(in_dim, config["mlp"]["hidden1"]),
            nn.BatchNorm1d(config["mlp"]["hidden1"]),
            nn.ReLU(),
            nn.Dropout(p=config["mlp"]["dropout"]),
            nn.Linear(config["mlp"]["hidden1"], config["mlp"]["hidden2"]),
            nn.ReLU(),
            nn.Linear(config["mlp"]["hidden2"], 1)
        )

    def forward(self, X_num, X_cat_dict, formula_counts, graph):
        # Numeric
        x_num = self.num_proj(X_num.float())  # [B, num_numeric] or [num_numeric]

        # Categorical
        x_cat = self.cat_emb(X_cat_dict)      # [B, cat_dim] or [cat_dim]

        # Formula (composition-level; if batch, apply per-sample)
        if formula_counts.dim() == 2:  # [B, E]
            x_formula = torch.stack([self.formula_emb(fc) for fc in formula_counts])
        else:  # [E]
            x_formula = self.formula_emb(formula_counts).unsqueeze(0)

        # Graph (per-sample; if batching graphs, you’d pool per sample)
        x_graph = self.graph_enc(graph)        # [D_graph]
        if x_graph.dim() == 1:
            x_graph = x_graph.unsqueeze(0)

        # Concatenate
        x = torch.cat([x_num, x_cat, x_formula, x_graph], dim=-1)

        # Logits
        return self.fc(x).squeeze(-1)          # [B]


In [None]:
import re

def parse_formula(formula):
    pattern = r"([A-Z][a-z]?)(\d*)"
    matches = re.findall(pattern, formula)
    comp = {}
    for elem, count in matches:
        comp[elem] = comp.get(elem, 0) + (int(count) if count else 1)
    return comp


# Build element vocabulary from the whole DataFrame
all_elements = set()
for f in df["formula"]:
    comp = parse_formula(f)
    all_elements.update(comp.keys())

element_vocab = {el: i for i, el in enumerate(sorted(all_elements))}

In [None]:
import torch
import numpy as np

# Simple element-to-index mapping for formula embedding
# (Assuming element_vocab is a dict {symbol: idx})
element_vocab_inv = {v: k for k, v in element_vocab.items()}  # if needed
num_elements = config["formula"]["num_elements"]

def element_symbol_to_index(symbol: str, vocab: dict) -> int:
    return vocab.get(symbol, vocab.get("UNK", 0))  # handle unknowns

def atoms_to_graph(atoms_obj, radius=5.0, max_neighbors=None, element_vocab=None, device="cpu"):
    """
    Build a graph from atoms_obj:
    - nodes: element indices -> embedded later
    - edges: undirected edges for neighbors within cutoff
    - edge_attr: simple geometric features (distance, 1/dist, normalized distance components)
    Returns dict with tensors suitable for batching.
    """
    # Fractional -> Cartesian using lattice
    cart_coords = np.array(atoms_obj.cart_coords)  # (N,3)
    species = atoms_obj.elements  # list of symbols length N
    N = len(species)

    # Node features: element indices
    node_idx = np.array([element_symbol_to_index(sym, element_vocab) for sym in species], dtype=np.int64)

    # Build neighbor edges (naive O(N^2) cutoff; replace with spatial search if needed)
    edge_src, edge_dst, edge_attr = [], [], []
    for i in range(N):
        for j in range(N):
            if i == j:
                continue
            rij = cart_coords[j] - cart_coords[i]
            dist = np.linalg.norm(rij)
            if dist <= radius:
                edge_src.append(i)
                edge_dst.append(j)
                # Edge features: [dist, 1/dist, dx/dist, dy/dist] (pad to edge_feat_dim)
                inv = 1.0 / dist if dist > 1e-8 else 0.0
                unit = rij / dist if dist > 1e-8 else np.zeros(3)
                feat = [dist, inv, unit[0], unit[1]]
                edge_attr.append(feat)

    # Optional: cap neighbors per node
    if max_neighbors is not None and len(edge_src) > 0:
        # Simple capping by sorting edges per src by distance
        capped_src, capped_dst, capped_attr = [], [], []
        edges_by_src = {}
        for s, d, a in zip(edge_src, edge_dst, edge_attr):
            edges_by_src.setdefault(s, []).append((d, a))
        for s, lst in edges_by_src.items():
            # sort by distance (edge_attr[0])
            lst_sorted = sorted(lst, key=lambda x: x[1][0])[:max_neighbors]
            for d, a in lst_sorted:
                capped_src.append(s)
                capped_dst.append(d)
                capped_attr.append(a)
        edge_src, edge_dst, edge_attr = capped_src, capped_dst, capped_attr

    # Convert to tensors
    node_idx = torch.tensor(node_idx, dtype=torch.long, device=device)           # (N,)
    x = node_idx  # node "features" are just indices; embedded later

    if len(edge_src) == 0:
        # Handle isolated atoms (rare for realistic crystals)
        edge_index = torch.empty((2, 0), dtype=torch.long, device=device)
        edge_attr = torch.empty((0, config["graph"]["edge_feat_dim"]), dtype=torch.float32, device=device)
    else:
        edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long, device=device)  # (2, E)
        edge_attr = torch.tensor(edge_attr, dtype=torch.float32, device=device)           # (E, F)

    return {"x": x, "edge_index": edge_index, "edge_attr": edge_attr}


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FormulaEmbedder(nn.Module):
    def __init__(self, num_elements, embed_dim):
        super().__init__()
        self.emb = nn.Embedding(num_elements, embed_dim)

    def forward(self, element_counts):  # dict-like or tensor of counts
        # element_counts: tensor [num_elements] with counts for the formula
        # Create indices for elements present
        idx = torch.arange(element_counts.shape[-1], device=element_counts.device)
        weights = element_counts.float().unsqueeze(-1)       # [E,1]
        vecs = self.emb(idx)                                 # [E, D]
        # Weighted sum of element embeddings (composition-level representation)
        return (vecs * weights).sum(dim=0)                   # [D]


class SimpleGraphEncoder(nn.Module):
    """
    Lightweight message-passing encoder:
    - node embedding from element indices
    - edge conditioning from distances/unit vectors
    - 2-layer message passing + global pooling
    """
    def __init__(self, num_elements, atom_embed_dim, edge_feat_dim):
        super().__init__()
        self.atom_emb = nn.Embedding(num_elements, atom_embed_dim)

        self.msg1 = nn.Linear(atom_embed_dim + edge_feat_dim, atom_embed_dim)
        self.msg2 = nn.Linear(atom_embed_dim + edge_feat_dim, atom_embed_dim)
        self.node_up1 = nn.Linear(atom_embed_dim, atom_embed_dim)
        self.node_up2 = nn.Linear(atom_embed_dim, atom_embed_dim)

    def message_pass(self, x, edge_index, edge_attr):
        if edge_index.numel() == 0:
            return x
        src, dst = edge_index
        x_src = x[src]                             # [E, D]
        m = torch.cat([x_src, edge_attr], dim=-1)  # [E, D+F]
        m = F.relu(self.msg1(m))
        # Aggregate messages per destination node (sum)
        agg = torch.zeros_like(x)
        agg.index_add_(0, dst, m)
        x = F.relu(self.node_up1(x + agg))
        # second round
        x_src = x[src]
        m = torch.cat([x_src, edge_attr], dim=-1)
        m = F.relu(self.msg2(m))
        agg = torch.zeros_like(x)
        agg.index_add_(0, dst, m)
        x = F.relu(self.node_up2(x + agg))
        return x

    def forward(self, graph):
        # graph: dict with keys x (indices), edge_index, edge_attr
        x = self.atom_emb(graph["x"])                 # [N, D]
        x = self.message_pass(x, graph["edge_index"], graph["edge_attr"])
        # global pooling (mean)
        if x.shape[0] == 0:
            return torch.zeros(self.atom_emb.embedding_dim, device=x.device)
        return x.mean(dim=0)                          # [D]


class CategoricalEmbeddings(nn.Module):
    def __init__(self, cat_cfg):
        super().__init__()
        self.embeddings = nn.ModuleDict({
            name: nn.Embedding(params["vocab_size"], params["embed_dim"])
            for name, params in cat_cfg.items()
        })

    def forward(self, X_cat_dict):
        # X_cat_dict: dict of tensors (Long) per categorical feature
        embs = [self.embeddings[name](X_cat_dict[name]) for name in self.embeddings]
        # If inputs are batch-sized, concatenate along last dim
        return torch.cat([e if e.dim() > 1 else e.unsqueeze(0) for e in embs], dim=-1)


class CandidateNetMultimodal(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.cfg = config

        # Numeric projection (optional normalization outside)
        self.num_proj = nn.Identity()

        # Categorical embeddings
        self.cat_emb = CategoricalEmbeddings(config["categorical"])
        cat_total_dim = sum([p["embed_dim"] for p in config["categorical"].values()])

        # Formula embedding
        self.formula_emb = FormulaEmbedder(
            num_elements=config["formula"]["num_elements"],
            embed_dim=config["formula"]["embed_dim"]
        )

        # Graph encoder
        self.graph_enc = SimpleGraphEncoder(
            num_elements=config["formula"]["num_elements"],
            atom_embed_dim=config["graph"]["atom_embed_dim"],
            edge_feat_dim=config["graph"]["edge_feat_dim"]
        )

        # Fusion MLP
        in_dim = len(config["numeric_cols"]) + cat_total_dim + config["formula"]["embed_dim"] + config["graph"]["atom_embed_dim"]
        self.fc = nn.Sequential(
            nn.Linear(in_dim, config["mlp"]["hidden1"]),
            nn.BatchNorm1d(config["mlp"]["hidden1"]),
            nn.ReLU(),
            nn.Dropout(p=config["mlp"]["dropout"]),
            nn.Linear(config["mlp"]["hidden1"], config["mlp"]["hidden2"]),
            nn.ReLU(),
            nn.Linear(config["mlp"]["hidden2"], 1)
        )

    def forward(self, X_num, X_cat_dict, formula_counts, graph):
        # Numeric
        x_num = self.num_proj(X_num.float())  # [B, num_numeric] or [num_numeric]

        # Categorical
        x_cat = self.cat_emb(X_cat_dict)      # [B, cat_dim] or [cat_dim]

        # Formula (composition-level; if batch, apply per-sample)
        if formula_counts.dim() == 2:  # [B, E]
            x_formula = torch.stack([self.formula_emb(fc) for fc in formula_counts])
        else:  # [E]
            x_formula = self.formula_emb(formula_counts).unsqueeze(0)

        # Graph (per-sample; if batching graphs, you’d pool per sample)
        x_graph = self.graph_enc(graph)        # [D_graph]
        if x_graph.dim() == 1:
            x_graph = x_graph.unsqueeze(0)

        # Concatenate
        x = torch.cat([x_num, x_cat, x_formula, x_graph], dim=-1)

        # Logits
        return self.fc(x).squeeze(-1)          # [B]


In [None]:
# Example single-sample preparation (extend to batches)
device = torch.device(config["devices"]["model"])
radius = config["graph"]["radius"]
max_neighbors = config["graph"]["max_neighbors"]

# Prepare one sample (pseudo-code; wire this into your Dataset __getitem__)
def prepare_sample(row):
    # Numeric
    X_num = torch.tensor(row[config["numeric_cols"]].values, dtype=torch.float32, device=device).unsqueeze(0)

    # Categorical (Long tensors)
    X_cat_dict = {
        "spg_number": torch.tensor([row["spg_number_idx"]], dtype=torch.long, device=device),
        "func": torch.tensor([row["func_idx"]], dtype=torch.long, device=device),
        "dimensionality": torch.tensor([row["dimensionality_idx"]], dtype=torch.long, device=device),
        "typ": torch.tensor([row["typ_idx"]], dtype=torch.long, device=device),
        "crys": torch.tensor([row["crys_idx"]], dtype=torch.long, device=device),
    }

    # Formula counts vector [E]
    # Build from parsed composition or your existing formula pipeline
    counts = np.zeros(config["formula"]["num_elements"], dtype=np.float32)
    for elem, cnt in row["composition_counts"].items():  # e.g., {"Na":1, "I":1}
        idx = element_vocab[elem]
        counts[idx] = cnt
    formula_counts = torch.tensor(counts, dtype=torch.float32, device=device)

    # Graph from atoms_obj
    graph = atoms_to_graph(
        row["atoms_obj"],
        radius=radius,
        max_neighbors=max_neighbors,
        element_vocab=element_vocab,
        device=device
    )

    # Target
    y = torch.tensor([row["target"]], dtype=torch.float32, device=device)

    return X_num, X_cat_dict, formula_counts, graph, y

# Instantiate model
model = CandidateNetMultimodal(config).to(device)
pos_weight = torch.tensor([len(df[df.target==0]) / len(df[df.target==1])], device=device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# Example training step (single sample for clarity)
model.train()
X_num, X_cat_dict, formula_counts, graph, y = prepare_sample(df_candidates.iloc[0])
logits = model(X_num, X_cat_dict, formula_counts, graph)
loss = criterion(logits, y)
loss.backward()
optimizer.step()


TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.