In [None]:
import pandas as pd
import torch
from tqdm.auto import tqdm
import torch.nn.init as init
import torch.nn.functional as F
from torch_geometric.nn import HANConv, GATConv
from torch_geometric.data import HeteroData
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

from utils import load_embedding_model
from constants import FULL_DATASET_PATH, E5_LARGE_INSTRUCT_CONFIG_PATH, RANDOM_STATE

In [2]:
torch.manual_seed(RANDOM_STATE)
torch.cuda.manual_seed(RANDOM_STATE)

In [3]:
df = pd.read_csv(FULL_DATASET_PATH)

In [4]:
df.dropna(subset=["class"], inplace=True)

In [5]:
from sklearn.preprocessing import LabelEncoder

le_seg = LabelEncoder()
le_fam = LabelEncoder()
le_cls = LabelEncoder()

df["segment_encode"] = le_seg.fit_transform(df["segment"])
df["family_encode"]  = le_fam.fit_transform(df["family"])
df["class_encode"]   = le_cls.fit_transform(df["class"])

In [None]:
embed_model = load_embedding_model(E5_LARGE_INSTRUCT_CONFIG_PATH)

product_embeds = embed_model.get_embeddings(df["product_name"].tolist())
product_embeds = torch.tensor(product_embeds, dtype=torch.float32)  # shape [N_products, 1024]

In [42]:
segment_y = torch.tensor(df["segment_encode"].values, dtype=torch.long)  # shape [N_products]
family_y = torch.tensor(df["family_encode"].values, dtype=torch.long)  # shape [N_products]
class_y = torch.tensor(df["class_encode"].values, dtype=torch.long)  # shape [N_products]

In [43]:
idx = list(range(len(df)))
train_idx, test_idx = train_test_split(idx, test_size=0.2, random_state=42)

train_mask = torch.zeros(len(df), dtype=torch.bool)
test_mask  = torch.zeros(len(df), dtype=torch.bool)
train_mask[train_idx] = True
test_mask[test_idx]  = True

In [44]:
num_products = len(df)
num_segments = df["segment_encode"].nunique()
num_families = df["family_encode"].nunique()
num_classes  = df["class_encode"].nunique()

# --- Product to Segment edges ---
prod_to_seg_src = torch.arange(num_products, dtype=torch.long)
prod_to_seg_dst = torch.tensor(df["segment_encode"].values, dtype=torch.long)
prod_to_seg_edge_index = torch.stack([prod_to_seg_src, prod_to_seg_dst], dim=0)

# --- Segment to Family edges ---
seg_to_fam_src = torch.tensor(df["segment_encode"].values, dtype=torch.long)
seg_to_fam_dst = torch.tensor(df["family_encode"].values, dtype=torch.long)
seg_to_fam_edge_index = torch.stack([seg_to_fam_src, seg_to_fam_dst], dim=0)

# --- Family to Class edges ---
fam_to_cls_src = torch.tensor(df["family_encode"].values, dtype=torch.long)
fam_to_cls_dst = torch.tensor(df["class_encode"].values, dtype=torch.long)
fam_to_cls_edge_index = torch.stack([fam_to_cls_src, fam_to_cls_dst], dim=0)


In [45]:
data = HeteroData()

# Product nodes
data["product"].x = product_embeds
# data["product"].segment = segment_y
# data["product"].family = family_y
# data["product"]._class = class_y
data["product"].train_mask = train_mask
data["product"].test_mask = test_mask
data["product"].y = class_y

# Segment / Family / Class nodes don’t have features (yet)
# They’ll get embeddings via the NodeFeatureEncoder
data["segment"].num_nodes = num_segments
data["family"].num_nodes  = num_families
data["class"].num_nodes   = num_classes

# Edges
data["product", "to", "segment"].edge_index = prod_to_seg_edge_index
data["segment", "to", "family"].edge_index  = seg_to_fam_edge_index
data["family", "to", "class"].edge_index    = fam_to_cls_edge_index

# Reverse edges
data["segment", "rev_to", "product"].edge_index = prod_to_seg_edge_index.flip(0)
data["family", "rev_to", "segment"].edge_index  = seg_to_fam_edge_index.flip(0)
data["class", "rev_to", "family"].edge_index    = fam_to_cls_edge_index.flip(0)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, SAGEConv
from typing import Optional, Dict

