In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len):
        super().__init__()
        
        # Assume d_v = d_k
        self.d_k = d_k
        self.n_heads = n_heads
        
        self.key = nn.Linear(d_model, d_k * n_heads)
        self.query = nn.Linear(d_model, d_k * n_heads)
        self.value = nn.Linear(d_model, d_k * n_heads)
        
        # final linear layer
        self.fc = nn.Linear(d_k * n_heads, d_model)
        
        # causal mask
        # make diagonal zeros to avoid having to shift the inputs to make targets
        cm = torch.tril(torch.ones(max_len, max_len))
        self.register_buffer("causal_mask", cm.view(1, 1, max_len, max_len))
        
    def forward(self, q, k, v, pad_mask=None):
        q = self.query(q) # N x T x (hd_k)
        k = self.key(k)   # N x T x (hd_k)
        v = self.value(v) # N x T x (hd_v)
        
        N = q.shape[0]
        T = q.shape[1]
        
        # change shape to:
        # (N, T, h, d_k) x (N, h, d_k, T) --> (N, h, T, T)
        # in order for matrix multiply to work correctly
        q = q.view(N, T, self.n_heads, self.d_k).transpose(1,2)
        k = k.view(N, T, self.n_heads, self.d_k).transpose(1,2)
        v = v.view(N, T, self.n_heads, self.d_k).transpose(1,2)
        
        # compute attention weights
        # (N, h, T, d_k) x (N, h, d_k, T) --> (N, h, T, T)
        attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)
        if pad_mask is not None:
            attn_scores = attn_scores.masked_fill(
                pad_mask[:, None, None, :] == 0, float('-inf'))
        
        attn_scores = attn_scores.masked_fill(
            self.causal_mask[:, :, :T, :T] == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # compute attention-weighted value
        # (N, h, T, T) x (N, h, T, d_k) --> (N, h, T, d_k)
        A = attn_weights @ v
        
        # reshape it back before final linear layer
        A = A.transpose(1, 2) # (N, T, h, d_k)
        A = A.contiguous().view(N, T, self.d_k * self.n_heads) # (N, T, h*d_k)
        
        # projection
        return self.fc(A)

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
        super().__init__()
        
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = CausalSelfAttention(d_k, d_model, n_heads, max_len)
        self.ann = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout_prob),
        )
        self.dropout = nn.Dropout(p=dropout_prob)
        
    def forward(self, x, pad_mask=None):
        x = self.ln1(x + self.mha(x,x,x,pad_mask))
        x = self.ln2(x + self.ann(x))
        x = self.dropout(x)
        return x
        

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048, dropout_prob=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        
        position = torch.arange(max_len).unsqueeze(1)
        exp_term = torch.arange(0, d_model, 2)
        div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        
    def forward(self,x):
        # x.shape: N x T x D
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)
    

In [5]:
class Decoder(nn.Module):
    def __init__(self, 
                 vocab_size,
                 max_len,
                 d_k,
                 d_model,
                 n_heads,
                 n_layers,
                 dropout_prob):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
        transformer_blocks = [
            TransformerBlock(
                d_k,
                d_model,
                n_heads,
                max_len,
                dropout_prob) for _ in range(n_layers)]
        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.ln = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, vocab_size)
    
    def forward(self, x, pad_mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x, pad_mask)
        x = self.ln(x)
        x = self.fc(x) # many to many
        return x

In [6]:
model = Decoder(20_000, 1024, 16, 64, 4, 2, 0.1)

In [7]:
# use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cpu


