In [1]:
# Parameters
lang1 = "hindi"
lang2 = "bengali"

In [2]:
# Parameters
lang1 = "telugu"
lang2 = "bengali"


In [3]:
print(f"Running with lang1: {lang1}, lang2: {lang2}")

Running with lang1: telugu, lang2: bengali


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import BertTokenizerFast, BertModel
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import numpy as np
import string
import json

2024-11-18 20:56:31.416586: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-18 20:56:31.416613: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-18 20:56:31.417975: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-18 20:56:31.424525: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.




In [5]:
# Load train data from json (L1)
with open(f"/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang1}_train.json") as f:
    train_data_1 = json.load(f)

# Load test data from json
with open(f"/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang1}_test.json") as f:
    test_data_1 = json.load(f)

# Load train data from json (L2)
with open(f"/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang2}_train.json") as f:
    train_data_2 = json.load(f)

# Load test data from json
with open(f"/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang2}_test.json") as f:
    test_data_2 = json.load(f)


In [6]:
print(f"Running with lang1: {lang1}, lang2: {lang2}")

Running with lang1: telugu, lang2: bengali


In [7]:
# Merge the two datasets
train_data = train_data_1 + train_data_2
test_data = test_data_1 + test_data_2

In [8]:
from transformers import AutoTokenizer, AutoModel

# Load tokeniser
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')

In [9]:
class NER_Dataset(Dataset):
    def __init__(self, data, tokenizer, max_len=128):
        self.data = data  # List of sentences and labels
        self.tokenizer = tokenizer  # mBERT tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["tokens"]
        word_labels = self.data[idx]["ner_tags"]

        # Tokenize the text and align the labels
        encoding = self.tokenizer(text,
                                  is_split_into_words=True,
                                  return_offsets_mapping=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len)

        labels = [self.align_labels(encoding['offset_mapping'], word_labels)]

        # Remove the offset mapping to prevent issues during model training
        del encoding['offset_mapping']

        item = {key: torch.tensor(val) for key, val in encoding.items()}
        item['labels'] = torch.tensor(labels[0], dtype=torch.long)

        return item
    
    # Create a function to align labels
    def align_labels(self, offset_mapping, labels):
        aligned_labels = []
        current_label_index = 0

        for offset in offset_mapping:
            # If the offset mapping is (0, 0), it's a special token ([CLS], [SEP], [PAD])
            if offset == (0, 0):
                aligned_labels.append(-100)  # -100 is used to ignore these tokens in the loss computation
            else:
                # Check if the token is the start of a new word
                if offset[0] == 0:
                    aligned_labels.append(labels[current_label_index])
                    current_label_index += 1
                else:
                    # If the token is not the first subtoken, you can decide how to label it. 
                    # For simplicity, let's use the same label as the first subtoken
                    aligned_labels.append(labels[current_label_index - 1])

        return aligned_labels

In [10]:
# Create train dataset and test dataset
train_dataset = NER_Dataset(train_data, tokenizer)
test_dataset = NER_Dataset(test_data, tokenizer)

# Create train dataloader and test dataloader
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [11]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

In [12]:
import torch
import torch.nn as nn

class CrossAttentionFusion(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super(CrossAttentionFusion, self).__init__()
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, embeddings_mbert, embeddings_indicbert):
        """
        Fuse embeddings from mBERT and Indic-BERT using cross-attention.

        Args:
            embeddings_mbert (torch.Tensor): Embeddings from mBERT of shape [batch_size, seq_len_m, hidden_dim].
            embeddings_indicbert (torch.Tensor): Embeddings from Indic-BERT of shape [batch_size, seq_len_i, hidden_dim].

        Returns:
            torch.Tensor: Fused embeddings of shape [batch_size, seq_len_m, hidden_dim].
        """
        embeddings_mbert = embeddings_mbert.to(DEVICE)
        embeddings_indicbert = embeddings_indicbert.to(DEVICE)
        
        # Transpose to shape [seq_len, batch_size, hidden_dim] for nn.MultiheadAttention
        embeddings_mbert = embeddings_mbert.transpose(0, 1)
        embeddings_indicbert = embeddings_indicbert.transpose(0, 1)

        # Cross-attention: Indic BERT queries attend to m-BERT keys and values
        attn_output, _ = self.cross_attention(
            query=embeddings_indicbert,       # Queries from Indic-BERT
            key=embeddings_mbert,     # Keys from m-BERT
            value=embeddings_mbert    # Values from m-BERT
        )

        attn_output.to(DEVICE)

        # Residual connection and layer normalization
        output = self.layer_norm(attn_output + embeddings_mbert)

        # Transpose back to [batch_size, seq_len_m, hidden_dim]
        output = output.transpose(0, 1)

        return output

