In [26]:
import random
import numpy as np
import torch
import json
from tqdm import tqdm
from pathlib import Path
from utils import * 
import copy
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
import os
import csv
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [27]:
# Default paths
ROOT = Path("Amazon_products") # Root Amazon_products directory
TRAIN_DIR = ROOT / "train"
TEST_DIR = ROOT / "test"

TEST_CORPUS_PATH = os.path.join(TEST_DIR, "test_corpus.txt")  # product_id \t text
TRAIN_CORPUS_PATH = os.path.join(TRAIN_DIR, "train_corpus.txt")

CLASS_HIERARCHY_PATH = ROOT / "class_hierarchy.txt" 
CLASS_RELATED_PATH = ROOT / "class_related_keywords.txt" 
CLASS_PATH = ROOT / "classes.txt" 

SUBMISSION_PATH = "Submission/submission.csv"  # output file

# --- Constants ---
NUM_CLASSES = 531  # total number of classes (0–530)
MIN_LABELS = 1     # minimum number of labels per sample
MAX_LABELS = 3     # maximum number of labels per sample


In [28]:
# Load Data
# Default paths
ROOT = Path("Amazon_products") # Root Amazon_products directory
TRAIN_DIR = ROOT / "train"
TEST_DIR = ROOT / "test"

TEST_CORPUS_PATH = os.path.join(TEST_DIR, "test_corpus.txt")  # product_id \t text
TRAIN_CORPUS_PATH = os.path.join(TRAIN_DIR, "train_corpus.txt")

CLASS_HIERARCHY_PATH = ROOT / "class_hierarchy.txt" 
CLASS_RELATED_PATH = ROOT / "class_related_keywords.txt" 
CLASS_PATH = ROOT / "classes.txt" 

SUBMISSION_PATH = "Submission/submission.csv"  # output file

# --- Constants ---
NUM_CLASSES = 531  # total number of classes (0–530)
MIN_LABELS = 1     # minimum number of labels per sample
MAX_LABELS = 3     # maximum number of labels per sample

# Load corpus
def load_corpus(path):
    """Load test corpus into {id: text} dictionary."""
    id2text = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                id, text = parts
                id2text[id] = text
    return id2text

def load_multilabel(path):
    """Load multi-label data into {id: [labels]} dictionary."""
    id2labels = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t")
            if len(parts) == 2:
                pid, label = parts
                pid = int(pid)
                label = int(label)

                if pid not in id2labels:
                    id2labels[pid] = []

                id2labels[pid].append(label)
    return id2labels

def load_class_keywords(path):
    """Load class keywords into {class_name: [keywords]} dictionary."""
    class2keywords = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if ":" not in line:
                continue
            classname, keywords = line.strip().split(":", 1)
            keyword_list = [kw.strip() for kw in keywords.split(",") if kw.strip()]
            class2keywords[classname] = keyword_list
    return class2keywords

id2text_test = load_corpus(TEST_CORPUS_PATH)
id2text_train = load_corpus(TRAIN_CORPUS_PATH)

id2class = load_corpus(CLASS_PATH)
class2hierarchy = load_multilabel(CLASS_HIERARCHY_PATH)
class2related = load_class_keywords(CLASS_RELATED_PATH)

# Load silver labels
with open("Silver/silver_train_roberta.json", "r") as f:
    pid2labelids_silver = json.load(f)

with open("Silver/silver_test_roberta.json", "r") as f:
    pid2labelids_test = json.load(f)

print(f"Train: {len(id2text_train)} samples")
print(f"Test: {len(id2text_test)} samples")
print(f"Classes: {len(id2class)}")


Train: 29487 samples
Test: 19658 samples
Classes: 531


In [29]:
# === Paths ===
X_TRAIN_PATH = "Embeddings/X_train.pt"
Y_TRAIN_PATH = "Embeddings/y_train.pt"
X_TEST_PATH  = "Embeddings/X_test.pt"
TRAIN_IDS_PATH = "Embeddings/train_ids.pt"
TEST_IDS_PATH  = "Embeddings/test_ids.pt"
LABEL_EMB_PATH = "Embeddings/label_emb.pt"

