In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/jigsaw-toxic-comment-classification-challenge/train.csv.zip
/kaggle/input/jigsaw-toxic-comment-classification-challenge/sample_submission.csv.zip
/kaggle/input/jigsaw-toxic-comment-classification-challenge/test_labels.csv.zip
/kaggle/input/jigsaw-toxic-comment-classification-challenge/test.csv.zip


# 0. Importing Libraries

In [None]:
#Main
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


#Utilities
import random
import re
import pickle
from tqdm.auto import tqdm


#Augmentation
import nltk
nltk.download('wordnet', quiet=True)
nltk.download('omw-1.4', quiet=True)
nltk.download('punkt')
from nltk.corpus import wordnet


#sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score, precision_recall_curve


#torch
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
import torch.nn as nn
import torch.nn.functional as F

#Hugging Face
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import set_seed
from transformers import get_linear_schedule_with_warmup


#### Random States:

In [None]:
# Python and Numpy
seed = 42
random.seed(seed)
np.random.seed(seed)

# Save states (optional, for later restoration)
python_state = random.getstate()
numpy_state = np.random.get_state()

In [None]:
# torch 
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # For multi-GPU
torch.backends.cudnn.deterministic = True  # Slower but reproducible
torch.backends.cudnn.benchmark = False

# Save RNG states
torch_rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None

In [None]:
#Hugging Face
set_seed(seed)

In [None]:
#Saving all random states

random_states = {
    "python": random.getstate(),
    "numpy": np.random.get_state(),
    "torch_cpu": torch.get_rng_state(),
    "torch_cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
    "sklearn_seed": seed  # For train_test_split
}

# Save to file
with open("random_states.pkl", "wb") as f:
    pickle.dump(random_states, f)

# 1. Load and Preprocess the Data:

In [None]:
df_train = pd.read_csv('/kaggle/input/jigsaw-toxic-comment-classification-challenge/train.csv.zip')
df_test = pd.read_csv('/kaggle/input/jigsaw-toxic-comment-classification-challenge/test.csv.zip')
df_test_labels = pd.read_csv('/kaggle/input/jigsaw-toxic-comment-classification-challenge/test_labels.csv.zip')

In [None]:
df_train

In [None]:
df_test

In [None]:
df_train.loc[df_train['toxic'] == 1, ['comment_text']].iloc[0]

In [None]:
# Define a function to remove punctuation using regular expressions
def remove_punctuation(text):
    return re.sub(r'[^\w\s]', '', text)

# Apply the function to the 'text' column
df_train['comment_text'] = df_train['comment_text'].apply(remove_punctuation)

# Define a function to remove special characters using regular expressions
def remove_special_characters(text):
    # Define a regular expression pattern to match special characters
    pattern = r'[^a-zA-Z0-9\s]'  # This pattern matches any character that is not a letter, digit, or whitespace
    return re.sub(pattern, '', text)

# Apply the function to the 'text' column
df_train['comment_text'] = df_train['comment_text'].apply(remove_special_characters)

### Data Augmentation

In [None]:
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']


label_counts = df_train[label_cols].sum()
max_class_count = label_counts.max()

print("Original class distribution:\n", label_counts)

In [None]:
def setup_nltk():
    """Handle NLTK data download and path configuration"""
    try:
        nltk_data_dir = '/kaggle/working/nltk_data'
        os.makedirs(nltk_data_dir, exist_ok=True)
        
        if not nltk.data.find('corpora/wordnet'):
            nltk.download('wordnet', download_dir=nltk_data_dir)
            nltk.download('omw-1.4', download_dir=nltk_data_dir)
        
        nltk.data.path.append(nltk_data_dir)
        wordnet.ensure_loaded()
        return True
    except Exception as e:
        print(f"Failed to setup NLTK: {str(e)}")
        return False