In [13]:
hidden_dim = 768
num_heads = 4

In [14]:
# Model for NER

class MBERT_NER(nn.Module):
    def __init__(self, num_labels, gru_hidden_size, num_gru_layers, freeze_bert=False):
        super(MBERT_NER, self).__init__()

        self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
        self.indic_bert = AutoModel.from_pretrained("ai4bharat/indic-bert")
        
        # Initialize the Fusion Attention
        self.attention = CrossAttentionFusion(hidden_dim=hidden_dim, num_heads=num_heads).to(DEVICE)
        
        self.gru = nn.GRU(input_size=self.bert.config.hidden_size * 3,
                          hidden_size=gru_hidden_size,
                          num_layers=num_gru_layers,
                          batch_first=True)
        self.classifier = nn.Linear(gru_hidden_size, num_labels)
        self.dropout = nn.Dropout(0.1)
        self.batch_norm = nn.BatchNorm1d(gru_hidden_size)

        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
            for param in self.indic_bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)
        
        # Embeddings
        outputs1 = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        outputs2 = self.indic_bert(input_ids = input_ids, attention_mask = attention_mask)        
        
        embeddings_mbert = outputs1.last_hidden_state.to(DEVICE)
        embeddings_indicbert = outputs2.last_hidden_state.to(DEVICE)
        
        # Fused embeddings
        embeddings = self.attention(embeddings_mbert, embeddings_indicbert).to(DEVICE)
        
        sequence_output = torch.cat((embeddings_mbert, embeddings_indicbert, embeddings), dim=-1)
        
        gru_output, _ = self.gru(sequence_output)
        gru_output = self.batch_norm(gru_output)
        gru_output = self.dropout(gru_output)
        
        logits = self.classifier(gru_output)
        return logits

# Create the NER model
NUM_LABELS = 7 # Number of NER tags
GRU_HIDDEN_SIZE = 128 # Hidden size of the GRU
NUM_GRU_LAYERS = 1 # Number of layers in the GRU
FREEZE_BERT = False # Whether to freeze the BERT model
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MBERT_NER(num_labels=NUM_LABELS,
                    gru_hidden_size=GRU_HIDDEN_SIZE,
                    num_gru_layers=NUM_GRU_LAYERS,
                    freeze_bert=FREEZE_BERT)

model.to(DEVICE)

MBERT_NER(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [15]:
torch.cuda.empty_cache()

In [16]:
from sklearn.metrics import f1_score, accuracy_score

# Optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

# Training loop
EPOCHS = 6

train_losses = []
train_f1_scores = []
train_acc_scores = []
val_losses = []
val_f1_scores = []
val_acc_scores = []


for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    train_predictions, train_labels = [], []

    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        logits = logits.view(-1, NUM_LABELS)
        labels = labels.view(-1)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Get predictions and filter out ignored indices for metric calculations
        predictions = torch.argmax(logits, dim=-1)
        active_indices = labels != -100
        train_predictions.extend(predictions[active_indices].cpu().numpy())
        train_labels.extend(labels[active_indices].cpu().numpy())

    train_loss /= len(train_dataloader)
    train_losses.append(train_loss)

    train_f1 = f1_score(train_labels, train_predictions, average='macro')
    train_f1_scores.append(train_f1)

    train_acc = accuracy_score(train_labels, train_predictions)
    train_acc_scores.append(train_acc)

    # Validation loop
    model.eval()
    val_loss = 0
    val_predictions, val_labels = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            logits = model(input_ids, attention_mask)
            logits = logits.view(-1, NUM_LABELS)
            labels = labels.view(-1)

            loss = criterion(logits, labels)
            val_loss += loss.item()

            # Filter predictions and labels
            predictions = torch.argmax(logits, dim=-1)
            active_indices = labels != -100
            val_predictions.extend(predictions[active_indices].cpu().numpy())
            val_labels.extend(labels[active_indices].cpu().numpy())

    val_loss /= len(test_dataloader)
    val_losses.append(val_loss)

    val_f1 = f1_score(val_labels, val_predictions, average='macro')
    val_f1_scores.append(val_f1)

    val_acc = accuracy_score(val_labels, val_predictions)
    val_acc_scores.append(val_acc)
                                                
    print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}, Val Acc: {val_acc:.4f}")