# === Load pre-computed embeddings ===
print("Loading pre-trained embeddings...")
X_train = torch.load(X_TRAIN_PATH).to(device)
y_train = torch.load(Y_TRAIN_PATH).to(device)
X_test  = torch.load(X_TEST_PATH).to(device)
label_emb  = torch.load(LABEL_EMB_PATH).to(device)

test_ids = list(id2text_test.keys())
train_ids = list(id2text_train.keys())

print(f"Train embeddings: {X_train.shape}")
print(f"Test embeddings:  {X_test.shape}")
print(f"Train labels:     {y_train.shape}")
print(f"Label Emb:     {label_emb.shape}")

# === Build index for convenience ===
pid2idx = {pid: i for i, pid in enumerate(train_ids)}

# === Some useful info ===
input_dim = X_train.size(1)
num_classes = y_train.size(1)

print("Input dimension:", input_dim)
print("Num classes:", num_classes)
print("Device:", device)

Loading pre-trained embeddings...
Train embeddings: torch.Size([29487, 768])
Test embeddings:  torch.Size([19658, 768])
Train labels:     torch.Size([29487, 531])
Label Emb:     torch.Size([531, 768])
Input dimension: 768
Num classes: 531
Device: cuda


In [30]:
# Classifier that uses label embeddings to make predictions
class InnerProductClassifier(nn.Module):
    def __init__(self, input_dim, label_embeddings, dropout=0.2, trainable_label_emb=False):
        super().__init__()
        self.dropout = nn.Dropout(dropout) # => for consistency
        # Project input features into the same dimension as label embeddings
        self.proj = nn.Linear(input_dim, label_embeddings.size(1))

        if trainable_label_emb:
            # Label embeddings are trainable parameters
            self.label_emb = nn.Parameter(label_embeddings.clone())
        else:
            # Label embeddings are fixed (not updated during training)
            self.register_buffer("label_emb", label_embeddings.clone())

    def forward(self, x, use_dropout=True):
        if use_dropout:
            x = self.dropout(x)
        # Project input feature vectors
        x_proj = self.proj(x)
        # Compute logits as similarity with each label embedding
        logits = torch.matmul(x_proj, self.label_emb.T)
        return logits
    
class ProductCategoryEmbeddingDataset(Dataset):
    def __init__(self, pid2label, pid2idx, embeddings, num_classes=531):
        self.pids = list(pid2label.keys())
        self.labels = [pid2label[pid] for pid in self.pids]
        self.indices = [pid2idx[pid] for pid in self.pids]
        self.embeddings = embeddings
        self.num_classes = num_classes

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx):
        emb = self.embeddings[self.indices[idx]]
        
        y = torch.zeros(self.num_classes, dtype=torch.float32)
        for label in self.labels[idx]:
            y[label] = 1.0

        return {"X": emb, "y": y}
    
class TensorDatasetFromVectors(Dataset):
    def __init__(self, X_list, y_list):
        self.X = torch.stack(X_list)                     # list of embeddings -> tensor
        self.y = torch.stack(y_list).float()             # list of binary vectors -> tensor

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return {"X": self.X[idx], "y": self.y[idx]}      # embedding + multi-label

class UnlabeledEmbeddingDataset(Dataset):
    def __init__(self, pids, pid2idx, embeddings):
        self.pids = pids
        self.indices = [pid2idx[pid] for pid in self.pids]
        self.embeddings = embeddings

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx):
        emb = self.embeddings[self.indices[idx]]
        pid = self.pids[idx]
        return {"X": emb, "pid": pid}
    
from sklearn.metrics import f1_score, accuracy_score
import torch

from sklearn.metrics import f1_score

