In [1]:
# Parameters
language = 'hindi'

In [2]:
# Parameters
language = "tamil"


In [3]:
print(f"Running with lang: {language}")

Running with lang: tamil


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import BertTokenizerFast, BertModel
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import numpy as np
import string
import json

from transformers import AutoTokenizer, AutoModel

2024-11-19 05:04:36.898904: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-19 05:04:36.898937: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-19 05:04:36.900264: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-19 05:04:36.906743: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.




In [5]:
# Load train data from json
with open(f'/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{language}_train.json') as f:
    train_data = json.load(f)

# Load test data from json
with open(f'/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{language}_test.json') as f:
    test_data = json.load(f)

In [6]:
print(f"Running with lang: {language}")

Running with lang: tamil


In [7]:
# Load tokeniser for MBERT
tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')

In [8]:
class NER_Dataset(Dataset):
    def __init__(self, data, tokenizer, max_len=128):
        self.data = data  # List of sentences and labels
        self.tokenizer = tokenizer  # mBERT tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["tokens"]
        word_labels = self.data[idx]["ner_tags"]

        # Tokenize the text and align the labels
        encoding = self.tokenizer(text,
                                  is_split_into_words=True,
                                  return_offsets_mapping=True,
                                  padding='max_length',
                                  truncation=True,
                                  max_length=self.max_len)

        labels = [self.align_labels(encoding['offset_mapping'], word_labels)]

        # Remove the offset mapping to prevent issues during model training
        del encoding['offset_mapping']

        item = {key: torch.tensor(val) for key, val in encoding.items()}
        item['labels'] = torch.tensor(labels[0], dtype=torch.long)

        return item
    
    # Create a function to align labels
    def align_labels(self, offset_mapping, labels):
        aligned_labels = []
        current_label_index = 0

        for offset in offset_mapping:
            # If the offset mapping is (0, 0), it's a special token ([CLS], [SEP], [PAD])
            if offset == (0, 0):
                aligned_labels.append(-100)  # -100 is used to ignore these tokens in the loss computation
            else:
                # Check if the token is the start of a new word
                if offset[0] == 0:
                    aligned_labels.append(labels[current_label_index])
                    current_label_index += 1
                else:
                    # If the token is not the first subtoken, you can decide how to label it. 
                    # For simplicity, let's use the same label as the first subtoken
                    aligned_labels.append(labels[current_label_index - 1])

        return aligned_labels


In [9]:
# Create train dataset and test dataset
train_dataset = NER_Dataset(train_data, tokenizer)
test_dataset = NER_Dataset(test_data, tokenizer)

# Create train dataloader and test dataloader
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [10]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

'cuda'

In [11]:
import torch
import torch.nn as nn

class CrossAttentionFusion(nn.Module):
    def __init__(self, hidden_dim, num_heads):
        super(CrossAttentionFusion, self).__init__()
        self.cross_attention = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, embeddings_mbert, embeddings_indicbert):
        """
        Fuse embeddings from mBERT and Indic-BERT using cross-attention.

        Args:
            embeddings_mbert (torch.Tensor): Embeddings from mBERT of shape [batch_size, seq_len_m, hidden_dim].
            embeddings_indicbert (torch.Tensor): Embeddings from Indic-BERT of shape [batch_size, seq_len_i, hidden_dim].

        Returns:
            torch.Tensor: Fused embeddings of shape [batch_size, seq_len_m, hidden_dim].
        """
        embeddings_mbert = embeddings_mbert.to(DEVICE)
        embeddings_indicbert = embeddings_indicbert.to(DEVICE)
        
        # Transpose to shape [seq_len, batch_size, hidden_dim] for nn.MultiheadAttention
        embeddings_mbert = embeddings_mbert.transpose(0, 1)
        embeddings_indicbert = embeddings_indicbert.transpose(0, 1)

        # Cross-attention: Indic BERT queries attend to m-BERT keys and values
        attn_output, _ = self.cross_attention(
            query=embeddings_indicbert,       # Queries from Indic-BERT
            key=embeddings_mbert,     # Keys from m-BERT
            value=embeddings_mbert    # Values from m-BERT
        )

        attn_output.to(DEVICE)

        # Residual connection and layer normalization
        output = self.layer_norm(attn_output + embeddings_mbert)

        # Transpose back to [batch_size, seq_len_m, hidden_dim]
        output = output.transpose(0, 1)

        return output