# -------------------------
# NodeFeatureEncoder (same as before)
# -------------------------
class NodeFeatureEncoder(nn.Module):
    def __init__(
        self,
        prod_in_dim: int,
        hidden_dim: int,
        num_families: Optional[int] = None,
        num_segments: Optional[int] = None,
        num_classes: Optional[int] = None,
        pretrained_category_embeddings: Optional[Dict[str, torch.Tensor]] = None,
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.product_proj = nn.Linear(prod_in_dim, hidden_dim)
        self.pretrained = pretrained_category_embeddings or {}

        if 'segment' in self.pretrained:
            self.register_buffer('segment_pre', self.pretrained['segment'])
            self.segment_embedding = None
            seg_in_dim = self.pretrained['segment'].shape[1]
        else:
            assert num_segments is not None, "num_segments required if no pretrained segment embeddings"
            seg_in_dim = hidden_dim
            self.segment_embedding = nn.Embedding(num_segments, seg_in_dim)

        if 'family' in self.pretrained:
            self.register_buffer('family_pre', self.pretrained['family'])
            self.family_embedding = None
            fam_in_dim = self.pretrained['family'].shape[1]
        else:
            assert num_families is not None, "num_families required if no pretrained family embeddings"
            fam_in_dim = hidden_dim
            self.family_embedding = nn.Embedding(num_families, fam_in_dim)

        if 'class' in self.pretrained:
            self.register_buffer('class_pre', self.pretrained['class'])
            self.class_embedding = None
            class_in_dim = self.pretrained['class'].shape[1]
        else:
            assert num_classes is not None, "num_classes required if no pretrained class embeddings"
            class_in_dim = hidden_dim
            self.class_embedding = nn.Embedding(num_classes, class_in_dim)

        # project category dims -> hidden_dim
        self.segment_proj = nn.Linear(seg_in_dim, hidden_dim)
        self.family_proj  = nn.Linear(fam_in_dim, hidden_dim)
        self.class_proj   = nn.Linear(class_in_dim, hidden_dim)

    def forward(self, product_x, segment_idx_or_none=None, family_idx_or_none=None, class_idx_or_none=None):
        out = {}
        out['product'] = self.product_proj(product_x)

        if hasattr(self, 'segment_pre') and self.segment_pre is not None:
            seg_feats = self.segment_pre
        else:
            seg_feats = self.segment_embedding(segment_idx_or_none)
        out['segment'] = self.segment_proj(seg_feats)

        if hasattr(self, 'family_pre') and self.family_pre is not None:
            fam_feats = self.family_pre
        else:
            fam_feats = self.family_embedding(family_idx_or_none)
        out['family'] = self.family_proj(fam_feats)

        if hasattr(self, 'class_pre') and self.class_pre is not None:
            class_feats = self.class_pre
        else:
            class_feats = self.class_embedding(class_idx_or_none)
        out['class'] = self.class_proj(class_feats)

        return out


# -------------------------
# Hetero GNN with updated relations
# -------------------------
class HeteroSAGENet(nn.Module):
    def __init__(self, hidden_dim: int, out_classes: int, num_layers: int = 2, dropout: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout

        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                # forward relations (updated)
                ('product', 'to', 'segment'): SAGEConv((-1, -1), hidden_dim),
                ('segment', 'to', 'family'): SAGEConv((-1, -1), hidden_dim),
                ('family', 'to', 'class'): SAGEConv((-1, -1), hidden_dim),
                # reverse relations
                ('segment', 'rev_to', 'product'): SAGEConv((-1, -1), hidden_dim),
                ('family', 'rev_to', 'segment'): SAGEConv((-1, -1), hidden_dim),
                ('class', 'rev_to', 'family'): SAGEConv((-1, -1), hidden_dim),
            }, aggr='mean')
            self.convs.append(conv)

        self.bn_product = nn.BatchNorm1d(hidden_dim)
        self.bn_segment = nn.BatchNorm1d(hidden_dim)
        self.bn_family  = nn.BatchNorm1d(hidden_dim)
        self.bn_class   = nn.BatchNorm1d(hidden_dim)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_classes)
        )

    def forward(self, x_dict, edge_index_dict):
        x = x_dict
        for conv in self.convs:
            x = conv(x, edge_index_dict)
            x['product'] = F.relu(self.bn_product(x['product']))
            x['segment'] = F.relu(self.bn_segment(x['segment']))
            x['family']  = F.relu(self.bn_family(x['family']))
            x['class']   = F.relu(self.bn_class(x['class']))

            x['product'] = F.dropout(x['product'], p=self.dropout, training=self.training)
            x['segment'] = F.dropout(x['segment'], p=self.dropout, training=self.training)
            x['family']  = F.dropout(x['family'], p=self.dropout, training=self.training)
            x['class']   = F.dropout(x['class'], p=self.dropout, training=self.training)

        logits = self.classifier(x['product'])
        return logits, x
    
