In [1]:
# Parameters
lang1 = "hindi"
lang2 = "bengali"

In [2]:
# Parameters
lang1 = "hindi"
lang2 = "telugu"


In [3]:
print(f"Running with lang1: {lang1}, lang2: {lang2}")

Running with lang1: hindi, lang2: telugu


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import BertTokenizerFast, BertModel
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import numpy as np
import string
import json

2024-11-18 20:15:33.471233: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-18 20:15:33.471260: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-18 20:15:33.472610: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-18 20:15:33.479152: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.




In [5]:
# Load train data from json (L1)
with open(f"/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang1}_train.json") as f:
    train_data_1 = json.load(f)

# Load test data from json
with open(f"/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang1}_test.json") as f:
    test_data_1 = json.load(f)

# Load train data from json (L2)
with open(f"/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang2}_train.json") as f:
    train_data_2 = json.load(f)

# Load test data from json
with open(f"/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang2}_test.json") as f:
    test_data_2 = json.load(f)


In [6]:
print(f"Running with lang1: {lang1}, lang2: {lang2}")

Running with lang1: hindi, lang2: telugu


In [7]:
# Merge the two datasets
train_data = train_data_1 + train_data_2
test_data = test_data_1 + test_data_2

In [8]:
from transformers import AutoTokenizer, AutoModel

# Load tokeniser
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')

In [9]:
class NER_Dataset(Dataset):
    def __init__(self, data, tokenizer, max_len=128):
        self.data = data  # List of sentences and labels
        self.tokenizer = tokenizer  # mBERT tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["tokens"]
        word_labels = self.data[idx]["ner_tags"]

        # Tokenize the text and align the labels
        encoding = self.tokenizer(text,
                                  is_split_into_words=True,
                                  return_offsets_mapping=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len)

        labels = [self.align_labels(encoding['offset_mapping'], word_labels)]

        # Remove the offset mapping to prevent issues during model training
        del encoding['offset_mapping']

        item = {key: torch.tensor(val) for key, val in encoding.items()}
        item['labels'] = torch.tensor(labels[0], dtype=torch.long)

        return item
    
    # Create a function to align labels
    def align_labels(self, offset_mapping, labels):
        aligned_labels = []
        current_label_index = 0

        for offset in offset_mapping:
            # If the offset mapping is (0, 0), it's a special token ([CLS], [SEP], [PAD])
            if offset == (0, 0):
                aligned_labels.append(-100)  # -100 is used to ignore these tokens in the loss computation
            else:
                # Check if the token is the start of a new word
                if offset[0] == 0:
                    aligned_labels.append(labels[current_label_index])
                    current_label_index += 1
                else:
                    # If the token is not the first subtoken, you can decide how to label it. 
                    # For simplicity, let's use the same label as the first subtoken
                    aligned_labels.append(labels[current_label_index - 1])

        return aligned_labels

In [10]:
# Create train dataset and test dataset
train_dataset = NER_Dataset(train_data, tokenizer)
test_dataset = NER_Dataset(test_data, tokenizer)

# Create train dataloader and test dataloader
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [11]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

In [12]:
import torch
import torch.nn as nn

class CrossAttentionFusion(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super(CrossAttentionFusion, self).__init__()
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, embeddings_mbert, embeddings_indicbert):
        """
        Fuse embeddings from mBERT and Indic-BERT using cross-attention.

        Args:
            embeddings_mbert (torch.Tensor): Embeddings from mBERT of shape [batch_size, seq_len_m, hidden_dim].
            embeddings_indicbert (torch.Tensor): Embeddings from Indic-BERT of shape [batch_size, seq_len_i, hidden_dim].

        Returns:
            torch.Tensor: Fused embeddings of shape [batch_size, seq_len_m, hidden_dim].
        """
        embeddings_mbert = embeddings_mbert.to(DEVICE)
        embeddings_indicbert = embeddings_indicbert.to(DEVICE)
        
        # Transpose to shape [seq_len, batch_size, hidden_dim] for nn.MultiheadAttention
        embeddings_mbert = embeddings_mbert.transpose(0, 1)
        embeddings_indicbert = embeddings_indicbert.transpose(0, 1)

        # Cross-attention: Indic BERT queries attend to m-BERT keys and values
        attn_output, _ = self.cross_attention(
            query=embeddings_indicbert,       # Queries from Indic-BERT
            key=embeddings_mbert,     # Keys from m-BERT
            value=embeddings_mbert    # Values from m-BERT
        )

        attn_output.to(DEVICE)

        # Residual connection and layer normalization
        output = self.layer_norm(attn_output + embeddings_mbert)

        # Transpose back to [batch_size, seq_len_m, hidden_dim]
        output = output.transpose(0, 1)

        return output