def evaluate(model, dataloader, device="cpu", threshold=0.5):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            X = batch["X"].to(device)
            y = batch["y"].cpu().numpy()
            logits = model(X)
            probs = torch.sigmoid(logits).cpu().numpy()
            preds = (probs > threshold).astype(int)
            all_preds.append(preds)
            all_labels.append(y)

    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)

    f1_macro = f1_score(all_labels, all_preds, average="macro", zero_division=0)
    f1_micro = f1_score(all_labels, all_preds, average="micro", zero_division=0)
    return {"f1_macro": f1_macro, "f1_micro": f1_micro}



In [31]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

# --- Split des silver labels en train/val (85/15) ---
silver_pids = list(pid2labelids_silver.keys())  # tu as déjà chargé pid2labelids_silver
train_pids, val_pids = train_test_split(
    silver_pids, test_size=0.15, random_state=42
)

train_labels = {pid: pid2labelids_silver[pid] for pid in train_pids}
val_labels   = {pid: pid2labelids_silver[pid] for pid in val_pids}

print(f"Train: {len(train_labels)} | Val: {len(val_labels)}")

# --- Datasets (basés sur tes embeddings Roberta déjà chargés) ---
train_dataset = ProductCategoryEmbeddingDataset(train_labels, pid2idx, X_train, num_classes=NUM_CLASSES)
val_dataset   = ProductCategoryEmbeddingDataset(val_labels, pid2idx, X_train, num_classes=NUM_CLASSES)

# --- DataLoaders ---
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=64)

# --- Test set ---
test_dataset = ProductCategoryEmbeddingDataset(pid2labelids_test, pid2idx, X_test, num_classes=NUM_CLASSES)
test_loader  = DataLoader(test_dataset, batch_size=64)


Train: 25063 | Val: 4424


In [32]:
import copy
import torch.nn.functional as F
from tqdm import tqdm

# === Initialisation du modèle ===
student = InnerProductClassifier(input_dim, label_emb).to(device)

optimizer = torch.optim.AdamW(student.parameters(), lr=1e-3, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=60)
criterion = nn.BCEWithLogitsLoss()

EPOCHS_BASE = 10
patience = 5
wait = 0
best_f1 = 0.0
best_model = copy.deepcopy(student.state_dict())

# === Training loop ===
def train_epoch(loader):
    student.train()
    total_loss = 0.0

    for batch in tqdm(loader, desc="Train", leave=False):
        X = batch["X"].to(device)
        y = batch["y"].to(device)

        logits = student(X)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

def val_f1():
    student.eval()
    return evaluate(student, val_loader, device)['f1_macro']

print("\n=== Baseline Training (InnerProductClassifier + BCE) ===")
for epoch in range(1, EPOCHS_BASE + 1):
    train_loss = train_epoch(train_loader)
    scheduler.step()

    f1_val = val_f1()
    print(f"[Epoch {epoch}] train_loss={train_loss:.4f} | val_f1_macro={f1_val:.4f}")

    if f1_val > best_f1:
        best_f1 = f1_val
        wait = 0
        best_model = copy.deepcopy(student.state_dict())
        print(f"Best model updated (val_f1={best_f1:.4f})")
    else:
        wait += 1
        print(f"No improvement: {wait}/{patience}")

    if wait >= patience:
        print("\nEarly stopping triggered!")
        break

# === Test final ===
student.load_state_dict(best_model)
print("\n=== Final Test (Best Model) ===")
test_result = evaluate(student, test_loader, device)
print(test_result)



=== Baseline Training (InnerProductClassifier + BCE) ===


Train:   0%|          | 0/392 [00:00<?, ?it/s]

                                                         

[Epoch 1] train_loss=0.0385 | val_f1_macro=0.0016
Best model updated (val_f1=0.0016)


                                                         

[Epoch 2] train_loss=0.0183 | val_f1_macro=0.0050
Best model updated (val_f1=0.0050)


                                                         

[Epoch 3] train_loss=0.0157 | val_f1_macro=0.0068
Best model updated (val_f1=0.0068)


                                                         

[Epoch 4] train_loss=0.0145 | val_f1_macro=0.0083
Best model updated (val_f1=0.0083)


                                                         

[Epoch 5] train_loss=0.0137 | val_f1_macro=0.0089
Best model updated (val_f1=0.0089)


                                                         

