In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len):
        super().__init__()
        self.d_k = d_k
        self.n_heads = n_heads

        self.key = nn.Linear(d_model, d_k*n_heads)
        self.query = nn.Linear(d_model, d_k*n_heads)
        self.value = nn.Linear(d_model, d_k*n_heads)

        self.fc = nn.Linear(d_k*n_heads, d_model)

        cm = torch.tril(torch.ones(max_len, max_len))
        self.register_buffer(
            'causal_mask',
            cm.view(1, 1, max_len, max_len)
        )


    def forward(self, q, k, v, pad_mask=None):
        q = self.query(q)
        k = self.key(k)
        v = self.value(v)

        N = q.shape[0]
        T = q.shape[1]

        q = q.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(N, T, self.n_heads, self.d_k).transpose(1, 2)

        attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)
        if pad_mask is not None:
            attn_scores = attn_scores.masked_fill(
                pad_mask[:, None, None, :] == 0, float('-inf')
            )
        attn_scores = attn_scores.masked_fill(
            self.causal_mask[:, :, :T, :T] == 0, float('-inf')
        )
        attn_weights = F.softmax(attn_scores, dim=-1)
        A = attn_weights @ v

        A = A.transpose(1, 2)
        A = A.contiguous().view(N, T, self.d_k * self.n_heads)

        return self.fc(A)

In [3]:
class TransformerBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = CausalSelfAttention(d_k, d_model, n_heads, max_len)
        self.ann = nn.Sequential(
            nn.Linear(d_model, d_model*4),
            nn.GELU(),
            nn.Linear(d_model*4, d_model),
            nn.Dropout(dropout_prob),
        )
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x, pad_mask=None):
        x = self.ln1(x + self.mha(x, x, x, pad_mask))
        x = self.ln2(x + self.ann(x))
        x = self.dropout(x)
        return x

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048, dropout_prob=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        position = torch.arange(max_len).unsqueeze(1)
        exp_term = torch.arange(0, d_model, 2)
        div_term = torch.exp(exp_term*(-math.log(10000.0)/ d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

In [5]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, max_len, d_k, d_model, n_heads, n_layers, dropout_prob):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
        transformer_blocks = [TransformerBlock(d_k, d_model, n_heads, max_len, dropout_prob) for _ in range(n_layers)]
        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.ln = nn.LayerNorm(d_model)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, pad_mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x, pad_mask)
        x = self.ln(x)
        x = self.fc(x)
        return x

In [6]:
model = Decoder(20_000, 1024, 16, 64, 4, 2, 0.1)

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
model.to(device)

cuda:0


