In [1]:
import random
import numpy as np
import torch
import json
from tqdm import tqdm
from pathlib import Path
from utils import * 
import copy
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
import os
import csv
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [2]:
# Default paths
ROOT = Path("Amazon_products") # Root Amazon_products directory
TRAIN_DIR = ROOT / "train"
TEST_DIR = ROOT / "test"

TEST_CORPUS_PATH = os.path.join(TEST_DIR, "test_corpus.txt")  # product_id \t text
TRAIN_CORPUS_PATH = os.path.join(TRAIN_DIR, "train_corpus.txt")

CLASS_HIERARCHY_PATH = ROOT / "class_hierarchy.txt" 
CLASS_RELATED_PATH = ROOT / "class_related_keywords.txt" 
CLASS_PATH = ROOT / "classes.txt" 

SUBMISSION_PATH = "Submission/submission.csv"  # output file

# --- Constants ---
NUM_CLASSES = 531  # total number of classes (0–530)
MIN_LABELS = 1     # minimum number of labels per sample
MAX_LABELS = 3     # maximum number of labels per sample


In [3]:
# --- Load ---

""" 
1. Training corpus: 29,487 product reviews without class labels.
2. Classes: 531 product categories.
3. Class hierarchy: A taxonomy file that defines parent–child relationships among classes (each line represents one relation).
4. Class-related keywords: A list of keywords associated with each product class.
5. Test corpus: 19,658 product reviews for evaluation.
"""

def load_corpus(path):
    """Load test corpus into {id: text} dictionary."""
    id2text = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                id, text = parts
                id2text[id] = text
    return id2text

def load_multilabel(path):
    """Load multi-label data into {id: [labels]} dictionary."""
    id2labels = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) == 2:
                pid, label = parts
                pid = int(pid)
                label = int(label)

                if pid not in id2labels:
                    id2labels[pid] = []

                id2labels[pid].append(label)
    return id2labels

def load_class_keywords(path):
    """Load class keywords into {class_name: [keywords]} dictionary."""
    class2keywords = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if ":" not in line:
                continue
            classname, keywords = line.strip().split(":", 1)
            keyword_list = [kw.strip() for kw in keywords.split(",") if kw.strip()]
            class2keywords[classname] = keyword_list
    return class2keywords

id2text_test = load_corpus(TEST_CORPUS_PATH)
id_list_test = list(id2text_test.keys())

id2text_train = load_corpus(TRAIN_CORPUS_PATH)
id_list_train = list(id2text_train.keys())

id2class = load_corpus(CLASS_PATH)
class2hierarchy = load_multilabel(CLASS_HIERARCHY_PATH)
class2related = load_class_keywords(CLASS_RELATED_PATH)

# ======== Print ===========

print(len(id2class)) 
for i in range(10):
    print(i, ":", id2class[str(i)])

print()

print(len(id2text_test)) 
for i, (id, text) in enumerate(id2text_test.items()):
    if i >= 10: 
        break
    print(id, ":", text)

print()

print(len(id2text_train)) 
for i, (id, text) in enumerate(id2text_train.items()):
    if i >= 10: 
        break
    print(id, ":", text)

print()
print(len(class2hierarchy)) 
for i, (id, node) in enumerate(class2hierarchy.items()):
    if i >= 10: 
        break
    print(id, ":", node)


print()
print(len(class2related)) 
for i, (classp, text) in enumerate(class2related.items()):
    if i >= 10: 
        break
    print(classp, ":", text)


531
0 : grocery_gourmet_food
1 : meat_poultry
2 : jerky
3 : toys_games
4 : games
5 : puzzles
6 : jigsaw_puzzles
7 : board_games
8 : beverages
9 : juices

19658
0 : conair cs15tcs professional straight styles straightening iron woah ! sure this straightener looks like all the other crappy straightners in the world , but there 's a twist to this one ! it is my first straightner and i 've had it for about 7 months . i bought it only because i was desperate for a cheap straightener because my hair is very thick , long , wavy ! i 'm looking for a new straighner right now ... but until then this one is doing just fine . if it works for me , it will work for you !
1 : barbie ballet shoes icon doll i was looking round the toysrus website and found this cheap doll ! " wow " i said . so i got it and her body is painted on ! which is really cute ! , parents would n't you like to get a toy where you save yourself from picking up another barbie item from the floor ! well , make a cardboard danceflo

In [4]:
# Silver Labeling

# TF-IDF -> 1st Baseline (Complete)