Epoch 1/6, Train Loss: 0.5972, Train F1: 0.7575, Train Acc: 0.8070, Val Loss: 0.3205, Val F1: 0.8704, Val Acc: 0.9055


Epoch 2/6, Train Loss: 0.2635, Train F1: 0.9050, Train Acc: 0.9235, Val Loss: 0.2598, Val F1: 0.8999, Val Acc: 0.9240


Epoch 3/6, Train Loss: 0.1701, Train F1: 0.9396, Train Acc: 0.9513, Val Loss: 0.2688, Val F1: 0.8937, Val Acc: 0.9186


Epoch 4/6, Train Loss: 0.1312, Train F1: 0.9548, Train Acc: 0.9636, Val Loss: 0.2120, Val F1: 0.9184, Val Acc: 0.9391


Epoch 5/6, Train Loss: 0.0920, Train F1: 0.9694, Train Acc: 0.9759, Val Loss: 0.2178, Val F1: 0.9152, Val Acc: 0.9362


Epoch 6/6, Train Loss: 0.0703, Train F1: 0.9764, Train Acc: 0.9812, Val Loss: 0.2334, Val F1: 0.9139, Val Acc: 0.9337


In [17]:
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

# Load the test data from json for all 5 languages
languages = ['hindi', 'bengali', 'marathi', 'tamil', 'telugu']

