In [100]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt


## Encoder

In [51]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_k, d_model, n_heads):
    super().__init__()
    self.d_k = d_k
    self.n_heads = n_heads

    self.key = nn.Linear(d_model, d_k * n_heads)
    self.query = nn.Linear(d_model, d_k * n_heads)
    self.value = nn.Linear(d_model, d_k * n_heads)

    self.fc = nn.Linear(d_k * n_heads, d_model)

  def forward(self, q, k, v, mask=None):
     q = self.query(q) # N x T x (hd_q)
     k = self.key(k)
     v = self.value(v)

     N = q.shape[0]
     T = q.shape[1]


    #(N, T, h, d_k) - > (N, h, T, d_k)
     q = q.view(N, T, self.n_heads, self.d_k).transpose(1,2)
     k = k.view(N, T, self.n_heads, self.d_k).transpose(1,2)
     v = v.view(N, T, self.n_heads, self.d_k).transpose(1,2)

    # Compute attention weights
    # (N, h, T, d_k) x (N, h, d_k, T) --> (N, h, T, T)
     attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)
     if mask is not None:
      attn_scores = attn_scores.masked_fill(
          mask[:, None, None, :] == 0 , float('-inf')
      )
     attn_weights = F.softmax(attn_scores, dim = -1)

     #(N, h, T, T) x (N, h, T, d_k) --> (N, h, T, d_k)
     A = attn_weights @ v

     A = A.transpose(1,2) # --> (N, T, h, d_k)
     A = A.contiguous().view(N, T, self.d_k * self.n_heads) # --> (N, T, h*d_k)

     return self.fc(A)


In [52]:
from torch.nn.modules import dropout
class TransformerBlock(nn.Module):
  def __init__(self, d_k, d_model, n_heads, dropout_prob = 0.1):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.mha = MultiHeadAttention(d_k, d_model, n_heads)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model * 4),
        nn.GELU(),
        nn.Linear(d_model * 4, d_model),
        nn.Dropout(dropout_prob)
    )
    self.dropout = nn.Dropout(p=dropout_prob)

  def forward(self, x, mask=None):
    x = self.ln1(x + self.mha(x, x, x, mask))
    x = self.ln2(x + self.ann(x))
    x = self.dropout(x)

    return x

