In [None]:
import sys
from pathlib import Path

project_root = str(Path.cwd().parents[3])
if project_root not in sys.path:
    sys.path.append(project_root)

In [1]:
import pandas as pd
import os
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from networks import *
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report,f1_score,multilabel_confusion_matrix
from sklearn.preprocessing import MultiLabelBinarizer
from torch.optim import Adam, AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau,OneCycleLR
import torch
from torch.nn.utils import clip_grad_norm_
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.nn import BCELoss
from tqdm import tqdm
import h5py

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
H5_PATH = os.path.join(Path.cwd().parent.parent.parent.parent, 'data', 'stackexchange_embeddings_tokenized.h5')

In [4]:
emb = pd.read_pickle(os.path.join(Path.cwd().parent,'data','stackexchange_reduced_tags_embeddings.pkl'))

In [5]:
mlb = MultiLabelBinarizer()
num_tags = np.vstack(emb['tags'].apply(len).values.reshape(-1) / 5.0)
y = mlb.fit_transform(emb['tags'])
indices = emb.index.tolist()

In [6]:
full_dataset = LazyDSFDataset(
    h5_path=H5_PATH, 
    num_tags_list=num_tags,
    binary_labels=y
)

In [7]:
# with h5py.File(H5_PATH, "r") as f:
#     print(f['question_ids'][1040])
#     print(f['body_seq'][1040])


In [8]:
train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size
generator = torch.Generator().manual_seed(42)
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size], generator=generator)

In [9]:
train_loader = DataLoader(
    train_dataset, 
    batch_size=64, 
    shuffle=True, 
)
val_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False
)

In [10]:
all_labels = full_dataset.binary_labels[train_dataset.indices]
num_positives = np.sum(all_labels, axis=0)
num_negatives = len(all_labels) - num_positives
# pos_weights = torch.tensor(num_negatives / (num_positives + 1e-5), dtype=torch.float32).to(device)
pos_weights = torch.tensor(np.sqrt(num_negatives / (num_positives + 1e-5)), dtype=torch.float32).to(device)

In [11]:
num_epochs = 100
model = DSF_Sequence_Aware_Classifier(hidden_dim=512, dropout=0.5).to(device)
optimizer = AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = OneCycleLR(optimizer, max_lr=3e-4, 
                                          steps_per_epoch=len(train_loader), 
                                          epochs=num_epochs+1,pct_start=0.3)
# criterion = AsymmetricLoss(gamma_neg=2, gamma_pos=1)
# criterion = BCEWithLogitsLoss(pos_weight=pos_weights)
# Asymmetric loss didn't help here, it underperformed BCEWithLogitsLoss because it overemphasized negative samples
# criterion = FocalLoss(alpha=0.25, gamma=2.0)
criterion = FocalLossSmooth(alpha=0.25, gamma=2.0, smoothing=0.05)

In [12]:
mixup_alpha = 0.4

def optimize_threshold(probs, targets):
    best_t = 0.5
    best_f1 = 0.0
    
    # Search thresholds from 0.1 to 0.9
    for t in np.arange(0.1, 0.9, 0.05):
        preds = (probs > t).astype(int)
        f1 = f1_score(targets, preds, average='macro')
        if f1 > best_f1:
            best_f1 = f1
            best_t = t
    return best_t, best_f1

def mixup_data(t_emb, b_emb, n_tags, y, alpha=0.4):
    """
    Applies Mixup to the embeddings and generates paired targets.
    """
    # 1. Generate Mixing Coefficient (Lambda)
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = t_emb.size(0)
    
    # 2. Generate Permutation Indices
    # Use the device of the input tensor to ensure compatibility
    index = torch.randperm(batch_size).to(t_emb.device)

    # 3. Create Mixed Inputs
    # Formula: mixed = λ * original + (1 - λ) * shuffled
    mixed_t = lam * t_emb + (1 - lam) * t_emb[index]
    mixed_b = lam * b_emb + (1 - lam) * b_emb[index]
    mixed_n = lam * n_tags + (1 - lam) * n_tags[index]
    
    # 4. Return mixed inputs and the two target sets
    y_a, y_b = y, y[index]
    return mixed_t, mixed_b, mixed_n, y_a, y_b, lam

