In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import math
import time
import os

In [37]:
class Config:
    def __init__(self):
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # data Params
        self.pad_idx = 0
        self.sos_idx = 1
        self.eos_idx = 2 
        self.dummy_vocab = {
            '<pad>': self.pad_idx,
            '<sos>': self.sos_idx,
            '<eos>': self.eos_idx,
            '0': 3, '1': 4, '2': 5, '3': 6, '4': 7,
            '5': 8, '6': 9, '7': 10, '8': 11, '9': 12
        }
        self.id_to_token = {v: k for k, v in self.dummy_vocab.items()}
        self.dummy_vocab_size = len(self.dummy_vocab)
        self.num_samples = 10000 
        self.max_sequence_length = 20 # including sos/eos/Padding


        self.batch_size = 64
        self.learning_rate = 0.0001
        self.num_epochs = 10
        self.clip_grad_norm = 1.0 # gradient clipping value

  
        self.d_model = 512       
        self.num_heads = 8       
        self.num_layers = 3      
        self.d_ff = 2048         
        self.dropout_rate = 0.1

        
        self.model_save_path = 'best_transformer_dummy_model.pt'

In [38]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [39]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.d_model = d_model

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output, attention_weights

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        Q = self.wq(Q)
        K = self.wk(K)
        V = self.wv(V)
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.wo(output)
        return output, attention_weights

In [40]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForwardNetwork, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

In [41]:
class AddNorm(nn.Module):
    def __init__(self, d_model, dropout_rate=0.1):
        super(AddNorm, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, sublayer_output):
        # add residual connection and apply layer norm
        return self.norm(x + self.dropout(sublayer_output))

In [42]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForwardNetwork(d_model, d_ff)
        self.add_norm1 = AddNorm(d_model, dropout_rate)
        self.add_norm2 = AddNorm(d_model, dropout_rate)

    def forward(self, x, src_mask):
        attn_output, _ = self.self_attention(x, x, x, mask=src_mask)
        x = self.add_norm1(x, attn_output)
        ff_output = self.feed_forward(x)
        x = self.add_norm2(x, ff_output)
        return x


In [43]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout_rate=0.1):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len=5000) 
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout_rate)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, src, src_mask):
        x = self.embedding(src)
        x = self.positional_encoding(x)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, src_mask)
        return x

In [44]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout_rate=0.1):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len=5000) 
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout_rate)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, src, src_mask):
        x = self.embedding(src)
        x = self.positional_encoding(x)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, src_mask)
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout_rate=0.1):
        super(DecoderLayer, self).__init__()
        self.masked_self_attention = MultiHeadAttention(d_model, num_heads)
        self.encoder_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForwardNetwork(d_model, d_ff)
        self.add_norm1 = AddNorm(d_model, dropout_rate)
        self.add_norm2 = AddNorm(d_model, dropout_rate)
        self.add_norm3 = AddNorm(d_model, dropout_rate)

    def forward(self, tgt, encoder_output, tgt_mask, src_mask):
        attn_output1, _ = self.masked_self_attention(tgt, tgt, tgt, mask=tgt_mask)
        tgt = self.add_norm1(tgt, attn_output1)
        attn_output2, _ = self.encoder_attention(tgt, encoder_output, encoder_output, mask=src_mask)
        tgt = self.add_norm2(tgt, attn_output2)
        ff_output = self.feed_forward(tgt)
        tgt = self.add_norm3(tgt, ff_output)
        return tgt

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout_rate=0.1): 
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len=5000)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout_rate)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, tgt, encoder_output, tgt_mask, src_mask):
        x = self.embedding(tgt)
        x = self.positional_encoding(x)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, encoder_output, tgt_mask, src_mask)
        return x