def safe_synonym_replacement(text, n=2):
    """Robust synonym replacement with full error handling"""
    if not isinstance(text, str) or not text.strip():
        return text
        
    try:
        words = text.split()
        if len(words) == 0:
            return text
            
        replaceable_words = []
        for word in words:
            try:
                if wordnet.synsets(word):
                    replaceable_words.append(word)
            except:
                continue
                
        if not replaceable_words:
            return text
            
        words_to_replace = random.sample(replaceable_words, min(n, len(replaceable_words)))
        
        for word in words_to_replace:
            try:
                synonyms = []
                for syn in wordnet.synsets(word):
                    for lemma in syn.lemmas():
                        lemma_name = lemma.name().replace('_', ' ')
                        if lemma_name.lower() != word.lower():
                            synonyms.append(lemma_name)
                if synonyms:
                    synonym = random.choice(list(set(synonyms)))
                    text = text.replace(word, synonym, 1)
            except:
                continue
                
        return text
    except Exception as e:
        print(f"Error in synonym replacement: {str(e)}")
        return text

def balance_dataset(df, label_cols):
    """Main augmentation function with complete error handling"""
    if not isinstance(df, pd.DataFrame):
        raise ValueError("Input must be a pandas DataFrame")
        
    missing_cols = [col for col in ['comment_text'] + label_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")
    
    try:
        label_counts = df[label_cols].sum()
        max_count = label_counts.max()
        print("Original distribution:\n", label_counts)
        
        augmented = []
        for label in label_cols:
            needed = max(0, int(0.5 * (max_count - label_counts[label])))
            if needed <= 0:
                continue
                
            samples = df[df[label] == 1]
            if len(samples) == 0:
                continue
                
            print(f"Augmenting {label} (+{needed})")
            
            for _ in range(needed):
                try:
                    sample = samples.sample(1).iloc[0]
                    new_sample = sample.copy()
                    new_sample['comment_text'] = safe_synonym_replacement(sample['comment_text'])
                    augmented.append(new_sample)
                except Exception as e:
                    print(f"Skipping sample due to error: {str(e)}")
                    continue
                    
        if augmented:
            return pd.concat([df, pd.DataFrame(augmented)], ignore_index=True)
        return df.copy()
    except Exception as e:
        print(f"Fatal error in balancing: {str(e)}")
        return df.copy()

# ===== MAIN EXECUTION =====
if not setup_nltk():
    print("Warning: Proceeding without WordNet - augmentation will be limited")

try:
    # Load your data (replace with your actual loading code)
    # df_train = pd.read_csv('your_data.csv')
    
    # Define your target labels
    label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    
    # Run augmentation
    df_balanced = balance_dataset(df_train, label_cols)
    
    # Verify results
    print("\nNew distribution:")
    print(df_balanced[label_cols].sum())
    
    # Save if needed
    # df_balanced.to_csv('balanced_data.csv', index=False)
    
except Exception as e:
    print(f"Fatal error in main execution: {str(e)}")
    # Fallback to original data if complete failure
    df_balanced = df_train.copy()


In [None]:
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

label_counts = df_train[label_cols].sum()
max_class_count = label_counts.max()

print("Original class distribution:\n", label_counts)

label_counts = df_balanced[label_cols].sum()
max_class_count = label_counts.max()

print("Augmented class distribution:\n", label_counts)

In [None]:
df_train['comment_text'].duplicated().sum()

In [None]:
df_train[df_train['comment_text'].duplicated()]

In [None]:
def find_duplicates_with_word(df, word, label=None, show_samples=5):
    """
    Find duplicate comments containing specific words, optionally filtered by label
    
    Parameters:
    - df: Your DataFrame
    - word: Word/phrase to search for
    - label: Optional specific label to check (e.g., 'threat')
    - show_samples: Number of examples to display
    """
    # Get all duplicates
    duplicates = df[df['comment_text'].duplicated(keep=False)]
    
    # Filter for word
    mask = duplicates['comment_text'].str.contains(word, case=False, regex=False)
    
    # Optional label filter
    if label:
        mask &= (duplicates[label] == 1)
    
    results = duplicates[mask]
    
    # Display findings
    print(f"\nFound {len(results)} duplicates containing '{word}'", 
          f"(with {label}=1)" if label else "")
    
    if not results.empty:
        display(results[['comment_text'] + label_cols].head(show_samples))
    
    return results

# Example usage:
threat_dupes = find_duplicates_with_word(df_train, "\n\n A barnstar for you \n\n The Original Bar", label='threat')
toxic_dupes = find_duplicates_with_word(df_train, "barnstar", label='toxic')

In [None]:
df_train.drop_duplicates(inplace = True)

In [None]:
df_keep = df_balanced.copy()

In [None]:
label_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']


label_counts = df_train[label_cols].sum()
max_class_count = label_counts.max()

print("Original class distribution:\n", label_counts)

In [None]:
comments = df_train["comment_text"].tolist()
labels = df_train[["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]].values


train_comments, val_comments, train_labels, val_labels = train_test_split(
    comments[:50], labels[:50], test_size=0.2, random_state=seed
)

# 2. Tokenization:

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")  # Changed to BERT

def tokenize(texts):
    return tokenizer(
        texts, 
        padding=True, 
        truncation=True, 
        max_length=128, 
        return_tensors="pt"
    )

# Tokenize data (works exactly the same way)
train_encodings = tokenize(train_comments)
val_encodings = tokenize(val_comments)


# 3. PyTorch Dataset:

In [None]:
class ToxicDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx],
            "labels": torch.FloatTensor(self.labels[idx])
        }

    def __len__(self):
        return len(self.labels)

train_dataset = ToxicDataset(train_encodings, train_labels)
val_dataset = ToxicDataset(val_encodings, val_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

# 4. BERT Model:

### a. Load Pre-trained Roberta Model:

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=6,
    problem_type="multi_label_classification"
).to("cuda" if torch.cuda.is_available() else "cpu")


#Ems7ha lw 3awz trg3 el adeem
class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        
    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        loss = ((1 - pt) ** self.gamma * bce_loss).mean()
        return loss

# Set Focal Loss as the model's loss function
model.loss_fct = FocalLoss(gamma=2).to("cuda" if torch.cuda.is_available() else "cpu")


### b. Training Loop

In [None]:
# Hyperparameters
epochs = 10 # Increased to allow early stopping to work
lr = 1e-3
warmup_steps = 100
max_grad_norm = 1.0
patience = 3  # Number of epochs to wait before stopping

# Training Setup
optimizer = AdamW(model.parameters(), lr=lr)

# Learning rate schedule
def lr_lambda(current_step):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))
    return 1.0