In [12]:
hidden_dim = 768
num_heads = 4

In [13]:
# Model for NER

class MBERT_NER(nn.Module):
    def __init__(self, num_labels, gru_hidden_size, num_gru_layers, freeze_bert=False):
        super(MBERT_NER, self).__init__()

        self.bert = BertModel.from_pretrained("bert-base-multilingual-cased")
        self.indic_bert = AutoModel.from_pretrained("ai4bharat/indic-bert")
        
        # Initialize the Fusion Attention
        self.attention = CrossAttentionFusion(hidden_dim=hidden_dim, num_heads=num_heads).to(DEVICE)
        
        self.gru = nn.GRU(input_size=self.bert.config.hidden_size * 3,
                          hidden_size=gru_hidden_size,
                          num_layers=num_gru_layers,
                          batch_first=True)
        self.classifier = nn.Linear(gru_hidden_size, num_labels)
        self.dropout = nn.Dropout(0.1)
        self.batch_norm = nn.BatchNorm1d(gru_hidden_size)

        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
            for param in self.indic_bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)
        
        # Embeddings
        outputs1 = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        outputs2 = self.indic_bert(input_ids = input_ids, attention_mask = attention_mask)        
        
        embeddings_mbert = outputs1.last_hidden_state.to(DEVICE)
        embeddings_indicbert = outputs2.last_hidden_state.to(DEVICE)
        
        # Fused embeddings
        embeddings = self.attention(embeddings_mbert, embeddings_indicbert).to(DEVICE)
        
        sequence_output = torch.cat((embeddings_mbert, embeddings_indicbert, embeddings), dim=-1)
        
        gru_output, _ = self.gru(sequence_output)
        gru_output = self.batch_norm(gru_output)
        gru_output = self.dropout(gru_output)
        
        logits = self.classifier(gru_output)
        return logits

# Create the NER model
NUM_LABELS = 7 # Number of NER tags
GRU_HIDDEN_SIZE = 128 # Hidden size of the GRU
NUM_GRU_LAYERS = 1 # Number of layers in the GRU
FREEZE_BERT = False # Whether to freeze the BERT model
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MBERT_NER(num_labels=NUM_LABELS,
                    gru_hidden_size=GRU_HIDDEN_SIZE,
                    num_gru_layers=NUM_GRU_LAYERS,
                    freeze_bert=FREEZE_BERT)

model.to(DEVICE)

MBERT_NER(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [14]:
torch.cuda.empty_cache()

In [15]:
from sklearn.metrics import f1_score, accuracy_score

# Optimizer and loss function
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

# Training loop
EPOCHS = 6

train_losses = []
train_f1_scores = []
train_acc_scores = []
val_losses = []
val_f1_scores = []
val_acc_scores = []


for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    train_predictions, train_labels = [], []

    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(DEVICE)
        attention_mask = batch['attention_mask'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)

        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        logits = logits.view(-1, NUM_LABELS)
        labels = labels.view(-1)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Get predictions and filter out ignored indices for metric calculations
        predictions = torch.argmax(logits, dim=-1)
        active_indices = labels != -100
        train_predictions.extend(predictions[active_indices].cpu().numpy())
        train_labels.extend(labels[active_indices].cpu().numpy())

    train_loss /= len(train_dataloader)
    train_losses.append(train_loss)

    train_f1 = f1_score(train_labels, train_predictions, average='macro')
    train_f1_scores.append(train_f1)

    train_acc = accuracy_score(train_labels, train_predictions)
    train_acc_scores.append(train_acc)

    # Validation loop
    model.eval()
    val_loss = 0
    val_predictions, val_labels = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            logits = model(input_ids, attention_mask)
            logits = logits.view(-1, NUM_LABELS)
            labels = labels.view(-1)

            loss = criterion(logits, labels)
            val_loss += loss.item()

            # Filter predictions and labels
            predictions = torch.argmax(logits, dim=-1)
            active_indices = labels != -100
            val_predictions.extend(predictions[active_indices].cpu().numpy())
            val_labels.extend(labels[active_indices].cpu().numpy())

    val_loss /= len(test_dataloader)
    val_losses.append(val_loss)

    val_f1 = f1_score(val_labels, val_predictions, average='macro')
    val_f1_scores.append(val_f1)

    val_acc = accuracy_score(val_labels, val_predictions)
    val_acc_scores.append(val_acc)

    print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}, Val Acc: {val_acc:.4f}")