In [45]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout_rate=0.1, max_len=5000):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, d_model, num_layers, num_heads, d_ff, dropout_rate)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_layers, num_heads, d_ff, dropout_rate)
        self.output_linear = nn.Linear(d_model, tgt_vocab_size)
        self.pad_idx = 0 

    def generate_src_mask(self, src):
        src_mask = (src != self.pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask

    def generate_tgt_mask(self, tgt):
        seq_len = tgt.size(1)
        causal_mask = (1 - torch.triu(torch.ones(seq_len, seq_len, device=tgt.device), diagonal=1)).bool()
        padding_mask = (tgt != self.pad_idx).unsqueeze(1).unsqueeze(2)
        tgt_mask = causal_mask & padding_mask
        return tgt_mask

    def forward(self, src, tgt):
        src_mask = self.generate_src_mask(src)
        tgt_mask = self.generate_tgt_mask(tgt)
        encoder_output = self.encoder(src, src_mask)
        decoder_output = self.decoder(tgt, encoder_output, tgt_mask, src_mask)
        output = self.output_linear(decoder_output)
        return output

In [46]:
def tokenize_sequence(sequence_str, vocab, add_sos_eos = False, sos_idx = 1, eos_idx = 2):
    tokens = []
    if add_sos_eos:
        tokens.append(sos_idx)
    for char in  sequence_str:
        tokens.append(vocab[char])
    if add_sos_eos:
        tokens.append(eos_idx)
    return tokens

def detokenize_sequence(sequence_ids, id_to_token, pad_idx = 0,sos_idx = 1, eos_idx =2):
    return ''.join([id_to_token[idx] for idx in sequence_ids if idx not in [pad_idx, sos_idx, eos_idx]])

In [47]:
class DummyTranslationDataset(Dataset):
    def __init__(self, num_samples, max_len, vocab, id_to_token, pad_idx, sos_idx, eos_idx):
        self.num_samples = num_samples
        self.max_len = max_len
        self.vocab = vocab
        self.id_to_token = id_to_token
        self.pad_idx = pad_idx
        self.sos_idx = sos_idx
        self.eos_idx = eos_idx
        self.data = self._generate_data()

    def _generate_data(self):
        data = []
        for _ in range(self.num_samples):
            src_len = torch.randint(2, self.max_len - 2, (1,)).item()
            src_sequence_str = ''.join(str(torch.randint(0, 10, (1,)).item()) for _ in range(src_len))
            tgt_sequence_str = src_sequence_str

            src_tokens = tokenize_sequence(src_sequence_str, self.vocab, add_sos_eos=True,
                                            sos_idx=self.sos_idx, eos_idx=self.eos_idx)
            tgt_tokens = tokenize_sequence(tgt_sequence_str, self.vocab, add_sos_eos=True,
                                            sos_idx=self.sos_idx, eos_idx=self.eos_idx)

            src_padded = src_tokens + [self.pad_idx] * (self.max_len - len(src_tokens))
            tgt_padded = tgt_tokens + [self.pad_idx] * (self.max_len - len(tgt_tokens))

            src_padded = src_padded[:self.max_len]
            tgt_padded = tgt_padded[:self.max_len]

            data.append((torch.tensor(src_padded, dtype=torch.long),
                         torch.tensor(tgt_padded, dtype=torch.long)))
        return data
    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]

In [48]:
def train_epoch(model, dataloader, optimizer, criterion, config):
    model.train()
    total_loss = 0
    for batch_idx, (src, tgt) in enumerate(dataloader):
        src, tgt = src.to(config.device), tgt.to(config.device)

        tgt_input = tgt[:, :-1]
        tgt_target = tgt[:, 1:]

        optimizer.zero_grad()
        output = model(src, tgt_input)

        loss = criterion(output.contiguous().view(-1, output.size(-1)),
                         tgt_target.contiguous().view(-1))

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad_norm)
        optimizer.step()
        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")

    return total_loss / len(dataloader)


In [49]:
def evaluate(model, dataloader, criterion, config):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for src, tgt in dataloader:
            src, tgt = src.to(config.device), tgt.to(config.device)

            tgt_input = tgt[:, :-1]
            tgt_target = tgt[:, 1:]

            output = model(src, tgt_input)
            loss = criterion(output.contiguous().view(-1, output.size(-1)),
                             tgt_target.contiguous().view(-1))
            total_loss += loss.item()
    return total_loss / len(dataloader)

In [50]:
def generate_sequence(model, src_sequence, config):
    model.eval()
    src_tokens = tokenize_sequence(src_sequence, config.dummy_vocab, add_sos_eos=True,
                                    sos_idx=config.sos_idx, eos_idx=config.eos_idx)
    src_tensor = torch.tensor(src_tokens, dtype=torch.long, device=config.device).unsqueeze(0)

    tgt_tokens = [config.sos_idx]
    tgt_tensor = torch.tensor(tgt_tokens, dtype=torch.long, device=config.device).unsqueeze(0)

    for _ in range(config.max_sequence_length):
        with torch.no_grad():
            output = model(src_tensor, tgt_tensor)
            next_token_logits = output[:, -1, :]
            next_token_id = next_token_logits.argmax(dim=-1).item()

            tgt_tokens.append(next_token_id)
            tgt_tensor = torch.tensor(tgt_tokens, dtype=torch.long, device=config.device).unsqueeze(0)

            if next_token_id == config.eos_idx:
                break

    return detokenize_sequence(tgt_tokens, config.id_to_token,
                               pad_idx=config.pad_idx, sos_idx=config.sos_idx, eos_idx=config.eos_idx)