[Epoch 6] train_loss=0.0132 | val_f1_macro=0.0106
Best model updated (val_f1=0.0106)


                                                         

[Epoch 7] train_loss=0.0128 | val_f1_macro=0.0109
Best model updated (val_f1=0.0109)


                                                         

[Epoch 8] train_loss=0.0125 | val_f1_macro=0.0118
Best model updated (val_f1=0.0118)


                                                         

[Epoch 9] train_loss=0.0122 | val_f1_macro=0.0123
Best model updated (val_f1=0.0123)


                                                         

[Epoch 10] train_loss=0.0120 | val_f1_macro=0.0140
Best model updated (val_f1=0.0140)

=== Final Test (Best Model) ===
{'f1_macro': 0.014158797269227327, 'f1_micro': 0.5700567714159087}


In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

# === MLP Classifier ===
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=512, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        return self.net(x)

# === Init ===
input_dim = X_train.size(1)
num_classes = y_train.size(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MLPClassifier(input_dim, num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-3)
criterion = nn.BCEWithLogitsLoss()

EPOCHS = 30

# === Training ===
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        X = batch["X"].to(device)
        y = batch["y"].to(device)
        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}: loss={avg_loss:.4f}")

    # === Validation ===
    val_metrics = evaluate(model, val_loader, device)
    print(f"F1-macro={val_metrics['f1_macro']:.4f} | F1-micro={val_metrics['f1_micro']:.4f}")

# === Test ===
print("\nFinal Test:")
test_metrics = evaluate(model, test_loader, device)
print(test_metrics)


Epoch 1/30: 100%|██████████| 392/392 [00:01<00:00, 248.02it/s]


Epoch 1: loss=0.0736
F1-macro=0.0011 | F1-micro=0.1561


Epoch 2/30: 100%|██████████| 392/392 [00:01<00:00, 336.38it/s]


Epoch 2: loss=0.0165
F1-macro=0.0030 | F1-micro=0.3271


Epoch 3/30: 100%|██████████| 392/392 [00:01<00:00, 321.81it/s]


Epoch 3: loss=0.0152
F1-macro=0.0052 | F1-micro=0.4130


Epoch 4/30: 100%|██████████| 392/392 [00:01<00:00, 301.48it/s]


Epoch 4: loss=0.0142
F1-macro=0.0061 | F1-micro=0.4776


Epoch 5/30: 100%|██████████| 392/392 [00:01<00:00, 302.68it/s]


Epoch 5: loss=0.0136
F1-macro=0.0069 | F1-micro=0.5007


Epoch 6/30: 100%|██████████| 392/392 [00:01<00:00, 302.28it/s]


Epoch 6: loss=0.0132
F1-macro=0.0073 | F1-micro=0.5092


Epoch 7/30: 100%|██████████| 392/392 [00:01<00:00, 293.30it/s]


Epoch 7: loss=0.0129
F1-macro=0.0078 | F1-micro=0.5158


Epoch 8/30: 100%|██████████| 392/392 [00:01<00:00, 311.96it/s]


Epoch 8: loss=0.0126
F1-macro=0.0082 | F1-micro=0.5242


Epoch 9/30: 100%|██████████| 392/392 [00:01<00:00, 298.83it/s]


Epoch 9: loss=0.0123
F1-macro=0.0089 | F1-micro=0.5334


Epoch 10/30: 100%|██████████| 392/392 [00:01<00:00, 299.62it/s]


Epoch 10: loss=0.0121
F1-macro=0.0096 | F1-micro=0.5440


Epoch 11/30: 100%|██████████| 392/392 [00:01<00:00, 301.33it/s]


Epoch 11: loss=0.0119
F1-macro=0.0101 | F1-micro=0.5512


Epoch 12/30: 100%|██████████| 392/392 [00:01<00:00, 298.20it/s]


Epoch 12: loss=0.0117
F1-macro=0.0106 | F1-micro=0.5531


Epoch 13/30: 100%|██████████| 392/392 [00:01<00:00, 296.01it/s]


