<a href="https://colab.research.google.com/github/Rg32601/Compact-Transformers/blob/main/NLP_Github_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# dataset settings

In [1]:
import argparse

def init_parser(dataset):
    """
    Initialize argument parser with dataset-specific defaults.

    Parameters:
        dataset (str): One of {"agnews", "dbpedia", "fewrel"}.

    Returns:
        argparse.Namespace: Parsed arguments.
    """
    allowed_datasets = {"agnews", "dbpedia", "fewrel"}
    if dataset not in allowed_datasets:
        raise ValueError(f"Invalid dataset '{dataset}'. Must be one of: {', '.join(allowed_datasets)}.")

    # Dataset-specific default values
    dataset_defaults = {
        "agnews": {
            "samples_per_label_train": 1000,
            "samples_per_label_test": 1000,
            "num_labels": 4
        },
        "dbpedia": {
            "samples_per_label_train": 100,
            "samples_per_label_test": 100,
            "num_labels": 14
        },
        "fewrel": {
            "samples_per_label_train": 630,
            "samples_per_label_test": 70,
            "num_labels": 64
        }
    }

    defaults = dataset_defaults[dataset]

    parser = argparse.ArgumentParser(description=f"Configure BERT6 pretraining on the {dataset.upper()} dataset.")

    # General training settings
    parser.add_argument("--epochs", type=int, default=50, help="Number of fine-tuning epochs")
    parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate for fine-tuning")
    parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay during fine-tuning")

    # Pre-training settings
    parser.add_argument("--pre_train_epochs", type=int, default=50, help="Number of pre-training epochs")
    parser.add_argument("--pre_train_lr", type=float, default=5e-5, help="Learning rate for pre-training")
    parser.add_argument("--pre_train_weight_decay", type=float, default=1e-2, help="Weight decay during pre-training")
    parser.add_argument("--pre_train_size", type=int, default=40000, help="Number of samples used for pre-training")

    # Data and batching
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training")
    parser.add_argument("--max_length", type=int, default=128, help="Maximum sequence length")
    parser.add_argument("--samples_per_label_train", type=int, default=defaults["samples_per_label_train"],
                        help="Train samples per label")
    parser.add_argument("--samples_per_label_test", type=int, default=defaults["samples_per_label_test"],
                        help="Test samples per label")

    # Masked language modeling
    parser.add_argument("--mlm_prob", type=float, default=0.15, help="Probability of masking tokens for MLM")

    # Model architecture
    parser.add_argument("--bert_layers", type=int, default=6, help="Number of BERT encoder layers")
    parser.add_argument("--num_attention_heads", type=int, default=12, help="Number of attention heads per layer")
    parser.add_argument("--head_size", type=int, default=64, help="Dimention of each attention head")
    parser.add_argument("--conv_layers", type=int, default=0, help="Number of convolutional layers (if any)")
    parser.add_argument("--kernel", type=int, default=3, help="Kernel size for convolutional layers")
    parser.add_argument("--d", type=int, default=3, help="Number of channels for convolutional layers")

    # Tokenizer and vocabulary
    parser.add_argument("--vocab_size", type=int, default=30522, help="Size of tokenizer vocabulary")
    parser.add_argument("--pad_token_id", type=int, default=0, help="ID for [PAD] token")
    parser.add_argument("--cls_token_id", type=int, default=101, help="ID for [CLS] token")
    parser.add_argument("--sep_token_id", type=int, default=102, help="ID for [SEP] token")
    parser.add_argument("--mask_token_id", type=int, default=103, help="ID for [MASK] token")

    # Misc
    parser.add_argument("--num_labels", type=int, default=defaults["num_labels"], help="Number of classification labels")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")

    return parser.parse_known_args()[0]

#BertConv

In [2]:
from torch import nn
from transformers import (
    BertConfig,
    BertForMaskedLM,
    BertForSequenceClassification,
)