def mixup_criterion(criterion, logits, y_a, y_b, lam):
    """
    Calculates the loss as the weighted sum of losses for both targets.
    """
    return lam * criterion(logits, y_a) + (1 - lam) * criterion(logits, y_b)

In [None]:
patience_best_val_f1 = 0.0
patience = 10
trigger_times = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")

    for t_emb, b_emb, padding_mask, n_tags, labels in train_bar:
        t_emb, b_emb = t_emb.to(device), b_emb.to(device)
        padding_mask = padding_mask.to(device).bool()
        n_tags, labels = n_tags.to(device), labels.to(device)

        optimizer.zero_grad()

        # --- Mixup Augmentation ---
        if epoch < (num_epochs - 5) or trigger_times <= 5:
            mixed_t, mixed_b, mixed_n, y_a, y_b, lam = mixup_data(
                t_emb, b_emb, n_tags, labels, alpha=mixup_alpha
            )
            logits, attn_weights = model(mixed_t, mixed_b, padding_mask, mixed_n)
            loss = mixup_criterion(criterion, logits, y_a, y_b, lam)
        else:          
            logits, attn_weights = model(t_emb, b_emb, padding_mask, n_tags)
            
            loss = criterion(logits, labels)

        loss.backward()
        
        clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        scheduler.step() 
        
        running_loss += loss.item()
        train_bar.set_postfix(loss=loss.item())

    avg_train_loss = running_loss / len(train_loader)
    
    # --- Validation ---
    model.eval()
    val_targets = []
    val_probs = []
    
    with torch.no_grad():
        for t_emb, b_emb,padding_mask, n_tags, labels in val_loader:
            t_emb = t_emb.to(device)
            b_emb = b_emb.to(device)
            padding_mask = padding_mask.to(device)
            n_tags = n_tags.to(device)
            labels = labels.cpu().numpy()
            
            logits, attn_weights = model(t_emb, b_emb, padding_mask, n_tags)
            
            probs = torch.sigmoid(logits).cpu().numpy()
            
            val_targets.append(labels)
            val_probs.append(probs)
            
    val_targets = np.vstack(val_targets)
    val_probs = np.vstack(val_probs)

    best_threshold, best_val_f1 = optimize_threshold(val_probs, val_targets)
    
    val_preds_final = (val_probs > best_threshold).astype(int)
    
    val_f1_micro = f1_score(val_targets, val_preds_final, average='micro')
    val_f1_macro = f1_score(val_targets, val_preds_final, average='macro')
    
    print(f"Epoch {epoch+1} Results:")
    print(f"Optimal Threshold: {best_threshold:.2f} -> New Val F1: {best_val_f1:.4f}")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val F1 (Micro): {val_f1_micro:.4f}")
    print(f"  Val F1 (Macro): {val_f1_macro:.4f}")

    if val_f1_micro > patience_best_val_f1:
        patience_best_val_f1 = val_f1_micro
        torch.save(model.state_dict(), "best_dsf_model.pth")
        trigger_times = 0
        print("  -> New Best Model Saved!")
    else:
        trigger_times += 1
        print(f"  -> No improvement. Patience: {trigger_times}/{patience}")
        if trigger_times >= patience:
            print("Early stopping triggered!")
            break
    

Epoch 1/100 [Train]: 100%|██████████| 1407/1407 [02:25<00:00,  9.66it/s, loss=0.00692]


Epoch 1 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.4107
  Train Loss: 0.0175
  Val F1 (Micro): 0.5139
  Val F1 (Macro): 0.4107
  -> New Best Model Saved!


Epoch 2/100 [Train]: 100%|██████████| 1407/1407 [02:25<00:00,  9.67it/s, loss=0.00664]


Epoch 2 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.5250
  Train Loss: 0.0066
  Val F1 (Micro): 0.5890
  Val F1 (Macro): 0.5250
  -> New Best Model Saved!


Epoch 3/100 [Train]: 100%|██████████| 1407/1407 [02:20<00:00, 10.02it/s, loss=0.00601]


Epoch 3 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.5596
  Train Loss: 0.0058
  Val F1 (Micro): 0.6080
  Val F1 (Macro): 0.5596
  -> New Best Model Saved!