from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
import re
import csv
from tqdm import tqdm

import re

def preprocess_text(text):
    """
    Clean text by removing special characters, collapsing spaces,
    converting to lowercase, and removing short tokens (<2 letters).
    """
    cleaned = re.sub(r"[>&]", " ", text)
    cleaned = re.sub(r"[^a-zA-Z0-9 ]", " ", cleaned)
    cleaned = cleaned.lower()
    cleaned = re.sub(r"\s+", " ", cleaned).strip()
    tokens = [t for t in cleaned.split() if len(t) >= 2]
    return " ".join(tokens)

def build_tfidf_vectorizer(label_texts):
    """
    Build and fit a TF-IDF vectorizer on label texts.

    Args:
        label_texts (list of str): A list of strings, each describing a label/category.

    Returns:
        vectorizer (TfidfVectorizer): The fitted TF-IDF vectorizer.
        label_tfidf (scipy.sparse.csr_matrix): TF-IDF matrix representation of label_texts.
    """
    vectorizer = TfidfVectorizer(
        max_features=5000,
        ngram_range=(1, 2),
        min_df=2,
        max_df=0.8,
        stop_words='english'
    )
    label_tfidf = vectorizer.fit_transform(label_texts)
    return vectorizer, label_tfidf

def compute_lexical_similarity(doc_text, vectorizer, label_tfidf):
    """
    Compute lexical similarity between a document and label texts using TF-IDF.

    Args:
        doc_text (str): The document (e.g., product description) as a string.
        vectorizer (TfidfVectorizer): The fitted TF-IDF vectorizer.
        label_tfidf (scipy.sparse.csr_matrix): TF-IDF matrix for label_texts.

    Returns:
        sims (numpy.ndarray): A 1D array of similarity scores for each label 
                              (e.g., cosine similarity values).
    """
    doc_vec = vectorizer.transform([doc_text])
    sims = cosine_similarity(doc_vec, label_tfidf)[0]
    return sims


def get_parents(class_id, class_hierarchy):
    """
    Get all parent classes of a given class.
    """
    parents = []
    for parent_id, children in class_hierarchy.items():
        if class_id in children:
            parents.append(parent_id)
    return parents


def select_labels(similarities, max_labels=3, margin=0.05):
    """
    Select 2 or 3 labels depending on the relative gap:
    
    Rule:
    - Look at top-3 scores
    - If score3 is much lower → only 2 labels
    - Else → 3 labels
    """
    ranked = np.argsort(similarities)[::-1]  # sort high→low
    top3 = ranked[:max_labels]
    
    s1, s2, s3 = similarities[top3[0]], similarities[top3[1]], similarities[top3[2]]

    if abs(s2 - s3) >= margin:
        selected = [top3[0], top3[1]]
    else:
        selected = list(top3)

    return [int(x) for x in selected]

def get_all_ancestors(cid, class_hierarchy):
    """Return all ancestors of a class ID using the class hierarchy."""
    ancestors = set()
    for parent, children in class_hierarchy.items():
        if cid in children:
            ancestors.add(parent)
            ancestors.update(get_all_ancestors(parent, class_hierarchy))
    return list(ancestors)


def select_labels_with_hierarchy(similarities, class_hierarchy, min_labels=2, max_labels=3, threshold=0.05):
    """
    Select 2-3 labels using TF-IDF scores + hierarchy expansion.
    
    Strategy:
    1. Find top-k most similar classes (k=1 or 2)
    2. Add their ancestors to respect hierarchy
    3. Keep only 2-3 most relevant labels
    """
    # Get top classes above threshold
    top_indices = np.argsort(similarities)[::-1]
    
    # Select core classes (top 1-2 with score > threshold)
    core_classes = []
    for idx in top_indices[:5]:  # Check top 5
        if similarities[idx] > threshold:
            core_classes.append(idx)
            if len(core_classes) >= 2:
                break
    
    # If no class above threshold, take top 2
    if len(core_classes) == 0:
        core_classes = list(top_indices[:2])
    
    # Expand with ancestors
    all_labels = set(core_classes)
    for core_id in core_classes:
        ancestors = get_all_ancestors(core_id, class_hierarchy)
        all_labels.update(ancestors)
    
    # Convert to list and score
    all_labels = list(all_labels)
    label_scores = [(label, similarities[label]) for label in all_labels]
    label_scores.sort(key=lambda x: x[1], reverse=True)
    
    # Select top 2-3
    selected = [label for label, _ in label_scores[:max_labels]]
    
    # Ensure minimum labels
    while len(selected) < min_labels:
        for idx in top_indices:
            if idx not in selected:
                selected.append(idx)
                break
    
    return sorted(selected[:max_labels])