scheduler = LambdaLR(optimizer, lr_lambda)

# Tracking
best_metrics = {
    'val_loss': float('inf'),
    'weights': None,
    'epoch': -1
}
history = []
epochs_without_improvement = 0  # Early stopping counter

for epoch in range(epochs):
    # --- Training Phase ---
    model.train()
    train_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs} [Train]", leave=False)
    
    for batch in progress_bar:
        optimizer.zero_grad()
        inputs = {k: v.to(model.device) for k, v in batch.items() if k != "labels"}
        labels = batch["labels"].to(model.device)
        
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()
        
        train_loss += loss.item()
        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
    
    avg_train_loss = train_loss / len(train_loader)

    # --- Validation Phase ---
    avg_val_loss = 0
    model.eval()
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating", leave=False):
            inputs = {k: v.to(model.device) for k, v in batch.items() if k != "labels"}
            labels = batch["labels"].to(model.device)
            
            outputs = model(**inputs, labels=labels)
            avg_val_loss += outputs.loss.item()
    
    avg_val_loss /= len(val_loader)

    # --- Early Stopping Check ---
    if avg_val_loss < best_metrics['val_loss']:
        best_metrics.update({
            'val_loss': avg_val_loss,
            'weights': model.state_dict().copy(),
            'epoch': epoch + 1
        })
        torch.save(best_metrics['weights'], "best_model.pt")
        epochs_without_improvement = 0  # Reset counter
        print(f"↳ New best model saved! (Loss: {avg_val_loss:.4f})")
    else:
        epochs_without_improvement += 1
        print(f"↳ No improvement ({epochs_without_improvement}/{patience})")
        
        if epochs_without_improvement >= patience:
            print(f"\nEarly stopping triggered at epoch {epoch + 1}!")
            print(f"Best model was from epoch {best_metrics['epoch']} with val_loss {best_metrics['val_loss']:.4f}")
            break

    # --- Progress Tracking ---
    history.append({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'early_stop_counter': epochs_without_improvement
    })
    
    print(f"\nEpoch {epoch + 1} Results:")
    print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