Decoder(
  (embedding): Embedding(20000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): CausalSelfAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05,

In [8]:
x = np.random.randint(0, 20_000, size=(8,512))
x_t = torch.tensor(x).to(device)

In [9]:
y = model(x_t)
y.shape

torch.Size([8, 512, 20000])

In [10]:
mask = np.ones((8, 512))
mask[:, 256:] = 0
mask_t = torch.tensor(mask).to(device)

In [11]:
y = model(x_t, mask_t)
y.shape

torch.Size([8, 512, 20000])

In [12]:
!pip install transformers datasets



In [13]:
from transformers import AutoTokenizer, DataCollatorWithPadding

2023-06-14 18:24:18.698637: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [14]:
checkpoint = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [15]:
from datasets import load_dataset

In [16]:
# use the same dataset, just ignore the labels
raw_datasets = load_dataset("glue", "sst2")

Found cached dataset glue (/Users/alex/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
100%|██████████| 3/3 [00:00<00:00, 364.83it/s]


In [17]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [18]:
def tokenize_fn(batch):
    return tokenizer(batch['sentence'], truncation=True)


In [19]:
tokenized_datasets = raw_datasets.map(tokenize_fn, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Loading cached processed dataset at /Users/alex/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-b711e37212625ef4.arrow
Loading cached processed dataset at /Users/alex/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-3b297f82c267f528.arrow
Loading cached processed dataset at /Users/alex/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e4fb4dac4dd3e7d2.arrow


In [20]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [21]:
tokenized_datasets = tokenized_datasets.remove_columns(
    ["sentence", "idx", "label"]
)

In [22]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    tokenized_datasets["train"],
    shuffle=True,
    batch_size=32,
    collate_fn=data_collator
)

In [23]:
# check how it works
for batch in train_loader:
    for k, v in batch.items():
        print("k:", k, "v.shape:", v.shape)
    break

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


k: input_ids v.shape: torch.Size([32, 49])
k: attention_mask v.shape: torch.Size([32, 49])


In [24]:
tokenizer.pad_token_id

0

In [25]:
model = Decoder(
    vocab_size=tokenizer.vocab_size,
    max_len=tokenizer.max_model_input_sizes[checkpoint],
    d_k=16,
    d_model=64,
    n_heads=4,
    n_layers=2,
    dropout_prob=0.1,
)
model.to(device)

Decoder(
  (embedding): Embedding(28996, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): CausalSelfAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05,

In [26]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.Adam(model.parameters())

In [27]:
from datetime import datetime

In [28]:
# a function to encapsulate the training loop
def train(model, criterion, optimizer, train_loader, epochs):
    train_losses = np.zeros(epochs)
    
    for it in range(epochs):
        model.train()
        t0 = datetime.now()
        train_loss = []
        for batch in train_loader:
            # move data to gpu
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # zero the parameter gradients
            optimizer.zero_grad()
            
            # shift targets backwards
            targets = batch['input_ids'].clone().detach()
            # remove the CLS flag at the beginning & end of the buffer because distilbert data set uses this
            targets = torch.roll(targets, shifts=-1, dims=1)
            targets[:, -1] = tokenizer.pad_token_id
            
            # Forward pass
            outputs = model(batch['input_ids'], batch['attention_mask'])
            # outputs are N x T x V
            # PyTorch is expecting N x V x T
            # print("outputs:", outputs)
            # print("targets:", targets)
            loss = criterion(outputs.transpose(2,1), targets)
            # N, T, V = outputs.shape
            # loss = criterion(outputs.view(N * T, V), targets.view(N * T))
            
            # Backward and optimize
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            
        # Get train and test loss
        train_loss = np.mean(train_loss)
        
        # Save losses
        train_losses[it] = train_loss
            
        dt = datetime.now() - t0
        print(f'Epoch {it+1}/{epochs}, Train Loss: {train_loss:.4f}, Duration: {dt}')
        
    return train_losses

In [29]:
train_losses = train(model, criterion, optimizer, train_loader, epochs=15)

Epoch 1/15, Train Loss: 5.9552, Duration: 0:23:33.733117
Epoch 2/15, Train Loss: 5.0078, Duration: 0:47:23.716784
Epoch 3/15, Train Loss: 4.6802, Duration: 1:27:40.672581
Epoch 4/15, Train Loss: 4.4877, Duration: 1:49:54.850989
Epoch 5/15, Train Loss: 4.3570, Duration: 1:56:49.864811


KeyboardInterrupt: 

In [33]:
valid_loader = DataLoader(
    tokenized_datasets["validation"],
    batch_size=1,
    collate_fn=data_collator
)

In [35]:
model.eval()
for batch in valid_loader:
  # move data to GPU
  batch = {k: v.to(device) for k, v in batch.items()}
  outputs = model(batch['input_ids'], batch['attention_mask'])
  break

In [36]:
outputs.shape

torch.Size([1, 12, 28996])

In [37]:
torch.argmax(outputs, axis=-1)

tensor([[ 170,  112,  188,  170, 1363,  117, 8362, 5827, 2523,  102,  102, 1106]])

In [38]:
prediction_ids = torch.argmax(outputs, axis=-1)

In [39]:
tokenizer.decode(prediction_ids[0])

"a's a good, un intense movie [SEP] [SEP] to"

In [40]:
tokenizer.decode(batch['input_ids'][0])

"[CLS] it's a charming and often affecting journey. [SEP]"

In [42]:
tokenizer.decode(torch.concat((batch['input_ids'][0, :5], prediction_ids[:, 4])))

"[CLS] it's a good"

In [43]:
# generate something
prompt = "it's"

tokenized_prompt = tokenizer(prompt, return_tensors='pt')
tokenized_prompt

{'input_ids': tensor([[ 101, 1122,  112,  188,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [44]:
outputs = model(
    tokenized_prompt['input_ids'][:, :-1].to(device),
    tokenized_prompt['attention_mask'][:, :-1].to(device))

outputs.shape

torch.Size([1, 4, 28996])

In [45]:
prediction_ids = torch.argmax(outputs[:, -1, :], axis=-1)

In [46]:
tokenizer.decode(prediction_ids[0])

'a'

In [47]:
# generate something
prompt = "it's a"

tokenized_prompt = tokenizer(prompt, return_tensors='pt')

# prepare inputs + get rid of SEP token at the end
input_ids = tokenized_prompt['input_ids'][:, :-1].to(device)
mask = tokenized_prompt['attention_mask'][:, :-1].to(device)

for _ in range(20):
  outputs = model(input_ids, mask)
  prediction_id = torch.argmax(outputs[:, -1, :], axis=-1)

  input_ids = torch.hstack((input_ids, prediction_id.view(1, 1)))
  mask = torch.ones_like(input_ids)

  if prediction_id == tokenizer.sep_token_id:
    break

In [48]:
tokenizer.decode(input_ids[0])

"[CLS] it's a good movie [SEP]"