class ConvolutionLayers(nn.Module):
    """
    A modular convolutional block used to optionally replace the first layer
    of a BERT encoder. Adds spatial context through stacked 2D convolutions.
    """

    def __init__(self, num_layers, channels, kernel_size):
        """
        Initialize a stack of convolutional layers.

        Args:
            num_layers (int): Number of convolutional blocks.
            channels (List[int]): List of channel dimensions. Length = num_layers + 1.
            kernel_size (int): Kernel size for each convolution.
        """
        super().__init__()

        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.extend([
                nn.Conv2d(channels[i], channels[i + 1], kernel_size=kernel_size, padding="same"),
                nn.BatchNorm2d(channels[i + 1]),
                nn.GELU()
            ])

        # Final projection to 1 channel
        self.conv = nn.Sequential(
            nn.Conv2d(channels[-1], 1, kernel_size=1, padding="same"),
            nn.BatchNorm2d(1),
            nn.GELU()
        )

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=None,
    ):
        """
        Forward pass through convolutional layers.

        Args:
            hidden_states (Tensor): Input tensor of shape (batch_size, seq_len, hidden_dim).

        Returns:
            Tuple[Tensor, Optional[Tensor]]:
                - Processed hidden states: (batch_size, seq_len, hidden_dim)
                - attention_mask (unchanged)
        """
        x = hidden_states.view(hidden_states.size(0), 1, hidden_states.size(1), hidden_states.size(2))

        for layer in self.layers:
            x = layer(x)

        x = self.conv(x)
        x = x.view(x.size(0), x.size(2), x.size(3))

        return x, attention_mask


def build_bert_with_optional_conv_for_pre_train(
    hidden_size,
    num_hidden_layers,
    num_conv_layers,
    kernel_size,
    num_attention_heads,
    intermediate_size,
    max_position_embeddings,
    conv_channels_dim,
    vocab_size=30522,
    cls_token_id=0,
    pad_token_id=101,
    sep_token_id=102,
    mask_token_id=103,
):
    """
    Construct a BERT model for masked language modeling with optional convolutional preprocessing.

    Args:
        hidden_size (int): Transformer hidden size.
        num_hidden_layers (int): Number of BERT transformer layers (not including CNN).
        num_conv_layers (int): If > 0, replaces first encoder layer with CNN block.
        kernel_size (int): Kernel size for convolutional layers.
        num_attention_heads (int): Number of self-attention heads per layer.
        intermediate_size (int): Size of the feedforward layer.
        max_position_embeddings (int): Max sequence length for positional embeddings.
        conv_channels_dim (int): Output channels for each convolutional layer.
        vocab_size (int): Token vocabulary size.
        cls_token_id, pad_token_id, sep_token_id, mask_token_id (int): Special token IDs.

    Returns:
        BertForMaskedLM: Configured model.
    """
    config = BertConfig(
        hidden_size=hidden_size,
        num_hidden_layers=num_hidden_layers + int(num_conv_layers > 0),
        num_attention_heads=num_attention_heads,
        intermediate_size=intermediate_size,
        max_position_embeddings=max_position_embeddings,
        vocab_size=vocab_size,
        cls_token_id=cls_token_id,
        pad_token_id=pad_token_id,
        sep_token_id=sep_token_id,
        mask_token_id=mask_token_id,
    )

    model = BertForMaskedLM(config)

    if num_conv_layers > 0:
        conv_block = ConvolutionLayers(
            num_layers=num_conv_layers,
            channels=[1] + [conv_channels_dim] * num_conv_layers,
            kernel_size=kernel_size
        )
        model.bert.encoder.layer[0] = conv_block

    return model


