<a href="https://colab.research.google.com/github/Vanilaks/Autoencoder-for-semantic-communication/blob/main/Autoencoder_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is the code for combined autoencoder and transformer, which is trained together (unlike the paper where the autoencoder is pretrained).
Autoencoder compresses d_model = 512 length vector to N = 64 length vector. The evaluation is done with no noise.

In [None]:
!pip3 install torchtext==0.15.2

Collecting torchtext==0.15.2
  Downloading torchtext-0.15.2-cp311-cp311-manylinux1_x86_64.whl.metadata (7.4 kB)
Collecting torch==2.0.1 (from torchtext==0.15.2)
  Downloading torch-2.0.1-cp311-cp311-manylinux1_x86_64.whl.metadata (24 kB)
Collecting torchdata==0.6.1 (from torchtext==0.15.2)
  Downloading torchdata-0.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.7.99 (from torch==2.0.1->torchtext==0.15.2)
  Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu11==11.7.99 (from torch==2.0.1->torchtext==0.15.2)
  Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cuda-cupti-cu11==11.7.101 (from torch==2.0.1->torchtext==0.15.2)
  Downloading nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu11==8.5.0.96 (from torch==2.0.1->torchtex

In [None]:
!pip3 install datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.

In [None]:
!pip3 install --force-reinstall "numpy<2"

Collecting numpy<2
  Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.2.4
    Uninstalling numpy-2.2.4:
      Successfully uninstalled numpy-2.2.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.21.0+cu124 requires torch==2.6.0, but you have torch 2.0.1 which is incompatible.[0m[31m
[0mSuccessfully installed numpy-1.26.4


In [None]:
!pip install sacrebleu



In [None]:
import torch
import torch.nn as nn
import math

class LayerNormalization(nn.Module):

    def __init__(self, features: int, eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
        self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter

    def forward(self, x):
        # x: (batch, seq_len, hidden_size)
         # Keep the dimension for broadcasting
        mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # Keep the dimension for broadcasting
        std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # eps is to prevent dividing by zero or when std is very small
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

    def forward(self, x):
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class InputEmbeddings(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        # (batch, seq_len) --> (batch, seq_len, d_model)
        # Multiply by sqrt(d_model) to scale the embeddings according to the paper
        return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # Create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)

class ResidualConnection(nn.Module):

        def __init__(self, features: int, dropout: float) -> None:
            super().__init__()
            self.dropout = nn.Dropout(dropout)
            self.norm = LayerNormalization(features)

        def forward(self, x, sublayer):
            return x + self.dropout(sublayer(self.norm(x)))

class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)

        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        return self.w_o(x)

class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x

class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x

class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return self.proj(x)

class Autoencoder(nn.Module):
    def __init__(self, input_dim, M, N):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, M),
            nn.ReLU(),
            nn.Linear(M, 2 * N),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(2 * N, M),
            nn.ReLU(),
            nn.Linear(M, input_dim)
        )

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)

class Transformer(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer, autoencoder: Autoencoder) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        self.autoencoder = autoencoder  # Autoencoder is integrated here

    def encode(self, src, src_mask):
        # (batch, seq_len, d_model)
        src = self.src_embed(src)
        src = self.src_pos(src)
        encoder_output = self.encoder(src, src_mask)

        # Reshape encoder output to match the input shape of the autoencoder
        # Assuming the encoder output has shape (batch, seq_len, d_model)
        # You need to flatten this into (batch * seq_len, d_model) to match the input of autoencoder

        batch_size, seq_len, d_model = encoder_output.shape
        encoder_output_flat = encoder_output.view(batch_size * seq_len, d_model)  # (batch * seq_len, d_model)

        # Pass the flattened encoder output through the autoencoder
        autoencoded_output = self.autoencoder.encode(encoder_output_flat)  # (batch * seq_len, d_model)

        return autoencoded_output

    def decode(self, autoencoded_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
        # (batch, seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)

        # Reshape it back to (batch, seq_len, d_model)
        batch_size, seq_len, d_model = tgt.shape
        encoder_output_reconstr = self.autoencoder.decode(autoencoded_output)
        encoder_output = encoder_output_reconstr.view(batch_size, seq_len, d_model)  # (batch, seq_len, d_model)

        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)

    def project(self, x):
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)


def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create the positional encoding layers
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)

    # Create the encoder and decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

    # Load the pretrained autoencoder model
    autoencoder = Autoencoder(input_dim=d_model, M=256, N=64)
    # autoencoder.load_state_dict(torch.load("autoencoder_awgn.pth"))
    # autoencoder.eval()  # Set the autoencoder to evaluation mode

    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer, autoencoder)

    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

