In [29]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt

In [30]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_k, d_model, n_heads):
        super().__init__()
        
        # assume d_v = d_k
        self.d_k = d_k
        self.n_heads = n_heads
        
        self.key = nn.Linear(d_model, d_k * n_heads)
        self.query = nn.Linear(d_model, d_k * n_heads)
        self.value = nn.Linear(d_model, d_k * n_heads)
        
        # final linear layer
        self.fc = nn.Linear(d_k * n_heads, d_model)
        
    def forward(self, q, k, v, mask=None):
        q = self.query(q) # N x T x (hd_k)
        k = self.key(k) 
        v = self.value(v)
        
        # Store batch size and sequence lengths
        N = q.shape[0]
        T = q.shape[1]
        
        # change shape to: 
        # (N, T, h, d_k) --> (N, h, T, d_k)
        # in order for matrix multiply to work properly
        q = q.view(N, T, self.n_heads, self.d_k).transpose(1,2)
        k = k.view(N, T, self.n_heads, self.d_k).transpose(1,2)
        v = v.view(N, T, self.n_heads, self.d_k).transpose(1,2)
        
        # compute attention weights
        # (N, h, T, d_k) x (N, h, d_k, T) --> (N, h, T, T)
        attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(
                mask[:, None, None, :] == 0, float ('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # compute attention-weighted values
        # (N, h, T, T) x (N, h, T, d_k) --> (N, h, T, d_k)
        A = attn_weights @ v
        
        # reshape it back before final linear layer
        A = A.transpose(1, 2) # (N, T, h, d_k)
        A = A.contiguous().view(N, T, self.d_k * self.n_heads) # (N, T, h*d_k)
        
        # projection
        return self.fc(A)

        

In [31]:
class TransformerBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, dropout_prob=0.1):
        super().__init__()
        
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_k, d_model, n_heads)
        self.ann = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout_prob)
        )
        self.dropout = nn.Dropout(p=dropout_prob)
        
    def forward(self, x, mask=None):
        x = self.ln1(x + self.mha(x, x, x, mask))
        x = self.ln2(x + self.ann(x))
        x = self.dropout(x)
        return x
    

In [32]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048, dropout_prob=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        
        position = torch.arange(max_len).unsqueeze(1)
        exp_term = torch.arange(0, d_model, 2)
        div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x.shape: N x T x D
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [33]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_k, d_model, n_heads, n_layers, n_classes, dropout_prob):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
        transformer_blocks = [
            TransformerBlock(
                d_k,
                d_model,
                n_heads,
                dropout_prob) for _ in range(n_layers)]
        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.ln = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, n_classes)
        
    def forward(self, x, mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x, mask)
            
        # many-to-one (x has the shape N x T x D)
        x = x[:, 0, :]
        
        x = self.ln(x)
        x = self.fc(x)
        
        return x

In [34]:
# Test encoder
model = Encoder(20_000, 1024, 16, 64, 4, 2, 5, 0.1)

In [35]:
# use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cpu