def build_bert_with_optional_conv_for_classification(
    hidden_size,
    num_hidden_layers,
    num_conv_layers,
    kernel_size,
    num_attention_heads,
    intermediate_size,
    max_position_embeddings,
    conv_channels_dim,
    num_labels,
    vocab_size=30522,
    cls_token_id=0,
    pad_token_id=101,
    sep_token_id=102,
    mask_token_id=103,
):
    """
    Construct a BERT model for sequence classification with optional convolutional preprocessing.

    Args:
        hidden_size (int): Transformer hidden size.
        num_hidden_layers (int): Number of BERT transformer layers (not including CNN).
        num_conv_layers (int): If > 0, replaces first encoder layer with CNN block.
        kernel_size (int): Kernel size for convolutional layers.
        num_attention_heads (int): Number of self-attention heads per layer.
        intermediate_size (int): Size of the feedforward layer.
        max_position_embeddings (int): Max sequence length for positional embeddings.
        conv_channels_dim (int): Output channels for each convolutional layer.
        num_labels (int): Number of output classes for classification.
        vocab_size (int): Token vocabulary size.
        cls_token_id, pad_token_id, sep_token_id, mask_token_id (int): Special token IDs.

    Returns:
        BertForSequenceClassification: Configured model.
    """
    config = BertConfig(
        hidden_size=hidden_size,
        num_hidden_layers=num_hidden_layers + int(num_conv_layers > 0),
        num_attention_heads=num_attention_heads,
        intermediate_size=intermediate_size,
        max_position_embeddings=max_position_embeddings,
        vocab_size=vocab_size,
        num_labels=num_labels,
        cls_token_id=cls_token_id,
        pad_token_id=pad_token_id,
        sep_token_id=sep_token_id,
        mask_token_id=mask_token_id,
    )

    model = BertForSequenceClassification(config)

    if num_conv_layers > 0:
        conv_block = ConvolutionLayers(
            num_layers=num_conv_layers,
            channels=[1] + [conv_channels_dim] * num_conv_layers,
            kernel_size=kernel_size
        )
        model.bert.encoder.layer[0] = conv_block

    return model

# Load dataset

In [1]:
!pip install --upgrade --force-reinstall datasets