In [None]:
'''
Train.py with special characters removed from the dataset
'''

from model import build_transformer
from dataset import BilingualDataset, causal_mask, clean_text
from config import get_config, get_weights_file_path, latest_weights_file_path

import torchtext.datasets as datasets
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import LambdaLR

import warnings
from tqdm import tqdm
import os
from pathlib import Path

# Huggingface datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

# import torchmetrics
from torch.utils.tensorboard import SummaryWriter

import sys

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import warnings
import pickle  # To save losses, BLEU scores, and CHRF scores
import sys
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from sacrebleu import corpus_chrf

def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')
    encoder_output = model.encode(source, source_mask)
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break
        decoder_mask = None
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1)
        if next_word == eos_idx:
            break
    return decoder_input.squeeze(0)


import numpy as np

def calculate_snr(clean_signal, noise_signal):
    """
    Calculate the Signal-to-Noise Ratio (SNR).
    - clean_signal: The ground truth signal (target).
    - noise_signal: The predicted signal (model output).
    """
    # Reshape the signals to have the same number of dimensions
    clean_signal = clean_signal.flatten()
    noise_signal = noise_signal.flatten()

    signal_power = np.sum(clean_signal ** 2)
    noise_power = np.sum((clean_signal - noise_signal) ** 2)

    # Avoid division by zero
    if noise_power == 0:
        return float('inf')

    snr = 10 * np.log10(signal_power / noise_power)
    return snr
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    total_val_loss = 0
    count = 0
    bleu_scores = []
    chrf_scores = []
    snr_scores = []  # List to store SNR scores
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
    smooth_fn = SmoothingFunction().method1

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device)
            encoder_mask = batch["encoder_mask"].to(device)
            label = batch["label"].to(device)
            decoder_input = batch["decoder_input"].to(device)
            decoder_mask = batch["decoder_mask"].to(device)

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
            proj_output = model.project(model.decode(model.encode(encoder_input, encoder_mask), encoder_mask, decoder_input, decoder_mask))
            val_loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            total_val_loss += val_loss.item()

            predicted_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
            target_text = batch["tgt_text"][0]

            # BLEU calculation
            reference = [target_text.split()]
            hypothesis = predicted_text.split()
            bleu = sentence_bleu(reference, hypothesis, smoothing_function=smooth_fn)
            bleu_scores.append(bleu)

            # CHRF calculation
            chrf = corpus_chrf([predicted_text], [[target_text]]).score
            chrf_scores.append(chrf)

            # # SNR calculation (Assuming 'encoder_input' is the clean signal and 'model_out' is the noisy signal)
            # snr = calculate_snr(encoder_input.cpu().numpy(), model_out.cpu().numpy())
            # snr_scores.append(snr)

            if count <= num_examples:
                source_text = batch["src_text"][0]
                print_msg('-' * 80)
                print_msg(f"SOURCE: {source_text}")
                print_msg(f"TARGET: {target_text}")
                print_msg(f"PREDICTED: {predicted_text}")
                print_msg(f"BLEU Score: {bleu:.4f}")
                print_msg(f"CHRF Score: {chrf:.4f}")
                # print_msg(f"SNR Score: {snr:.4f}")

        print_msg('-' * 80)

    avg_val_loss = total_val_loss / count if count > 0 else 0
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
    avg_chrf = sum(chrf_scores) / len(chrf_scores) if chrf_scores else 0

    print(f"Validation Loss: {avg_val_loss:.4f}, BLEU Score: {avg_bleu:.4f}, CHRF Score: {avg_chrf:.4f}")
    writer.add_scalar('Validation Loss', avg_val_loss, global_step)
    writer.add_scalar('Validation BLEU', avg_bleu, global_step)
    writer.add_scalar('Validation CHRF', avg_chrf, global_step)
    writer.flush()

    # Return SNR scores along with other metrics
    return avg_val_loss, avg_bleu, avg_chrf, snr_scores

