In [1]:
# Parameters
language = 'hindi'

In [2]:
# Parameters
language = "bengali"


In [3]:
print(f"Running with lang: {language}")

Running with lang: bengali


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import BertTokenizerFast, BertModel
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import numpy as np
import string
import json

from transformers import AutoTokenizer, AutoModel

2024-11-19 05:10:40.435381: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-19 05:10:40.435413: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-19 05:10:40.436752: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-19 05:10:40.442742: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.




In [None]:
# Load train data from json
with open(f'/home/ashu/Desktop/sem7/llm/CSE664-LLM-Project/Cross-Attention-Embedding-Fusion/Dataset/{language}_train.json') as f:
    train_data = json.load(f)

# Load test data from json
with open(f'/home/ashu/Desktop/sem7/llm/CSE664-LLM-Project/Cross-Attention-Embedding-Fusion/Dataset/{language}_test.json') as f:
    test_data = json.load(f)

In [6]:
print(f"Running with lang: {language}")

Running with lang: bengali


In [7]:
# Load tokeniser for MBERT
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')

In [8]:
class NER_Dataset(Dataset):
    def __init__(self, data, tokenizer, max_len=128):
        self.data = data  # List of sentences and labels
        self.tokenizer = tokenizer  # mBERT tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["tokens"]
        word_labels = self.data[idx]["ner_tags"]

        # Tokenize the text and align the labels
        encoding = self.tokenizer(text,
                                  is_split_into_words=True,
                                  return_offsets_mapping=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len)

        labels = [self.align_labels(encoding['offset_mapping'], word_labels)]

        # Remove the offset mapping to prevent issues during model training
        del encoding['offset_mapping']

        item = {key: torch.tensor(val) for key, val in encoding.items()}
        item['labels'] = torch.tensor(labels[0], dtype=torch.long)

        return item
    
    # Create a function to align labels
    def align_labels(self, offset_mapping, labels):
        aligned_labels = []
        current_label_index = 0

        for offset in offset_mapping:
            # If the offset mapping is (0, 0), it's a special token ([CLS], [SEP], [PAD])
            if offset == (0, 0):
                aligned_labels.append(-100)  # -100 is used to ignore these tokens in the loss computation
            else:
                # Check if the token is the start of a new word
                if offset[0] == 0:
                    aligned_labels.append(labels[current_label_index])
                    current_label_index += 1
                else:
                    # If the token is not the first subtoken, you can decide how to label it. 
                    # For simplicity, let's use the same label as the first subtoken
                    aligned_labels.append(labels[current_label_index - 1])

        return aligned_labels


In [9]:
# Create train dataset and test dataset
train_dataset = NER_Dataset(train_data, tokenizer)
test_dataset = NER_Dataset(test_data, tokenizer)

# Create train dataloader and test dataloader
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [10]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

In [11]:
import torch
import torch.nn as nn

class CrossAttentionFusion(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super(CrossAttentionFusion, self).__init__()
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, embeddings_mbert, embeddings_indicbert):
        """
        Fuse embeddings from mBERT and Indic-BERT using cross-attention.

        Args:
            embeddings_mbert (torch.Tensor): Embeddings from mBERT of shape [batch_size, seq_len_m, hidden_dim].
            embeddings_indicbert (torch.Tensor): Embeddings from Indic-BERT of shape [batch_size, seq_len_i, hidden_dim].

        Returns:
            torch.Tensor: Fused embeddings of shape [batch_size, seq_len_m, hidden_dim].
        """
        embeddings_mbert = embeddings_mbert.to(DEVICE)
        embeddings_indicbert = embeddings_indicbert.to(DEVICE)
        
        # Transpose to shape [seq_len, batch_size, hidden_dim] for nn.MultiheadAttention
        embeddings_mbert = embeddings_mbert.transpose(0, 1)
        embeddings_indicbert = embeddings_indicbert.transpose(0, 1)

        # Cross-attention: Indic BERT queries attend to m-BERT keys and values
        attn_output, _ = self.cross_attention(
            query=embeddings_indicbert,       # Queries from Indic-BERT
            key=embeddings_mbert,     # Keys from m-BERT
            value=embeddings_mbert    # Values from m-BERT
        )

        attn_output.to(DEVICE)

        # Residual connection and layer normalization
        output = self.layer_norm(attn_output + embeddings_mbert)

        # Transpose back to [batch_size, seq_len_m, hidden_dim]
        output = output.transpose(0, 1)

        return output