!rm -rf ~/.cache/huggingface/datasets
!rm -rf /root/.cache/huggingface/datasets

Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting filelock (from datasets)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting numpy>=1.17 (from datasets)
  Downloading numpy-2.3.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.1/62.1 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-20.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Downloading pandas-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (91 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.4-

In [3]:
import random

def sample_per_label(dataset, max_samples_per_label, seed=42):
    """
    Samples a balanced subset of the dataset with at most `max_samples_per_label` for each label.

    Args:
        dataset (datasets.Dataset): The input dataset.
        max_samples_per_label (int): Max number of samples to take per label.
        seed (int): Random seed for reproducibility.

    Returns:
        datasets.Dataset: Subsampled dataset.
    """
    random.seed(seed)
    label_to_indices = {}
    for idx, example in enumerate(dataset):
        label = example["label"]
        label_to_indices.setdefault(label, []).append(idx)

    sampled_indices = []
    for indices in label_to_indices.values():
        if len(indices) >= max_samples_per_label:
            sampled = random.sample(indices, max_samples_per_label)
        else:
            sampled = indices
        sampled_indices.extend(sampled)

    return dataset.select(sampled_indices)


def tokenize_dataset(dataset, tokenizer, input_key="content", max_length=128):
    """
    Applies tokenizer to a dataset.

    Args:
        dataset (datasets.Dataset): The input dataset.
        tokenizer (transformers.PreTrainedTokenizer): A HuggingFace tokenizer.
        input_key (str): The field in the dataset to tokenize (e.g., "text", "content").
        max_length (int): Max sequence length.

    Returns:
        datasets.Dataset: Tokenized dataset with torch format.
    """
    tokenized = dataset.map(
        lambda x: tokenizer(
            x[input_key],
            truncation=True,
            padding="max_length",
            max_length=max_length
        ),
        batched=True
    )

    # Check if 'label' column exists (e.g. in AG News, DBpedia)
    columns = ["input_ids", "attention_mask"]
    if "label" in dataset.column_names:
        tokenized = tokenized.rename_column("label", "labels")
        columns.append("labels")

    tokenized.set_format(type="torch", columns=columns)
    return tokenized

In [4]:
import numpy as np
from torch.utils.data import DataLoader
from datasets import (
    load_dataset,
    Dataset
    )
from transformers import BertTokenizer


def load_dbpedia_dataloaders(
    samples_per_label_train,
    samples_per_label_test,
    batch_size=32,
    seed=42,
    max_length=128
):
    """
    Loads, samples, tokenizes, and formats the DBpedia dataset into PyTorch DataLoaders.

    Args:
        samples_per_label_train (int): Number of training samples per label.
        samples_per_label_test (int): Number of test samples per label.
        batch_size (int): Dataloader batch size.
        seed (int): Random seed.

    Returns:
        Tuple[DataLoader, DataLoader]: train_dataloader, test_dataloader
    """
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    # Load raw datasets
    train_dataset = load_dataset("dbpedia_14", split="train")
    test_dataset = load_dataset("dbpedia_14", split="test")

    # Sample balanced subsets
    small_train = sample_per_label(train_dataset, samples_per_label_train, seed)
    small_test = sample_per_label(test_dataset, samples_per_label_test, seed)

    # Tokenize and format
    tokenized_train = tokenize_dataset(small_train, tokenizer, "content", max_length)
    tokenized_test = tokenize_dataset(small_test, tokenizer, "content", max_length)

    # Create DataLoaders
    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(tokenized_test, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

def load_agnews_dataloaders(
    samples_per_label_train,
    samples_per_label_test,
    batch_size=32,
    seed=42,
    max_length=128
):
    """
    Loads, samples, tokenizes, and formats the AG News dataset into PyTorch DataLoaders.

    Args:
        samples_per_label_train (int): Number of training samples per label.
        samples_per_label_test (int): Number of test samples per label.
        batch_size (int): Dataloader batch size.
        seed (int): Random seed.
        max_length (int): Max token length for BERT tokenizer.

    Returns:
        Tuple[DataLoader, DataLoader]: train_dataloader, test_dataloader
    """
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

    # Load raw datasets
    train_dataset = load_dataset("ag_news", split="train")
    test_dataset = load_dataset("ag_news", split="test")

    # Sample balanced subsets
    small_train = sample_per_label(train_dataset, samples_per_label_train, seed)
    small_test = sample_per_label(test_dataset, samples_per_label_test, seed)

    # Tokenize and format
    tokenized_train = tokenize_dataset(small_train, tokenizer, "text", max_length)
    tokenized_test = tokenize_dataset(small_test, tokenizer,"text", max_length)

    # Create DataLoaders
    train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(tokenized_test, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


def load_wikipedia_subset(pre_train_size=40000, seed=42, max_length=128 , batch_size= 32):
    """
    Load a random subset of Wikipedia dataset.

    Args:
        pre_train_size (int): Number of random paragraphs to select.
        seed (int): Random seed.
        max_length (int): Max token length for BERT tokenizer.
        batch_size (int): Dataloader batch size.

    Returns:
        Dataset: Subset of Wikipedia dataset.
    """
    # Load the Wikipedia dataset (default: English, 20220301.en)
    wiki = load_dataset("wikimedia/wikipedia", "20231101.en", split="train")

    # Shuffle and select a random subset
    wiki = wiki.shuffle(seed=seed)
    wiki_small = wiki.select(range(pre_train_size))

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    # Step 2: Tokenize
    tokenized_dataset = tokenize_dataset(wiki_small, tokenizer, input_key="text", max_length=max_length)

    # Step 3: Wrap in DataLoader
    dataloader = DataLoader(tokenized_dataset, batch_size=batch_size, shuffle=True)

    return dataloader

# Pre train

In [8]:
import torch
from torch import nn

def mask_tokens(
    inputs: torch.Tensor,
    mlm_probability: float = 0.15,
    vocab_size: int = 30522,
    pad_token_id: int = 0,
    cls_token_id: int = 101,
    sep_token_id: int = 102,
    mask_token_id: int = 103
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Prepare masked tokens inputs/labels for masked language modeling (MLM) as in BERT.

    Args:
        inputs (torch.Tensor): Input token IDs (batch_size x seq_len).
        mlm_probability (float): Probability of masking each token.
        vocab_size (int): Size of the tokenizer vocabulary.
        pad_token_id (int): Token ID used for padding.
        cls_token_id (int): Token ID for [CLS] token.
        sep_token_id (int): Token ID for [SEP] token.
        mask_token_id (int): Token ID for [MASK] token.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: A tuple of (masked_input_ids, labels),
            where labels have -100 for non-MLM positions (ignored in loss).
    """
    labels = inputs.clone()

    # Create mask for special tokens
    special_tokens_mask = (inputs == pad_token_id) | (inputs == cls_token_id) | (inputs == sep_token_id)

    # Decide which tokens to mask
    probability_matrix = torch.full(labels.shape, mlm_probability)
    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()

    # Labels: only compute loss on masked tokens
    labels[~masked_indices] = -100

    # 80% of the time, replace with [MASK]
    indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
    inputs[indices_replaced] = mask_token_id

    # 10% of the time, replace with random word
    indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
    random_words = torch.randint(vocab_size, labels.shape, dtype=torch.long)
    inputs[indices_random] = random_words[indices_random]

    # 10% of the time, keep the original token (no need to modify)

    return inputs, labels

def mlm_train(
    dataloader: DataLoader,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
) -> float:
    """
    Trains a masked language model for one epoch with a tqdm progress bar showing live loss.

    Args:
        dataloader (DataLoader): Unlabeled dataset loader for MLM pretraining.
        model (nn.Module): Model with MLM objective (e.g., BertForMaskedLM).
        optimizer (Optimizer): Optimizer.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    model = model.to(device)
    scaler = torch.cuda.amp.GradScaler()

    progress_bar = tqdm(dataloader, desc="MLM Training", leave=True)

    for batch in progress_bar:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        token_type_ids = batch.get('token_type_ids', None)

        masked_input_ids, mlm_labels = mask_tokens(input_ids.clone())

        masked_input_ids = masked_input_ids.to(device)
        attention_mask = attention_mask.to(device)
        mlm_labels = mlm_labels.to(device)
        if token_type_ids is not None:
            token_type_ids = token_type_ids.to(device)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(
                input_ids=masked_input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=mlm_labels
            )
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        progress_bar.set_postfix(loss=f"{loss.item():.4f}")

# train + test

In [7]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

def cls_test(
    test_dataloader: DataLoader,
    model: nn.Module
) -> float:
    """
    Evaluate a classification model with accuracy metric.

    Args:
        test_dataloader (DataLoader): Test set loader.
        model (nn.Module): Model that outputs logits.

    Returns:
        float: Accuracy score on the test set.
    """
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    correct = 0
    total = 0

    progress_bar = tqdm(test_dataloader, desc="Testing", leave=True)
    for batch in progress_bar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device, dtype=torch.float)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask).logits
        preds = torch.argmax(outputs, dim=1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

        accuracy = correct / total if total > 0 else 0.0
        progress_bar.set_postfix(acc=f"{accuracy:.4f}")

    return accuracy

def cls_train(
    train_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    model: nn.Module
) -> None:
    """
    Train a classification model using mixed precision and display progress.

    Args:
        train_dataloader (DataLoader): Training data loader.
        optimizer (Optimizer): Optimizer for model parameters.
        criterion (nn.Module): Loss function.
        model (nn.Module): Model that outputs logits.
    """
    scaler = torch.cuda.amp.GradScaler()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    model = model.to(device)

    for batch in tqdm(train_dataloader, desc="Training", leave=False):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device, dtype=torch.float)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask).logits
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

# Main

In [None]:
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup


args = init_parser('agnews')
train_dataloader, test_dataloader = load_agnews_dataloaders(args.samples_per_label_train, args.samples_per_label_test)
wiki_dataloader = load_wikipedia_subset(args.pre_train_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pre_train_model = build_bert_with_optional_conv_for_pre_train(
    hidden_size = args.head_size * args.num_attention_heads,
    num_hidden_layers = args.bert_layers,
    num_conv_layers = args.conv_layers,
    kernel_size = args.kernel,
    num_attention_heads = args.num_attention_heads,
    intermediate_size = args.head_size * args.num_attention_heads * 4,
    max_position_embeddings = args.max_length,
    conv_channels_dim = args.d,
    vocab_size= args.vocab_size,
    cls_token_id= args.cls_token_id,
    pad_token_id=args.pad_token_id,
    sep_token_id=args.sep_token_id,
    mask_token_id=args.mask_token_id,
)
optimizer = AdamW(pre_train_model.parameters(), lr= args.pre_train_lr, weight_decay= args.pre_train_weight_decay)
epochs = args.pre_train_epochs
num_warmup_steps = epochs//100
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=epochs,
)
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    mlm_train(wiki_dataloader, pre_train_model, optimizer)
    scheduler.step()


model = build_bert_with_optional_conv_for_classification(
    hidden_size =  args.head_size * args.num_attention_heads,
    num_hidden_layers = args.bert_layers,
    num_conv_layers = args.conv_layers,
    kernel_size = args.kernel,
    num_attention_heads = args.num_attention_heads,
    intermediate_size = args.head_size * args.num_attention_heads * 4,
    max_position_embeddings = args.max_length,
    conv_channels_dim = args.d,
    num_labels = args.num_labels,
    vocab_size= args.vocab_size,
    cls_token_id= args.cls_token_id,
    pad_token_id=args.pad_token_id,
    sep_token_id=args.sep_token_id,
    mask_token_id=args.mask_token_id,
)

model.load_state_dict(pre_train_model.state_dict(), strict = False)
model.to(device)


optimizer = torch.optim.AdamW(model.parameters(), lr= args.lr, weight_decay= args.weight_decay)
criterion = torch.nn.CrossEntropyLoss()

num_epochs = args.epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_epochs
)

