_must revisit later..._

In [31]:
import numpy as np
import pandas as pd
import sentencepiece as spm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from functools import partial

---

In [37]:
E2M_CONFIG = {
    "vocab_size" : 50257,
    "context_len" : 512,
    "emb_dim" : 768,
    "num_heads" : 8,
    "n_layers" : 8,
    "drop_rate" : 0.1,
    "qkv_bias" : False
}

In [2]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim = -1, keepdim = True)
        var = x.var(dim = -1, keepdim = True, unbiased = False) # unbiased so var is divided by n-1
        norm = (x - mean) / (torch.sqrt(var + self.eps)) # epsilon to prevent division by 0
        return self.scale * norm + self.shift # element wise operations - trainable parameters to learn appropriate scaling and shifting of norm values that best suits the data
    

class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3))))
    
    
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), # expansion
            GELU(), # activation
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), # contraction
        )
    
    def forward(self, x):
        return self.layers(x)

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        # s2
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # s3
        self.W_q = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # to combine head outputs
        self.dropout = nn.Dropout(dropout) 
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_len, context_len), diagonal = 1)
        )

    def forward(self, x):
        b, num_tokens, d_out = x.shape # s1

        # s4
        keys = self.W_k(x)
        queries = self.W_q(x)
        values = self.W_v(x)
        
        # s5
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # s6
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # s7
        attention_scores = queries @ keys.transpose(2, 3)
        attention_weights = torch.softmax(attention_scores / keys.shape[-1] ** 0.5, dim = -1)
        attention_weights = self.dropout(attention_weights) # s8

        context_vec = (attention_weights @ values).transpose(1, 2) # s9 & s10
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # s11
        context_vec = self.out_proj(context_vec) # optional
 
        return context_vec

In [4]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        # s2
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # s3
        self.W_q = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # to combine head outputs
        self.dropout = nn.Dropout(dropout) 
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_len, context_len), diagonal = 1)
        )

    def forward(self, x):
        b, num_tokens, d_out = x.shape # s1

        # s4
        keys = self.W_k(x)
        queries = self.W_q(x)
        values = self.W_v(x)
        
        # s5
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # s6
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # s7
        attention_scores = queries @ keys.transpose(2, 3)

        # add masking for masked multi-head self attention
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attention_scores.masked_fill_(mask_bool, -torch.inf)
        attention_weights = torch.softmax(attention_scores / keys.shape[-1] ** 0.5, dim = -1)
        attention_weights = self.dropout(attention_weights) # s8

        context_vec = (attention_weights @ values).transpose(1, 2) # s9 & s10
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # s11
        context_vec = self.out_proj(context_vec) # optional
 
        return context_vec