Epoch 4/100 [Train]: 100%|██████████| 1407/1407 [02:10<00:00, 10.82it/s, loss=0.00303]


Epoch 4 Results:
Optimal Threshold: 0.35 -> New Val F1: 0.5758
  Train Loss: 0.0054
  Val F1 (Micro): 0.6004
  Val F1 (Macro): 0.5758
  -> No improvement. Patience: 1/10


Epoch 5/100 [Train]: 100%|██████████| 1407/1407 [02:12<00:00, 10.63it/s, loss=0.00353]


Epoch 5 Results:
Optimal Threshold: 0.35 -> New Val F1: 0.5864
  Train Loss: 0.0052
  Val F1 (Micro): 0.6104
  Val F1 (Macro): 0.5864
  -> New Best Model Saved!


Epoch 6/100 [Train]: 100%|██████████| 1407/1407 [02:09<00:00, 10.84it/s, loss=0.00461]


Epoch 6 Results:
Optimal Threshold: 0.35 -> New Val F1: 0.5920
  Train Loss: 0.0051
  Val F1 (Micro): 0.6147
  Val F1 (Macro): 0.5920
  -> New Best Model Saved!


Epoch 7/100 [Train]: 100%|██████████| 1407/1407 [02:12<00:00, 10.59it/s, loss=0.00716]


Epoch 7 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.5986
  Train Loss: 0.0050
  Val F1 (Micro): 0.6364
  Val F1 (Macro): 0.5986
  -> New Best Model Saved!


Epoch 8/100 [Train]: 100%|██████████| 1407/1407 [02:10<00:00, 10.79it/s, loss=0.00491]


Epoch 8 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.5979
  Train Loss: 0.0049
  Val F1 (Micro): 0.6374
  Val F1 (Macro): 0.5979
  -> New Best Model Saved!


Epoch 9/100 [Train]: 100%|██████████| 1407/1407 [02:11<00:00, 10.72it/s, loss=0.00422]


Epoch 9 Results:
Optimal Threshold: 0.35 -> New Val F1: 0.6004
  Train Loss: 0.0048
  Val F1 (Micro): 0.6224
  Val F1 (Macro): 0.6004
  -> No improvement. Patience: 1/10


Epoch 10/100 [Train]: 100%|██████████| 1407/1407 [02:12<00:00, 10.59it/s, loss=0.00696]


Epoch 10 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6066
  Train Loss: 0.0048
  Val F1 (Micro): 0.6422
  Val F1 (Macro): 0.6066
  -> New Best Model Saved!


Epoch 11/100 [Train]: 100%|██████████| 1407/1407 [02:14<00:00, 10.44it/s, loss=0.0034] 


Epoch 11 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6011
  Train Loss: 0.0048
  Val F1 (Micro): 0.6391
  Val F1 (Macro): 0.6011
  -> No improvement. Patience: 1/10


Epoch 12/100 [Train]: 100%|██████████| 1407/1407 [02:10<00:00, 10.78it/s, loss=0.00565]


Epoch 12 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6067
  Train Loss: 0.0047
  Val F1 (Micro): 0.6426
  Val F1 (Macro): 0.6067
  -> New Best Model Saved!


Epoch 13/100 [Train]: 100%|██████████| 1407/1407 [02:20<00:00, 10.04it/s, loss=0.00627]


Epoch 13 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6094
  Train Loss: 0.0046
  Val F1 (Micro): 0.6432
  Val F1 (Macro): 0.6094
  -> New Best Model Saved!


Epoch 14/100 [Train]: 100%|██████████| 1407/1407 [02:18<00:00, 10.17it/s, loss=0.00488]


Epoch 14 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6085
  Train Loss: 0.0046
  Val F1 (Micro): 0.6434
  Val F1 (Macro): 0.6085
  -> New Best Model Saved!


Epoch 15/100 [Train]: 100%|██████████| 1407/1407 [02:19<00:00, 10.07it/s, loss=0.00645]


Epoch 15 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6088
  Train Loss: 0.0046
  Val F1 (Micro): 0.6437
  Val F1 (Macro): 0.6088
  -> New Best Model Saved!


Epoch 16/100 [Train]: 100%|██████████| 1407/1407 [02:25<00:00,  9.64it/s, loss=0.00333]