class HeteroAttnNet(nn.Module):
    def __init__(self, hidden_dim: int, out_classes: int, num_layers: int = 2, dropout: float = 0.1, heads: int = 2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout

        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                # forward relations
                ('product', 'to', 'segment'): GATConv(
                    (-1, -1), hidden_dim, heads=heads, concat=False, add_self_loops=False
                ),
                ('segment', 'to', 'family'): GATConv(
                    (-1, -1), hidden_dim, heads=heads, concat=False, add_self_loops=False
                ),
                ('family', 'to', 'class'): GATConv(
                    (-1, -1), hidden_dim, heads=heads, concat=False, add_self_loops=False
                ),
                # reverse relations
                ('segment', 'rev_to', 'product'): GATConv(
                    (-1, -1), hidden_dim, heads=heads, concat=False, add_self_loops=False
                ),
                ('family', 'rev_to', 'segment'): GATConv(
                    (-1, -1), hidden_dim, heads=heads, concat=False, add_self_loops=False
                ),
                ('class', 'rev_to', 'family'): GATConv(
                    (-1, -1), hidden_dim, heads=heads, concat=False, add_self_loops=False
                ),
            }, aggr='mean')
            self.convs.append(conv)

        # batchnorms
        self.bn_product = nn.BatchNorm1d(hidden_dim)
        self.bn_segment = nn.BatchNorm1d(hidden_dim)
        self.bn_family  = nn.BatchNorm1d(hidden_dim)
        self.bn_class   = nn.BatchNorm1d(hidden_dim)

        # classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_classes)
        )

    def forward(self, x_dict, edge_index_dict):
        x = x_dict
        for conv in self.convs:
            x = conv(x, edge_index_dict)

            x['product'] = F.relu(self.bn_product(x['product']))
            x['segment'] = F.relu(self.bn_segment(x['segment']))
            x['family']  = F.relu(self.bn_family(x['family']))
            x['class']   = F.relu(self.bn_class(x['class']))

            x['product'] = F.dropout(x['product'], p=self.dropout, training=self.training)
            x['segment'] = F.dropout(x['segment'], p=self.dropout, training=self.training)
            x['family']  = F.dropout(x['family'], p=self.dropout, training=self.training)
            x['class']   = F.dropout(x['class'], p=self.dropout, training=self.training)

        logits = self.classifier(x['product'])
        return logits, x