# Final save if no best model was found
if not os.path.exists("best_model.pt"):
    torch.save(model.state_dict(), "final_model.pt")
    print("Saved final model weights (no improvement during training)")

# 5. Model Evaluation:

### a. Calculating Metrics

In [None]:
model.eval()  # Set model to evaluation mode
val_preds = []
val_true = []

with torch.no_grad():
    for batch in val_loader:  # Use validation DataLoader
        inputs = {k: v.to(model.device) for k, v in batch.items() if k != "labels"}
        labels = batch["labels"].cpu().numpy()
        outputs = model(**inputs)
        probs = torch.sigmoid(outputs.logits).cpu().numpy()
        val_preds.extend(probs)
        val_true.extend(labels)

val_preds = np.array(val_preds)
val_true = np.array(val_true)

In [None]:
# Classification Report (threshold to be controlled)
print(classification_report(
    val_true, 
    val_preds > 0.5,  # Binary predictions
    target_names=['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
))

### b. Visualizing:

In [None]:
plt.figure(figsize=(10, 6))
for i, label in enumerate(['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']):
    precision, recall, _ = precision_recall_curve(val_true[:, i], val_preds[:, i])
    plt.plot(recall, precision, label=label)
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend()
plt.title("Precision-Recall Curves")
plt.show()

### c. Finding Optimal Threshold:

In [None]:
optimal_thresholds = []
for i in range(6):
    precision, recall, thresholds = precision_recall_curve(val_true[:, i], val_preds[:, i])
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-9)
    optimal_thresholds.append(thresholds[np.argmax(f1_scores)])
print(f"Optimal Thresholds: {optimal_thresholds}")

# 6. Making Predictions:

In [None]:
# 1. Tokenize test data
test_encodings = tokenizer(
    df_test["comment_text"].tolist(),
    padding=True,
    truncation=True,
    max_length=128,
    return_tensors="pt"
)

In [None]:
# 2. Define Dataset class
class TestDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx]
        }

    def __len__(self):
        return len(self.encodings["input_ids"])

In [None]:
# 3. Create Dataset and Loader
test_dataset = TestDataset(test_encodings)
loader = DataLoader(test_dataset, batch_size=32)

In [None]:
# 4. Run inference
model.eval()
all_probs = []

with torch.no_grad():
    for batch in tqdm(loader, desc="Processing"):
        inputs = {k: v.to(model.device) for k, v in batch.items()}
        outputs = model(**inputs)
        probs = torch.sigmoid(outputs.logits).cpu().numpy()
        all_probs.extend(probs)

        # Memory cleanup
        del inputs, outputs, batch
        torch.cuda.empty_cache()

# 5. Final predictions array
probs = np.vstack(all_probs)