In [5]:
# Reminder -> id2text_train, id_list_train, id2class, class2hierarchy, class2related
# TF-IDF BASELINE PIPELINE

def generate_tfidf_baseline(id2text_train, id2text_test, id2class, class2related):
    """Generate baseline predictions using TF-IDF similarity."""
    
    print(f"\nTrain: {len(id2text_train)} | Test: {len(id2text_test)} | Classes: {len(id2class)}")
    
    # 1. Prepare class descriptions -> add more context words
    class_texts = []
    for i in range(len(id2class)):
        class_name = id2class[str(i)]
        if class2related and class_name in class2related:
            text = class_name + " " + " ".join(class2related[class_name])
        else:
            text = class_name
        class_texts.append(preprocess_text(text))

    print(class_texts)
    
    # 2. Preprocess documents
    train_texts = [preprocess_text(txt) for txt in id2text_train.values()]
    test_texts = [preprocess_text(txt) for txt in id2text_test.values()]
    test_ids = list(id2text_test.keys())
    train_ids = list(id2text_train.keys())

    all_texts = train_texts + test_texts + class_texts # we take all texts for generalization

    vectorizer = TfidfVectorizer(
        max_features=5000,
        ngram_range=(1, 2),
        min_df=2,
        max_df=0.8,
        stop_words='english'
    )

    tfidf_all = vectorizer.fit_transform(all_texts)
    n_train = len(train_texts)
    n_test  = len(test_texts)

    train_tfidf = tfidf_all[:n_train]
    test_tfidf  = tfidf_all[n_train:n_train+n_test]
    class_tfidf = tfidf_all[n_train+n_test:]
    
    silver_train_sim = cosine_similarity(train_tfidf, class_tfidf)
    print("Lexical similarity train matrix:", silver_train_sim.shape)

    silver_test_sim = cosine_similarity(test_tfidf, class_tfidf)
    print("Lexical similarity test matrix:", silver_test_sim.shape)

    silver_train = {
        train_ids[i]: select_labels(silver_train_sim[i])
        for i in range(len(train_ids))
    }

    silver_test = {
        test_ids[i]: select_labels(silver_test_sim[i])
        for i in range(len(test_ids))
    }

    return silver_test, silver_train

# Main loop
silver_test1, silver_train1 = generate_tfidf_baseline(
    id2text_train, 
    id2text_test, 
    id2class, 
    class2related
)

# ===== Example =====
print("\nExample Silver Labels - TRAIN:")
for i, (pid, labels) in enumerate(list(silver_train1.items())[:5]):
    print(f"  {pid} -> {labels}")
    print((f"      {pid} -> {id2text_train[pid]}")) 
    for label in labels:
        print((f"      {label} -> {id2class[str(label)]}"))

print("\nExample Silver Labels - TEST:")
for i, (pid, labels) in enumerate(list(silver_test1.items())[:5]):
    print(f"  {pid} -> {labels}")



