In [1]:
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(query, key, value, mask=None):
  """
  Computes scaled dot-product attention.

  Args:
    query: Tensor of shape (batch_size, num_heads, seq_len_q, dim_k).
    key: Tensor of shape (batch_size, num_heads, seq_len_kv, dim_k).
    value: Tensor of shape (batch_size, num_heads, seq_len_kv, dim_v).
    mask: Optional mask tensor of shape (batch_size, 1, 1, seq_len_kv).

  Returns:
    Output tensor of shape (batch_size, num_heads, seq_len_q, dim_v).
    Attention weights tensor of shape (batch_size, num_heads, seq_len_q, seq_len_kv).
  """
  matmul_qk = torch.matmul(query, key.transpose(-2, -1)) # (..., seq_len_q, seq_len_k)

  # Scale matmul_qk
  dim_k = torch.tensor(key.shape[-1], dtype=torch.float32)
  scaled_attention_logits = matmul_qk / torch.sqrt(dim_k)

  # Add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = F.softmax(scaled_attention_logits, dim=-1)

  output = torch.matmul(attention_weights, value) # (..., seq_len_q, dim_v)

  return output, attention_weights

## Implement multi-head attention




In [2]:
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.dense = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)

    def forward(self, v, k, q, mask):
        batch_size = q.shape[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask)

        scaled_attention = scaled_attention.permute(0, 2, 1, 3)  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = scaled_attention.reshape(batch_size, -1, self.d_model)  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output, attention_weights

## Create positional encoding