Decoder(
  (embedding): Embedding(20000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): CausalSelfAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05,

In [8]:
x = np.random.randint(0, 20_000, size=(8, 512))
x_t = torch.tensor(x).to(device)

In [9]:
y = model(x_t)

In [10]:
y.shape

torch.Size([8, 512, 20000])

In [11]:
mask = np.ones((8, 512))
mask[:, 256:] = 0
mask_t = torch.tensor(mask).to(device)

In [12]:
y = model(x_t, mask_t)

In [13]:
y.shape

torch.Size([8, 512, 20000])

In [14]:
from transformers import AutoTokenizer, DataCollatorWithPadding

In [15]:
checkpoint = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [17]:
!pip3 install datasets
from datasets import load_dataset

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [18]:
raw_datasets = load_dataset('glue', 'sst2')

README.md:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [19]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [20]:
def tokenizer_fn(batch):
    return tokenizer(batch['sentence'], truncation=True)

In [21]:
tokenized_datasets = raw_datasets.map(tokenizer_fn, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [22]:
tokenized_datasets = tokenized_datasets.remove_columns(['sentence', 'idx', 'label'])

In [23]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    tokenized_datasets['train'],
    shuffle=True,
    batch_size=32,
    collate_fn=data_collator
)

In [24]:
# for batch in train_loader:
#     for k, v in batch.items():
#         print('k : ',k, 'v.shape : ',v.shape)

In [25]:
tokenizer.pad_token_id

0

In [26]:
model = Decoder(vocab_size=tokenizer.vocab_size, max_len=tokenizer.model_max_length, d_k=16, d_model=64, n_heads=4, n_layers=2, dropout_prob=0.1)
model.to(device)

Decoder(
  (embedding): Embedding(28996, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): CausalSelfAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05,

In [27]:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.Adam(model.parameters())

In [28]:
from datetime import datetime

In [29]:
def train(model, criterion, optimizer, train_loader, epochs):
    train_losses = np.zeros(epochs)

    for it in range(epochs):
        model.train()
        t0 = datetime.now()
        train_loss = []
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            optimizer.zero_grad()

            targets = batch['input_ids'].clone().detach()
            targets = torch.roll(targets, shifts=-1, dims=-1)
            targets[:, -1] = tokenizer.pad_token_id

            outputs = model(batch['input_ids'], batch['attention_mask'])
            loss = criterion(outputs.transpose(2, 1), targets)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())

        train_loss = np.mean(train_loss)
        train_losses[it] = train_loss

        dt = datetime.now() - t0
        print(f'Epoch {it+1}/{epochs}, Train Loss : {train_loss:.4f}, Duration : {dt}')

    return train_losses

In [30]:
train_losses = train(model, criterion, optimizer, train_loader, epochs=15)

Epoch 1/15, Train Loss : 5.9574, Duration : 0:00:58.246116
Epoch 2/15, Train Loss : 5.0099, Duration : 0:00:58.511642
Epoch 3/15, Train Loss : 4.6770, Duration : 0:00:58.396384
Epoch 4/15, Train Loss : 4.4903, Duration : 0:00:58.910589
Epoch 5/15, Train Loss : 4.3579, Duration : 0:00:59.209705
Epoch 6/15, Train Loss : 4.2522, Duration : 0:01:04.682949
Epoch 7/15, Train Loss : 4.1671, Duration : 0:01:13.993897
Epoch 8/15, Train Loss : 4.0906, Duration : 0:01:07.050961
Epoch 9/15, Train Loss : 4.0240, Duration : 0:01:02.638657
Epoch 10/15, Train Loss : 3.9661, Duration : 0:01:05.579269
Epoch 11/15, Train Loss : 3.9091, Duration : 0:01:06.552824
Epoch 12/15, Train Loss : 3.8571, Duration : 0:01:08.767135
Epoch 13/15, Train Loss : 3.8130, Duration : 0:01:09.809238
Epoch 14/15, Train Loss : 3.7697, Duration : 0:01:12.829010
Epoch 15/15, Train Loss : 3.7319, Duration : 0:01:11.445359


In [31]:
valid_loader = DataLoader(
    tokenized_datasets['validation'],
    batch_size=1,
    collate_fn=data_collator
)

In [32]:
model.eval()
for batch in valid_loader:
  batch = {k: v.to(device) for k,v in batch.items()}
  outputs = model(batch['input_ids'], batch['attention_mask'])
  break

In [33]:
outputs.shape

torch.Size([1, 12, 28996])

In [34]:
torch.argmax(outputs, axis=-1)

tensor([[ 170,  112,  188,  170, 1363, 1105, 6276, 6276, 2025,  102,  102,  117]],
       device='cuda:0')

In [35]:
predictions_ids = torch.argmax(outputs, axis=-1)

In [36]:
tokenizer.decode(predictions_ids[0])

"a ' s a good and funny funny study [SEP] [SEP],"

In [37]:
tokenizer.decode(batch['input_ids'][0])

"[CLS] it ' s a charming and often affecting journey. [SEP]"

In [38]:
tokenizer.decode(torch.concat((batch['input_ids'][0, :5], predictions_ids[:, 4])))

"[CLS] it ' s a good"

In [52]:
prompt = "it's"

In [53]:
tokenized_prompt = tokenizer(prompt, return_tensors='pt')

In [54]:
tokenized_prompt

{'input_ids': tensor([[ 101, 1122,  112,  188,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [55]:
outputs = model(tokenized_prompt['input_ids'][:, :-1].to(device), tokenized_prompt['attention_mask'][:, :-1].to(device))

In [57]:
prediction_ids = torch.argmax(outputs[:, -1, :], axis=-1)

In [58]:
tokenizer.decode(prediction_ids[0])

'a'

In [64]:
prompt = "it's a"
tokenized_prompt = tokenizer(prompt, return_tensors='pt')
input_ids = tokenized_prompt['input_ids'][:, :-1].to(device)
mask = tokenized_prompt['attention_mask'][:, :-1].to(device)

for _ in range(20):
  outputs = model(input_ids, mask)
  prediction_id = torch.argmax(outputs[:, -1, :], axis=-1)
  input_ids = torch.hstack((input_ids, prediction_id.view(1, 1)))
  mask = torch.ones_like(input_ids)
  if prediction_id == tokenizer.sep_token_id:
    break


In [65]:
tokenizer.decode(input_ids[0])

"[CLS] it ' s a good time [SEP]"