In [53]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len = 2048, dropout_prob=.1):
    super().__init__()
    self.dropout = nn.Dropout(p = dropout_prob)

    position = torch.arange(max_len).unsqueeze(1) # 2d - array
    exp_term = torch.arange(0, d_model, 2)
    div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
    pe = torch.zeros(1, max_len, d_model)
    pe[0, :, 0::2] = torch.sin(position * div_term)
    pe[0, :, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe', pe) # save and load model correctly

  def forward(self, x):
    x += self.pe[:, :x.size(1), :]
    return self.dropout(x)

In [54]:
class Encoder(nn.Module):
  def __init__(self,
               vocab_size,
               max_len,
               d_k,
               d_model,
               n_heads,
               n_layers,
               n_classes,
               dropout_prob):

    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformer_blocks = [
        TransformerBlock(
            d_k,
            d_model,
            n_heads,
            dropout_prob
        ) for _ in range(n_layers)
    ]
    self.transformer_blocks = nn.Sequential(*transformer_blocks)
    self.ln = nn.LayerNorm(d_model)
    self.fc = nn.Linear(d_model, n_classes)

  def forward(self, x, mask = None):
    x = self.embedding(x)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
      x = block(x, mask)

    # many to one (x has the sape X x T x D)
    x = x[:,0,:]

    x = self.ln(x)
    x = self.fc(x)

    return x

In [55]:
model = Encoder(20000, 1024, 16, 64, 4, 2, 5, 0.1)

In [56]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cuda:0


Encoder(
  (embedding): Embedding(20000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, 

In [57]:
x = np.random.randint(0, 20000, size=(8,512))
x_t = torch.tensor(x).to(device)

In [58]:
mask = np.ones((8,512))
mask[:, 256:] = 0
mask_t = torch.tensor(mask).to(device)

In [59]:
y = model(x_t, mask_t)

In [60]:
y.shape

torch.Size([8, 5])

In [None]:
!pip install transformers datasets

In [62]:
from transformers import AutoTokenizer, DataCollatorWithPadding

In [63]:
checkpoint = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [64]:
from datasets import load_dataset

In [65]:
raw_datasets = load_dataset('glue', 'sst2')

Downloading builder script:   0%|          | 0.00/28.8k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/27.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [66]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [69]:
def tokenize_fn(batch):
  return tokenizer(batch['sentence'], truncation = True)

In [70]:
tokenized_datasets = raw_datasets.map(tokenize_fn, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [71]:
data_collator

DataCollatorWithPadding(tokenizer=DistilBertTokenizerFast(name_or_path='distilbert-base-cased', vocab_size=28996, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True), padding=True, max_length=None, pad_to_multiple_of=None, return_tensors='pt')

In [72]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [73]:
tokenized_datasets = tokenized_datasets.remove_columns(['sentence','idx'])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

In [74]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [76]:
from torch.utils.data import DataLoader

In [77]:
train_loader = DataLoader(
    tokenized_datasets['train'],
    shuffle = True,
    batch_size = 32,
    collate_fn = data_collator
)

valid_loader = DataLoader(
    tokenized_datasets['validation'],
    batch_size = 32,
    collate_fn = data_collator
)

In [81]:
for batch in train_loader:
  for k, v in batch.items():
    print("k:", k, "v.shape:", v.shape)
  break

k: labels v.shape: torch.Size([32])
k: input_ids v.shape: torch.Size([32, 61])
k: attention_mask v.shape: torch.Size([32, 61])


In [82]:
set(tokenized_datasets['train']['labels'])

{0, 1}

In [86]:
tokenizer.vocab_size

28996

In [85]:
tokenizer.max_model_input_sizes

{'distilbert-base-uncased': 512,
 'distilbert-base-uncased-distilled-squad': 512,
 'distilbert-base-cased': 512,
 'distilbert-base-cased-distilled-squad': 512,
 'distilbert-base-german-cased': 512,
 'distilbert-base-multilingual-cased': 512}

In [88]:
model = Encoder(
    vocab_size = tokenizer.vocab_size,
    max_len = tokenizer.max_model_input_sizes[checkpoint],
    d_k = 16,
    d_model = 64,
    n_heads = 4,
    n_layers = 2,
    n_classes = 2,
    dropout_prob = .1
)
model.to(device)

Encoder(
  (embedding): Embedding(28996, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, 

In [90]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [91]:
from datetime import datetime

In [95]:
def train(model, criterion, optimizer, train_loader, valide_loader, epochs):
  train_losses = np.zeros(epochs)
  test_losses = np.zeros(epochs)

  for it in range(epochs):
    model.train()
    tO = datetime.now()
    train_loss = 0
    n_train = 0
    for batch in train_loader:
      batch = {k:v.to(device) for k,v in batch.items()}

      optimizer.zero_grad()

      outputs = model(batch['input_ids'], batch['attention_mask'])
      loss = criterion(outputs, batch['labels'])

      loss.backward()
      optimizer.step()

      train_loss += loss.item()*batch['input_ids'].size(0)
      n_train += batch['input_ids'].size(0)

    train_loss = train_loss / n_train

    model.eval()
    test_loss = 0
    n_test = 0
    for batch in valid_loader:
      batch = {k:v.to(device) for k,v in batch.items()}
      outputs = model(batch['input_ids'], batch['attention_mask'])
      loss = criterion(outputs, batch['labels'])
      test_loss += loss.item()*batch['input_ids'].size(0)
      n_test += batch['input_ids'].size(0)

    test_loss = test_loss / n_test

    train_losses[it] = train_loss
    test_losses[it] = test_loss

    dt = datetime.now() - tO

    print(f'Epoch {it+1}/{epochs}, Train_Loss: {train_loss :.4f}, \
     Test_Loss: {test_loss:.4f}, Duration: {dt}')

  return train_losses, test_losses

In [96]:
train_losses , test_losses = train(
    model, criterion, optimizer, train_loader, valid_loader, epochs=4
)

Epoch 1/4, Train_Loss: 0.3731,      Test_Loss: 0.5140, Duration: 0:00:27.576877
Epoch 2/4, Train_Loss: 0.3051,      Test_Loss: 0.4582, Duration: 0:00:19.039151
Epoch 3/4, Train_Loss: 0.2618,      Test_Loss: 0.5078, Duration: 0:00:19.032908
Epoch 4/4, Train_Loss: 0.2323,      Test_Loss: 0.5007, Duration: 0:00:20.689512


In [99]:
model.eval()
n_correct = 0.
n_total = 0.
for batch in train_loader:
  batch = {k: v.to(device) for k, v in batch.items()}

  outputs = model(batch['input_ids'], batch['attention_mask'])

  _, predictions = torch.max(outputs , 1)

  n_correct += (predictions == batch['labels']).sum().item()
  n_total += batch['labels'].shape[0]

train_acc = n_correct / n_total

n_correct = 0.
n_total = 0.

for batch in valid_loader:
  batch = {k: v.to(device) for k, v in batch.items()}

  outputs = model(batch['input_ids'], batch['attention_mask'])

  _, predictions = torch.max(outputs , 1)

  n_correct += (predictions == batch['labels']).sum().item()
  n_total += batch['labels'].shape[0]

test_acc = n_correct / n_total

print(f" Train acc : {train_acc:.4f}, Test acc: {test_acc:.4f}")


 Train acc : 0.9428, Test acc: 0.7936


## Decoder

In [148]:
class CausalSelfAttention(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len):
    super().__init__()

    self.d_k = d_k
    self.n_heads = n_heads

    self.key = nn.Linear(d_model, d_k * n_heads)
    self.query = nn.Linear(d_model, d_k * n_heads)
    self.value = nn.Linear(d_model, d_k * n_heads)

    self.fc = nn.Linear(d_k * n_heads, d_model)


    cm = torch.tril(torch.ones(max_len, max_len))
    self.register_buffer(
        "causal_mask",
        cm.view(1, 1, max_len, max_len)
    )

  def forward(self, q, k, v, pad_mask = None):
    q = self.query(q)
    k = self.key(k)
    v = self.value(v)

    N = q.shape[0]
    T = q.shape[1]

    q = q.view(N, T, self.n_heads, self.d_k).transpose(1,2)
    k = k.view(N, T, self.n_heads, self.d_k).transpose(1,2)
    v = v.view(N, T, self.n_heads, self.d_k).transpose(1,2)

    attn_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)

    if pad_mask is not None:
      attn_scores = attn_scores.masked_fill(
          pad_mask[:, None, None, :] == 0 , float('-inf')
      )

    attn_scores = attn_scores.masked_fill(
        self.causal_mask[:, :, :T, :T] == 0, float('-inf')
    )

    attn_weights = F.softmax(attn_scores, dim = -1)

    #(N, h, T, T) x (N, h, T, d_k) --> (N, h, T, d_k)
    A = attn_weights @ v

    A = A.transpose(1,2) # --> (N, T, h, d_k)
    A = A.contiguous().view(N, T, self.d_k * self.n_heads) # --> (N, T, h*d_k)

    return self.fc(A)



In [149]:
from torch.nn.modules import dropout

class TransformerBlock(nn.Module):
  def __init__(self, d_k, d_model, n_heads, max_len, dropout_prob = 0.1):
    super().__init__()

    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.mha = CausalSelfAttention(d_k, d_model, n_heads, max_len)
    self.ann = nn.Sequential(
        nn.Linear(d_model, d_model * 4),
        nn.GELU(),
        nn.Linear(d_model * 4, d_model),
        nn.Dropout(dropout_prob)
    )
    self.dropout = nn.Dropout(p=dropout_prob)

  def forward(self, x, pad_mask=None):
    x = self.ln1(x + self.mha(x, x, x, pad_mask))
    x = self.ln2(x + self.ann(x))
    x = self.dropout(x)

    return x

In [150]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len = 2048, dropout_prob=.1):
    super().__init__()
    self.dropout = nn.Dropout(p = dropout_prob)

    position = torch.arange(max_len).unsqueeze(1) # 2d - array
    exp_term = torch.arange(0, d_model, 2)
    div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
    pe = torch.zeros(1, max_len, d_model)
    pe[0, :, 0::2] = torch.sin(position * div_term)
    pe[0, :, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe', pe) # save and load model correctly

  def forward(self, x):
    x += self.pe[:, :x.size(1), :]
    return self.dropout(x)

In [151]:
class Decoder(nn.Module):
  def __init__(self,
               vocab_size,
               max_len,
               d_k,
               d_model,
               n_heads,
               n_layers,
               dropout_prob):

    super().__init__()

    self.embedding = nn.Embedding(vocab_size, d_model)
    self.pos_encoding = PositionalEncoding(d_model, max_len, dropout_prob)
    transformer_blocks = [
        TransformerBlock(
            d_k,
            d_model,
            n_heads,
            max_len,
            dropout_prob
        ) for _ in range(n_layers)
    ]
    self.transformer_blocks = nn.Sequential(*transformer_blocks)
    self.ln = nn.LayerNorm(d_model)
    self.fc = nn.Linear(d_model, vocab_size)

  def forward(self, x, pad_mask=None):
    x = self.embedding(x)
    x = self.pos_encoding(x)
    for block in self.transformer_blocks:
      x = block(x, pad_mask)

    x = self.ln(x)
    x = self.fc(x) #many to many
    return x


In [152]:
model = Decoder(20000, 1024, 16, 64, 4, 2, .1)

In [153]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cuda:0


Decoder(
  (embedding): Embedding(20000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): CausalSelfAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05,

In [154]:
x = np.random.randint(0, 20000, size = (8,512))
x_t = torch.tensor(x).to(device)

In [155]:
y = model(x_t)
y.shape

torch.Size([8, 512, 20000])

In [156]:
mask = np.ones((8,512))
mask[:, 256:] = 0
mask_t = torch.tensor(mask).to(device)

In [157]:
y = model(x_t , mask_t)
y.shape

torch.Size([8, 512, 20000])

In [None]:
!pip install transformers datasets

In [159]:
from transformers import AutoTokenizer, DataCollatorWithPadding

In [160]:
checkpoint = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [161]:
from datasets import load_dataset

In [162]:
raw_datasets = load_dataset('glue', 'sst2')

In [163]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [164]:
def tokenize_fn(batch):
  return tokenizer(batch['sentence'], truncation = True)

In [165]:
tokenized_datasets = raw_datasets.map(tokenize_fn, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

In [166]:
tokenized_datasets = tokenized_datasets.remove_columns(['sentence','idx','label'])


In [167]:
from torch.utils.data import DataLoader

In [168]:
train_loader = DataLoader(
    tokenized_datasets['train'],
    shuffle = True,
    batch_size = 32,
    collate_fn = data_collator
)

valid_loader = DataLoader(
    tokenized_datasets['validation'],
    batch_size = 32,
    collate_fn = data_collator
)

In [170]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [171]:
for batch in train_loader:
  for k, v in batch.items():
    print("k:", k, "v.shape:", v.shape)
  break

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


k: input_ids v.shape: torch.Size([32, 38])
k: attention_mask v.shape: torch.Size([32, 38])


In [172]:
tokenizer.pad_token_id

0

In [173]:
model = Decoder(
    vocab_size = tokenizer.vocab_size,
    max_len = tokenizer.max_model_input_sizes[checkpoint],
    d_k = 16,
    d_model = 64,
    n_heads = 4,
    n_layers = 2,
    dropout_prob = .1,
)
model.to(device)

Decoder(
  (embedding): Embedding(28996, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): CausalSelfAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((64,), eps=1e-05,

In [175]:
criterion = nn.CrossEntropyLoss(ignore_index = tokenizer.pad_token_id) #Many to Many
optimizer = torch.optim.Adam(model.parameters())

In [176]:
from datetime import datetime

In [180]:
def train(model, criterion, optimizer, train_loader, epochs):
  train_losses = np.zeros(epochs)

  for it in range(epochs):
    model.train()
    tO = datetime.now()
    train_loss = []

    for batch in train_loader:
      batch = {k:v.to(device) for k,v in batch.items()}

      optimizer.zero_grad()

      targets = batch['input_ids'].clone().detach()
      targets = torch.roll(targets, shifts = -1 , dims = 1) # Shift to the left
      targets[:, -1] = tokenizer.pad_token_id

      outputs = model(batch['input_ids'],batch['attention_mask'])

      loss = criterion(outputs.transpose(2,1), targets)

      loss.backward()
      optimizer.step()

      train_loss.append(loss.item())

    train_loss = np.mean(train_loss)



    train_losses[it] = train_loss

    dt = datetime.now() - tO

    print(f'Epoch {it+1}/{epochs}, Train_Loss: {train_loss :.4f}, \
    Duration: {dt}')

  return train_losses

In [181]:
train_losses = train(
    model, criterion, optimizer, train_loader, epochs=4
)

Epoch 1/4, Train_Loss: 4.8892,     Duration: 0:01:04.975051
Epoch 2/4, Train_Loss: 4.6119,     Duration: 0:01:09.376107
Epoch 3/4, Train_Loss: 4.4460,     Duration: 0:01:06.395798
Epoch 4/4, Train_Loss: 4.3221,     Duration: 0:01:00.104286


In [182]:
valid_loader = DataLoader(
    tokenized_datasets["validation"],
    batch_size=1,
    collate_fn = data_collator
)

In [183]:
model.eval()
for batch in valid_loader:
  batch = {k:v.to(device) for k,v in batch.items()}
  outputs = model(batch['input_ids'], batch['attention_mask'])
  break

In [184]:
outputs.shape

torch.Size([1, 12, 28996])

In [185]:
prediction_ids = torch.argmax(outputs, axis = -1)

In [186]:
tokenizer.decode(prediction_ids[0])

"the's a good, a surprising portrait [SEP] [SEP]."

In [187]:
tokenizer.decode(batch['input_ids'][0])

"[CLS] it's a charming and often affecting journey. [SEP]"

In [188]:
tokenizer.decode(torch.concat((batch['input_ids'][0, :5],prediction_ids[:, 4])))

"[CLS] it's a good"

In [189]:
prompt = "it's"

tokenized_prompt = tokenizer(prompt, return_tensors = 'pt')
tokenized_prompt

{'input_ids': tensor([[ 101, 1122,  112,  188,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [190]:
outputs = model(tokenized_prompt['input_ids'][:,:-1].to(device),
                tokenized_prompt['attention_mask'][:,:-1].to(device)) # to cut off SEP token in the last

outputs.shape

torch.Size([1, 4, 28996])

In [191]:
prediction_ids = torch.argmax(outputs[:, -1, :], axis = -1)

In [192]:
tokenizer.decode(prediction_ids[0])

'a'

In [193]:
prompt = "it's"

tokenized_prompt = tokenizer(prompt, return_tensors = 'pt')
input_ids = tokenized_prompt['input_ids'][:,:-1].to(device)
mask = tokenized_prompt['attention_mask'][:,:-1].to(device)

for _ in range(20):
  outputs = model(input_ids, mask)
  prediction_id = torch.argmax(outputs[:, -1, :], axis = -1) #Final Time step

  input_ids = torch.hstack((input_ids , prediction_id.view(1,1)))
  mask = torch.ones_like(input_ids)

  if prediction_id == tokenizer.sep_token_id:
    break

In [194]:
tokenizer.decode(input_ids[0])

"[CLS] it's a good time [SEP]"