In [51]:
class DummyTranslationDataset(Dataset):
    def __init__(self, num_samples, max_len, vocab, id_to_token, pad_idx, sos_idx, eos_idx):
        self.num_samples = int(num_samples)
        self.max_len = max_len
        self.vocab = vocab
        self.id_to_token = id_to_token
        self.pad_idx = pad_idx
        self.sos_idx = sos_idx
        self.eos_idx = eos_idx
        self.data = self._generate_data()

    def _generate_data(self):
        data = []
        for _ in range(self.num_samples):
            src_len = torch.randint(2, self.max_len - 2, (1,)).item()
            src_sequence_str = ''.join(str(torch.randint(0, 10, (1,)).item()) for _ in range(src_len))
            tgt_sequence_str = src_sequence_str

            src_tokens = tokenize_sequence(src_sequence_str, self.vocab, add_sos_eos=True,
                                            sos_idx=self.sos_idx, eos_idx=self.eos_idx)
            tgt_tokens = tokenize_sequence(tgt_sequence_str, self.vocab, add_sos_eos=True,
                                            sos_idx=self.sos_idx, eos_idx=self.eos_idx)

            src_padded = src_tokens + [self.pad_idx] * (self.max_len - len(src_tokens))
            tgt_padded = tgt_tokens + [self.pad_idx] * (self.max_len - len(tgt_tokens))

            src_padded = src_padded[:self.max_len]
            tgt_padded = tgt_padded[:self.max_len]

            data.append((torch.tensor(src_padded, dtype=torch.long),
                         torch.tensor(tgt_padded, dtype=torch.long)))
        return data

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]

In [52]:
def main():
    config = Config() 

    print(f"Using device: {config.device}")

    print("generating dummy dataset")
    train_dataset = DummyTranslationDataset(
        config.num_samples * 0.8, config.max_sequence_length, 
        config.dummy_vocab, config.id_to_token, config.pad_idx,
        config.sos_idx, config.eos_idx
    )
    val_dataset = DummyTranslationDataset(
        config.num_samples * 0.2, config.max_sequence_length, 
        config.dummy_vocab, config.id_to_token, config.pad_idx,
        config.sos_idx, config.eos_idx
    )

    train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, pin_memory=True)
    print(f"Dataset generated. Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

    model = Transformer(
        src_vocab_size=config.dummy_vocab_size,
        tgt_vocab_size=config.dummy_vocab_size,
        d_model=config.d_model,
        num_layers=config.num_layers,
        num_heads=config.num_heads,
        d_ff=config.d_ff,
        dropout_rate=config.dropout_rate,
        max_len=config.max_sequence_length
    ).to(config.device)

    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, betas=(0.9, 0.98), eps=1e-9)
    criterion = nn.CrossEntropyLoss(ignore_index=config.pad_idx)

    print("starting training")
    best_val_loss = float('inf')

    for epoch in range(config.num_epochs):
        start_time = time.time()
        train_loss = train_epoch(model, train_dataloader, optimizer, criterion, config)
        val_loss = evaluate(model, val_dataloader, criterion, config)
        end_time = time.time()
        epoch_mins, epoch_secs = divmod(end_time - start_time, 60)

        print(f"Epoch: {epoch+1:02} | Time: {epoch_mins:.0f}m {epoch_secs:.0f}s")
        print(f"\tTrain Loss: {train_loss:.3f}")
        print(f"\tVal Loss: {val_loss:.3f}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), config.model_save_path)
            print(f"\t--> Model saved to {config.model_save_path}: New best validation loss {best_val_loss:.3f}")

    print("training complete")


    print(" testing Inference")
    if os.path.exists(config.model_save_path):
        model.load_state_dict(torch.load(config.model_save_path, map_location=config.device))
        print("Loaded best model for inference.")
    else:
        print("No saved model found. Inference might not be optimal.")

    model.eval()

    test_src_sequences = ["123", "98765", "001", "5", "42", "73"]
    for src_seq in test_src_sequences:
        generated_tgt = generate_sequence(model, src_seq, config)
        print(f"Source: '{src_seq}' | Generated Target: '{generated_tgt}'")

In [53]:
main()

Using device: cuda
generating dummy dataset
Dataset generated. Train samples: 8000, Val samples: 2000
starting training
  Batch 0/125, Loss: 2.8255
  Batch 100/125, Loss: 0.3366
Epoch: 01 | Time: 0m 9s
	Train Loss: 1.039
	Val Loss: 0.170
	--> Model saved to best_transformer_dummy_model.pt: New best validation loss 0.170
  Batch 0/125, Loss: 0.2305
  Batch 100/125, Loss: 0.0881
Epoch: 02 | Time: 0m 9s
	Train Loss: 0.123
	Val Loss: 0.022
	--> Model saved to best_transformer_dummy_model.pt: New best validation loss 0.022
  Batch 0/125, Loss: 0.0858
  Batch 100/125, Loss: 0.0874
Epoch: 03 | Time: 0m 10s
	Train Loss: 0.046
	Val Loss: 0.004
	--> Model saved to best_transformer_dummy_model.pt: New best validation loss 0.004
  Batch 0/125, Loss: 0.0618
  Batch 100/125, Loss: 0.0108
Epoch: 04 | Time: 0m 10s
	Train Loss: 0.031
	Val Loss: 0.018
  Batch 0/125, Loss: 0.0488
  Batch 100/125, Loss: 0.0154
Epoch: 05 | Time: 0m 10s
	Train Loss: 0.023
	Val Loss: 0.002
	--> Model saved to best_transforme