In [12]:
hidden_dim = 768
num_heads = 4

In [13]:
# Model for NER

class MBERT_NER(nn.Module):
    def __init__(self, num_labels, gru_hidden_size, num_gru_layers, freeze_bert=False):
        super(MBERT_NER, self).__init__()

        self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
        self.indic_bert = AutoModel.from_pretrained("ai4bharat/indic-bert")
        
        # Initialize the Fusion Attention
        self.attention = CrossAttentionFusion(hidden_dim=hidden_dim, num_heads=num_heads).to(DEVICE)
        
        self.gru = nn.GRU(input_size=self.bert.config.hidden_size * 3,
                          hidden_size=gru_hidden_size,
                          num_layers=num_gru_layers,
                          batch_first=True)
        self.classifier = nn.Linear(gru_hidden_size, num_labels)
        self.dropout = nn.Dropout(0.1)
        self.batch_norm = nn.BatchNorm1d(gru_hidden_size)

        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
            for param in self.indic_bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)
        
        # Embeddings
        outputs1 = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        outputs2 = self.indic_bert(input_ids = input_ids, attention_mask = attention_mask)        
        
        embeddings_mbert = outputs1.last_hidden_state.to(DEVICE)
        embeddings_indicbert = outputs2.last_hidden_state.to(DEVICE)
        
        # Fused embeddings
        embeddings = self.attention(embeddings_mbert, embeddings_indicbert).to(DEVICE)
        
        sequence_output = torch.cat((embeddings_mbert, embeddings_indicbert, embeddings), dim=-1)
        
        gru_output, _ = self.gru(sequence_output)
        gru_output = self.batch_norm(gru_output)
        gru_output = self.dropout(gru_output)
        
        logits = self.classifier(gru_output)
        return logits

# Create the NER model
NUM_LABELS = 7 # Number of NER tags
GRU_HIDDEN_SIZE = 128 # Hidden size of the GRU
NUM_GRU_LAYERS = 1 # Number of layers in the GRU
FREEZE_BERT = False # Whether to freeze the BERT model
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MBERT_NER(num_labels=NUM_LABELS,
                    gru_hidden_size=GRU_HIDDEN_SIZE,
                    num_gru_layers=NUM_GRU_LAYERS,
                    freeze_bert=FREEZE_BERT)

model.to(DEVICE)

MBERT_NER(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [14]:
torch.cuda.empty_cache()

In [15]:
from sklearn.metrics import f1_score, accuracy_score

# Optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

# Training loop
EPOCHS = 6

train_losses = []
train_f1_scores = []
train_acc_scores = []
val_losses = []
val_f1_scores = []
val_acc_scores = []


for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    train_predictions, train_labels = [], []

    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        logits = logits.view(-1, NUM_LABELS)
        labels = labels.view(-1)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Get predictions and filter out ignored indices for metric calculations
        predictions = torch.argmax(logits, dim=-1)
        active_indices = labels != -100
        train_predictions.extend(predictions[active_indices].cpu().numpy())
        train_labels.extend(labels[active_indices].cpu().numpy())

    train_loss /= len(train_dataloader)
    train_losses.append(train_loss)

    train_f1 = f1_score(train_labels, train_predictions, average='macro')
    train_f1_scores.append(train_f1)

    train_acc = accuracy_score(train_labels, train_predictions)
    train_acc_scores.append(train_acc)

    # Validation loop
    model.eval()
    val_loss = 0
    val_predictions, val_labels = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            logits = model(input_ids, attention_mask)
            logits = logits.view(-1, NUM_LABELS)
            labels = labels.view(-1)

            loss = criterion(logits, labels)
            val_loss += loss.item()

            # Filter predictions and labels
            predictions = torch.argmax(logits, dim=-1)
            active_indices = labels != -100
            val_predictions.extend(predictions[active_indices].cpu().numpy())
            val_labels.extend(labels[active_indices].cpu().numpy())

    val_loss /= len(test_dataloader)
    val_losses.append(val_loss)

    val_f1 = f1_score(val_labels, val_predictions, average='macro')
    val_f1_scores.append(val_f1)

    val_acc = accuracy_score(val_labels, val_predictions)
    val_acc_scores.append(val_acc)

    print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}, Val Acc: {val_acc:.4f}")