In [None]:
class MultiCrossAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        # s2
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # s3
        self.W_q = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # to combine head outputs
        self.dropout = nn.Dropout(dropout) 

    def forward(self, dec_x, enc_x):
        b, dec_num_tokens, d_out = dec_x.shape
        _, enc_num_tokens, _ = enc_x.shape 

        # s4
        keys = self.W_k(enc_x)      
        queries = self.W_q(dec_x)    
        values = self.W_v(enc_x)     
        
        # s5 
        keys = keys.view(b, enc_num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, dec_num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, enc_num_tokens, self.num_heads, self.head_dim)

        # s6
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # s7
        attention_scores = queries @ keys.transpose(2, 3)
        attention_weights = torch.softmax(attention_scores / keys.shape[-1] ** 0.5, dim = -1)
        attention_weights = self.dropout(attention_weights)

        context_vec = (attention_weights @ values).transpose(1, 2)
        context_vec = context_vec.contiguous().view(b, dec_num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
 
        return context_vec

In [6]:
class Encoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention( # converts input to context vectors  
            d_in = cfg["emb_dim"],
            d_out = cfg["emb_dim"],
            context_len = cfg["context_len"],
            num_heads = cfg["num_heads"],
            dropout = cfg["drop_rate"],
            qkv_bias = cfg["qkv_bias"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
    
    def forward(self, x):
        # MHA
        shortcut = x
        x = self.norm1(x)
        x = self.att(x) # shape: [batch size, num tokens, emb size]
        x = self.drop_shortcut(x)
        x = x + shortcut # f(x) + x

        # FCL
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut # f(x) + x

        return x

In [7]:
class Decoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.masked_att = MaskedMultiHeadAttention( # converts input to context vectors  
            d_in = cfg["emb_dim"],
            d_out = cfg["emb_dim"],
            context_len = cfg["context_len"],
            num_heads = cfg["num_heads"],
            dropout = cfg["drop_rate"],
            qkv_bias = cfg["qkv_bias"]
        )
        self.cross_att = MultiCrossAttention( # converts input to context vectors  
            d_in = cfg["emb_dim"],
            d_out = cfg["emb_dim"],
            context_len = cfg["context_len"],
            num_heads = cfg["num_heads"],
            dropout = cfg["drop_rate"],
            qkv_bias = cfg["qkv_bias"]
        )
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.norm3 = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
    
    def forward(self, dec_x, enc_x): # input for the decoder and context from the encoder
        
        # MASKED MULTIHEAD ATTENTION
        shortcut = dec_x
        dec_x = self.norm1(dec_x)
        dec_x = self.masked_att(dec_x) # shape: [batch size, num tokens, emb size]
        dec_x = self.drop_shortcut(dec_x)
        dec_x = dec_x + shortcut # f(x) + x 

        # MULTIHEAD CROSS ATTENTION
        shortcut = dec_x
        dec_x = self.norm2(dec_x)
        dec_x = self.cross_att(dec_x, enc_x)
        dec_x = self.drop_shortcut(dec_x)
        dec_x = dec_x + shortcut # f(x) + x

        # FCL
        shortcut = dec_x
        dec_x = self.norm3(dec_x)
        dec_x = self.ff(dec_x)
        dec_x = self.drop_shortcut(dec_x)
        dec_x = dec_x + shortcut # f(x) + x

        return dec_x

In [None]:
class E2M(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        
        self.enc_tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.enc_pos_emb = nn.Embedding(cfg["context_len"], cfg["emb_dim"])     
        self.dec_tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.dec_pos_emb = nn.Embedding(cfg["context_len"], cfg["emb_dim"])

        self.drop_emb = nn.Dropout(cfg["drop_rate"])
        self.encoder_blocks = nn.Sequential(
            *[Encoder(cfg) for _ in range(cfg["n_layers"])]
        )
        # we change this becuase nn.Sequential cannot handle 2 variables at a time
        self.decoder_blocks = nn.ModuleList( 
            [Decoder(cfg) for _ in range(cfg["n_layers"])]
        )
        self.final_norm = LayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(
            cfg["emb_dim"], cfg["vocab_size"], bias = False
        )
        
    def encode(self, enc_in_idx):
        _, enc_seq_len = enc_in_idx.shape
        enc_tok_embeds = self.enc_tok_emb(enc_in_idx)

        pos_indices = torch.arange(enc_seq_len, device=enc_in_idx.device).unsqueeze(0)
        enc_pos_embeds = self.enc_pos_emb(pos_indices)
        enc_x = enc_tok_embeds + enc_pos_embeds
        enc_x = self.drop_emb(enc_x)
        enc_x = self.encoder_blocks(enc_x)
        return enc_x

    def decode(self, dec_in_idx, enc_x):
        _, dec_seq_len = dec_in_idx.shape
        dec_tok_embeds = self.dec_tok_emb(dec_in_idx)

        pos_indices = torch.arange(dec_seq_len, device=dec_in_idx.device).unsqueeze(0)
        dec_pos_embeds = self.dec_pos_emb(pos_indices)
        dec_x = dec_tok_embeds + dec_pos_embeds
        dec_x = self.drop_emb(dec_x)

        for block in self.decoder_blocks:
            dec_x = block(dec_x, enc_x)
            
        x = self.final_norm(dec_x)
        logits = self.out_head(x)
        return logits

    def forward(self, enc_in_idx, dec_in_idx):
        enc_x = self.encode(enc_in_idx)
        logits = self.decode(dec_in_idx, enc_x)
        return logits

---

In [9]:
eng = "eng-mal/train.en"
mal = "eng-mal/train.ml"
output_csv_path = "eng-mal.csv"

with open(eng, 'r', encoding = 'utf-8') as f:
    english_sentences = [line.strip() for line in f if line.strip()]  

with open(mal, 'r', encoding = 'utf-8') as f:
    malayalam_sentences = [line.strip() for line in f if line.strip()] 

df = pd.DataFrame({
    "english": english_sentences,
    "malayalam": malayalam_sentences
})

assert len(english_sentences) == len(malayalam_sentences), "Sentence count mismatch!"

In [10]:
df = pd.DataFrame({
    "english": english_sentences,
    "malayalam": malayalam_sentences
})

df.to_csv(output_csv_path, index = False, encoding = "utf-8")

In [11]:
combined_file = "combined.txt"
with open(combined_file, "w", encoding="utf-8") as f:
    for e, m in zip(df["english"], df["malayalam"]):
        f.write(e + "\n")
        f.write(m + "\n")

In [12]:
spm.SentencePieceTrainer.train(
    input = combined_file,
    model_prefix = "eng_mal_spm",
    vocab_size = 32000,
    model_type = "bpe",
    character_coverage = 1.0  # ensures malayalam characters are fully captured
)

In [13]:
sp = spm.SentencePieceProcessor()
sp.load("eng_mal_spm.model")
vocab_size = sp.get_piece_size() 
E2M_CONFIG["vocab_size"] = vocab_size 

df["english_ids"] = df["english"].apply(lambda x: sp.encode(x, out_type = int))
df["malayalam_ids"] = df["malayalam"].apply(lambda x: sp.encode(x, out_type = int))

print("Sample encoding:")
print("English:", df['english'].iloc[0])
print("Encoded:", df['english_ids'].iloc[0])
print("Malayalam:", df['malayalam'].iloc[0])
print("Encoded:", df['malayalam_ids'].iloc[0])

Sample encoding:
English: The plot of the movie revolves around the life of two cancer patients Kizie and Manny.
Encoded: [91, 18279, 72, 15, 1534, 4046, 125, 2071, 3092, 15, 1452, 72, 768, 9997, 6148, 206, 988, 479, 84, 1036, 7822, 29646]
Malayalam: ക്യാന്‍സറിനോട് പോരാടുന്ന കിസി, മാനി എന്നിവരുടെ ജീവിതമാണ് ചിത്രം പറയുന്നത്.
Encoded: [27958, 17480, 6452, 4387, 14561, 1927, 363, 29672, 2333, 29629, 9922, 2409, 428, 1637, 3032, 29646]


In [15]:
df.to_pickle("eng_mal_tokenized.pkl")

---

In [18]:
df_eng = df['english_ids']
df_mal = df['malayalam_ids']

In [21]:
train_df_eng = int(len(df_eng) * 0.85)
val_df_eng = int(len(df_eng) * 0.10)
test_df_eng = len(df_eng) - (train_df_eng + val_df_eng)

print(train_df_eng, val_df_eng, test_df_eng)

5035762 592442 296222


In [22]:
train_data_eng = df_eng.iloc[:train_df_eng]
val_data_eng = df_eng.iloc[train_df_eng:train_df_eng + val_df_eng]
test_data_eng = df_eng.iloc[train_df_eng + val_df_eng:]

In [23]:
train_df_mal = int(len(df_mal) * 0.85)
val_df_mal = int(len(df_mal) * 0.10)
test_df_mal = len(df_mal) - (train_df_mal + val_df_mal)

print(train_df_mal, val_df_mal, test_df_mal)

5035762 592442 296222


In [24]:
train_data_mal = df_mal.iloc[:train_df_mal]
val_data_mal = df_mal.iloc[train_df_mal:train_df_mal + val_df_mal]
test_data_mal = df_mal.iloc[train_df_mal + val_df_mal:]

In [28]:
def custom_collate_fn(batch, pad_token_id = 50256, ignore_index = -100, allowed_max_len = None, device = "cpu"):
    batch_max_len = max(len(item) + 1 for item in batch)
    inputs_lst, targets_lst = [], []

    for item in batch:
        new_item = item.copy()
        new_item += [pad_token_id]
        padded = (
            new_item + [pad_token_id] *
            (batch_max_len - len(new_item))
        )
        inputs = torch.tensor(padded[:-1])  
        targets = torch.tensor(padded[1:])

        mask = targets == pad_token_id
        indices = torch.nonzero(mask).squeeze()
        if indices.numel() > 1:
            targets[indices[1:]] = ignore_index

        if allowed_max_len is not None:
            inputs = inputs[:allowed_max_len]
            targets = targets[:allowed_max_len]

        inputs_lst.append(inputs)
        targets_lst.append(targets)

    inputs_tensor = torch.stack(inputs_lst).to(device)
    targets_tensor = torch.stack(targets_lst).to(device)

    return inputs_tensor, targets_tensor

In [32]:
device = torch.device("cuda")
customized_collate_fn = partial(custom_collate_fn, device = device, allowed_max_len = 1024) # does seperation of tasks when model training

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, eng_data, mal_data):
        self.eng_data = eng_data
        self.mal_data = mal_data

    def __getitem__(self, index):
        return self.eng_data.iloc[index], self.mal_data.iloc[index]
    
    def __len__(self):
        return len(self.eng_data)

train_dataset = TranslationDataset(train_data_eng, train_data_mal)
val_dataset = TranslationDataset(val_data_eng, val_data_mal)
test_dataset = TranslationDataset(test_data_eng, test_data_mal)

def paired_collate_fn(batch, pad_token_id = -1, device="cpu"): 
    eng_batch, mal_batch = zip(*batch)
    
    eng_max_len = max(len(item) for item in eng_batch)
    eng_padded = [item + [pad_token_id] * (eng_max_len - len(item)) for item in eng_batch]
    eng_tensor = torch.tensor(eng_padded, device=device)

    mal_max_len = max(len(item) for item in mal_batch)
    
    # decoder input: <bos> + sentence
    dec_inputs = [[1] + item for item in mal_batch] # 1 is <bos>
    dec_padded_inputs = [item + [pad_token_id] * (mal_max_len + 1 - len(item)) for item in dec_inputs]
    dec_inputs_tensor = torch.tensor(dec_padded_inputs, device=device)

    # target: sentence + <eos>
    targets = [item + [2] for item in mal_batch] # 2 is <eos>
    targets_padded = [item + [pad_token_id] * (mal_max_len + 1 - len(item)) for item in targets]
    targets_tensor = torch.tensor(targets_padded, device=device)

    return eng_tensor, dec_inputs_tensor, targets_tensor

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size = 8,
    collate_fn=partial(paired_collate_fn, device=device),
    shuffle=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size = 8,
    collate_fn=partial(paired_collate_fn, device=device),
    shuffle=False
)

---

In [39]:
model = E2M(E2M_CONFIG).to(device)
model.eval()

E2M(
  (enc_tok_emb): Embedding(50257, 768)
  (enc_pos_emb): Embedding(512, 768)
  (dec_tok_emb): Embedding(50257, 768)
  (dec_pos_emb): Embedding(512, 768)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (encoder_blocks): Sequential(
    (0): Encoder(
      (att): MultiHeadAttention(
        (W_q): Linear(in_features=768, out_features=768, bias=False)
        (W_k): Linear(in_features=768, out_features=768, bias=False)
        (W_v): Linear(in_features=768, out_features=768, bias=False)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (drop_shortcut): Dropout(p=0.1, inplace=False)
    )
    (1): Encoder(
      (att): Mu

In [43]:
def text_to_token_ids(text, sp):
    token_ids = sp.encode(text, out_type=int)
    return torch.tensor(token_ids).unsqueeze(0)  # shape: (1, seq_len)

def token_ids_to_text(token_ids, sp):
    flat_ids = token_ids.squeeze(0).tolist()
    return sp.decode(flat_ids)

In [None]:
def generate(model, enc_in_idx, max_new_tokens, context_size, temperature = 1.0, top_k = None):
    model.eval()
    
    with torch.no_grad():
        enc_x = model.encode(enc_in_idx)
        
        # start the decoder with a start-of-sequence token (id : 1)
        dec_in_idx = torch.tensor([[1]], device = enc_in_idx.device)

        # autoregressive generation loop
        for _ in range(max_new_tokens):
            dec_in_idx_cond = dec_in_idx[:, -context_size:]
            
            logits = model.decode(dec_in_idx_cond, enc_x)
            logits = logits[:, -1, :] # Get logits for the last token

            if top_k is not None:
                top_logits, _ = torch.topk(logits, top_k)
                min_val = top_logits[:, -1]
                logits = torch.where(
                    logits < min_val, 
                    torch.tensor(float("-inf")).to(logits.device),
                    logits
                )

            if temperature > 0.0:
                probs = torch.softmax(logits / temperature, dim=-1)
                idx_next = torch.multinomial(probs, num_samples=1)
            else:
                idx_next = torch.argmax(logits, dim=-1, keepdim=True)
            
            # sentencepiece model also has an EOS id 2
            if idx_next.item() == 2:
                break
                
            dec_in_idx = torch.cat((dec_in_idx, idx_next), dim=1)
    
    return dec_in_idx

In [None]:
torch.manual_seed(123)

token_ids = generate(
    model = model, 
    idx = text_to_token_ids("Everytime I see you", sp).to(device),
    max_new_tokens = 25,
    context_size = E2M_CONFIG["context_len"],
    top_k = 50,
    temperature = 1.5
)

print("Output:", token_ids_to_text(token_ids, sp))

---

In [None]:
def generate_text_simple(model, idx, max_new_tokens, context_size): # idx is the input batch
    for _ in range(max_new_tokens):
        # crop current context
        idx_cond = idx[:, -context_size:]
        # get predictions
        with torch.no_grad():
            logits = model(idx_cond) # batch_size x tokens_num x vocab_size
        # get the last time step (last set of logits)
        logits = logits[:, -1, :]
        # apply softmax
        probs = torch.softmax(logits, dim = -1)
        # get id of max
        idx_next = torch.argmax(probs, dim = -1, keepdim = True)
        # append id to running sequence
        idx = torch.cat((idx, idx_next), dim = -1)
    return idx

In [None]:
def calc_loss_batch(eng_batch, dec_inputs_batch, targets_batch, model, loss_fn, device):
    eng_batch = eng_batch.to(device)
    dec_inputs_batch = dec_inputs_batch.to(device)
    targets_batch = targets_batch.to(device)

    logits = model(eng_batch, dec_inputs_batch)
    
    loss = loss_fn(logits.view(-1, E2M_CONFIG["vocab_size"]), targets_batch.view(-1))
    return loss


def calc_loss_loader(data_loader, model, device, num_batches = None): # this will show the loss of the LM
    total_loss = 0
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    
    return total_loss / num_batches # mean loss per batch


def evaluate_model(model, data_loader, loss_fn, device):
    model.eval() 
    total_loss = 0.
    with torch.no_grad():
        for eng_batch, dec_inputs_batch, targets_batch in data_loader:
            loss = calc_loss_batch(eng_batch, dec_inputs_batch, targets_batch, model, loss_fn, device)
            total_loss += loss.item()
    
    return total_loss / len(data_loader)