<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/_Transformer_Model_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch transformers sklearn --upgrade

In [None]:
pip install nlpaug

In [None]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import numpy as np
import nlpaug.augmenter.word as naw

# Define TextDataset class with augmentation
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len, aug=None):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.aug = aug

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        if self.aug:
            text = self.aug.augment(text)
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Define train_model function
def train_model(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct_preds = 0
    total_preds = 0

    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        total_loss += loss.item()

        # Calculate accuracy
        _, preds = torch.max(logits, dim=1)
        correct_preds += torch.sum(preds == labels)
        total_preds += labels.size(0)

        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_loader)
    accuracy = correct_preds.double() / total_preds
    return avg_loss, accuracy.item()

# Define evaluate_model function
def evaluate_model(model, val_loader, device, criterion):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            total_loss += loss.item()

            # Get predictions
            _, preds = torch.max(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(val_loader)
    report = classification_report(all_labels, all_preds, output_dict=True)
    return avg_loss, report, all_preds

# Cross-validation function with early stopping
def cross_validate_with_scheduler(model, dataset, n_splits, tokenizer, device, max_len, batch_size, learning_rate, num_epochs, aug=None):
    texts = dataset.texts
    labels = dataset.labels
    all_val_labels = []
    all_val_preds = []

    for fold in range(n_splits):
        print(f"Training fold {fold + 1}/{n_splits}")

        train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=fold)

        train_dataset = TextDataset(train_texts, train_labels, tokenizer, max_len, aug=aug)
        val_dataset = TextDataset(val_texts, val_labels, tokenizer, max_len)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)

        model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
        model.to(device)

        optimizer = AdamW(model.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

        best_val_loss = float('inf')
        patience = 3
        early_stop_counter = 0

        # Training and evaluation with early stopping
        for epoch in range(num_epochs):
            print(f"Epoch {epoch + 1}/{num_epochs}")
            train_loss, train_accuracy = train_model(model, train_loader, optimizer, F.cross_entropy, device)
            print(f"Train loss: {train_loss:.4f}, accuracy: {train_accuracy:.4f}")

            val_loss, val_report, val_preds = evaluate_model(model, val_loader, device, F.cross_entropy)
            print(f"Validation loss: {val_loss:.4f}")
            print(val_report)

            scheduler.step(val_loss)

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                early_stop_counter = 0
                torch.save(model.state_dict(), 'best_model_state.bin')
                print("Model saved!")
            else:
                early_stop_counter += 1
                if early_stop_counter >= patience:
                    print("Early stopping triggered")
                    break

        # Collect predictions and labels for each fold
        all_val_labels.extend(val_labels)
        all_val_preds.extend(val_preds)

    # Final classification report across all folds
    print("Final classification report:")
    print(classification_report(all_val_labels, all_val_preds, zero_division=1))

# Main function
def main():
    # Load a larger and realistic dataset
    # Example placeholder texts and labels. Replace with your actual dataset.
    texts = ["This is a positive example"] * 500 + ["This is a negative example"] * 500
    labels = [1] * 500 + [0] * 500

    # Initialize tokenizer and dataset
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    max_len = 128
    batch_size = 8
    learning_rate = 2e-5
    num_epochs = 5
    n_splits = 5  # For cross-validation
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define text augmentation using random swap
    aug = naw.RandomWordAug(action="swap")

    dataset = TextDataset(texts, labels, tokenizer, max_len)

    # Run cross-validation with scheduler
    cross_validate_with_scheduler(
        model=None,  # This will be created within the cross-validation function
        dataset=dataset,
        n_splits=n_splits,
        tokenizer=tokenizer,
        device=device,
        max_len=max_len,
        batch_size=batch_size,
        learning_rate=learning_rate,
        num_epochs=num_epochs,
        aug=aug
    )

if __name__ == "__main__":
    main()