Epoch 1/6, Train Loss: 0.8006, Train F1: 0.7277, Train Acc: 0.7375, Val Loss: 0.3744, Val F1: 0.8816, Val Acc: 0.8820


Epoch 2/6, Train Loss: 0.2955, Train F1: 0.9109, Train Acc: 0.9136, Val Loss: 0.3230, Val F1: 0.9082, Val Acc: 0.9080


Epoch 3/6, Train Loss: 0.1803, Train F1: 0.9492, Train Acc: 0.9519, Val Loss: 0.2280, Val F1: 0.9419, Val Acc: 0.9431


Epoch 4/6, Train Loss: 0.1451, Train F1: 0.9604, Train Acc: 0.9620, Val Loss: 0.2760, Val F1: 0.9317, Val Acc: 0.9309


Epoch 5/6, Train Loss: 0.1086, Train F1: 0.9718, Train Acc: 0.9724, Val Loss: 0.2298, Val F1: 0.9443, Val Acc: 0.9457


Epoch 6/6, Train Loss: 0.0809, Train F1: 0.9788, Train Acc: 0.9795, Val Loss: 0.2459, Val F1: 0.9468, Val Acc: 0.9467


In [16]:
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

# Load the test data from json for all 5 languages
languages = ['hindi', 'bengali', 'marathi', 'tamil', 'telugu']