Epoch 16 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6138
  Train Loss: 0.0045
  Val F1 (Micro): 0.6434
  Val F1 (Macro): 0.6138
  -> No improvement. Patience: 1/10


Epoch 17/100 [Train]: 100%|██████████| 1407/1407 [02:28<00:00,  9.46it/s, loss=0.00423]


Epoch 17 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6103
  Train Loss: 0.0045
  Val F1 (Micro): 0.6443
  Val F1 (Macro): 0.6103
  -> New Best Model Saved!


Epoch 18/100 [Train]: 100%|██████████| 1407/1407 [02:30<00:00,  9.33it/s, loss=0.00891]


Epoch 18 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6171
  Train Loss: 0.0045
  Val F1 (Micro): 0.6456
  Val F1 (Macro): 0.6171
  -> New Best Model Saved!


Epoch 19/100 [Train]: 100%|██████████| 1407/1407 [02:27<00:00,  9.53it/s, loss=0.00402]


Epoch 19 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6143
  Train Loss: 0.0044
  Val F1 (Micro): 0.6460
  Val F1 (Macro): 0.6143
  -> New Best Model Saved!


Epoch 20/100 [Train]: 100%|██████████| 1407/1407 [02:16<00:00, 10.29it/s, loss=0.00476]


Epoch 20 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6176
  Train Loss: 0.0044
  Val F1 (Micro): 0.6471
  Val F1 (Macro): 0.6176
  -> New Best Model Saved!


Epoch 21/100 [Train]: 100%|██████████| 1407/1407 [02:19<00:00, 10.08it/s, loss=0.00524]


Epoch 21 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6152
  Train Loss: 0.0044
  Val F1 (Micro): 0.6469
  Val F1 (Macro): 0.6152
  -> No improvement. Patience: 1/10


Epoch 22/100 [Train]: 100%|██████████| 1407/1407 [02:14<00:00, 10.46it/s, loss=0.00609]


Epoch 22 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6131
  Train Loss: 0.0044
  Val F1 (Micro): 0.6452
  Val F1 (Macro): 0.6131
  -> No improvement. Patience: 2/10


Epoch 23/100 [Train]: 100%|██████████| 1407/1407 [02:21<00:00,  9.92it/s, loss=0.00297]


Epoch 23 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6157
  Train Loss: 0.0043
  Val F1 (Micro): 0.6444
  Val F1 (Macro): 0.6157
  -> No improvement. Patience: 3/10


Epoch 24/100 [Train]: 100%|██████████| 1407/1407 [02:20<00:00, 10.02it/s, loss=0.00733]


Epoch 24 Results:
Optimal Threshold: 0.40 -> New Val F1: 0.6150
  Train Loss: 0.0043
  Val F1 (Micro): 0.6441
  Val F1 (Macro): 0.6150
  -> No improvement. Patience: 4/10


Epoch 25/100 [Train]:  12%|█▏        | 166/1407 [00:16<02:06,  9.80it/s, loss=0.00417]


KeyboardInterrupt: 

The model is too strong for our dataset, effectively memorizing the entire dataset.
First the ASL was tried, but it underperformed compared to BCEWithLogitsLoss with class weights.
Then Focal Loss was tried, which yielded slightly better results than BCEWithLogitsLoss. 
- Focal loss shows real improvemnt in the threshold optimization step, earlier the optimal threshold was around 0.9, indicating very confident predictions. Now it is around 0.4, indicating more balanced predictions.  

Essentially, we've hit *Generalization Ceiling*:
- Train Loss (0.002) is near zero. The model has effectively memorized the training data.
-  Validation F1 (0.64) is stuck.
---
Now we see the need to further increase the difficulty of the task by  **Label Smoothing** techniques.
- For the *Label Smoothing* we modified the *FocalLoss*

After that the loss is 0.003 but f1 didn't improved that much.

**Manifold Mixum** resulted in the same performance as before, model memorized everything.

In [None]:
f1_score(val_targets, val_preds, average='micro')

0.0

Fine tuning the model should help. This time with no mixup for easier task to adapt to the 'pure' data.

In [None]:

optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) 
criterion = FocalLossSmooth(alpha=0.25, gamma=2.0, smoothing=0.05)

