In [None]:
"""
    DistilBERT training via knowledge distillation from BERT using PyTorch and Hugging Face Transformers.
"""
# !pip install torch transformers datasets

In [None]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AdamW 
from tqdm import tqdm

In [None]:
"""
    AdamW is a popular optimization algorithm in deep learning, especially well-suited for training models like Transformers (e.g., BERT, GPT, etc.). 
    It is a variant of the Adam optimizer that introduces a correct way to apply weight decay (L2 regularization).
    AdamW helps prevent overfitting while maintaining the benefits of Adam (adaptive learning rates, momentum).
    It is the default optimizer in Hugging Face Transformers and many other frameworks for fine-tuning pre-trained language models.
"""

In [None]:
# Load and Preprocess the Dataset
dataset = load_dataset("imdb", split='train[:1%]').train_test_split(test_size=0.2) # Loads 1% of the IMDb dataset's training set for a quick demo or test
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # Loads the tokenizer that corresponds to bert-base-uncased
"""
    The IMDB dataset is a popular dataset used for binary sentiment classification—determining whether a movie review is positive or negative.
"""

In [None]:
# Tokenizing the Text
def encode_batch(batch):
    return tokenizer(batch['text'], truncation=True, padding='max_length', max_length=256)
"""
    encode_batch(): Tokenizes batches of texts, with:
        - truncation: Cuts long reviews down to max_length.
        - padding='max_length': Ensures uniform tensor sizes.
        - max_length=256: Keeps sequences to 256 tokens max.
"""

dataset = dataset.map(encode_batch, batched=True) # Applies tokenizer to all samples 
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label']) # Converts dataset to PyTorch tensors
"""
    It ensures that each time you access a sample from the dataset, it returns a dictionary like:
        {
            'input_ids': tensor(...),
            'attention_mask': tensor(...),
            'label': tensor(...)
        }
"""

print(">>> first sample from dataset: ")
print(dataset['train'][0])
print(">>> shape of input_ids: ", dataset['train'][0]['input_ids'].shape)
print(">>> shape of attention_mask: ", dataset['train'][0]['attention_mask'].shape)
print(">>> shape of label: ", dataset['train'][0]['label'].shape)

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

>>> first sample from dataset: 
{'label': tensor(0), 'input_ids': tensor([  101,  2023,  2003,  2009,  1012,  2023,  2003,  1996,  2028,  1012,
         2023,  2003,  1996,  5409,  3185,  2412,  2081,  1012,  2412,  1012,
         2009, 10299,  2673,  1012,  1045,  2031,  2196,  2464,  4788,  1012,
        11036,  1996,  5384,  1998,  2507,  2009,  2000,  2122,  2111,  1012,
         1012,  1012,  1012,  1012,  2045,  1005,  1055,  2074,  2053,  7831,
         1012,  1026,  7987,  1013,  1028,  1026,  7987,  1013,  1028,  2130,
         2093,  2420,  2044,  3666,  2023,  1006,  2005,  2070,  3114,  1045,
         2145,  2123,  1005,  1056,  2113,  2339,  1007,  1045,  3685,  2903,
         2129,  9577,  2135, 23512,  2023,  3185,  2003,  1013,  2001,  1012,
         2049,  2061,  2919,  1012,  2061,  2521,  2013,  2505,  2008,  2071,
         2022,  2641,  1037,  3185,  1010,  1037,  2466,  2030,  2505,  2008,
         2323,  2031,  2412,  2042,  2580,  1998,  2716,  2046,  2256,  4598

In [35]:
# Apply DataLoader
train_loader = DataLoader(dataset['train'], batch_size=8, shuffle=True) # Prepares data loader for batching
first_batch = next(iter(train_loader)) # Get first batch

print(">>> shape of input_ids in first_batch: ", first_batch['input_ids'].shape)
print(">>> shape of attention_mask in first_batch: ", first_batch['attention_mask'].shape)
print(">>> shape of label in first_batch: ", first_batch['label'].shape)

>>> shape of input_ids in first_batch:  torch.Size([8, 256])
>>> shape of attention_mask in first_batch:  torch.Size([8, 256])
>>> shape of label in first_batch:  torch.Size([8])