Train: 29487 | Test: 19658 | Classes: 531
['grocery gourmet food snacks condiments beverages specialty foods spices cooking oils baking ingredients gourmet chocolates artisanal cheeses organic foods', 'meat poultry butcher cuts marination grilling roasting seasoning halal organic deli marbling', 'jerky beef turkey chicken venison buffalo kangaroo elk ostrich bison spicy', 'toys games board games puzzles action figures building blocks dolls outdoor toys educational toys card games remote control toys plush toys', 'games board games card games tabletop games party games roleplaying games video games strategy games family games word games dice games', 'puzzles jigsaw puzzles brain teasers puzzle accessories puzzle storage puzzle mats puzzle glue puzzle organizers puzzle books puzzle magazines puzzle competitions', 'jigsaw puzzles interlocking pieces puzzle boards puzzle glue puzzle storage puzzle frames puzzle rolls puzzle organizers puzzle tables puzzle sleeves puzzle sorting trays', 'b

In [7]:
# TF-IDF BASELINE + HIERARCHY PIPELINE

def generate_tfidf_baseline_hierarchy(id2text_train, id2text_test, id2class, class2related, class_hierarchy):
    """Generate TF-IDF + hierarchy silver labels."""
    
    print(f"\nTrain: {len(id2text_train)} | Test: {len(id2text_test)} | Classes: {len(id2class)}")

    # Class descriptions with keywords
    class_texts = []
    for i in range(len(id2class)):
        cname = id2class[str(i)]
        if class2related and cname in class2related:
            desc = cname + " " + " ".join(class2related[cname])
        else:
            desc = cname
        class_texts.append(preprocess_text(desc))

    train_texts = [preprocess_text(txt) for txt in id2text_train.values()]
    test_texts = [preprocess_text(txt) for txt in id2text_test.values()]
    train_ids = list(id2text_train.keys())
    test_ids = list(id2text_test.keys())

    all_texts = train_texts + test_texts + class_texts

    vectorizer = TfidfVectorizer(
        max_features=5000,
        ngram_range=(1, 2),
        min_df=2,
        max_df=0.8,
        stop_words='english'
    )

    tfidf_all = vectorizer.fit_transform(all_texts)
    n_train, n_test = len(train_texts), len(test_texts)

    train_tfidf = tfidf_all[:n_train]
    test_tfidf = tfidf_all[n_train:n_train+n_test]
    class_tfidf = tfidf_all[n_train+n_test:]

    sim_train = cosine_similarity(train_tfidf, class_tfidf)
    sim_test = cosine_similarity(test_tfidf, class_tfidf)

    print("\nGenerating silver labels (train + test) ...")
    silver_train = {
        train_ids[i]: select_labels_with_hierarchy(sim_train[i], class_hierarchy)
        for i in tqdm(range(len(train_ids)))
    }
    silver_test = {
        test_ids[i]: select_labels_with_hierarchy(sim_test[i], class_hierarchy)
        for i in tqdm(range(len(test_ids)))
    }

    return silver_test, silver_train


# Main exec
silver_test2, silver_train2 = generate_tfidf_baseline_hierarchy(
    id2text_train,
    id2text_test,
    id2class,
    class2related,
    class2hierarchy
)

print("\nSilver Labels - TRAIN examples:")
for pid in list(silver_train2.keys())[:5]:
    labels = silver_train2[pid]
    names = [id2class[str(l)] for l in labels]
    print(f"{pid}: {names}")

print("\nSilver Labels - TEST examples:")
for pid in list(silver_test2.keys())[:5]:
    labels = silver_test2[pid]
    names = [id2class[str(l)] for l in labels]
    print(f"{pid}: {names}")


Train: 29487 | Test: 19658 | Classes: 531

Generating silver labels (train + test) ...


100%|██████████| 29487/29487 [00:02<00:00, 14491.98it/s]
100%|██████████| 19658/19658 [00:01<00:00, 14447.18it/s]



Silver Labels - TRAIN examples:
0: ['health_care', 'stress_reduction', 'women_s_health']
1: ['candy_chocolate', 'chocolate_bars', 'chocolate_gifts']
2: ['snack_food', 'bars', 'granola_bars']
3: ['beauty', 'hair_care', 'styling_products']
4: ['snack_food', 'bars', 'granola_bars']

Silver Labels - TEST examples:
0: ['hair_care', 'styling_products', 'hair_relaxers']
1: ['toys_games', 'dolls_accessories', 'doll_accessories']
2: ['baby_products', 'aquarium_lights', 'diaper_stackers_caddies']
3: ['gourmet_gifts', 'oils', 'cheese_gifts']
4: ['beauty', 'fragrance', 'children_s']


In [None]:
# Stats

def label_stats(name, silver):
    counts = [len(v) for v in silver.values()]
    print(f"\n{name}")
    print(f"  Documents: {len(counts)}")
    print(f"  Avg labels/doc: {np.mean(counts):.2f}")
    print(f"  Min labels: {np.min(counts)}")
    print(f"  Max labels: {np.max(counts)}")

def compute_overlap(silver_A, silver_B):
    same = sum(1 for k in silver_A if set(silver_A[k]) == set(silver_B[k]))
    return 100 * same / len(silver_A)

def hierarchy_consistency(silver, hierarchy):
    ok = total = 0
    for labels in silver.values():
        L = set(labels)
        for parent, children in hierarchy.items():
            for child in children:
                if child in L:
                    total += 1
                    if parent in L:
                        ok += 1
    return ok / total if total > 0 else 0


print("\n================= TRAIN SET =================")

label_stats("Baseline Train (no hierarchy)", silver_train1)
label_stats("Hierarchy Train", silver_train2)

train_overlap = compute_overlap(silver_train1, silver_train2)
print(f"\nOverlap between the 2 methods for Train: {train_overlap:.2f}%")

train_consistency_no = hierarchy_consistency(silver_train1, class2hierarchy)
train_consistency_yes = hierarchy_consistency(silver_train2, class2hierarchy)
print(f"Hierarchy Consistency Train:")
print(f"  No Hierarchy: {train_consistency_no:.4f}")
print(f"  With Hierarchy: {train_consistency_yes:.4f}")




Baseline Train (no hierarchy)
  Documents: 29487
  Avg labels/doc: 2.79
  Min labels: 2
  Max labels: 3

Hierarchy Train
  Documents: 29487
  Avg labels/doc: 2.99
  Min labels: 2
  Max labels: 3

Overlap Train: 6.23%
Hierarchy Consistency Train:
  No Hierarchy: 0.0839
  With Hierarchy: 0.3468


Integrating the hierarchy into the generation of silver labels significantly improves the taxonomic consistency of predictions (from ~8% to ~34%) while maintaining a number of labels per document consistent with real data, making it a more reliable basis for self-supervised learning.

In [None]:
# BERT fine-tuning

""" Our TaxoClass framework consists of four major
 steps: (1) document-class similarity calculation, (2)
 document core class mining, (3) core class guided
 classifier training, and (4) multi-label self-training.
 Fig. 2 shows our framework overview and below
 sections discuss each step in more details.

 weuseRoberta-Large-MNLI1as
 ourtextualentailmentmodelwhichutilizesthepre
trainedRoberta-Largeasitsbackboneandis
 fine-tunedontheMNLIdataset.

  DocumentEncoder. In thiswork, we instan
tiateour document encoder gdoc
 () tobeapre
trainedBERT-base-uncased

Forclassencodergclass
 (),wefol
low(Shenetal.,2020)anduseagraphneuralnet
work(GNN)

✅ Utiliser les données fournies uniquement
✅ Faire du Hierarchical Multi-Label Classification
✅ Exploiter l’hiérarchie des classes
✅ Tu peux utiliser silver labels générés automatiquement
✅ Tu peux utiliser un modèle pré-entraîné (BERT)
✅ Utiliser maximum 1000 appels LLM (si tu veux aider avec LLM, pas obligatoire)

We fine-tune BERT using the silver labels produced by TF-IDF similarity with hierarchical expansion.
This enables learning a semantic multi-label classifier without ground-truth labels.

BERT-based classifier

The paper uses LLMs only for these parts:

Step	Purpose	Requires LLM?	Paper Section
(A) Enrich class information	Generate keywords for each class	✅ Yes (few prompts)	Sec. 3.1 improvement
(B) Get more accurate labels for hard cases	Filter candidate labels	✅ Yes but few calls	Sec. 3.4 (Self-training help)
(C) Optional iterative correction	Improve hard samples gradually	✅ Yes (but limited)	Sec. 3.3–3.4

✅ Quand est-ce qu’on utilise le LLM ?

Tu ne l’utilises pas pour tout classer.

Tu l’utilises uniquement pour deux choses :

1️⃣ Améliorer les descriptions des classes
→ Générer des mots-clés supplémentaires pour chaque catégorie
(Ex : “baby cereal → céréale bébé, sans gluten, iron-fortified…”)

2️⃣ Corriger les cas difficiles
→ Quand ton modèle (BERT) est peu confiant, tu demandes au LLM de vérifier/améliorer les labels

✅ Donc usage limité & intelligent
✅ Reste bien dans le quota de 1000 appels
"""

In [None]:
"""# --- Generate random predictions ---
all_pids, all_labels = [], []
for pid in tqdm(id_list_test, desc="Generating dummy predictions"):
    n_labels = random.randint(MIN_LABELS, MAX_LABELS)
    labels = random.sample(range(NUM_CLASSES), n_labels)
    labels = sorted(labels)
    all_pids.append(pid)
    all_labels.append(labels)

# --- Save submission file ---
with open(SUBMISSION_PATH, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["pid", "labels"])
    for pid, labels in zip(all_pids, all_labels):
        writer.writerow([pid, ",".join(map(str, labels))])

print(f"Dummy submission file saved to: {SUBMISSION_PATH}")
print(f"Total samples: {len(all_pids)}, Classes per sample: {MIN_LABELS}-{MAX_LABELS}")"""