# Iterate over all languages and evaluate the model
for lang in languages:
    with open(f"/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang}_test.json") as f:
        test_data = json.load(f)

    test_dataset = NER_Dataset(test_data, tokenizer)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model.eval()
    test_predictions, test_labels = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            logits = model(input_ids, attention_mask)
            logits = logits.view(-1, NUM_LABELS)
            labels = labels.view(-1)

            # Filter predictions and labels
            predictions = torch.argmax(logits, dim=-1)
            active_indices = labels != -100
            test_predictions.extend(predictions[active_indices].cpu().numpy())
            test_labels.extend(labels[active_indices].cpu().numpy())

    weighted_f1 = f1_score(test_labels, test_predictions, average='weighted')
    macro_f1 = f1_score(test_labels, test_predictions, average='macro')
    accuracy = accuracy_score(test_labels, test_predictions)

    if lang == 'hindi':
        LANG = "Hindi"
    elif lang == 'bengali':
        LANG = "Bengali"
    elif lang == 'marathi':
        LANG = "Marathi"
    elif lang == 'tamil':
        LANG = "Tamil"
    elif lang == 'telugu':
        LANG = "Telugu"

    print(f"Language: {LANG}")
    print()
    print(f"Weighted F1: {weighted_f1:.4f}")
    print(f"Macro F1: {macro_f1:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print()
    print("CLASSIFICATION REPORT: ")
    print(classification_report(test_labels, test_predictions))
    print()
    print("CONFUSION MATRIX: ")
    print(confusion_matrix(test_labels, test_predictions))
    print("--------------------------------------------------")

Language: Hindi

Weighted F1: 0.7959
Macro F1: 0.7474
Accuracy: 0.7925

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.94      0.78      0.85      5277
           1       0.74      0.82      0.78      1481
           2       0.83      0.89      0.86      1874
           3       0.53      0.73      0.61       986
           4       0.77      0.87      0.82      2602
           5       0.73      0.67      0.70      1195
           6       0.62      0.60      0.61       867

    accuracy                           0.79     14282
   macro avg       0.74      0.77      0.75     14282
weighted avg       0.81      0.79      0.80     14282


CONFUSION MATRIX: 
[[4122  300   71  442   97  153   92]
 [  23 1215  142   42   19   34    6]
 [  19   38 1666    7  101    0   43]
 [  22   61    0  724  132   39    8]
 [ 113    8  110   15 2267   23   66]
 [  38   10    1  144   97  801  104]
 [  35    7   28    3  225   46  523]]
--------------------

Language: Bengali

Weighted F1: 0.9512
Macro F1: 0.9503
Accuracy: 0.9512

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.98      0.90      0.94      1835
           1       0.94      0.95      0.94      1343
           2       0.96      0.96      0.96      2490
           3       0.93      0.96      0.94      1239
           4       0.93      0.97      0.95      3068
           5       0.97      0.94      0.95      1510
           6       0.96      0.96      0.96      2080

    accuracy                           0.95     13565
   macro avg       0.95      0.95      0.95     13565
weighted avg       0.95      0.95      0.95     13565


CONFUSION MATRIX: 
[[1652   31   17   19   72   16   28]
 [   4 1277   11   37    0   14    0]
 [   7   13 2384    0   69    0   17]
 [   4   25    0 1193    6   11    0]
 [   3    0   59    1 2988    0   17]
 [   4   15    3   37   19 1419   13]
 [   4    0   17    0   67    2 1990]]
------------------

Language: Marathi

Weighted F1: 0.8238
Macro F1: 0.7203
Accuracy: 0.8183

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.96      0.88      0.92     10408
           1       0.76      0.88      0.82      1342
           2       0.83      0.84      0.83      1992
           3       0.41      0.49      0.45      1172
           4       0.72      0.80      0.76      2535
           5       0.78      0.73      0.76      1829
           6       0.47      0.57      0.52       820

    accuracy                           0.82     20098
   macro avg       0.70      0.74      0.72     20098
weighted avg       0.83      0.82      0.82     20098


CONFUSION MATRIX: 
[[9199  119   64  522  174  168  162]
 [  55 1178   26   69    2   11    1]
 [  67   67 1673    8   88   43   46]
 [  61  128    2  572  257  112   40]
 [  53   10  213   19 2017    3  220]
 [ 141   42   16  203   34 1339   54]
 [  56    0   31    4  225   36  468]]
------------------

Language: Tamil

Weighted F1: 0.8135
Macro F1: 0.7077
Accuracy: 0.8088

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.95      0.88      0.91     11461
           1       0.68      0.73      0.70      1643
           2       0.85      0.80      0.82      2295
           3       0.53      0.52      0.52      1512
           4       0.65      0.82      0.72      2532
           5       0.73      0.77      0.75      1881
           6       0.47      0.58      0.52       677

    accuracy                           0.81     22001
   macro avg       0.69      0.73      0.71     22001
weighted avg       0.82      0.81      0.81     22001


CONFUSION MATRIX: 
[[10076   315    46   377   298   220   129]
 [   66  1199   192   119    35    28     4]
 [   97   171  1834     4   161     0    28]
 [   52    72     6   780   350   242    10]
 [  152     8    68     7  2064    16   217]
 [  109     0     4   196    62  1449    61]
 [   48     0    

Language: Telugu

Weighted F1: 0.9262
Macro F1: 0.8581
Accuracy: 0.9247

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.98      0.95      0.96     17291
           1       0.85      0.87      0.86      1810
           2       0.91      0.90      0.90      1353
           3       0.74      0.81      0.78      1375
           4       0.83      0.92      0.87      2269
           5       0.85      0.91      0.88      2072
           6       0.70      0.82      0.75       455

    accuracy                           0.92     26625
   macro avg       0.84      0.88      0.86     26625
weighted avg       0.93      0.92      0.93     26625


CONFUSION MATRIX: 
[[16377   181    30   233   240   155    75]
 [   70  1567    28    80     3    62     0]
 [   25    23  1214     4    71     0    16]
 [   75    49     0  1119    21   109     2]
 [   56     0    56     5  2083    14    55]
 [   60    16     0    67    28  1890    11]
 [    8     0   

In [18]:
# Save the model
torch.save(model.state_dict(), f'NER_FineTune_{lang1[:2]}_{lang2[:2]}_cross_attn.pth')

-------------------------------------------------------------------------------------------------------------------------

# DUMP

--------------------------------------------------------------------------------------------------------------

In [19]:
# import torch
# import torch.nn as nn

# class FusionModule(nn.Module):
#     def __init__(self, hidden_size):
#         super(FusionModule, self).__init__()
#         # Linear layer to compute attention scores
#         self.attention_layer = nn.Linear(hidden_size * 2, 2)

#     def forward(self, embeddings_mbert, embeddings_indicbert):
#         """
#         Args:
#             embeddings_mbert (torch.Tensor): [batch_size, seq_len, hidden_size]
#             embeddings_indicbert (torch.Tensor): [batch_size, seq_len, hidden_size]
#         Returns:
#             torch.Tensor: Fused embeddings [batch_size, seq_len, hidden_size]
#         """
#         # Concatenate embeddings along the last dimension
#         concat_embeddings = torch.cat((embeddings_mbert, embeddings_indicbert), dim=-1)  # [batch_size, seq_len, hidden_size * 2]

#         # Compute attention scores
#         attn_scores = self.attention_layer(concat_embeddings)  # [batch_size, seq_len, 2]

#         # Apply softmax to get attention weights
#         attn_weights = torch.softmax(attn_scores, dim=-1)  # [batch_size, seq_len, 2]

#         # Split attention weights
#         attn_weights_mbert = attn_weights[..., 0].unsqueeze(-1)  # [batch_size, seq_len, 1]
#         attn_weights_indicbert = attn_weights[..., 1].unsqueeze(-1)  # [batch_size, seq_len, 1]

#         # Compute weighted sum of embeddings
#         fused_embeddings = attn_weights_mbert * embeddings_mbert + attn_weights_indicbert * embeddings_indicbert  # [batch_size, seq_len, hidden_size]

#         return fused_embeddings

In [20]:
# import torch
# import torch.nn as nn

# class SelfAttentionLayer(nn.Module):
#     def __init__(self, hidden_dim, num_heads, dropout=0.1):
#         super(SelfAttentionLayer, self).__init__()
#         self.self_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout)
#         self.layer_norm = nn.LayerNorm(hidden_dim)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, embeddings, attention_mask=None):
#         """
#         Args:
#             embeddings (torch.Tensor): [batch_size, seq_len, hidden_dim]
#             attention_mask (torch.Tensor): [batch_size, seq_len], with 1 for tokens to attend to and 0 for padding
#         Returns:
#             torch.Tensor: [batch_size, seq_len, hidden_dim]
#         """
#         embeddings = embeddings.to(DEVICE)
        
        
#         # Transpose embeddings to [seq_len, batch_size, hidden_dim] for nn.MultiheadAttention
#         embeddings = embeddings.transpose(0, 1)

#         # Prepare attention mask for nn.MultiheadAttention (optional)
#         if attention_mask is not None:
#     # nn.MultiheadAttention expects key_padding_mask of shape [batch_size, seq_len]
#             key_padding_mask = attention_mask == 0  # True for padding tokens
#         else:
#             key_padding_mask = None


#         # Self-attention
#         attn_output, _ = self.self_attention(
#             query=embeddings,
#             key=embeddings,
#             value=embeddings,
#             key_padding_mask=key_padding_mask
#         )

#         # Residual connection and layer normalization
#         attn_output = self.dropout(attn_output)
#         attn_output = self.layer_norm(attn_output + embeddings)

#         # Transpose back to [batch_size, seq_len, hidden_dim]
#         attn_output = attn_output.transpose(0, 1)

#         return attn_output