ft_epochs = 10
# Early stopping parameter took from previous best 
best_ft_f1 = patience_best_val_f1 

for epoch in range(ft_epochs):
    model.train()
    running_loss = 0.0
    
    train_bar = tqdm(train_loader, desc=f"Fine-Tune Epoch {epoch+1}/{ft_epochs}")

    for t_emb, b_emb, padding_mask, n_tags, labels in train_bar:
        t_emb, b_emb = t_emb.to(device), b_emb.to(device)
        padding_mask = padding_mask.to(device).bool()
        n_tags, labels = n_tags.to(device), labels.to(device)

        optimizer.zero_grad()

        logits, _ = model(t_emb, b_emb, padding_mask, n_tags)
        loss = criterion(logits, labels)

        loss.backward()
        clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        running_loss += loss.item()
        train_bar.set_postfix(loss=loss.item())

    model.eval()
    val_targets = []
    val_probs = []
    
    with torch.no_grad():
        for t_emb, b_emb, padding_mask, n_tags, labels in val_loader:
            t_emb, b_emb = t_emb.to(device), b_emb.to(device)
            padding_mask = padding_mask.to(device)
            n_tags, labels = n_tags.to(device), labels.to(device)
            labels = labels.cpu().numpy()
            
            logits, _ = model(t_emb, b_emb, padding_mask, n_tags)
            probs = torch.sigmoid(logits).cpu().numpy()
            
            val_targets.append(labels)
            val_probs.append(probs)

    val_targets = np.vstack(val_targets)
    val_probs = np.vstack(val_probs)

    best_threshold, current_val_f1 = optimize_threshold(val_probs, val_targets)
    
    print(f"FT Epoch {epoch+1}: Threshold {best_threshold:.2f} | F1: {current_val_f1:.4f}")
    
    if current_val_f1 > best_ft_f1:
        best_ft_f1 = current_val_f1
        torch.save(model.state_dict(), "best_dsf_model_finetuned.pth")
        print("-> Improved!")

Starting Fine-Tuning Phase (No Mixup)...


Fine-Tune Epoch 1/10: 100%|██████████| 1407/1407 [02:44<00:00,  8.57it/s, loss=0.00276]


FT Epoch 1: Threshold 0.30 | F1: 0.6486


Fine-Tune Epoch 2/10: 100%|██████████| 1407/1407 [02:42<00:00,  8.65it/s, loss=0.0021] 


FT Epoch 2: Threshold 0.35 | F1: 0.6502
-> Improved!


Fine-Tune Epoch 3/10: 100%|██████████| 1407/1407 [02:37<00:00,  8.95it/s, loss=0.0031] 


FT Epoch 3: Threshold 0.35 | F1: 0.6500


Fine-Tune Epoch 4/10: 100%|██████████| 1407/1407 [02:37<00:00,  8.93it/s, loss=0.00319]


FT Epoch 4: Threshold 0.35 | F1: 0.6503
-> Improved!


Fine-Tune Epoch 5/10: 100%|██████████| 1407/1407 [02:28<00:00,  9.45it/s, loss=0.00172]


FT Epoch 5: Threshold 0.40 | F1: 0.6505
-> Improved!


Fine-Tune Epoch 6/10: 100%|██████████| 1407/1407 [02:32<00:00,  9.25it/s, loss=0.00221]


FT Epoch 6: Threshold 0.35 | F1: 0.6494


Fine-Tune Epoch 7/10: 100%|██████████| 1407/1407 [02:28<00:00,  9.48it/s, loss=0.00269]


FT Epoch 7: Threshold 0.35 | F1: 0.6498


Fine-Tune Epoch 8/10: 100%|██████████| 1407/1407 [02:29<00:00,  9.39it/s, loss=0.00241]


FT Epoch 8: Threshold 0.35 | F1: 0.6499


Fine-Tune Epoch 9/10: 100%|██████████| 1407/1407 [02:30<00:00,  9.36it/s, loss=0.00241]


FT Epoch 9: Threshold 0.35 | F1: 0.6488


Fine-Tune Epoch 10/10: 100%|██████████| 1407/1407 [02:32<00:00,  9.25it/s, loss=0.00135]


FT Epoch 10: Threshold 0.40 | F1: 0.6488