for epoch in range(num_epochs):
    cls_train(train_dataloader, optimizer, criterion, model)
    scheduler.step()
    acc = cls_test(test_dataloader, model)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/41 [00:00<?, ?files/s]

train-00000-of-00041.parquet:   0%|          | 0.00/420M [00:00<?, ?B/s]

train-00001-of-00041.parquet:   0%|          | 0.00/351M [00:00<?, ?B/s]

train-00002-of-00041.parquet:   0%|          | 0.00/329M [00:00<?, ?B/s]

train-00003-of-00041.parquet:   0%|          | 0.00/331M [00:00<?, ?B/s]

train-00004-of-00041.parquet:   0%|          | 0.00/307M [00:00<?, ?B/s]

train-00005-of-00041.parquet:   0%|          | 0.00/244M [00:00<?, ?B/s]

train-00006-of-00041.parquet:   0%|          | 0.00/266M [00:00<?, ?B/s]

train-00007-of-00041.parquet:   0%|          | 0.00/228M [00:00<?, ?B/s]

train-00008-of-00041.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

train-00009-of-00041.parquet:   0%|          | 0.00/227M [00:00<?, ?B/s]

train-00010-of-00041.parquet:   0%|          | 0.00/234M [00:00<?, ?B/s]

train-00011-of-00041.parquet:   0%|          | 0.00/232M [00:00<?, ?B/s]