In [3]:
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, d_model)
        """
        return x + self.pe[:, :x.size(1)]


## Build the encoder layer



In [4]:
import torch.nn as nn

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(EncoderLayer, self).__init__()

        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model)
        )

        self.layernorm1 = nn.LayerNorm(d_model, eps=1e-6)
        self.layernorm2 = nn.LayerNorm(d_model, eps=1e-6)

        self.dropout1 = nn.Dropout(rate)
        self.dropout2 = nn.Dropout(rate)

    def forward(self, x, mask):
        attn_output, _ = self.mha(x, x, x, mask)  # Self-attention
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(x + attn_output)  # Add & Norm

        ffn_output = self.ffn(out1)  # Feed-forward
        ffn_output = self.dropout2(ffn_output)
        out2 = self.layernorm2(out1 + ffn_output)  # Add & Norm

        return out2

## Construct the decoder layer




In [5]:
import torch
import torch.nn as nn

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        self.mha1 = MultiHeadAttention(d_model, num_heads)  # Masked self-attention
        self.mha2 = MultiHeadAttention(d_model, num_heads)  # Encoder-decoder attention

        self.ffn = nn.Sequential(
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model)
        )

        self.layernorm1 = nn.LayerNorm(d_model, eps=1e-6)
        self.layernorm2 = nn.LayerNorm(d_model, eps=1e-6)
        self.layernorm3 = nn.LayerNorm(d_model, eps=1e-6)

        self.dropout1 = nn.Dropout(rate)
        self.dropout2 = nn.Dropout(rate)
        self.dropout3 = nn.Dropout(rate)

    def forward(self, x, enc_output, look_ahead_mask, padding_mask):
        # enc_output.shape == (batch_size, input_seq_len, d_model)

        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # Masked self-attention
        attn1 = self.dropout1(attn1)
        out1 = self.layernorm1(x + attn1)  # Add & Norm

        attn2, attn_weights_block2 = self.mha2(
            enc_output, enc_output, out1, padding_mask)  # Encoder-decoder attention
        attn2 = self.dropout2(attn2)
        out2 = self.layernorm2(out1 + attn2)  # Add & Norm

        ffn_output = self.ffn(out2)  # Feed-forward
        ffn_output = self.dropout3(ffn_output)
        out3 = self.layernorm3(out2 + ffn_output)  # Add & Norm

        return out3, attn_weights_block1, attn_weights_block2

## Assemble the transformer model



In [6]:
import torch.nn as nn

class Transformer(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
                 target_vocab_size, pe_input, pe_target, rate=0.1):
        super(Transformer, self).__init__()

        self.num_layers = num_layers
        self.d_model = d_model
        self.num_heads = num_heads
        self.dff = dff
        self.input_vocab_size = input_vocab_size
        self.target_vocab_size = target_vocab_size
        self.pe_input = pe_input
        self.pe_target = pe_target
        self.rate = rate

        self.embedding_input = nn.Embedding(input_vocab_size, d_model)
        self.embedding_target = nn.Embedding(target_vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len=max(pe_input, pe_target))

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, dff, rate)
                                             for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, dff, rate)
                                             for _ in range(num_layers)])

        self.final_layer = nn.Linear(d_model, target_vocab_size)

        self.dropout = nn.Dropout(rate)

    def forward(self, inp, tar, enc_padding_mask, look_ahead_mask, dec_padding_mask):
        # inp.shape == (batch_size, seq_len_in)
        # tar.shape == (batch_size, seq_len_out)
        # enc_padding_mask.shape == (batch_size, 1, 1, seq_len_in)
        # look_ahead_mask.shape == (batch_size, 1, seq_len_out, seq_len_out)
        # dec_padding_mask.shape == (batch_size, 1, 1, seq_len_in)

        # Encoder
        inp = self.embedding_input(inp)  # (batch_size, seq_len_in, d_model)
        inp += self.pos_encoding(inp)
        enc_output = self.dropout(inp)

        for i in range(self.num_layers):
            enc_output = self.encoder_layers[i](enc_output, enc_padding_mask)

        # Decoder
        tar = self.embedding_target(tar)  # (batch_size, seq_len_out, d_model)
        tar += self.pos_encoding(tar)
        dec_output = self.dropout(tar)

        attention_weights = {}
        for i in range(self.num_layers):
            dec_output, attn1, attn2 = self.decoder_layers[i](
                dec_output, enc_output, look_ahead_mask, dec_padding_mask)
            attention_weights[f'decoder_layer{i+1}_block1'] = attn1
            attention_weights[f'decoder_layer{i+1}_block2'] = attn2

        final_output = self.final_layer(dec_output)  # (batch_size, seq_len_out, target_vocab_size)

        return final_output, attention_weights

## Implement training and evaluation



In [7]:
import torch.nn as nn
import torch.optim as optim

# 1. Define the loss function
loss_object = nn.CrossEntropyLoss(ignore_index=0, reduction='none')

def loss_function(real, pred):
    mask = torch.logical_not(torch.eq(real, 0)).float()
    loss_ = loss_object(pred.transpose(-1, -2), real) * mask
    return torch.sum(loss_)/torch.sum(mask)

# Assuming the Transformer class is defined in a previous cell
# Create an instance of the Transformer model
# These are placeholder values; replace with actual model parameters
num_layers = 6
d_model = 512
num_heads = 8
dff = 2048
input_vocab_size = 10000 # Placeholder
target_vocab_size = 10000 # Placeholder
pe_input = 1000 # Placeholder
pe_target = 1000 # Placeholder
rate = 0.1

transformer = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size,
                          pe_input, pe_target, rate)


# 2. Define the optimizer
optimizer = optim.Adam(transformer.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

# 3. Implement a custom learning rate scheduler (basic example)
class CustomSchedule(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = torch.tensor(d_model, dtype=torch.float32)
        self.warmup_steps = warmup_steps
        self.step_num = 0
        super().__init__(optimizer)

    def get_lr(self):
        self.step_num += 1
        arg1 = torch.tensor(self.step_num, dtype=torch.float32).pow(-0.5)
        arg2 = torch.tensor(self.step_num * self.warmup_steps**(-1.5), dtype=torch.float32)
        return [self.d_model.pow(-0.5) * torch.min(arg1, arg2).item()]

learning_rate = CustomSchedule(optimizer, d_model=d_model) # Use the defined d_model


In [9]:
import torch


class DummyDataLoader:
    def __init__(self, num_batches=100, batch_size=64, input_seq_len=50, target_seq_len=50, input_vocab_size=10000, target_vocab_size=10000):
        self.num_batches = num_batches
        self.batch_size = batch_size
        self.input_seq_len = input_seq_len
        self.target_seq_len = target_seq_len
        self.input_vocab_size = input_vocab_size
        self.target_vocab_size = target_vocab_size

    def __iter__(self):
        for _ in range(self.num_batches):
            inp = torch.randint(1, self.input_vocab_size, (self.batch_size, self.input_seq_len))
            tar = torch.randint(1, self.target_vocab_size, (self.batch_size, self.target_seq_len))
            # Simulate padding with 0
            inp[:, :5] = 0
            tar[:, :5] = 0
            yield inp, tar

    def __len__(self):
        return self.num_batches

train_dataloader = DummyDataLoader()
val_dataloader = DummyDataLoader(num_batches=10)


# Helper function to create masks
def create_masks(inp, tar, device):
    # Encoder padding mask (for the input)
    enc_padding_mask = (inp == 0).unsqueeze(1).unsqueeze(1).to(device) # (batch_size, 1, 1, seq_len)

    # Decoder padding mask (for the second attention block in decoder)
    dec_padding_mask = (inp == 0).unsqueeze(1).unsqueeze(1).to(device) # (batch_size, 1, 1, seq_len)

    # Look-ahead mask (for the first attention block in decoder)
    size = tar.shape[1]
    look_ahead_mask = torch.triu(torch.ones((size, size)), diagonal=1).bool().to(device)
    look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)

    dec_target_padding_mask = (tar == 0).unsqueeze(1).unsqueeze(1).to(device) # (batch_size, 1, 1, seq_len)
    look_ahead_mask = look_ahead_mask | dec_target_padding_mask

    return enc_padding_mask, look_ahead_mask, dec_padding_mask

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer.to(device)


# 4. Define the training step function
def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, look_ahead_mask, dec_padding_mask = create_masks(inp, tar_inp, device)

    optimizer.zero_grad()

    predictions, _ = transformer(inp, tar_inp, enc_padding_mask, look_ahead_mask, dec_padding_mask)
    loss = loss_function(tar_real.to(device), predictions)

    loss.backward()
    optimizer.step()
    learning_rate.step()

    return loss.item()

# 5. Define the evaluation step function
def eval_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, look_ahead_mask, dec_padding_mask = create_masks(inp, tar_inp, device)

    with torch.no_grad():
        predictions, _ = transformer(inp, tar_inp, enc_padding_mask, look_ahead_mask, dec_padding_mask)
        loss = loss_function(tar_real.to(device), predictions)

    return loss.item()

# 6. Implement the main training loop
EPOCHS = 10 # Placeholder

for epoch in range(EPOCHS):
    total_train_loss = 0
    transformer.train()
    for (batch, (inp, tar)) in enumerate(train_dataloader):
        inp, tar = inp.to(device), tar.to(device)
        batch_loss = train_step(inp, tar)
        total_train_loss += batch_loss

        if batch % 50 == 0:
            print(f'Epoch {epoch+1} Batch {batch} Loss {batch_loss:.4f}')

    print(f'Epoch {epoch+1} Train Loss {total_train_loss / len(train_dataloader):.4f}')

    # Evaluation phase
    total_eval_loss = 0
    transformer.eval()
    for (batch, (inp, tar)) in enumerate(val_dataloader):
        inp, tar = inp.to(device), tar.to(device)
        batch_loss = eval_step(inp, tar)
        total_eval_loss += batch_loss

    print(f'Epoch {epoch+1} Eval Loss {total_eval_loss / len(val_dataloader):.4f}')

    # 7. Include functionality to save model checkpoints
    # Placeholder for saving; replace with desired path and logic
    if (epoch + 1) % 5 == 0: # Save every 5 epochs
        torch.save(transformer.state_dict(), f'transformer_epoch_{epoch+1}.pt')
        print(f'Model checkpoint saved for epoch {epoch+1}')

Epoch 1 Batch 0 Loss 9.3952
Epoch 1 Batch 50 Loss 9.3708
Epoch 1 Train Loss 9.3624
Epoch 1 Eval Loss 9.3234
Epoch 2 Batch 0 Loss 9.3272
Epoch 2 Batch 50 Loss 9.2821
Epoch 2 Train Loss 9.2869
Epoch 2 Eval Loss 9.2560
Epoch 3 Batch 0 Loss 9.2604
Epoch 3 Batch 50 Loss 9.2415
Epoch 3 Train Loss 9.2425
Epoch 3 Eval Loss 9.2301
Epoch 4 Batch 0 Loss 9.2315
Epoch 4 Batch 50 Loss 9.2363
Epoch 4 Train Loss 9.2320
Epoch 4 Eval Loss 9.2280
Epoch 5 Batch 0 Loss 9.2343
Epoch 5 Batch 50 Loss 9.2361
Epoch 5 Train Loss 9.2332
Epoch 5 Eval Loss 9.2308
Model checkpoint saved for epoch 5
Epoch 6 Batch 0 Loss 9.2353
Epoch 6 Batch 50 Loss 9.2409
Epoch 6 Train Loss 9.2360
Epoch 6 Eval Loss 9.2323
Epoch 7 Batch 0 Loss 9.2427
Epoch 7 Batch 50 Loss 9.2346
Epoch 7 Train Loss 9.2380
Epoch 7 Eval Loss 9.2334
Epoch 8 Batch 0 Loss 9.2339
Epoch 8 Batch 50 Loss 9.2390
Epoch 8 Train Loss 9.2385
Epoch 8 Eval Loss 9.2341
Epoch 9 Batch 0 Loss 9.2399
Epoch 9 Batch 50 Loss 9.2348
Epoch 9 Train Loss 9.2386
Epoch 9 Eval Loss 