In [None]:
def find_optimal_thresholds(val_true, val_preds, toxicity_classes):
    """
    Calculate optimal thresholds maximizing F1 for each class
    
    Args:
        val_true: Array of true labels (n_samples × n_classes)
        val_preds: Array of predicted probabilities (n_samples × n_classes)
        toxicity_classes: List of class names
        
    Returns:
        Dictionary of {class_name: optimal_threshold}
    """
    optimal_thresholds = {}
    
    for i, class_name in enumerate(toxicity_classes):
        # Get precision-recall curve for this class
        precision, recall, thresholds = precision_recall_curve(
            val_true[:, i], 
            val_preds[:, i]
        )
        
        # Calculating F1 scores
        f1_scores = 2 * (precision * recall) / (precision + recall + 1e-9)
        
        # Finding threshold with max F1
        optimal_idx = np.argmax(f1_scores)
        optimal_threshold = thresholds[optimal_idx]
        
        # Store optimal threshold in dictionary
        optimal_thresholds[class_name] = optimal_threshold
    
    return optimal_thresholds

# Calling the function
toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
optimal_thresholds = find_optimal_thresholds(val_true, val_preds, toxicity_classes)

# Resulting dictionary
print(optimal_thresholds)

In [None]:
# Optimal thresholds (tune these on your validation set)
class_thresholds = optimal_thresholds

# Add probability columns (unchanged)
for i, col in enumerate(toxicity_classes):
    df_test[f"{col}_prob"] = probs[:, i]
    
# Apply custom thresholds for binary predictions
for col in toxicity_classes:
    df_test[f"{col}_pred"] = (df_test[f"{col}_prob"] > class_thresholds[col]).astype(int)

# Check class distribution after thresholding
print("\nPredicted Positive Rates:")
for col in toxicity_classes:
    rate = df_test[f"{col}_pred"].mean()
    print(f"{col}: {rate:.2%} (threshold={class_thresholds[col]})")

In [None]:
probs[:, i]

In [None]:
df_test.head(10)

In [None]:
submission = df_test.drop(['comment_text'], axis = 1)
for col in toxicity_classes:
    submission.drop([f"{col}_prob"], inplace = True, axis = 1)

submission.to_csv('submission.csv', index=False)

In [None]:
model.save_pretrained('bert_model_weights')

In [None]:
submission

# 99. Try

#### Single Sentence Prediction:

In [None]:
'''
# Load your trained model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=6)
model.load_state_dict(torch.load("best_model.pt"))  # or "final_model.pt"
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
'''

# Move model to device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()  # Set to evaluation mode

# Your input sentence
sentence = "SLUT"

# Tokenize and predict
inputs = tokenizer(sentence, 
                 padding=True, 
                 truncation=True, 
                 max_length=128, 
                 return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model(**inputs)
    probs = torch.sigmoid(outputs.logits).cpu().numpy()[0]  # Get probabilities

toxicity_classes = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]


# results
print(sentence)

print("Predicted probabilities:")
for cls, prob in zip(toxicity_classes, probs):
    print(f"{cls}: {prob:.4f}")

# binary predictions
binary_preds = (probs > 0.5).astype(int)
print("\nBinary predictions (threshold=0.5):")
for cls, pred in zip(toxicity_classes, binary_preds):
    print(f"{cls}: {'✅' if pred else '❌'}")

#### Multi-Sentence Prediction:

In [None]:
sentences = [
    "You're stupid!",
    "Thanks for your help",
    "Go back to your country"
]

# Tokenize batch
inputs = tokenizer(sentences, 
                 padding=True, 
                 truncation=True, 
                 max_length=128, 
                 return_tensors="pt").to(device)

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    all_probs = torch.sigmoid(outputs.logits).cpu().numpy()

# Display results
for i, sentence in enumerate(sentences):
    print(f"\nSentence: '{sentence}'")
    for cls, prob in zip(toxicity_classes, all_probs[i]):
        print(f"{cls}: {prob:.4f}")

## For Threat Problem:

### a. Focal Loss