train-00012-of-00041.parquet:   0%|          | 0.00/239M [00:00<?, ?B/s]

train-00013-of-00041.parquet:   0%|          | 0.00/241M [00:00<?, ?B/s]

train-00014-of-00041.parquet:   0%|          | 0.00/223M [00:00<?, ?B/s]

train-00015-of-00041.parquet:   0%|          | 0.00/235M [00:00<?, ?B/s]

train-00016-of-00041.parquet:   0%|          | 0.00/503M [00:00<?, ?B/s]

train-00017-of-00041.parquet:   0%|          | 0.00/231M [00:00<?, ?B/s]

train-00018-of-00041.parquet:   0%|          | 0.00/231M [00:00<?, ?B/s]

train-00019-of-00041.parquet:   0%|          | 0.00/195M [00:00<?, ?B/s]

train-00020-of-00041.parquet:   0%|          | 0.00/225M [00:00<?, ?B/s]

train-00021-of-00041.parquet:   0%|          | 0.00/216M [00:00<?, ?B/s]

train-00022-of-00041.parquet:   0%|          | 0.00/202M [00:00<?, ?B/s]

train-00023-of-00041.parquet:   0%|          | 0.00/213M [00:00<?, ?B/s]

train-00024-of-00041.parquet:   0%|          | 0.00/221M [00:00<?, ?B/s]

train-00025-of-00041.parquet:   0%|          | 0.00/221M [00:00<?, ?B/s]

train-00026-of-00041.parquet:   0%|          | 0.00/208M [00:00<?, ?B/s]

train-00027-of-00041.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

