In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.d_k = d_k

    def forward(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(
            torch.tensor(self.d_k, dtype=torch.float32, device=Q.device)
        )
        if mask is not None:
            while mask.dim() < scores.dim():
                mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        output = torch.matmul(attn, V)
        return output, attn

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model //num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.W_o = nn.Linear(d_model, d_model)

        self.attention = ScaledDotProductAttention(self.d_k)
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # 1. Linear projections: (batch, seq_len, d_model) → (batch, seq_len, num_heads * d_k)
        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)

         # 2. Split into heads: (batch, seq_len, num_heads, d_k)

        Q = Q.view(batch_size , -1, self.num_heads, self.d_k).transpose(1,2)
        K = K.view(batch_size, -1, self.num_heads , self.d_k).transpose(1,2)
        V = V.view(batch_size, -1, self.num_heads , self.d_k).transpose(1,2)

        # 3. Apply attention on each head
        # Now Q, K, V: (batch, heads, seq_len, d_k)
        
        atten_output, atten_weights = self.attention(Q,K,V,mask = mask)

        # 4. Concatenate heads
        # (batch, heads, seq_len, d_k) → (batch, seq_len, d_model)
        
        concat =  atten_output.transpose(1,2).contiguous().view(batch_size , -1, self.d_model)

        # 5. Final linear projection
        
        output = self.W_o(concat)

        return output , atten_weights

In [6]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model , d_ff , dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
    def forward(self,x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

In [7]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedForward(d_model , d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    def forward(self, x, mask=None):
        attn_out, _ = self.attn(x,x,x,mask)
        x = self.norm1(self.dropout1(attn_out))

        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout2(ffn_out))
        return x

In [8]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_out, tgt_mask=None, memory_mask = None):
        _x, _ = self.self_attn(x,x,x,tgt_mask)
        x = self.norm1(x+self.dropout1(_x))

        _x, _ = self.cross_attn(x,enc_out, enc_out, memory_mask)
        x = self.norm2(x + self.dropout2(_x))

        ffn_out = self.ffn(x)
        x = self.norm3(x+self.dropout3(ffn_out))

        return x

In [9]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, d_ff, 
                 num_encoder_layers, num_decoder_layers, dropout=0.1):
        super().__init__()
        self.src_embed = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([
            EncoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])
        self.final_linear = nn.Linear(d_model, tgt_vocab_size)
    def encode(self, src, src_mask=None):
        x = self.pos_enc(self.src_embed(src))
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x
    def decode(self, tgt, memory, tgt_mask=None, memory_mask=None):
        x = self.pos_enc(self.tgt_embed(tgt))
        for layer in self.decoder_layers:
            x = layer(x, memory, tgt_mask, memory_mask)
        return x
    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        memory = self.encode(src, src_mask)
        out = self.decode(tgt, memory, tgt_mask, memory_mask)
        logits = self.final_linear(out)
        return logits

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

In [11]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm

In [12]:
dataset = load_dataset('wmt14','de-en',split='train[:1%]')
tokenizer = AutoTokenizer.from_pretrained('t5-small')
MAX_LEN = 64

In [13]:
print(dataset[0])

{'translation': {'de': 'Wiederaufnahme der Sitzungsperiode', 'en': 'Resumption of the session'}}


In [14]:
def preprocess(example):
    en_text = example['translation']['en']
    de_text = example['translation']['de']

    inputs = tokenizer(en_text, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt")
    targets = tokenizer(de_text, max_length=MAX_LEN, padding="max_length", truncation=True, return_tensors="pt")

    return {
        'src_input_ids': inputs.input_ids.squeeze(0),
        'tgt_input_ids': targets.input_ids.squeeze(0)
    }

processed_dataset = dataset.map(preprocess,remove_columns=dataset.column_names)

In [15]:
processed_dataset.set_format(type="torch", columns=["src_input_ids", "tgt_input_ids"])

In [16]:
batch_size = 32
def collate_fn(batch):
    src = torch.stack([item['src_input_ids'] for item in batch])
    tgt = torch.stack([item['tgt_input_ids'] for item in batch])
    return src,tgt

train_loader = DataLoader(processed_dataset, batch_size=batch_size, shuffle=True, collate_fn = collate_fn)

In [17]:
def create_src_mask(src):
    mask = (src != 0).unsqueeze(1).unsqueeze(2)
    return mask
def create_tgt_mask(tgt):
    batch_size, tgt_len = tgt.shape
    pad_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
    subsequent_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
    subsequent_mask = subsequent_mask.unsqueeze(0).unsqueeze(1) 
    return pad_mask & subsequent_mask

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Transformer(
    src_vocab_size=tokenizer.vocab_size,
    tgt_vocab_size=tokenizer.vocab_size,
    d_model=512,
    num_heads=8,
    d_ff=2048,
    num_encoder_layers=6,
    num_decoder_layers=6,
    dropout=0.1
).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

In [19]:
EPOCHS = 5

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for src, tgt in loop:
        src = src.to(device)
        tgt = tgt.to(device)
        
        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        src_mask = create_src_mask(src).to(device)
        tgt_mask = create_tgt_mask(tgt_input).to(device)

        logits = model(src, tgt_input, src_mask=src_mask, tgt_mask=tgt_mask)

        logits = logits.reshape(-1, logits.size(-1))
        tgt_output = tgt_output.reshape(-1)

        loss = criterion(logits, tgt_output)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    print(f"Epoch {epoch+1} finished with loss: {epoch_loss / len(train_loader):.4f}")

Epoch 1: 100%|██████████████████████████████████████████████████████████| 1409/1409 [06:33<00:00,  3.58it/s, loss=4.06]


Epoch 1 finished with loss: 4.8163


Epoch 2: 100%|███████████████████████████████████████████████████████████| 1409/1409 [06:34<00:00,  3.57it/s, loss=3.7]


Epoch 2 finished with loss: 3.7411


Epoch 3: 100%|██████████████████████████████████████████████████████████| 1409/1409 [06:34<00:00,  3.57it/s, loss=3.09]


Epoch 3 finished with loss: 3.3754


Epoch 4: 100%|██████████████████████████████████████████████████████████| 1409/1409 [06:35<00:00,  3.57it/s, loss=3.32]


Epoch 4 finished with loss: 3.1504


Epoch 5: 100%|██████████████████████████████████████████████████████████| 1409/1409 [06:34<00:00,  3.57it/s, loss=2.91]

Epoch 5 finished with loss: 2.9790



