In [1]:
import random
import numpy as np
import torch
import json
import os
from tqdm import tqdm
from pathlib import Path
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import copy

# Seed pour reproductibilité
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

Device: cuda


In [2]:
ROOT = Path("Amazon_products")
TRAIN_DIR = ROOT / "train"
TEST_DIR = ROOT / "test"

TEST_CORPUS_PATH = TEST_DIR / "test_corpus.txt"
TRAIN_CORPUS_PATH = TRAIN_DIR / "train_corpus.txt"

CLASS_HIERARCHY_PATH = ROOT / "class_hierarchy.txt"
CLASS_RELATED_PATH = ROOT / "class_related_keywords.txt"
CLASS_PATH = ROOT / "classes.txt"

SUBMISSION_PATH = "Submission/submission.csv"

NUM_CLASSES = 531
MIN_LABELS = 2
MAX_LABELS = 3


In [3]:
def load_corpus(path):
    """Load corpus into {id: text} dictionary."""
    id2text = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                id, text = parts
                id2text[id] = text
    return id2text

def load_multilabel(path):
    """Load multi-label data into {id: [labels]} dictionary."""
    id2labels = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) == 2:
                pid, label = parts
                pid = int(pid)
                label = int(label)
                if pid not in id2labels:
                    id2labels[pid] = []
                id2labels[pid].append(label)
    return id2labels

def load_class_keywords(path):
    """Load class keywords into {class_name: [keywords]} dictionary."""
    class2keywords = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if ":" not in line:
                continue
            classname, keywords = line.strip().split(":", 1)
            keyword_list = [kw.strip() for kw in keywords.split(",") if kw.strip()]
            class2keywords[classname] = keyword_list
    return class2keywords


In [4]:
id2text_test = load_corpus(TEST_CORPUS_PATH)
id2text_train = load_corpus(TRAIN_CORPUS_PATH)

# Classes
id2class = load_corpus(CLASS_PATH)
class2hierarchy = load_multilabel(CLASS_HIERARCHY_PATH)
class2related = load_class_keywords(CLASS_RELATED_PATH)

# Silver labels (RoBERTa - les meilleurs)
with open("Silver/silver_train_roberta.json", "r") as f:
    pid2labelids_silver = json.load(f)
    

print(f"Train: {len(id2text_train)} samples")
print(f"Test: {len(id2text_test)} samples")
print(f"Classes: {len(id2class)}")

Train: 29487 samples
Test: 19658 samples
Classes: 531


In [5]:
# Embeddings
X_train = torch.load("Embeddings/X_train.pt").to(device)
X_test = torch.load("Embeddings/X_test.pt").to(device)
label_emb = torch.load("Embeddings/label_emb.pt").to(device)
test_ids = list(id2text_test.keys())
train_ids = list(id2text_train.keys())

print(f"Train embeddings: {X_train.shape}")
print(f"Test embeddings: {X_test.shape}")
print(f"Label embeddings: {label_emb.shape}")

# Index mapping
pid2idx = {pid: i for i, pid in enumerate(train_ids)}

input_dim = X_train.size(1)
num_classes = NUM_CLASSES

print(f"Input dimension: {input_dim}")
print(f"Num classes: {num_classes}")

Train embeddings: torch.Size([29487, 768])
Test embeddings: torch.Size([19658, 768])
Label embeddings: torch.Size([531, 768])
Input dimension: 768
Num classes: 531


In [6]:
class ProductCategoryDataset(Dataset):
    """Dataset using pre-calculated embeddings (train or test compatible)"""
    def __init__(self, pid2label, pid2idx, embeddings, num_classes=531):
        self.pid2label = pid2label
        self.pid2idx = pid2idx
        self.embeddings = embeddings
        self.num_classes = num_classes

        if self.pid2label is not None:
            self.pids = list(pid2label.keys())
            self.has_labels = True
        else:
            self.pids = list(pid2idx.keys())
            self.has_labels = False

        self.indices = [pid2idx[pid] for pid in self.pids]

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx):
        emb = self.embeddings[self.indices[idx]]

        if self.has_labels:
            y = torch.zeros(self.num_classes, dtype=torch.float32)
            for label in self.pid2label[self.pids[idx]]:
                y[label] = 1.0
            return {"X": emb, "y": y}
        else:
            return {"X": emb}