Epoch 13: loss=0.0115
F1-macro=0.0113 | F1-micro=0.5586


Epoch 14/30: 100%|██████████| 392/392 [00:01<00:00, 299.91it/s]


Epoch 14: loss=0.0113
F1-macro=0.0117 | F1-micro=0.5609


Epoch 15/30: 100%|██████████| 392/392 [00:01<00:00, 296.49it/s]


Epoch 15: loss=0.0112
F1-macro=0.0123 | F1-micro=0.5650


Epoch 16/30: 100%|██████████| 392/392 [00:01<00:00, 292.05it/s]


Epoch 16: loss=0.0111
F1-macro=0.0125 | F1-micro=0.5698


Epoch 17/30: 100%|██████████| 392/392 [00:01<00:00, 296.29it/s]


Epoch 17: loss=0.0110
F1-macro=0.0130 | F1-micro=0.5726


Epoch 18/30: 100%|██████████| 392/392 [00:01<00:00, 261.92it/s]


Epoch 18: loss=0.0108
F1-macro=0.0140 | F1-micro=0.5782


Epoch 19/30: 100%|██████████| 392/392 [00:01<00:00, 277.81it/s]


Epoch 19: loss=0.0107
F1-macro=0.0143 | F1-micro=0.5822


Epoch 20/30: 100%|██████████| 392/392 [00:01<00:00, 292.44it/s]


Epoch 20: loss=0.0106
F1-macro=0.0149 | F1-micro=0.5834


Epoch 21/30: 100%|██████████| 392/392 [00:01<00:00, 297.70it/s]


Epoch 21: loss=0.0105
F1-macro=0.0156 | F1-micro=0.5883


Epoch 22/30: 100%|██████████| 392/392 [00:01<00:00, 280.87it/s]


Epoch 22: loss=0.0104
F1-macro=0.0155 | F1-micro=0.5892


Epoch 23/30: 100%|██████████| 392/392 [00:01<00:00, 301.68it/s]


Epoch 23: loss=0.0103
F1-macro=0.0165 | F1-micro=0.5908


Epoch 24/30: 100%|██████████| 392/392 [00:01<00:00, 305.06it/s]


Epoch 24: loss=0.0102
F1-macro=0.0174 | F1-micro=0.5920


Epoch 25/30: 100%|██████████| 392/392 [00:01<00:00, 298.90it/s]


Epoch 25: loss=0.0102
F1-macro=0.0178 | F1-micro=0.5967


Epoch 26/30: 100%|██████████| 392/392 [00:01<00:00, 297.77it/s]


Epoch 26: loss=0.0101
F1-macro=0.0178 | F1-micro=0.5974


Epoch 27/30: 100%|██████████| 392/392 [00:01<00:00, 303.70it/s]


Epoch 27: loss=0.0100
F1-macro=0.0186 | F1-micro=0.6033


Epoch 28/30: 100%|██████████| 392/392 [00:01<00:00, 305.01it/s]


Epoch 28: loss=0.0100
F1-macro=0.0190 | F1-micro=0.6005


Epoch 29/30: 100%|██████████| 392/392 [00:01<00:00, 299.99it/s]


Epoch 29: loss=0.0099
F1-macro=0.0196 | F1-micro=0.6060


Epoch 30/30: 100%|██████████| 392/392 [00:01<00:00, 302.47it/s]


Epoch 30: loss=0.0099
F1-macro=0.0208 | F1-micro=0.6077

Final Test:
{'f1_macro': 0.021519004340856133, 'f1_micro': 0.6079575054905766}


In [40]:
import csv, os
from tqdm import tqdm
import torch
import numpy as np

THRESHOLD = 0.5
OUTPUT_PATH = "Submission/submission.csv"
os.makedirs("Submission", exist_ok=True)

all_pids, all_pred_labels = [], []