def get_all_sentences(ds, lang):
    for item in ds:
        yield clean_text(item['translation'][lang])

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

def get_ds(config):
    # It only has the train split, so we divide it overselves
    ds_raw = load_dataset(f"{config['datasource']}", f"{config['ds_lang_src']}-{config['ds_lang_tgt']}", split='train')

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
    return model

def train_model(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
    writer = SummaryWriter(config['experiment_name'])

    train_losses = []
    val_losses = []
    bleu_scores = []
    chrf_scores = []
    # snr_scores_all = []  # List to store SNR scores over all epochs
    global_step = 0

    for epoch in range(config['num_epochs']):
        model.train()
        total_train_loss = 0
        count = 0

        batch_iterator = tqdm(train_dataloader, desc=f"Epoch {epoch:02d}")
        for batch in batch_iterator:
            encoder_input = batch['encoder_input'].to(device)
            decoder_input = batch['decoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)
            label = batch['label'].to(device)

            optimizer.zero_grad()
            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            count += 1
            batch_iterator.set_postfix({"loss": f"{loss.item():.4f}"})
            writer.add_scalar('Train Loss', loss.item(), global_step)
            writer.flush()
            global_step += 1

        avg_train_loss = total_train_loss / count if count > 0 else 0
        train_losses.append(avg_train_loss)

        # Run validation and collect SNR scores
        avg_val_loss, avg_bleu, avg_chrf, snr_scores = run_validation(
            model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device,
            lambda msg: batch_iterator.write(msg), global_step, writer
        )
        val_losses.append(avg_val_loss)
        bleu_scores.append(avg_bleu)
        chrf_scores.append(avg_chrf)
        # snr_scores_all.extend(snr_scores)  # Add SNR scores from this epoch

        # Save training metrics (including SNR scores)
        with open('training_metrics.pkl', 'wb') as f:
            pickle.dump({
                'train_losses': train_losses,
                'val_losses': val_losses,
                'bleu_scores': bleu_scores,
                'chrf_scores': chrf_scores,
                # 'snr_scores': snr_scores_all  # Save all SNR scores
            }, f)

    print("Training complete. Metrics saved to training_metrics.pkl")


if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    config = get_config()
    config['num_epochs'] = 20
    train_model(config)


Using device: cuda
Max length of source sentence: 204
Max length of target sentence: 204


Epoch 00: 100%|██████████| 421/421 [01:04<00:00,  6.49it/s, loss=5.7686]


--------------------------------------------------------------------------------
SOURCE: the cook threw a fryingpan after her as she went out but it just missed her
TARGET: the cook threw a fryingpan after her as she went out but it just missed her
PREDICTED: she a a a
BLEU Score: 0.0048
CHRF Score: 3.0783
--------------------------------------------------------------------------------
SOURCE: i see said the queen who had meanwhile been examining the roses
TARGET: i see said the queen who had meanwhile been examining the roses
PREDICTED: i said the
BLEU Score: 0.0132
CHRF Score: 11.7025
--------------------------------------------------------------------------------
Validation Loss: 5.3514, BLEU Score: 0.0189, CHRF Score: 10.0628


Epoch 01: 100%|██████████| 421/421 [01:03<00:00,  6.58it/s, loss=5.0971]


--------------------------------------------------------------------------------
SOURCE: pray what is the reason of that
TARGET: pray what is the reason of that
PREDICTED: what you know the king
BLEU Score: 0.0428
CHRF Score: 16.0237
--------------------------------------------------------------------------------
SOURCE: consider your verdict he said to the jury in a low trembling voice
TARGET: consider your verdict he said to the jury in a low trembling voice
PREDICTED: i dont dont said the mock turtle
BLEU Score: 0.0167
CHRF Score: 14.3277
--------------------------------------------------------------------------------
Validation Loss: 4.9145, BLEU Score: 0.0305, CHRF Score: 14.2032


Epoch 02: 100%|██████████| 421/421 [01:03<00:00,  6.59it/s, loss=4.2140]


--------------------------------------------------------------------------------
SOURCE: on this the white rabbit blew three blasts on the trumpet and then unrolled the parchment scroll and read as follows
TARGET: on this the white rabbit blew three blasts on the trumpet and then unrolled the parchment scroll and read as follows
PREDICTED: on the queen on the gryphon and the gryphon and the gryphon and the gryphon
BLEU Score: 0.0256
CHRF Score: 17.2253
--------------------------------------------------------------------------------
SOURCE: who is it directed to said one of the jurymen
TARGET: who is it directed to said one of the jurymen
PREDICTED: it said the mock turtle to the gryphon
BLEU Score: 0.0306
CHRF Score: 19.3469
--------------------------------------------------------------------------------
Validation Loss: 4.5723, BLEU Score: 0.0396, CHRF Score: 18.1373


Epoch 03: 100%|██████████| 421/421 [01:04<00:00,  6.57it/s, loss=3.4513]


--------------------------------------------------------------------------------
SOURCE: you shant be beheaded said alice and she put them into a large flowerpot that stood near
TARGET: you shant be beheaded said alice and she put them into a large flowerpot that stood near
PREDICTED: you be said alice and she that
BLEU Score: 0.1133
CHRF Score: 25.9497
--------------------------------------------------------------------------------
SOURCE: how queer it seems alice said to herself to be going messages for a rabbit
TARGET: how queer it seems alice said to herself to be going messages for a rabbit
PREDICTED: how it would be said alice to herself for a little
BLEU Score: 0.0480
CHRF Score: 29.6382
--------------------------------------------------------------------------------
Validation Loss: 4.1072, BLEU Score: 0.0773, CHRF Score: 24.7615


Epoch 04: 100%|██████████| 421/421 [01:04<00:00,  6.53it/s, loss=3.7420]


--------------------------------------------------------------------------------
SOURCE: what will become of me
TARGET: what will become of me
PREDICTED: what are you are me
BLEU Score: 0.0639
CHRF Score: 16.0080
--------------------------------------------------------------------------------
SOURCE: i told you butter wouldnt suit the works he added looking angrily at the march hare
TARGET: i told you butter wouldnt suit the works he added looking angrily at the march hare
PREDICTED: i must you were say added the he said the gryphon
BLEU Score: 0.0187
CHRF Score: 16.2420
--------------------------------------------------------------------------------
Validation Loss: 3.6261, BLEU Score: 0.1358, CHRF Score: 31.7090


Epoch 05: 100%|██████████| 421/421 [01:03<00:00,  6.60it/s, loss=3.5859]


--------------------------------------------------------------------------------
SOURCE: as she said this she looked up and there was the cat again sitting on a branch of a tree
TARGET: as she said this she looked up and there was the cat again sitting on a branch of a tree
PREDICTED: as she said this she looked down and was the cat was too much a a of a
BLEU Score: 0.3416
CHRF Score: 47.0772
--------------------------------------------------------------------------------
SOURCE: its always six oclock now
TARGET: its always six oclock now
PREDICTED: its no use now
BLEU Score: 0.0744
CHRF Score: 12.6837
--------------------------------------------------------------------------------
Validation Loss: 3.2425, BLEU Score: 0.2030, CHRF Score: 39.5127


Epoch 06: 100%|██████████| 421/421 [01:03<00:00,  6.58it/s, loss=2.6655]


--------------------------------------------------------------------------------
SOURCE: she said it to the knave of hearts who only bowed and smiled in reply
TARGET: she said it to the knave of hearts who only bowed and smiled in reply
PREDICTED: she said to it of the knave who who only and in another moment
BLEU Score: 0.0567
CHRF Score: 38.6091
--------------------------------------------------------------------------------
SOURCE: its always six oclock now
TARGET: its always six oclock now
PREDICTED: its sure im afraid now
BLEU Score: 0.0639
CHRF Score: 14.2771
--------------------------------------------------------------------------------
Validation Loss: 2.9204, BLEU Score: 0.2611, CHRF Score: 44.6021


Epoch 07: 100%|██████████| 421/421 [01:04<00:00,  6.55it/s, loss=2.8927]


--------------------------------------------------------------------------------
SOURCE: how puzzling all these changes are
TARGET: how puzzling all these changes are
PREDICTED: how all all these are are
BLEU Score: 0.1027
CHRF Score: 31.2780
--------------------------------------------------------------------------------
SOURCE: i didnt the march hare interrupted in a great hurry
TARGET: i didnt the march hare interrupted in a great hurry
PREDICTED: i didnt march the march hare in a great hurry
BLEU Score: 0.4234
CHRF Score: 65.2872
--------------------------------------------------------------------------------
Validation Loss: 2.6835, BLEU Score: 0.3319, CHRF Score: 50.6358


Epoch 08: 100%|██████████| 421/421 [01:04<00:00,  6.57it/s, loss=2.7538]


--------------------------------------------------------------------------------
SOURCE: swim after them screamed the gryphon
TARGET: swim after them screamed the gryphon
PREDICTED: after them them the gryphon
BLEU Score: 0.1316
CHRF Score: 52.5814
--------------------------------------------------------------------------------
SOURCE: the caterpillar and alice looked at each other for some time in silence at last the caterpillar took the hookah out of its mouth and addressed her in a languid sleepy voice
TARGET: the caterpillar and alice looked at each other for some time in silence at last the caterpillar took the hookah out of its mouth and addressed her in a languid sleepy voice
PREDICTED: the caterpillar and alice looked at once for some time in silence in at last the caterpillar took the caterpillar took its mouth and the in a voice
BLEU Score: 0.4557
CHRF Score: 64.3933
--------------------------------------------------------------------------------
Validation Loss: 2.4821, BLEU

Epoch 09: 100%|██████████| 421/421 [01:04<00:00,  6.58it/s, loss=1.5716]


--------------------------------------------------------------------------------
SOURCE: everybody looked at alice
TARGET: everybody looked at alice
PREDICTED: everybody looked at alice
BLEU Score: 1.0000
CHRF Score: 100.0000
--------------------------------------------------------------------------------
SOURCE: the executioners argument was that you couldnt cut off a head unless there was a body to cut it off from that he had never had to do such a thing before and he wasnt going to begin at his time of life
TARGET: the executioners argument was that you couldnt cut off a head unless there was a body to cut it off from that he had never had to do such a thing before and he wasnt going to begin at his time of life
PREDICTED: the argument was that you couldnt answer off a head there was a large cat to explain it had off that he had never to do that he was going to do at his eye and and he was at his head
BLEU Score: 0.2337
CHRF Score: 47.8594
-------------------------------------------

Epoch 10: 100%|██████████| 421/421 [01:03<00:00,  6.61it/s, loss=1.9075]


--------------------------------------------------------------------------------
SOURCE: if they had any sense theyd take the roof off
TARGET: if they had any sense theyd take the roof off
PREDICTED: if they had any rate take the sky off
BLEU Score: 0.3301
CHRF Score: 49.0229
--------------------------------------------------------------------------------
SOURCE: swim after them screamed the gryphon
TARGET: swim after them screamed the gryphon
PREDICTED: after them them screamed the gryphon
BLEU Score: 0.5774
CHRF Score: 85.8108
--------------------------------------------------------------------------------
Validation Loss: 2.1645, BLEU Score: 0.4676, CHRF Score: 63.9841


Epoch 11: 100%|██████████| 421/421 [01:03<00:00,  6.58it/s, loss=1.3685]


--------------------------------------------------------------------------------
SOURCE: where are you
TARGET: where are you
PREDICTED: where are you
BLEU Score: 0.5623
CHRF Score: 100.0000
--------------------------------------------------------------------------------
SOURCE: they were indeed a queerlooking party that assembled on the bankthe birds with draggled feathers the animals with their fur clinging close to them and all dripping wet cross and uncomfortable
TARGET: they were indeed a queerlooking party that assembled on the bankthe birds with draggled feathers the animals with their fur clinging close to them and all dripping wet cross and uncomfortable
PREDICTED: they were indeed a party that on the birds with with the birds with their slates and and all all all
BLEU Score: 0.1276
CHRF Score: 33.9466
--------------------------------------------------------------------------------
Validation Loss: 2.1055, BLEU Score: 0.4906, CHRF Score: 66.0926


Epoch 12: 100%|██████████| 421/421 [01:03<00:00,  6.58it/s, loss=1.3149]


--------------------------------------------------------------------------------
SOURCE: youre a serpent and theres no use denying it
TARGET: youre a serpent and theres no use denying it
PREDICTED: youre a serpent and theres no use it
BLEU Score: 0.7673
CHRF Score: 77.7180
--------------------------------------------------------------------------------
SOURCE: she got up and went to the table to measure herself by it and found that as nearly as she could guess she was now about two feet high and was going on shrinking rapidly she soon found out that the cause of this was the fan she was holding and she dropped it hastily just in time to avoid shrinking away altogether
TARGET: she got up and went to the table to measure herself by it and found that as nearly as she could guess she was now about two feet high and was going on shrinking rapidly she soon found out that the cause of this was the fan she was holding and she dropped it hastily just in time to avoid shrinking away altogether
P

Epoch 13: 100%|██████████| 421/421 [01:04<00:00,  6.54it/s, loss=1.4829]


--------------------------------------------------------------------------------
SOURCE: nobody moved
TARGET: nobody moved
PREDICTED: nobody moved
BLEU Score: 0.3162
CHRF Score: 100.0000
--------------------------------------------------------------------------------
SOURCE: just then her head struck against the roof of the hall in fact she was now more than nine feet high and she at once took up the little golden key and hurried off to the garden door
TARGET: just then her head struck against the roof of the hall in fact she was now more than nine feet high and she at once took up the little golden key and hurried off to the garden door
PREDICTED: just then her head finished the roof of the hall in fact she was now was more than feet high she high she took up and took the little golden key and the door
BLEU Score: 0.4478
CHRF Score: 60.6708
--------------------------------------------------------------------------------
Validation Loss: 1.9812, BLEU Score: 0.5287, CHRF Score: 69.0091


Epoch 14: 100%|██████████| 421/421 [01:04<00:00,  6.53it/s, loss=1.7022]


--------------------------------------------------------------------------------
SOURCE: sounds of more broken glass
TARGET: sounds of more broken glass
PREDICTED: stuff of more glass glass
BLEU Score: 0.1257
CHRF Score: 30.8195
--------------------------------------------------------------------------------
SOURCE: yes we went to school in the sea though you maynt believe it
TARGET: yes we went to school in the sea though you maynt believe it
PREDICTED: yes we went to school in the sea you believe it
BLEU Score: 0.6335
CHRF Score: 69.3567
--------------------------------------------------------------------------------
Validation Loss: 1.9292, BLEU Score: 0.5424, CHRF Score: 69.0447


Epoch 15: 100%|██████████| 421/421 [01:04<00:00,  6.52it/s, loss=1.3189]


--------------------------------------------------------------------------------
SOURCE: i dare say you never even spoke to time
TARGET: i dare say you never even spoke to time
PREDICTED: i dare say you never even to time
BLEU Score: 0.6753
CHRF Score: 75.8853
--------------------------------------------------------------------------------
SOURCE: they all can said the duchess and most of em do
TARGET: they all can said the duchess and most of em do
PREDICTED: they all can said the duchess and most chorus of tea
BLEU Score: 0.6989
CHRF Score: 82.4413
--------------------------------------------------------------------------------
Validation Loss: 1.9689, BLEU Score: 0.5446, CHRF Score: 70.2681


Epoch 16: 100%|██████████| 421/421 [01:05<00:00,  6.46it/s, loss=1.6136]


--------------------------------------------------------------------------------
SOURCE: but there seemed to be no chance of this so she began looking at everything about her to pass away the time
TARGET: but there seemed to be no chance of this so she began looking at everything about her to pass away the time
PREDICTED: but there seemed to be no chance of this so she began looking at everything her everything to the time
BLEU Score: 0.7115
CHRF Score: 78.6957
--------------------------------------------------------------------------------
SOURCE: what sort of people live about here
TARGET: what sort of people live about here
PREDICTED: what sort of people live about here
BLEU Score: 1.0000
CHRF Score: 100.0000
--------------------------------------------------------------------------------
Validation Loss: 1.8793, BLEU Score: 0.5837, CHRF Score: 72.1770


Epoch 17: 100%|██████████| 421/421 [01:04<00:00,  6.54it/s, loss=1.2481]


--------------------------------------------------------------------------------
SOURCE: just then her head struck against the roof of the hall in fact she was now more than nine feet high and she at once took up the little golden key and hurried off to the garden door
TARGET: just then her head struck against the roof of the hall in fact she was now more than nine feet high and she at once took up the little golden key and hurried off to the garden door
PREDICTED: just then her head carrying against the roof of the hall in fact she was now more than saying high and she found she took up at the little golden key and hurried off
BLEU Score: 0.5825
CHRF Score: 72.6763
--------------------------------------------------------------------------------
SOURCE: nobody moved
TARGET: nobody moved
PREDICTED: nobody moved moved
BLEU Score: 0.2403
CHRF Score: 89.2257
--------------------------------------------------------------------------------
Validation Loss: 1.8686, BLEU Score: 0.5786, CHRF Sc

Epoch 18: 100%|██████████| 421/421 [01:04<00:00,  6.52it/s, loss=1.2530]


--------------------------------------------------------------------------------
SOURCE: i told you butter wouldnt suit the works he added looking angrily at the march hare
TARGET: i told you butter wouldnt suit the works he added looking angrily at the march hare
PREDICTED: i told you wouldnt wouldnt wouldnt the he added looking at the march hare
BLEU Score: 0.2981
CHRF Score: 57.2701
--------------------------------------------------------------------------------
SOURCE: thinking again the duchess asked with another dig of her sharp little chin
TARGET: thinking again the duchess asked with another dig of her sharp little chin
PREDICTED: thinking again the duchess asked another of her flamingo of little chin
BLEU Score: 0.3839
CHRF Score: 68.4907
--------------------------------------------------------------------------------
Validation Loss: 1.7878, BLEU Score: 0.6081, CHRF Score: 74.6599


Epoch 19: 100%|██████████| 421/421 [01:04<00:00,  6.53it/s, loss=1.2638]


--------------------------------------------------------------------------------
SOURCE: if they had any sense theyd take the roof off
TARGET: if they had any sense theyd take the roof off
PREDICTED: if they had any rate take less the roof off
BLEU Score: 0.3928
CHRF Score: 58.6056
--------------------------------------------------------------------------------
SOURCE: alice had been looking over his shoulder with some curiosity
TARGET: alice had been looking over his shoulder with some curiosity
PREDICTED: alice had been over his notebook somebody with some curiosity
BLEU Score: 0.1996
CHRF Score: 61.8309
--------------------------------------------------------------------------------
Validation Loss: 1.7829, BLEU Score: 0.5709, CHRF Score: 71.8806
Training complete. Metrics saved to training_metrics.pkl