Epoch 1/6, Train Loss: 0.7022, Train F1: 0.6223, Train Acc: 0.7730, Val Loss: 0.4232, Val F1: 0.7815, Val Acc: 0.8712


Epoch 2/6, Train Loss: 0.3464, Train F1: 0.8245, Train Acc: 0.8959, Val Loss: 0.3429, Val F1: 0.8421, Val Acc: 0.9022


Epoch 3/6, Train Loss: 0.2486, Train F1: 0.8782, Train Acc: 0.9276, Val Loss: 0.3463, Val F1: 0.8518, Val Acc: 0.9031


Epoch 4/6, Train Loss: 0.1771, Train F1: 0.9131, Train Acc: 0.9487, Val Loss: 0.3410, Val F1: 0.8502, Val Acc: 0.9025


Epoch 5/6, Train Loss: 0.1445, Train F1: 0.9285, Train Acc: 0.9579, Val Loss: 0.3863, Val F1: 0.8529, Val Acc: 0.9011


Epoch 6/6, Train Loss: 0.1193, Train F1: 0.9428, Train Acc: 0.9664, Val Loss: 0.3614, Val F1: 0.8370, Val Acc: 0.9013


In [16]:
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

# Load the test data from json for all 5 languages
languages = ['hindi', 'bengali', 'marathi', 'tamil', 'telugu']