In [13]:
hidden_dim = 768
num_heads = 4

In [14]:
# Model for NER

class MBERT_NER(nn.Module):
    def __init__(self, num_labels, gru_hidden_size, num_gru_layers, freeze_bert=False):
        super(MBERT_NER, self).__init__()

        self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
        self.indic_bert = AutoModel.from_pretrained("ai4bharat/indic-bert")
        
        # Initialize the Fusion Attention
        self.attention = CrossAttentionFusion(hidden_dim=hidden_dim, num_heads=num_heads).to(DEVICE)
        
        self.gru = nn.GRU(input_size=self.bert.config.hidden_size * 3,
                          hidden_size=gru_hidden_size,
                          num_layers=num_gru_layers,
                          batch_first=True)
        self.classifier = nn.Linear(gru_hidden_size, num_labels)
        self.dropout = nn.Dropout(0.1)
        self.batch_norm = nn.BatchNorm1d(gru_hidden_size)

        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
            for param in self.indic_bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)
        
        # Embeddings
        outputs1 = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        outputs2 = self.indic_bert(input_ids = input_ids, attention_mask = attention_mask)        
        
        embeddings_mbert = outputs1.last_hidden_state.to(DEVICE)
        embeddings_indicbert = outputs2.last_hidden_state.to(DEVICE)
        
        # Fused embeddings
        embeddings = self.attention(embeddings_mbert, embeddings_indicbert).to(DEVICE)
        
        sequence_output = torch.cat((embeddings_mbert, embeddings_indicbert, embeddings), dim=-1)
        
        gru_output, _ = self.gru(sequence_output)
        gru_output = self.batch_norm(gru_output)
        gru_output = self.dropout(gru_output)
        
        logits = self.classifier(gru_output)
        return logits

# Create the NER model
NUM_LABELS = 7 # Number of NER tags
GRU_HIDDEN_SIZE = 128 # Hidden size of the GRU
NUM_GRU_LAYERS = 1 # Number of layers in the GRU
FREEZE_BERT = False # Whether to freeze the BERT model
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MBERT_NER(num_labels=NUM_LABELS,
                    gru_hidden_size=GRU_HIDDEN_SIZE,
                    num_gru_layers=NUM_GRU_LAYERS,
                    freeze_bert=FREEZE_BERT)

model.to(DEVICE)

MBERT_NER(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [15]:
torch.cuda.empty_cache()

In [16]:
from sklearn.metrics import f1_score, accuracy_score

# Optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

# Training loop
EPOCHS = 6

train_losses = []
train_f1_scores = []
train_acc_scores = []
val_losses = []
val_f1_scores = []
val_acc_scores = []


for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    train_predictions, train_labels = [], []

    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        logits = logits.view(-1, NUM_LABELS)
        labels = labels.view(-1)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Get predictions and filter out ignored indices for metric calculations
        predictions = torch.argmax(logits, dim=-1)
        active_indices = labels != -100
        train_predictions.extend(predictions[active_indices].cpu().numpy())
        train_labels.extend(labels[active_indices].cpu().numpy())

    train_loss /= len(train_dataloader)
    train_losses.append(train_loss)

    train_f1 = f1_score(train_labels, train_predictions, average='macro')
    train_f1_scores.append(train_f1)

    train_acc = accuracy_score(train_labels, train_predictions)
    train_acc_scores.append(train_acc)

    # Validation loop
    model.eval()
    val_loss = 0
    val_predictions, val_labels = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            logits = model(input_ids, attention_mask)
            logits = logits.view(-1, NUM_LABELS)
            labels = labels.view(-1)

            loss = criterion(logits, labels)
            val_loss += loss.item()

            # Filter predictions and labels
            predictions = torch.argmax(logits, dim=-1)
            active_indices = labels != -100
            val_predictions.extend(predictions[active_indices].cpu().numpy())
            val_labels.extend(labels[active_indices].cpu().numpy())

    val_loss /= len(test_dataloader)
    val_losses.append(val_loss)

    val_f1 = f1_score(val_labels, val_predictions, average='macro')
    val_f1_scores.append(val_f1)

    val_acc = accuracy_score(val_labels, val_predictions)
    val_acc_scores.append(val_acc)
                                                
    print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}, Val Acc: {val_acc:.4f}")