train-00028-of-00041.parquet:   0%|          | 0.00/188M [00:00<?, ?B/s]

train-00029-of-00041.parquet:   0%|          | 0.00/218M [00:00<?, ?B/s]

train-00030-of-00041.parquet:   0%|          | 0.00/204M [00:00<?, ?B/s]

train-00031-of-00041.parquet:   0%|          | 0.00/215M [00:00<?, ?B/s]

train-00032-of-00041.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

train-00033-of-00041.parquet:   0%|          | 0.00/203M [00:00<?, ?B/s]

train-00034-of-00041.parquet:   0%|          | 0.00/219M [00:00<?, ?B/s]

train-00035-of-00041.parquet:   0%|          | 0.00/224M [00:00<?, ?B/s]

train-00036-of-00041.parquet:   0%|          | 0.00/610M [00:00<?, ?B/s]

train-00037-of-00041.parquet:   0%|          | 0.00/674M [00:00<?, ?B/s]

train-00038-of-00041.parquet:   0%|          | 0.00/538M [00:00<?, ?B/s]

train-00039-of-00041.parquet:   0%|          | 0.00/465M [00:00<?, ?B/s]

train-00040-of-00041.parquet:   0%|          | 0.00/422M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/6407814 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

Map:   0%|          | 0/40000 [00:00<?, ? examples/s]

Epoch 1/50


  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
MLM Training: 100%|██████████| 1250/1250 [00:42<00:00, 29.35it/s, loss=6.9130]


Epoch 2/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.19it/s, loss=6.3651]


Epoch 3/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.24it/s, loss=5.9437]


Epoch 4/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.20it/s, loss=6.0050]


Epoch 5/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.25it/s, loss=5.2769]


Epoch 6/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.02it/s, loss=4.7138]


Epoch 7/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.23it/s, loss=4.4232]


Epoch 8/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.24it/s, loss=4.2131]


Epoch 9/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.22it/s, loss=3.9966]


Epoch 10/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.22it/s, loss=4.2136]


Epoch 11/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.26it/s, loss=3.7658]


Epoch 12/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.12it/s, loss=3.6593]


Epoch 13/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 29.87it/s, loss=3.4924]


Epoch 14/50


MLM Training: 100%|██████████| 1250/1250 [00:42<00:00, 29.39it/s, loss=3.4435]


Epoch 15/50


MLM Training: 100%|██████████| 1250/1250 [00:42<00:00, 29.49it/s, loss=3.5375]


Epoch 16/50


MLM Training: 100%|██████████| 1250/1250 [00:42<00:00, 29.67it/s, loss=3.1728]


Epoch 17/50


MLM Training: 100%|██████████| 1250/1250 [00:42<00:00, 29.39it/s, loss=3.4850]


Epoch 18/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 29.94it/s, loss=2.9271]


Epoch 19/50


MLM Training: 100%|██████████| 1250/1250 [00:42<00:00, 29.61it/s, loss=3.5867]


Epoch 20/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 29.83it/s, loss=3.3751]


Epoch 21/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.01it/s, loss=3.2074]


Epoch 22/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.00it/s, loss=3.0778]


Epoch 23/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 29.88it/s, loss=3.0153]


Epoch 24/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 29.86it/s, loss=2.9626]


Epoch 25/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.30it/s, loss=2.9801]


Epoch 26/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.12it/s, loss=2.7010]


Epoch 27/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.17it/s, loss=2.7861]


Epoch 28/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.24it/s, loss=2.8134]


Epoch 29/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.19it/s, loss=2.9326]


Epoch 30/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.18it/s, loss=2.5100]


Epoch 31/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.09it/s, loss=2.5349]


Epoch 32/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.25it/s, loss=2.7322]


Epoch 33/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.05it/s, loss=2.6002]


Epoch 34/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.05it/s, loss=2.3744]


Epoch 35/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.02it/s, loss=3.0088]


Epoch 36/50


MLM Training: 100%|██████████| 1250/1250 [00:41<00:00, 30.14it/s, loss=2.4454]


Epoch 37/50


MLM Training:  40%|████      | 503/1250 [00:16<00:24, 29.93it/s, loss=2.7091]