# Iterate over all languages and evaluate the model
for lang in languages:
    with open(f'/home/vijay/slim_sense/SHARPax/Python_code/test/NER-Experiments/Dataset/{lang}_test.json') as f:
        test_data = json.load(f)

    test_dataset = NER_Dataset(test_data, tokenizer)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model.eval()
    test_predictions, test_labels = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)
            labels = batch['labels'].to(DEVICE)

            logits = model(input_ids, attention_mask)
            logits = logits.view(-1, NUM_LABELS)
            labels = labels.view(-1)

            # Filter predictions and labels
            predictions = torch.argmax(logits, dim=-1)
            active_indices = labels != -100
            test_predictions.extend(predictions[active_indices].cpu().numpy())
            test_labels.extend(labels[active_indices].cpu().numpy())

    weighted_f1 = f1_score(test_labels, test_predictions, average='weighted')
    macro_f1 = f1_score(test_labels, test_predictions, average='macro')
    accuracy = accuracy_score(test_labels, test_predictions)

    if lang == 'hindi':
        LANG = "Hindi"
    elif lang == 'bengali':
        LANG = "Bengali"
    elif lang == 'marathi':
        LANG = "Marathi"
    elif lang == 'tamil':
        LANG = "Tamil"
    elif lang == 'telugu':
        LANG = "Telugu"

    print(f"Language: {LANG}")
    print()
    print(f"Weighted F1: {weighted_f1:.4f}")
    print(f"Macro F1: {macro_f1:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print()
    print("CLASSIFICATION REPORT: ")
    print(classification_report(test_labels, test_predictions))
    print()
    print("CONFUSION MATRIX: ")
    print(confusion_matrix(test_labels, test_predictions))
    print("--------------------------------------------------")

Language: Hindi

Weighted F1: 0.7928
Macro F1: 0.7375
Accuracy: 0.7898

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.91      0.82      0.87      5277
           1       0.75      0.81      0.78      1481
           2       0.83      0.90      0.86      1874
           3       0.53      0.59      0.56       986
           4       0.83      0.77      0.80      2602
           5       0.64      0.76      0.70      1195
           6       0.58      0.62      0.60       867

    accuracy                           0.79     14282
   macro avg       0.72      0.75      0.74     14282
weighted avg       0.80      0.79      0.79     14282


CONFUSION MATRIX: 
[[4347  260   41  349   76  158   46]
 [  38 1206  138   32   11   54    2]
 [  45   22 1694    5   51    7   50]
 [  39   79    8  583  119  151    7]
 [ 195    7  118   34 2009   36  203]
 [  38   31    7   87   41  906   85]
 [  69    3   37    2  124   97  535]]
--------------------

Language: Bengali

Weighted F1: 0.7775
Macro F1: 0.7834
Accuracy: 0.7810

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.75      0.94      0.83      1835
           1       0.79      0.86      0.82      1343
           2       0.78      0.84      0.81      2490
           3       0.84      0.77      0.80      1239
           4       0.80      0.72      0.76      3068
           5       0.78      0.82      0.80      1510
           6       0.76      0.59      0.66      2080

    accuracy                           0.78     13565
   macro avg       0.78      0.79      0.78     13565
weighted avg       0.78      0.78      0.78     13565


CONFUSION MATRIX: 
[[1726   13    5   12   46   14   19]
 [  45 1154   24   64    0   53    3]
 [ 210   22 2086    0   88    6   78]
 [  16  130    4  955    4  127    3]
 [ 134   20  373   22 2217   20  282]
 [  48  120    3   86    8 1234   11]
 [ 128    8  187    2  412  121 1222]]
------------------

Language: Marathi

Weighted F1: 0.8321
Macro F1: 0.7334
Accuracy: 0.8354

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.92      0.94      0.93     10408
           1       0.73      0.90      0.80      1342
           2       0.78      0.86      0.82      1992
           3       0.64      0.51      0.57      1172
           4       0.81      0.65      0.72      2535
           5       0.75      0.75      0.75      1829
           6       0.54      0.54      0.54       820

    accuracy                           0.84     20098
   macro avg       0.74      0.74      0.73     20098
weighted avg       0.83      0.84      0.83     20098


CONFUSION MATRIX: 
[[9805   96   61   98   97  135  116]
 [  51 1211   21   31    0   28    0]
 [  70   34 1722    0   40   55   71]
 [  67  268   10  596   62  162    7]
 [ 312   16  348   58 1636   16  149]
 [ 204   44    5  150   15 1379   32]
 [ 111    0   49    1  166   52  441]]
------------------

Language: Tamil

Weighted F1: 0.9028
Macro F1: 0.8370
Accuracy: 0.9013

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.97      0.95      0.96     11461
           1       0.82      0.89      0.86      1643
           2       0.90      0.94      0.92      2295
           3       0.82      0.74      0.78      1512
           4       0.87      0.84      0.85      2532
           5       0.81      0.85      0.83      1881
           6       0.58      0.77      0.66       677

    accuracy                           0.90     22001
   macro avg       0.82      0.85      0.84     22001
weighted avg       0.91      0.90      0.90     22001


CONFUSION MATRIX: 
[[10834   137    92    78   135   141    44]
 [   37  1468    59    36     4    39     0]
 [   36    22  2148     0    36     0    53]
 [   47   110     3  1125    36   187     4]
 [   38    10    79    10  2129    11   255]
 [   80    38     0   122    16  1602    23]
 [   44     2    

Language: Telugu

Weighted F1: 0.8402
Macro F1: 0.6906
Accuracy: 0.8368

CLASSIFICATION REPORT: 
              precision    recall  f1-score   support

           0       0.96      0.90      0.93     17291
           1       0.56      0.73      0.63      1810
           2       0.65      0.92      0.76      1353
           3       0.50      0.36      0.42      1375
           4       0.78      0.71      0.75      2269
           5       0.70      0.85      0.77      2072
           6       0.48      0.74      0.58       455

    accuracy                           0.84     26625
   macro avg       0.66      0.74      0.69     26625
weighted avg       0.85      0.84      0.84     26625


CONFUSION MATRIX: 
[[15523   601   183   221   313   318   132]
 [  123  1313   234    51     6    82     1]
 [   32    25  1239     3    12     0    42]
 [  161   351    48   495    74   226    20]
 [  142     0   191   123  1617    76   120]
 [   87    59     3   104    12  1757    50]
 [   31     0   

In [17]:
# Save the model
torch.save(model.state_dict(), f'NER_FineTune_{language[:2]}_cross_attn.pth')