In [None]:
'''
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, num_classes=6):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.num_classes = num_classes
        self.ce_loss = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        loss = self.ce_loss(inputs, targets)
        p_t = torch.exp(-loss)
        focal_loss = self.alpha * (1 - p_t) ** self.gamma * loss
        return focal_loss.mean()

class CustomBERTForSequenceClassificationWithFocalLoss(BertForSequenceClassification):
    def __init__(self, config, focal_loss_alpha=0.25, focal_loss_gamma=2):
        super().__init__(config)
        self.focal_loss = FocalLoss(alpha=focal_loss_alpha, gamma=focal_loss_gamma)

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, **kwargs):
        # Call the parent model's forward method
        outputs = super().forward(input_ids=input_ids, 
                                  attention_mask=attention_mask, 
                                  token_type_ids=token_type_ids, 
                                  **kwargs)
        logits = outputs.logits

        # Compute loss if labels are provided
        if labels is not None:
            loss = self.focal_loss(logits, labels)
            return (loss, outputs)
        else:
            return outputs

# Example usage:
model = CustomBERTForSequenceClassificationWithFocalLoss.from_pretrained("bert-base-uncased", num_labels=6).to("cuda" if torch.cuda.is_available() else "cpu")
'''

### b. Class Weights

In [None]:
'''
# Calculate class weights (inverse of class frequencies)
class_counts = np.array([sum(train_labels[:, i]) for i in range(6)])  # Count per class
class_weights = torch.tensor(
    (1.0 / (class_counts + 1e-6)) * (len(train_labels)/6),  # Normalize
    dtype=torch.float32,
    device=device
)

# Modify your model initialization
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=6,
    problem_type="multi_label_classification"
)
model.loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights)
'''

### c. Focal Loss

In [None]:
'''
class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        
    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        loss = ((1-pt)**self.gamma * bce_loss).mean()
        return loss

model.loss_fct = FocalLoss(gamma=2).to(device)
'''

### d. Data Loader

In [None]:
'''
from torch.utils.data import WeightedRandomSampler

# Calculate sample weights (higher for threat-containing samples)
sample_weights = torch.where(
    train_labels[:, 3] == 1,  # Threat is index 3
    torch.tensor(50.0),       # 50x higher sampling for threats
    torch.tensor(1.0)
)

sampler = WeightedRandomSampler(
    sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

# Modify your DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    sampler=sampler,  # Replaces shuffle=True
    num_workers=4
)
'''

In [None]:
# Augmentation
'''
# Your settings
TARGET_LABELS = {
    'severe_toxic': 5000,
    'threat': 10000, 
    'identity_hate': 8000
}

# Track existing texts to prevent duplicates
existing_texts = set(df_train['comment_text'].tolist())

def augment_text(text, n=5):
    """Your original augmentation function with duplicate check"""
    original_text = text
    for _ in range(n):
        try:
            words = text.split()
            replaceable = [w for w in words if wordnet.synsets(w)]
            if not replaceable:
                break
                
            word = random.choice(replaceable)
            synonyms = wordnet.synsets(word)
            if synonyms:
                lemma = random.choice([l for s in synonyms for l in s.lemmas()])
                new_text = text.replace(word, lemma.name().replace('_', ' '), 1)
                if new_text != original_text:  # Only keep meaningful changes
                    text = new_text
        except:
            continue
    return text if text != original_text else None  # Return None if no changes made

# Main augmentation
augmented = []
for label, count in tqdm(TARGET_LABELS.items(), desc="Augmenting"):
    samples = df_train[df_train[label] == 1]
    if len(samples) == 0:
        continue
        
    with tqdm(total=count, desc=f"{label}") as pbar:
        generated = 0
        while generated < count:
            original = samples.sample(1).iloc[0]
            new_text = augment_text(original['comment_text'])
            
            if new_text and new_text not in existing_texts:
                new_row = original.copy()
                new_row['comment_text'] = new_text
                augmented.append(new_row)
                existing_texts.add(new_text)
                generated += 1
                pbar.update(1)

# Final deduplication (just in case)
augmented = [dict(t) for t in {tuple(d.items()) for d in augmented}]

# Combine results
if augmented:
    df_train = pd.concat([
        df_train,
        pd.DataFrame(augmented)
    ], ignore_index=True)
'''