# -------------------------
# Training wrapper (updated edge names/order)
# -------------------------
def train_model(
    product_embeddings,                # tensor [N_p, 1024]
    product_to_segment_edge_index,     # long tensor [2, E_ps]
    segment_to_family_edge_index,      # long tensor [2, E_sf]
    family_to_class_edge_index,        # long tensor [2, E_fc]
    product_y,                         # long tensor [N_p]
    product_train_mask,                # bool tensor [N_p]
    product_test_mask,                 # bool tensor [N_p]
    num_families=None,
    num_segments=None,
    num_classes=None,
    pretrained_category_embeddings: Optional[Dict[str, torch.Tensor]] = None,
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    hidden_dim=128,
    num_layers=2,
    heads=None,
    lr=1e-3,
    weight_decay=1e-5,
    epochs=100,
    verbose=True
):
    data = HeteroData()
    data['product'].x = product_embeddings
    data['product'].y = product_y
    data['product'].train_mask = product_train_mask
    data['product'].test_mask = product_test_mask

    # Updated edges
    data['product', 'to', 'segment'].edge_index = product_to_segment_edge_index
    data['segment', 'to', 'family'].edge_index = segment_to_family_edge_index
    data['family', 'to', 'class'].edge_index = family_to_class_edge_index

    # reverse edges (flip rows)
    data['segment', 'rev_to', 'product'].edge_index = product_to_segment_edge_index.flip(0)
    data['family', 'rev_to', 'segment'].edge_index = segment_to_family_edge_index.flip(0)
    data['class', 'rev_to', 'family'].edge_index = family_to_class_edge_index.flip(0)

    encoder = NodeFeatureEncoder(
        prod_in_dim=product_embeddings.size(1),
        hidden_dim=hidden_dim,
        num_families=num_families,
        num_segments=num_segments,
        num_classes=num_classes,
        pretrained_category_embeddings=pretrained_category_embeddings
    ).to(device)

    model = None
    if heads is not None:
        model = HeteroAttnNet(hidden_dim=hidden_dim, out_classes=num_classes, num_layers=num_layers, heads=heads).to(device)
    else: 
        model = HeteroSAGENet(hidden_dim=hidden_dim, out_classes=num_classes, num_layers=num_layers).to(device)
    
    params = list(encoder.parameters()) + list(model.parameters())
    optimizer = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)

    # category indices for embedding path (if needed)
    segment_idx = torch.arange(num_segments, dtype=torch.long, device=device) if (pretrained_category_embeddings is None or 'segment' not in (pretrained_category_embeddings or {})) else None
    family_idx  = torch.arange(num_families, dtype=torch.long, device=device)  if (pretrained_category_embeddings is None or 'family' not in (pretrained_category_embeddings or {})) else None
    class_idx   = torch.arange(num_classes, dtype=torch.long, device=device)   if (pretrained_category_embeddings is None or 'class' not in (pretrained_category_embeddings or {})) else None

    data = data.to(device)
    product_embeddings = product_embeddings.to(device)
    product_y = product_y.to(device)
    product_train_mask = product_train_mask.to(device)
    product_test_mask = product_test_mask.to(device)

    best_test_acc = 0.0
    best_state = None

    for epoch in tqdm(range(1, epochs+1)):
        encoder.train(); model.train()
        optimizer.zero_grad()

        x_dict = encoder(product_embeddings, segment_idx, family_idx, class_idx)
        logits, _ = model(x_dict, data.edge_index_dict)

        loss = F.cross_entropy(logits[product_train_mask], product_y[product_train_mask])
        loss.backward()
        optimizer.step()

        if verbose and epoch % max(1, epochs//10) == 0:
            encoder.eval(); model.eval()
            with torch.inference_mode():
                x_eval = encoder(product_embeddings, segment_idx, family_idx, class_idx)
                logits_eval, _ = model(x_eval, data.edge_index_dict)
                pred = logits_eval.argmax(dim=-1)
                train_acc = (pred[product_train_mask] == product_y[product_train_mask]).sum().item() / max(1, int(product_train_mask.sum().item()))
                test_acc = (pred[product_test_mask] == product_y[product_test_mask]).sum().item() / max(1, int(product_test_mask.sum().item()))
            if test_acc > best_test_acc:
                best_test_acc = test_acc
                best_state = {'encoder': encoder.state_dict(), 'model': model.state_dict(), 'epoch': epoch, 'test_acc': test_acc}
            print(f"\nEpoch {epoch:03d} | Loss: {loss.item():.4f} | Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")

    print("\nTraining finished. Best test acc:", best_test_acc)
    return encoder, model, best_state

In [62]:
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()

In [None]:
encoder, model, best_state = train_model(
    product_embeddings=data["product"].x,                  # [N_products, 1024]
    product_to_segment_edge_index=data["product", "to", "segment"].edge_index,
    segment_to_family_edge_index=data["segment", "to", "family"].edge_index,
    family_to_class_edge_index=data["family", "to", "class"].edge_index,
    product_y=data["product"].y,                           # class labels
    product_train_mask=data["product"].train_mask,         # boolean mask
    product_test_mask=data["product"].test_mask,           # boolean mask
    num_segments=data["segment"].num_nodes,
    num_families=data["family"].num_nodes,
    num_classes=data["class"].num_nodes,
    hidden_dim=128,       # you can tune
    num_layers=3,         # number of GNN layers
    # heads=2,
    lr=1e-3,
    weight_decay=1e-5,
    epochs=200,            # train longer for better results
    verbose=True
)

In [None]:
best_state