Encoder(
  (embedding): Embedding(20000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, 

In [36]:
# create random sequence of tokens
x = np.random.randint(0, 20_000, size=(8, 512))
x_t = torch.tensor(x).to(device)

In [37]:
mask = np.ones((8, 512))
mask[:, 256:] = 0
mask_t = torch.tensor(mask).to(device)

In [38]:
y = model(x_t, mask_t)

In [39]:
# Should be 8 x 5
# 8 batches, 5 classes
y.shape

torch.Size([8, 5])

In [41]:
# Start training and evaluating using a real data set
!pip install transformers datasets

Collecting transformers
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m34.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting filelock (from transformers)
  Using cached filelock-3.12.2-py3-none-any.whl (10 kB)
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Using cached huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
Collecting regex!=2019.12.17 (from transformers)
  Downloading regex-2023.6.3-cp311-cp311-macosx_10_9_x86_64.whl (294 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.7/294.7 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp311-cp31

In [43]:
from transformers import AutoTokenizer, DataCollatorWithPadding

In [49]:
# Using distilbert because bert because bert use segment embeddings to disinguish between two sentences if two 
# sentences are passed in. Distilbert does not use segment embeddings
checkpoint = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [45]:
from datasets import load_dataset

In [47]:
raw_datasets = load_dataset("glue", "sst2")

Downloading builder script: 100%|██████████| 28.8k/28.8k [00:00<00:00, 19.7MB/s]
Downloading metadata: 100%|██████████| 28.7k/28.7k [00:00<00:00, 13.7MB/s]
Downloading readme: 100%|██████████| 27.9k/27.9k [00:00<00:00, 16.7MB/s]


Downloading and preparing dataset glue/sst2 to /Users/alex/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading data: 100%|██████████| 7.44M/7.44M [00:00<00:00, 29.9MB/s]
                                                                                       

Dataset glue downloaded and prepared to /Users/alex/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


100%|██████████| 3/3 [00:00<00:00, 518.78it/s]


In [48]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [51]:
# Apply tokenizer the sentence portion of the data
def tokenize_fn(batch):
    return tokenizer(batch['sentence'], truncation=True)

In [52]:
tokenized_datasets = raw_datasets.map(tokenize_fn, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

                                                                    

In [53]:
data_collator

DataCollatorWithPadding(tokenizer=DistilBertTokenizerFast(name_or_path='distilbert-base-cased', vocab_size=28996, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True), padding=True, max_length=None, pad_to_multiple_of=None, return_tensors='pt')

In [54]:
# input_ids are the token indices, attention_mask determine which are real tokens vs. padding
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [55]:
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"])

# need to rename label to labels to work with Pytorch DataLoader
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

In [56]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [58]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    batch_size=32,
    collate_fn=data_collator
)
valid_loader = DataLoader(
    tokenized_datasets["validation"],
    batch_size=32,
    collate_fn=data_collator
)

In [59]:
# check how it works
for batch in train_loader:
    for k, v in batch.items():
        print("k:", k, "v.shape", v.shape)
    break

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


k: labels v.shape torch.Size([32])
k: input_ids v.shape torch.Size([32, 46])
k: attention_mask v.shape torch.Size([32, 46])


In [63]:
# Check some of the values we'll be passing to encoder when creating data set
set(tokenized_datasets['train']['labels'])

{0, 1}

In [61]:
tokenizer.vocab_size

28996

In [62]:
tokenizer.max_model_input_sizes

{'distilbert-base-uncased': 512,
 'distilbert-base-uncased-distilled-squad': 512,
 'distilbert-base-cased': 512,
 'distilbert-base-cased-distilled-squad': 512,
 'distilbert-base-german-cased': 512,
 'distilbert-base-multilingual-cased': 512}

In [64]:
model = Encoder(
    vocab_size=tokenizer.vocab_size,
    max_len=tokenizer.max_model_input_sizes[checkpoint],
    d_k=16,
    d_model=64,
    n_heads=4,
    n_layers=2,
    n_classes=2,
    dropout_prob=0.1,
)
model.to(device)

Encoder(
  (embedding): Embedding(28996, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, 

In [66]:
# Loss and optimizer objects 
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [67]:
from datetime import datetime

In [69]:
# Function to encapsulate the training loop
def train(model, criterion, optimizer, train_loader, valid_loader, epochs):
    train_losses = np.zeros(epochs)
    test_losses = np.zeros(epochs)
    
    for it in range(epochs):
        model.train()
        t0 = datetime.now()
        train_loss = 0
        n_train = 0
        for batch in train_loader:
            # move data to GPU
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # zero the paramter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(batch['input_ids'], batch['attention_mask'])
            loss = criterion(outputs, batch['labels'])
            
            # Backward and optimize
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()*batch['input_ids'].size(0)
            n_train += batch['input_ids'].size(0)

        # Get average training loss
        train_loss = train_loss / n_train
        
        model.eval()
        test_loss = 0
        n_test = 0
        for batch in valid_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(batch['input_ids'], batch['attention_mask'])
            loss = criterion(outputs, batch['labels'])
            test_loss += loss.item()*batch['input_ids'].size(0)
            n_test += batch['input_ids'].size(0)
        
        test_loss = test_loss / n_test
        
        # Save losses
        train_losses[it] = train_loss
        test_losses[it] = test_loss
        
        dt = datetime.now() - t0
        print(f'Epoch {it+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Duration: {dt}')
        
    return train_losses, test_losses

In [71]:
train_losses, test_losses = train(model, criterion, optimizer, train_loader, valid_loader, epochs=4)

Epoch 1/4, Train Loss: 0.4832, Test Loss: 0.4725, Duration: 0:02:04.399448
Epoch 2/4, Train Loss: 0.3496, Test Loss: 0.5049, Duration: 0:02:07.416352
Epoch 3/4, Train Loss: 0.2880, Test Loss: 0.4793, Duration: 0:02:10.305197
Epoch 4/4, Train Loss: 0.2487, Test Loss: 0.5126, Duration: 0:02:12.891374


In [73]:
# Accuracy

model.eval()
n_correct = 0.
n_total = 0.

for batch in train_loader:
    # move to GPU
    batch = {k: v.to(device) for k, v in batch.items()}
    
    # Forward pass
    outputs = model(batch['input_ids'], batch['attention_mask'])
    
    # Get predictions
    # torch.max returns both max and argmax
    _, predictions = torch.max(outputs,1)
    
    # update counts
    n_correct += (predictions == batch['labels']).sum().item()
    n_total += batch['labels'].shape[0]
    
train_acc = n_correct / n_total

n_correct = 0.
n_total = 0.
for batch in valid_loader:
    # move to GPU
    batch = {k: v.to(device) for k, v in batch.items()}
    
    # Forward pass
    outputs = model(batch['input_ids'], batch['attention_mask'])
    
    # Get predictions
    # torch.max returns both max and argmax
    _, predictions = torch.max(outputs,1)
    
    # update counts
    n_correct += (predictions == batch['labels']).sum().item()
    n_total += batch['labels'].shape[0]
    
test_acc = n_correct / n_total
print(f"Train acc: {train_acc:.4f}, Test acc: {test_acc:.4f}")

Train acc: 0.9353, Test acc: 0.7661


In [None]:
# Need to compute F1 and AUC due to data set imbalance