Epoch 1/6, Train Loss: 0.5785, Train F1: 0.7126, Train Acc: 0.8149, Val Loss: 0.3273, Val F1: 0.8284, Val Acc: 0.8984


Epoch 2/6, Train Loss: 0.2748, Train F1: 0.8696, Train Acc: 0.9182, Val Loss: 0.2678, Val F1: 0.8668, Val Acc: 0.9221


Epoch 3/6, Train Loss: 0.1890, Train F1: 0.9107, Train Acc: 0.9443, Val Loss: 0.2634, Val F1: 0.8773, Val Acc: 0.9265


Epoch 4/6, Train Loss: 0.1343, Train F1: 0.9407, Train Acc: 0.9620, Val Loss: 0.2455, Val F1: 0.8885, Val Acc: 0.9305


Epoch 5/6, Train Loss: 0.0993, Train F1: 0.9568, Train Acc: 0.9731, Val Loss: 0.2557, Val F1: 0.8937, Val Acc: 0.9363


Epoch 6/6, Train Loss: 0.0830, Train F1: 0.9643, Train Acc: 0.9774, Val Loss: 0.2570, Val F1: 0.8916, Val Acc: 0.9344


In [17]:
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

# Load the test data from json for all 5 languages
languages = ['hindi', 'bengali', 'marathi', 'tamil', 'telugu']