with torch.no_grad():
    for start in tqdm(range(0, len(X_test), 64), desc="Generating predictions"):
        end = start + 64
        batch = X_test[start:end]
        batch_pids = test_ids[start:end]

        logits = model(batch)
        probs = torch.sigmoid(logits).cpu().numpy()

        for pid, prob in zip(batch_pids, probs):
            pred_row = (prob > THRESHOLD).astype(int)

            if pred_row.sum() == 0:
                pred_row[prob.argmax()] = 1

            labels = [str(j) for j, v in enumerate(pred_row) if v == 1]

            all_pids.append(pid)
            all_pred_labels.append(labels)

print(f"{len(all_pids)} samples generated.")

with open(OUTPUT_PATH, "w", newline="", encoding="utf-8") as f:
    writer = csv.writer(f)
    writer.writerow(["id", "label"])
    for pid, labels in zip(all_pids, all_pred_labels):
        writer.writerow([pid, ",".join(labels)])

print(f"Submission file saved: {OUTPUT_PATH}")


Generating predictions: 100%|██████████| 308/308 [00:01<00:00, 229.76it/s]

19658 samples generated.
Submission file saved: Submission/submission.csv





In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import json
import numpy as np
from sklearn.metrics import f1_score

# ============================================
# 1. DATASET AVEC TEXTE BRUT (pas embeddings)
# ============================================
class TextClassificationDataset(Dataset):
    """Dataset qui prend le TEXTE brut, pas les embeddings"""
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Tokenize le texte
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.float32)
        }

# ============================================
# 2. MODÈLE : BERT + Classifier Head
# ============================================
class BERTMultiLabelClassifier(nn.Module):
    """
    Architecture du papier :
    - BERT encoder (frozen ou fine-tuné)
    - Classification head
    """
    def __init__(self, model_name='distilroberta-base', num_labels=531, dropout=0.1):
        super().__init__()
        
        # Encoder BERT
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_labels)
        )
    
    def forward(self, input_ids, attention_mask):
        # Encode avec BERT
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Prend le [CLS] token (première position)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        
        # Classification
        logits = self.classifier(pooled_output)
        return logits

# ============================================
# 3. TRAINING LOOP
# ============================================
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

# ============================================
# 4. EVALUATION
# ============================================
def evaluate(model, dataloader, device, threshold=0.5):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()
            
            # Prédictions
            logits = model(input_ids, attention_mask)
            probs = torch.sigmoid(logits).cpu().numpy()
            preds = (probs > threshold).astype(int)
            
            all_preds.append(preds)
            all_labels.append(labels)
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    # Calcul F1
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    f1_micro = f1_score(all_labels, all_preds, average='micro', zero_division=0)
    
    return {'f1_macro': f1_macro, 'f1_micro': f1_micro}

# ============================================
# 5. MAIN PIPELINE
# ============================================
def main():
    # Configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Charger les données
    with open('Amazon_products/train/train_corpus.txt', 'r', encoding='utf-8') as f:
        train_texts = [line.strip().split('\t')[1] for line in f]
    
    with open('Silver/silver_train_roberta.json', 'r') as f:
        silver_labels = json.load(f)
    
    # Préparer les labels
    train_labels = []
    for pid in silver_labels.keys():
        labels_binary = np.zeros(531)
        for label in silver_labels[pid]:
            labels_binary[int(label)] = 1
        train_labels.append(labels_binary)
    
    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
    
    # Dataset
    train_dataset = TextClassificationDataset(
        train_texts, 
        train_labels, 
        tokenizer
    )
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    
    # Modèle
    model = BERTMultiLabelClassifier().to(device)
    
    # Optimizer et loss
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    criterion = nn.BCEWithLogitsLoss()
    
    # Training
    for epoch in range(5):
        loss = train_epoch(model, train_loader, optimizer, criterion, device)
        print(f"Epoch {epoch+1}: Loss = {loss:.4f}")
    
    # Sauvegarder
    torch.save(model.state_dict(), 'best_model.pt')

if __name__ == '__main__':
    main()

Training: 100%|██████████| 1843/1843 [09:47<00:00,  3.14it/s]


Epoch 1: Loss = 0.0419


Training:  90%|████████▉ | 1652/1843 [09:32<01:02,  3.08it/s]