### Building Transformer

### Imports

In [None]:
from datasets import load_dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers.tokenization_utils_base import BatchEncoding

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List

### Load Arabic Dataset

In [None]:
# dataset = load_dataset("oscar", "unshuffled_deduplicated_ar") # Large & requires permission
dataset = load_dataset("arbml/ArSAS")

In [None]:
dataset

In [None]:
df_train = pd.DataFrame(dataset["train"])[["Tweet_text"]]

In [None]:
df_train.head()

In [None]:
text_corpus = df_train["Tweet_text"].tolist()

### Load Bert Tokenizer

In [None]:
from transformers import AutoTokenizer, AutoModel

model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name)

In [None]:
# Tokenization ...
text_tokens = ...

In [None]:
# Compute true sequence lengths (ignoring padding)
lengths = [len(t) for t in text_tokens["input_ids"]]

plt.figure(figsize=(16, 5))

# --- Histogram ---
plt.subplot(1, 3, 1)
sns.histplot(lengths, bins=50, kde=True, color="royalblue")
plt.title("Token Length Histogram")
plt.xlabel("Sequence Length")
plt.ylabel("Count")

# --- KDE Density ---
plt.subplot(1, 3, 2)
sns.kdeplot(lengths, fill=True, color="green")
plt.title("Token Length Density (KDE)")
plt.xlabel("Sequence Length")
plt.ylabel("Density")

# --- CDF Plot ---
plt.subplot(1, 3, 3)
sns.ecdfplot(lengths, color="purple")
plt.title("Cumulative Distribution (CDF)")
plt.xlabel("Sequence Length")
plt.ylabel("Proportion â‰¤ length")

plt.tight_layout()
plt.show()


### Config

In [None]:
class config:
    batch_size = ...
    hidden_size = ...
    n_heads = ...
    max_seq_len = ...
    vocab_size = ...
    base = ...
    pad_id = ...
    ignored_index = ...

### Dataset

In [None]:
class NanoDataset(Dataset):

    def __init__(self, corpus: BatchEncoding): ...

    def __len__(self): ...

    def __getitem__(self, idx):
        ...
        return {"input_ids": ..., "labels": ...}

In [None]:
def collate_fn(batch: List[Dict[str, torch.Tensor]]):
    """Takes list of items of the dataset

        i.e:
    >>> ds = NanoDataset(tokenized_corpus)
    >>> [ds[i] for i in ds]

    """

    return {"input_ids": ..., "labels": ..., "attention_mask": ...}

In [None]:
ds = NanoDataset(text_tokens)

In [None]:
train_dataloader = ...