In [7]:
def evaluate(model, dataloader, device="cpu", threshold=0.5):
    """Évalue le modèle sur un dataloader"""
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            X = batch["X"].to(device)
            y = batch["y"].cpu().numpy()
            
            logits = model(X)
            probs = torch.sigmoid(logits).cpu().numpy()
            preds = (probs > threshold).astype(int)
            
            all_preds.append(preds)
            all_labels.append(y)

    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)

    f1_macro = f1_score(all_labels, all_preds, average="macro", zero_division=0)
    f1_micro = f1_score(all_labels, all_preds, average="micro", zero_division=0)
    
    return {"f1_macro": f1_macro, "f1_micro": f1_micro}

In [8]:
class MLPClassifier(nn.Module):
    """Simple MLP for classification multi-label"""
    def __init__(self, input_dim, num_classes, hidden_dim=512, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        return self.net(x)


In [9]:
# Split train/val
silver_pids = list(pid2labelids_silver.keys())
train_pids, val_pids = train_test_split(silver_pids, test_size=0.2, random_state=42)

train_labels = {pid: pid2labelids_silver[pid] for pid in train_pids}
val_labels = {pid: pid2labelids_silver[pid] for pid in val_pids}

print(f"Train: {len(train_labels)} | Val: {len(val_labels)}")

# Datasets
train_dataset = ProductCategoryDataset(train_labels, pid2idx, X_train, num_classes=NUM_CLASSES)
val_dataset = ProductCategoryDataset(val_labels, pid2idx, X_train, num_classes=NUM_CLASSES)

test_dataset = ProductCategoryDataset(
    None,
    {pid: i for i, pid in enumerate(test_ids)},
    X_test,
    num_classes=NUM_CLASSES
)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

Train: 23589 | Val: 5898


In [10]:
import copy
import torch.nn.functional as F
from tqdm import tqdm

# === Init Models ===
student = MLPClassifier(input_dim, num_classes).to(device)
teacher = copy.deepcopy(student).to(device)

optimizer = torch.optim.AdamW(student.parameters(), lr=2e-4, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
criterion = nn.BCEWithLogitsLoss()

# === Self-training parameters ===
alpha_ema = 0.995                    # EMA update factor
lambda_cons = 0.5                    # weight for consistency loss
pseudo_update_freq = 3               # update pseudo-labels every N epochs
pseudo_threshold = 0.9               # threshold for teacher pseudo-labels
patience = 5
EPOCHS = 50

best_f1 = 0.0
best_model = copy.deepcopy(student.state_dict())
wait = 0

def ema_update(teacher, student, alpha):
    """Exponential Moving Average update for teacher"""
    for t_param, s_param in zip(teacher.parameters(), student.parameters()):
        t_param.data = alpha * t_param.data + (1 - alpha) * s_param.data

def consistency_loss(logits_s, logits_t):
    """MSE between student and teacher predictions"""
    return F.mse_loss(torch.sigmoid(logits_s), torch.sigmoid(logits_t))

def generate_pseudo_labels(model, loader, threshold=0.7):
    """Generate high-confidence pseudo-labels from teacher"""
    model.eval()
    pseudo_dict = {}
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Generating pseudo-labels", leave=False):
            X = batch["X"].to(device)
            
            logits = model(X)
            probs = torch.sigmoid(logits).cpu().numpy()
            
            # For each sample in batch
            for i, prob in enumerate(probs):
                # Only keep high-confidence predictions
                confident_labels = [j for j, p in enumerate(prob) if p > threshold]
                
                # Only add if we have 2-3 labels (reasonable for this task)
                if 2 <= len(confident_labels) <= 3:
                    pass 
    
    return pseudo_dict

# === TRAINING LOOP ===
for epoch in range(1, EPOCHS + 1):
    print(f"\n=== Epoch {epoch}/{EPOCHS} ===")

    student.train()
    teacher.eval()
    total_loss = 0.0
    total_sup = 0.0
    total_cons = 0.0

    for batch in tqdm(train_loader, desc=f"Train Epoch {epoch}", leave=False):
        X = batch["X"].to(device)
        y = batch["y"].to(device)

        # --- Student forward ---
        logits_s = student(X)
        
        # --- Teacher forward (for consistency) ---
        with torch.no_grad():
            logits_t = teacher(X)

        # --- Loss ---
        loss_sup = criterion(logits_s, y)
        loss_cons = consistency_loss(logits_s, logits_t)
        loss = loss_sup + lambda_cons * loss_cons

        # --- Backprop ---
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # --- Update teacher with EMA ---
        ema_update(teacher, student, alpha_ema)

        total_loss += loss.item()
        total_sup += loss_sup.item()
        total_cons += loss_cons.item()

    scheduler.step()

    # --- Print training stats ---
    avg_loss = total_loss / len(train_loader)
    avg_sup = total_sup / len(train_loader)
    avg_cons = total_cons / len(train_loader)
    
    print(f"Loss: {avg_loss:.4f} (sup: {avg_sup:.4f}, cons: {avg_cons:.4f})")

    # --- Validation ---
    val_metrics = evaluate(teacher, val_loader, device)
    print(f"Val F1-macro={val_metrics['f1_macro']:.4f} | "
          f"F1-micro={val_metrics['f1_micro']:.4f}")

    # --- Save best model ---
    if val_metrics['f1_macro'] > best_f1:
        best_f1 = val_metrics['f1_macro']
        best_model = copy.deepcopy(teacher.state_dict())
        print(f"New best model (F1-macro={best_f1:.4f})")
        wait = 0
    else:
        wait += 1
        print(f"No improvement: {wait}/{patience}")
        if wait >= patience:
            print("Early stopping triggered")
            break

# === Load best model ===
teacher.load_state_dict(best_model)

print("\n" + "="*50)
print("FINAL EVALUATION")
print("="*50)

# Test on validation set (since we don't have true test labels)
test_result = evaluate(teacher, val_loader, device)
print(f"Best Val F1-macro: {test_result['f1_macro']:.4f}")
print(f"Best Val F1-micro: {test_result['f1_micro']:.4f}")


=== Epoch 1/50 ===


                                                                 

Loss: 0.1097 (sup: 0.0847, cons: 0.0499)
Val F1-macro=0.0013 | F1-micro=0.2609
New best model (F1-macro=0.0013)

=== Epoch 2/50 ===


                                                                 

Loss: 0.0173 (sup: 0.0172, cons: 0.0002)
Val F1-macro=0.0014 | F1-micro=0.2490
New best model (F1-macro=0.0014)

=== Epoch 3/50 ===


                                                                 

Loss: 0.0165 (sup: 0.0164, cons: 0.0001)
Val F1-macro=0.0017 | F1-micro=0.2518
New best model (F1-macro=0.0017)

=== Epoch 4/50 ===


                                                                 

Loss: 0.0159 (sup: 0.0159, cons: 0.0001)
Val F1-macro=0.0026 | F1-micro=0.3057
New best model (F1-macro=0.0026)

=== Epoch 5/50 ===


                                                                 

Loss: 0.0154 (sup: 0.0153, cons: 0.0001)
Val F1-macro=0.0033 | F1-micro=0.3470
New best model (F1-macro=0.0033)

=== Epoch 6/50 ===


                                                                 

Loss: 0.0149 (sup: 0.0149, cons: 0.0001)
Val F1-macro=0.0042 | F1-micro=0.3867
New best model (F1-macro=0.0042)

=== Epoch 7/50 ===


                                                                 

Loss: 0.0145 (sup: 0.0145, cons: 0.0001)
Val F1-macro=0.0052 | F1-micro=0.4249
New best model (F1-macro=0.0052)

=== Epoch 8/50 ===


                                                                 

Loss: 0.0141 (sup: 0.0141, cons: 0.0001)
Val F1-macro=0.0059 | F1-micro=0.4569
New best model (F1-macro=0.0059)

=== Epoch 9/50 ===


                                                                 

Loss: 0.0137 (sup: 0.0137, cons: 0.0001)
Val F1-macro=0.0064 | F1-micro=0.4784
New best model (F1-macro=0.0064)

=== Epoch 10/50 ===


                                                                  

Loss: 0.0135 (sup: 0.0134, cons: 0.0001)
Val F1-macro=0.0068 | F1-micro=0.4930
New best model (F1-macro=0.0068)

=== Epoch 11/50 ===


                                                                  

Loss: 0.0132 (sup: 0.0132, cons: 0.0001)
Val F1-macro=0.0073 | F1-micro=0.5034
New best model (F1-macro=0.0073)

=== Epoch 12/50 ===


                                                                  

Loss: 0.0131 (sup: 0.0130, cons: 0.0001)
Val F1-macro=0.0075 | F1-micro=0.5088
New best model (F1-macro=0.0075)

=== Epoch 13/50 ===


                                                                  

Loss: 0.0129 (sup: 0.0129, cons: 0.0001)
Val F1-macro=0.0078 | F1-micro=0.5139
New best model (F1-macro=0.0078)

=== Epoch 14/50 ===


                                                                  

Loss: 0.0128 (sup: 0.0128, cons: 0.0001)
Val F1-macro=0.0079 | F1-micro=0.5179
New best model (F1-macro=0.0079)

=== Epoch 15/50 ===


                                                                  

Loss: 0.0127 (sup: 0.0127, cons: 0.0001)
Val F1-macro=0.0080 | F1-micro=0.5198
New best model (F1-macro=0.0080)

=== Epoch 16/50 ===


                                                                  

Loss: 0.0126 (sup: 0.0126, cons: 0.0001)
Val F1-macro=0.0081 | F1-micro=0.5216
New best model (F1-macro=0.0081)

=== Epoch 17/50 ===


                                                                  

Loss: 0.0126 (sup: 0.0125, cons: 0.0001)
Val F1-macro=0.0082 | F1-micro=0.5230
New best model (F1-macro=0.0082)

=== Epoch 18/50 ===


                                                                  

Loss: 0.0126 (sup: 0.0125, cons: 0.0001)
Val F1-macro=0.0082 | F1-micro=0.5244
New best model (F1-macro=0.0082)

=== Epoch 19/50 ===


                                                                  

Loss: 0.0125 (sup: 0.0125, cons: 0.0001)
Val F1-macro=0.0082 | F1-micro=0.5244
No improvement: 1/5

=== Epoch 20/50 ===


                                                                  

Loss: 0.0125 (sup: 0.0125, cons: 0.0001)
Val F1-macro=0.0082 | F1-micro=0.5242
New best model (F1-macro=0.0082)

=== Epoch 21/50 ===


                                                                  

Loss: 0.0125 (sup: 0.0125, cons: 0.0001)
Val F1-macro=0.0082 | F1-micro=0.5243
No improvement: 1/5

=== Epoch 22/50 ===


                                                                  

Loss: 0.0125 (sup: 0.0125, cons: 0.0001)
Val F1-macro=0.0082 | F1-micro=0.5243
No improvement: 2/5

=== Epoch 23/50 ===


                                                                  

Loss: 0.0125 (sup: 0.0125, cons: 0.0001)
Val F1-macro=0.0083 | F1-micro=0.5246
New best model (F1-macro=0.0083)

=== Epoch 24/50 ===


                                                                  

Loss: 0.0125 (sup: 0.0125, cons: 0.0001)
Val F1-macro=0.0083 | F1-micro=0.5249
New best model (F1-macro=0.0083)

=== Epoch 25/50 ===


                                                                  

Loss: 0.0125 (sup: 0.0124, cons: 0.0001)
Val F1-macro=0.0083 | F1-micro=0.5257
New best model (F1-macro=0.0083)

=== Epoch 26/50 ===


                                                                  

Loss: 0.0124 (sup: 0.0124, cons: 0.0001)
Val F1-macro=0.0085 | F1-micro=0.5271
New best model (F1-macro=0.0085)

=== Epoch 27/50 ===


                                                                  

Loss: 0.0124 (sup: 0.0123, cons: 0.0001)
Val F1-macro=0.0086 | F1-micro=0.5290
New best model (F1-macro=0.0086)

=== Epoch 28/50 ===


                                                                  

Loss: 0.0123 (sup: 0.0122, cons: 0.0001)
Val F1-macro=0.0087 | F1-micro=0.5312
New best model (F1-macro=0.0087)

=== Epoch 29/50 ===


                                                                  

Loss: 0.0122 (sup: 0.0121, cons: 0.0001)
Val F1-macro=0.0088 | F1-micro=0.5343
New best model (F1-macro=0.0088)

=== Epoch 30/50 ===


                                                                  

Loss: 0.0120 (sup: 0.0120, cons: 0.0001)
Val F1-macro=0.0090 | F1-micro=0.5374
New best model (F1-macro=0.0090)

=== Epoch 31/50 ===


                                                                  

Loss: 0.0119 (sup: 0.0119, cons: 0.0001)
Val F1-macro=0.0094 | F1-micro=0.5417
New best model (F1-macro=0.0094)

=== Epoch 32/50 ===


                                                                  

Loss: 0.0117 (sup: 0.0117, cons: 0.0001)
Val F1-macro=0.0098 | F1-micro=0.5460
New best model (F1-macro=0.0098)

=== Epoch 33/50 ===


                                                                  

Loss: 0.0116 (sup: 0.0115, cons: 0.0001)
Val F1-macro=0.0102 | F1-micro=0.5509
New best model (F1-macro=0.0102)

=== Epoch 34/50 ===


                                                                  

Loss: 0.0114 (sup: 0.0114, cons: 0.0001)
Val F1-macro=0.0106 | F1-micro=0.5565
New best model (F1-macro=0.0106)

=== Epoch 35/50 ===


                                                                  

Loss: 0.0112 (sup: 0.0112, cons: 0.0001)
Val F1-macro=0.0111 | F1-micro=0.5613
New best model (F1-macro=0.0111)

=== Epoch 36/50 ===


                                                                  

Loss: 0.0111 (sup: 0.0111, cons: 0.0001)
Val F1-macro=0.0116 | F1-micro=0.5662
New best model (F1-macro=0.0116)

=== Epoch 37/50 ===


                                                                  

Loss: 0.0109 (sup: 0.0109, cons: 0.0001)
Val F1-macro=0.0122 | F1-micro=0.5715
New best model (F1-macro=0.0122)

=== Epoch 38/50 ===


                                                                  

Loss: 0.0108 (sup: 0.0108, cons: 0.0001)
Val F1-macro=0.0127 | F1-micro=0.5757
New best model (F1-macro=0.0127)

=== Epoch 39/50 ===


                                                                  

Loss: 0.0106 (sup: 0.0106, cons: 0.0001)
Val F1-macro=0.0132 | F1-micro=0.5800
New best model (F1-macro=0.0132)

=== Epoch 40/50 ===


                                                                  

Loss: 0.0105 (sup: 0.0105, cons: 0.0001)
Val F1-macro=0.0137 | F1-micro=0.5843
New best model (F1-macro=0.0137)

=== Epoch 41/50 ===


                                                                  

Loss: 0.0104 (sup: 0.0104, cons: 0.0001)
Val F1-macro=0.0141 | F1-micro=0.5873
New best model (F1-macro=0.0141)

=== Epoch 42/50 ===


                                                                  

Loss: 0.0103 (sup: 0.0103, cons: 0.0001)
Val F1-macro=0.0147 | F1-micro=0.5913
New best model (F1-macro=0.0147)

=== Epoch 43/50 ===


                                                                  

Loss: 0.0102 (sup: 0.0102, cons: 0.0001)
Val F1-macro=0.0153 | F1-micro=0.5947
New best model (F1-macro=0.0153)

=== Epoch 44/50 ===


                                                                  

Loss: 0.0101 (sup: 0.0101, cons: 0.0001)
Val F1-macro=0.0159 | F1-micro=0.5973
New best model (F1-macro=0.0159)

=== Epoch 45/50 ===


                                                                  

Loss: 0.0100 (sup: 0.0100, cons: 0.0001)
Val F1-macro=0.0165 | F1-micro=0.6007
New best model (F1-macro=0.0165)

=== Epoch 46/50 ===


                                                                  

Loss: 0.0100 (sup: 0.0099, cons: 0.0001)
Val F1-macro=0.0169 | F1-micro=0.6032
New best model (F1-macro=0.0169)

=== Epoch 47/50 ===


                                                                  

Loss: 0.0099 (sup: 0.0099, cons: 0.0001)
Val F1-macro=0.0173 | F1-micro=0.6049
New best model (F1-macro=0.0173)

=== Epoch 48/50 ===


                                                                  

Loss: 0.0098 (sup: 0.0098, cons: 0.0001)
Val F1-macro=0.0176 | F1-micro=0.6069
New best model (F1-macro=0.0176)

=== Epoch 49/50 ===


                                                                  

Loss: 0.0098 (sup: 0.0098, cons: 0.0001)
Val F1-macro=0.0176 | F1-micro=0.6079
New best model (F1-macro=0.0176)

=== Epoch 50/50 ===


                                                                  

Loss: 0.0098 (sup: 0.0097, cons: 0.0001)
Val F1-macro=0.0179 | F1-micro=0.6083
New best model (F1-macro=0.0179)

FINAL EVALUATION
Best Val F1-macro: 0.0179
Best Val F1-micro: 0.6083


In [11]:
# Save the best model
torch.save(teacher.state_dict(), "Model/mlp_selftraining_best.pt")
print("\nBest model saved to Models/mlp_selftraining_best.pt")


Best model saved to Models/mlp_selftraining_best.pt


In [13]:
import csv, os
from tqdm import tqdm
import torch
import numpy as np

THRESHOLD = 0.5
OUTPUT_PATH = "Submission/MLP.csv"
os.makedirs("Submission", exist_ok=True)

all_pids, all_pred_labels = [], []

with torch.no_grad():
    for start in tqdm(range(0, len(X_test), 64), desc="Generating predictions"):
        end = start + 64
        batch = X_test[start:end]
        batch_pids = test_ids[start:end]

        logits = teacher(batch)
        probs = torch.sigmoid(logits).cpu().numpy()

        for pid, prob in zip(batch_pids, probs):
            pred_row = (prob > THRESHOLD).astype(int)

            if pred_row.sum() == 0:
                pred_row[prob.argmax()] = 1

            labels = [str(j) for j, v in enumerate(pred_row) if v == 1]

            all_pids.append(pid)
            all_pred_labels.append(labels)

print(f"{len(all_pids)} samples generated.")

with open(OUTPUT_PATH, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["id", "label"])
    for pid, labels in zip(all_pids, all_pred_labels):
        writer.writerow([pid, ",".join(labels)])

print(f"Submission file saved: {OUTPUT_PATH}")


Generating predictions: 100%|██████████| 308/308 [00:01<00:00, 237.38it/s]

19658 samples generated.
Submission file saved: Submission/MLP.csv