# Iterate over all languages and evaluate the model
for lang in languages:
    with open(f"/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang}_test.json") as f:
        test_data = json.load(f)

    test_dataset = NER_Dataset(test_data, tokenizer)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model.eval()
    test_predictions, test_labels = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            logits = model(input_ids, attention_mask)
            logits = logits.view(-1, NUM_LABELS)
            labels = labels.view(-1)

            # Filter predictions and labels
            predictions = torch.argmax(logits, dim=-1)
            active_indices = labels != -100
            test_predictions.extend(predictions[active_indices].cpu().numpy())
            test_labels.extend(labels[active_indices].cpu().numpy())

    weighted_f1 = f1_score(test_labels, test_predictions, average='weighted')
    macro_f1 = f1_score(test_labels, test_predictions, average='macro')
    accuracy = accuracy_score(test_labels, test_predictions)

    if lang == 'hindi':
        LANG = "Hindi"
    elif lang == 'bengali':
        LANG = "Bengali"
    elif lang == 'marathi':
        LANG = "Marathi"
    elif lang == 'tamil':
        LANG = "Tamil"
    elif lang == 'telugu':
        LANG = "Telugu"

    print(f"Language: {LANG}")
    print()
    print(f"Weighted F1: {weighted_f1:.4f}")
    print(f"Macro F1: {macro_f1:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print()
    print("CLASSIFICATION REPORT: ")
    print(classification_report(test_labels, test_predictions))
    print()
    print("CONFUSION MATRIX: ")
    print(confusion_matrix(test_labels, test_predictions))
    print("--------------------------------------------------")

Language: Hindi

Weighted F1: 0.9351
Macro F1: 0.9155
Accuracy: 0.9350

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.98      0.96      0.97      5277
           1       0.90      0.95      0.92      1481
           2       0.90      0.97      0.93      1874
           3       0.93      0.86      0.89       986
           4       0.96      0.91      0.93      2602
           5       0.89      0.91      0.90      1195
           6       0.85      0.87      0.86       867

    accuracy                           0.94     14282
   macro avg       0.91      0.92      0.92     14282
weighted avg       0.94      0.94      0.94     14282


CONFUSION MATRIX: 
[[5067   66   37   22   19   38   28]
 [  26 1413   16    8    0   18    0]
 [   0   18 1825    0   16    1   14]
 [  19   70    2  846    5   43    1]
 [  34    0  115    0 2366   20   67]
 [  27   10    5   36    3 1087   27]
 [  22    0   31    0   51   13  750]]
--------------------

Language: Bengali

Weighted F1: 0.8002
Macro F1: 0.8020
Accuracy: 0.8039

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.84      0.95      0.89      1835
           1       0.75      0.89      0.82      1343
           2       0.76      0.88      0.82      2490
           3       0.83      0.79      0.81      1239
           4       0.81      0.79      0.80      3068
           5       0.81      0.76      0.79      1510
           6       0.84      0.59      0.69      2080

    accuracy                           0.80     13565
   macro avg       0.81      0.81      0.80     13565
weighted avg       0.81      0.80      0.80     13565


CONFUSION MATRIX: 
[[1744    9   23   12   23   21    3]
 [  36 1197   26   36    0   48    0]
 [  99   35 2187    0   85    4   80]
 [  23  142    0  980    2   91    1]
 [ 129    7  352   13 2428   18  121]
 [  12  182    1  139    6 1147   23]
 [  38   15  282    1  442   80 1222]]
------------------

Language: Marathi

Weighted F1: 0.8368
Macro F1: 0.7409
Accuracy: 0.8412

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.93      0.94      0.93     10408
           1       0.69      0.92      0.79      1342
           2       0.76      0.90      0.82      1992
           3       0.80      0.45      0.58      1172
           4       0.88      0.63      0.73      2535
           5       0.71      0.81      0.76      1829
           6       0.58      0.57      0.57       820

    accuracy                           0.84     20098
   macro avg       0.76      0.75      0.74     20098
weighted avg       0.85      0.84      0.84     20098


CONFUSION MATRIX: 
[[9803  102   50   68   72  239   74]
 [  34 1235   28   11    0   34    0]
 [  42   46 1785    0   28   51   40]
 [  54  292   11  528   32  249    6]
 [ 246   47  416    5 1601   12  208]
 [ 200   60   13   52    0 1490   14]
 [ 186    0   51    0   96   22  465]]
------------------

Language: Tamil

Weighted F1: 0.8218
Macro F1: 0.7267
Accuracy: 0.8233

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.91      0.91      0.91     11461
           1       0.69      0.72      0.70      1643
           2       0.83      0.80      0.82      2295
           3       0.60      0.52      0.55      1512
           4       0.75      0.76      0.75      2532
           5       0.75      0.83      0.79      1881
           6       0.61      0.52      0.56       677

    accuracy                           0.82     22001
   macro avg       0.73      0.72      0.73     22001
weighted avg       0.82      0.82      0.82     22001


CONFUSION MATRIX: 
[[10483   306    44   320   115   150    43]
 [  130  1178   211    70    12    42     0]
 [  229   129  1835    15    54     8    25]
 [  115    75    21   781   278   234     8]
 [  329     9    82    23  1929    48   112]
 [  139     7     3   101    36  1554    41]
 [  120     0    

Language: Telugu

Weighted F1: 0.9341
Macro F1: 0.8672
Accuracy: 0.9340

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.98      0.97      0.97     17291
           1       0.84      0.88      0.86      1810
           2       0.87      0.94      0.90      1353
           3       0.85      0.73      0.79      1375
           4       0.93      0.89      0.91      2269
           5       0.83      0.92      0.87      2072
           6       0.73      0.81      0.77       455

    accuracy                           0.93     26625
   macro avg       0.86      0.88      0.87     26625
weighted avg       0.94      0.93      0.93     26625


CONFUSION MATRIX: 
[[16716   135    56    77    83   185    39]
 [  107  1596    30    36     0    41     0]
 [   34    20  1269     0    10     0    20]
 [  109   116     0  1002    18   130     0]
 [   56     4    95     4  2015    33    62]
 [   77    22     0    55     0  1900    18]
 [   40     0   

In [18]:
# Save the model
torch.save(model.state_dict(), f'NER_FineTune_{lang1[:2]}_{lang2[:2]}_cross_attn.pth')

-------------------------------------------------------------------------------------------------------------------------

# DUMP

--------------------------------------------------------------------------------------------------------------

In [19]:
# import torch
# import torch.nn as nn

# class FusionModule(nn.Module):
#     def __init__(self, hidden_size):
#         super(FusionModule, self).__init__()
#         # Linear layer to compute attention scores
#         self.attention_layer = nn.Linear(hidden_size * 2, 2)

#     def forward(self, embeddings_mbert, embeddings_indicbert):
#         """
#         Args:
#             embeddings_mbert (torch.Tensor): [batch_size, seq_len, hidden_size]
#             embeddings_indicbert (torch.Tensor): [batch_size, seq_len, hidden_size]
#         Returns:
#             torch.Tensor: Fused embeddings [batch_size, seq_len, hidden_size]
#         """
#         # Concatenate embeddings along the last dimension
#         concat_embeddings = torch.cat((embeddings_mbert, embeddings_indicbert), dim=-1)  # [batch_size, seq_len, hidden_size * 2]

#         # Compute attention scores
#         attn_scores = self.attention_layer(concat_embeddings)  # [batch_size, seq_len, 2]

#         # Apply softmax to get attention weights
#         attn_weights = torch.softmax(attn_scores, dim=-1)  # [batch_size, seq_len, 2]

#         # Split attention weights
#         attn_weights_mbert = attn_weights[..., 0].unsqueeze(-1)  # [batch_size, seq_len, 1]
#         attn_weights_indicbert = attn_weights[..., 1].unsqueeze(-1)  # [batch_size, seq_len, 1]

#         # Compute weighted sum of embeddings
#         fused_embeddings = attn_weights_mbert * embeddings_mbert + attn_weights_indicbert * embeddings_indicbert  # [batch_size, seq_len, hidden_size]

#         return fused_embeddings

In [20]:
# import torch
# import torch.nn as nn

# class SelfAttentionLayer(nn.Module):
#     def __init__(self, hidden_dim, num_heads, dropout=0.1):
#         super(SelfAttentionLayer, self).__init__()
#         self.self_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout)
#         self.layer_norm = nn.LayerNorm(hidden_dim)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, embeddings, attention_mask=None):
#         """
#         Args:
#             embeddings (torch.Tensor): [batch_size, seq_len, hidden_dim]
#             attention_mask (torch.Tensor): [batch_size, seq_len], with 1 for tokens to attend to and 0 for padding
#         Returns:
#             torch.Tensor: [batch_size, seq_len, hidden_dim]
#         """
#         embeddings = embeddings.to(DEVICE)
        
        
#         # Transpose embeddings to [seq_len, batch_size, hidden_dim] for nn.MultiheadAttention
#         embeddings = embeddings.transpose(0, 1)

#         # Prepare attention mask for nn.MultiheadAttention (optional)
#         if attention_mask is not None:
#     # nn.MultiheadAttention expects key_padding_mask of shape [batch_size, seq_len]
#             key_padding_mask = attention_mask == 0  # True for padding tokens
#         else:
#             key_padding_mask = None


#         # Self-attention
#         attn_output, _ = self.self_attention(
#             query=embeddings,
#             key=embeddings,
#             value=embeddings,
#             key_padding_mask=key_padding_mask
#         )

#         # Residual connection and layer normalization
#         attn_output = self.dropout(attn_output)
#         attn_output = self.layer_norm(attn_output + embeddings)

#         # Transpose back to [batch_size, seq_len, hidden_dim]
#         attn_output = attn_output.transpose(0, 1)

#         return attn_output