# Iterate over all languages and evaluate the model
for lang in languages:
    with open(f'/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang}_test.json') as f:
        test_data = json.load(f)

    test_dataset = NER_Dataset(test_data, tokenizer)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model.eval()
    test_predictions, test_labels = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            logits = model(input_ids, attention_mask)
            logits = logits.view(-1, NUM_LABELS)
            labels = labels.view(-1)

            # Filter predictions and labels
            predictions = torch.argmax(logits, dim=-1)
            active_indices = labels != -100
            test_predictions.extend(predictions[active_indices].cpu().numpy())
            test_labels.extend(labels[active_indices].cpu().numpy())

    weighted_f1 = f1_score(test_labels, test_predictions, average='weighted')
    macro_f1 = f1_score(test_labels, test_predictions, average='macro')
    accuracy = accuracy_score(test_labels, test_predictions)

    if lang == 'hindi':
        LANG = "Hindi"
    elif lang == 'bengali':
        LANG = "Bengali"
    elif lang == 'marathi':
        LANG = "Marathi"
    elif lang == 'tamil':
        LANG = "Tamil"
    elif lang == 'telugu':
        LANG = "Telugu"

    print(f"Language: {LANG}")
    print()
    print(f"Weighted F1: {weighted_f1:.4f}")
    print(f"Macro F1: {macro_f1:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print()
    print("CLASSIFICATION REPORT: ")
    print(classification_report(test_labels, test_predictions))
    print()
    print("CONFUSION MATRIX: ")
    print(confusion_matrix(test_labels, test_predictions))
    print("--------------------------------------------------")

Language: Hindi

Weighted F1: 0.7586
Macro F1: 0.7019
Accuracy: 0.7569

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.92      0.76      0.83      5277
           1       0.65      0.81      0.73      1481
           2       0.70      0.96      0.81      1874
           3       0.50      0.50      0.50       986
           4       0.82      0.77      0.80      2602
           5       0.65      0.69      0.67      1195
           6       0.60      0.57      0.59       867

    accuracy                           0.76     14282
   macro avg       0.69      0.72      0.70     14282
weighted avg       0.77      0.76      0.76     14282


CONFUSION MATRIX: 
[[3988  440   95  403   74  190   87]
 [  38 1204  209   17    7    5    1]
 [  16   16 1799    4   35    0    4]
 [  75  109   13  492  136  160    1]
 [ 111    9  291   14 2006   27  144]
 [  81   54   34   58   55  824   89]
 [  42    7  124    2  127   68  497]]
--------------------

Language: Bengali

Weighted F1: 0.9468
Macro F1: 0.9468
Accuracy: 0.9467

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.97      0.92      0.94      1835
           1       0.90      0.97      0.94      1343
           2       0.91      0.98      0.94      2490
           3       0.97      0.94      0.95      1239
           4       0.95      0.94      0.95      3068
           5       0.97      0.93      0.95      1510
           6       0.98      0.94      0.96      2080

    accuracy                           0.95     13565
   macro avg       0.95      0.95      0.95     13565
weighted avg       0.95      0.95      0.95     13565


CONFUSION MATRIX: 
[[1690   32   30    5   55    8   15]
 [  14 1305   12    8    0    4    0]
 [  12   11 2441    0   21    0    5]
 [   8   44    2 1159    6   20    0]
 [   6    3  153    0 2884    1   21]
 [  18   48    0   28    5 1408    3]
 [   3    0   49    0   63   10 1955]]
------------------

Language: Marathi

Weighted F1: 0.7510
Macro F1: 0.6173
Accuracy: 0.7467

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.94      0.83      0.88     10408
           1       0.67      0.85      0.75      1342
           2       0.59      0.90      0.71      1992
           3       0.16      0.15      0.15      1172
           4       0.68      0.56      0.62      2535
           5       0.72      0.74      0.73      1829
           6       0.41      0.58      0.48       820

    accuracy                           0.75     20098
   macro avg       0.60      0.66      0.62     20098
weighted avg       0.77      0.75      0.75     20098


CONFUSION MATRIX: 
[[8670  192  127  827  163  210  219]
 [  38 1134  156    1    5    8    0]
 [  67   52 1784    0   17   59   13]
 [  80  242  119  172  339  163   57]
 [  52   16  699    8 1423    9  328]
 [ 204   63   46   60   44 1346   66]
 [  77    0   81    0  100   84  478]]
------------------

Language: Tamil

Weighted F1: 0.7694
Macro F1: 0.6465
Accuracy: 0.7684

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.94      0.83      0.89     11461
           1       0.59      0.72      0.65      1643
           2       0.65      0.94      0.77      2295
           3       0.57      0.30      0.39      1512
           4       0.70      0.73      0.71      2532
           5       0.61      0.72      0.66      1881
           6       0.40      0.54      0.46       677

    accuracy                           0.77     22001
   macro avg       0.64      0.68      0.65     22001
weighted avg       0.79      0.77      0.77     22001


CONFUSION MATRIX: 
[[9556  466  421  237  290  311  180]
 [  80 1175  329   14   13   32    0]
 [  69   27 2160    0   34    0    5]
 [ 100  120  103  455  262  448   24]
 [ 157   18  223    2 1846   27  259]
 [ 112  193   38   84   30 1350   74]
 [  41    0   59    0  163   51  363]]
--------------------

Language: Telugu

Weighted F1: 0.7971
Macro F1: 0.6131
Accuracy: 0.7897

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.96      0.85      0.90     17291
           1       0.51      0.70      0.59      1810
           2       0.51      0.97      0.67      1353
           3       0.43      0.24      0.31      1375
           4       0.71      0.69      0.70      2269
           5       0.68      0.80      0.73      2072
           6       0.29      0.56      0.38       455

    accuracy                           0.79     26625
   macro avg       0.58      0.69      0.61     26625
weighted avg       0.82      0.79      0.80     26625


CONFUSION MATRIX: 
[[14623   803   526   287   317   432   303]
 [  158  1271   314    13     8    46     0]
 [   23     2  1313     0     8     0     7]
 [  211   366    71   334   180   161    52]
 [   78     3   286    96  1574    70   162]
 [  172    53    11    48    45  1659    84]
 [   12     0   

In [17]:
# Save the model
torch.save(model.state_dict(), f'NER_FineTune_{language[:2]}_cross_attn.pth')