In [None]:
!pip install transformers



In [None]:
# imports 
import re
import numpy as np
import torch
import pandas as pd

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [None]:
import os
import pandas as pd

directory = '/content/hindi_coref_data'

datasets = []
for file in os.listdir(directory):
    if file.endswith(".csv"):
      df = pd.read_csv(os.path.join(directory, file))
      datasets.append(df)

print(len(datasets))

275


In [None]:
data = datasets[10]
data.columns = ["word", "cref", "crefHead", "acrefmod", "acrefmodHead", "crefmod", "creftype", "Chainhead"]
print(data)

       word     cref   crefHead acrefmod acrefmodHead crefmod creftype  \
0        की        _          _        _            _       _        _   
1    भूमिका        _          _        _            _       _        _   
2    अपनाते        _          _        _            _       _        _   
3       हुए        _          _        _            _       _        _   
4    झारखंड  i2%1:t2  झारखंड:i2        _            _       _        _   
..      ...      ...        ...      ...          ...     ...      ...   
325     हाथ        _          _        _            _       _        _   
326    धोना        _          _        _            _       _        _   
327    पड़ता        _          _        _            _       _        _   
328      है        _          _        _            _       _        _   
329       ।        _          _        _            _       _        _   

    Chainhead  
0           _  
1           _  
2           _  
3           _  
4           _  
..        ...  

Our coreference annotation scheme includes all-
together 7 fields. They are cref : This field represents
the unique index for a mention, the unique index for a
chain to which a mention belongs and the textual span
of a mention [template- MentionId%(0/1):chainId],
crefHead : This field represents the linguistic head of
a mention, acrefmod : This field specifies the unique
index for a modifier and its textual span, crefmod
: This field is used to link a mention and its modi-
fier with the unique modifier index, crefmodHead
: This field represents the linguistic head of a modi-
fier, crefType : This field specifies the type relation
between mentions of the same chain, and crefChain-
Head : This field is used to mark the head mention
of the chain.

https://arxiv.org/abs/2103.10730 muril dataset

https://aclanthology.org/L16-1025/ hindi dataset for anaphora

# hindi 



In [None]:
from transformers import AutoModel, AutoTokenizer

# load the murilbert model
path = 'google/muril-base-cased'

tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModel.from_pretrained(path,
                                  output_hidden_states=True # Whether the model returns all hidden-states.
                                  ) 
model.to(device)

Downloading:   0%|          | 0.00/181 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/411 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.02M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/113 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/909M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/muril-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
def prepare_data(data, start):
  # data is taken sentence wise, and stored in the texts array
  texts = []
  text = ""
  tags = []
  input_tokens = []

  for i in range(start, len(data)):
    if str(data["word"][i]) == "।" :
      text += str(data["word"][i])
      texts.append(text)
      input_tokens += (tokenizer.convert_ids_to_tokens(tokenizer.encode(text)))
      text = ""
    else:
      text += str(data["word"][i]) + " "
      tags.append(data["cref"][i])

  return texts, tags, input_tokens

def get_word_vectors(texts):
  " function to get hindi word vectors from murilbert "

  outputs = []

  for text in texts:
    # encoded input with input ids, token type ids and attention mask
    input_encoded = tokenizer.encode_plus(text, return_tensors="pt")
    input_encoded.to(device)

    # obtain and take the sum of all 13 states of BERT output
    with torch.no_grad():
            states = model(**input_encoded).hidden_states

    output = torch.stack([states[i] for i in range(len(states))]).sum(dim = 0)
    output = output.squeeze()
    outputs.append(output)

  return torch.cat(outputs, dim = 0)

texts, tags, input_tokens = prepare_data(data, 0)

output = get_word_vectors(texts)
output.to(device)
output.shape

torch.Size([372, 768])

In [None]:
def getvec(mention):
  " function to get the vector of a mention, it takes the average of all word vectors in the mention "
  vec = torch.zeros(768)
  tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(mention))
  count = 1
  for i in range(1, len(tokens) - 1):
    try: 
      idx = input_tokens.index(tokens[i])
    except:
      idx = 1
    vec = vec.to(device)
    vec = torch.add(vec, output[idx].to(device))
    count += 1
  
  return torch.div(vec, count).to(device)

In [None]:
# do this file wise
def map_mentions(data, start):
  " function to make a list of mentions, along with a list of their corresponding cluster ids"
  mentions = []
  mention_ids = [1]
  mention = ""
  count = 1
  for i in range(start, len(data)):
    tag = data["cref"][i]
    if (tag[0] == "i"):
      idx = (re.search('i(\d*)%', tag)).group(1)
      mention_idx = (re.search('t(.*)', tag)).group(1) 
      if int(idx) == count:
        mention += str(data["word"][i]) + " "
      else:
        mentions.append(mention)
        mention_ids.append(mention_idx)
        mention = str(data["word"][i]) + " "
        count += 1

  return mentions, mention_ids

def make_mention_pairs(mentions, mention_ids):
  " function to make a list of mention pairs and their true/false values "
  x_train = torch.empty(0).to(device)
  y_train = torch.empty(0).to(device)

  for i in range(len(mentions)):
    for j in range(i + 1, len(mentions)):
      x_train = torch.cat((x_train, (getvec(mentions[i]) + getvec(mentions[j]))) , 0)
      if mention_ids[i] == mention_ids[j]:
        y_train = torch.cat((y_train, torch.tensor([1]).to(device)))
      else:
        y_train = torch.cat((y_train, torch.tensor([0]).to(device)))

  x_train = x_train.reshape(-1, 768)

  return x_train, y_train

mentions, mention_ids = map_mentions(data, 0)

x, y = make_mention_pairs(mentions, mention_ids)
print(x.shape, y.shape)

# x_train = x
# y_train = y

torch.Size([3003, 768]) torch.Size([3003])


In [None]:
x_train = torch.empty(0).to(device)
y_train = torch.empty(0).to(device)

for data in datasets[:224]:
  try:
    data.columns = ["word", "cref", "crefHead", "acrefmod", "acrefmodHead", "crefmod", "creftype", "Chainhead"]

    texts, tags, input_tokens = prepare_data(data, 0)
    output = get_word_vectors(texts)
    mentions, mention_ids = map_mentions(data, 0)

    x, y = make_mention_pairs(mentions, mention_ids)
    print(x.shape, y.shape)

    x_train = torch.cat((x_train, x.to(device)))
    y_train = torch.cat((y_train, y.to(device)))
  except:
    continue

x_train = x_train.reshape(-1, 768)

In [None]:
print(x_train.shape, y_train.shape)

torch.Size([493881, 768]) torch.Size([493881])


In [None]:
# # downsampling
# neg_count = 0
# pos_count = 0
# for i in range(len(y_train)):
#   # print(int(y_train[i].item()))
#   if(int(y_train[i].item()) == 0):
#     neg_count += 1
#     if(neg_count %20):
#       x_train = np.delete(x_train, i)
#       y_train = np.delete(y_train, i)

# print(len(y_train))
# print(neg_count)

In [None]:
#defining dataset class
from torch.utils.data import Dataset, DataLoader
class dataset(Dataset):
  def __init__(self,x,y):
    self.x = torch.tensor(x,dtype=torch.float32)
    self.y = torch.tensor(y,dtype=torch.float32)
    self.length = self.x.shape[0]
 
  def __getitem__(self,idx):
    return self.x[idx],self.y[idx]  

  def __len__(self):
    return self.length
  
trainset = dataset(x_train,y_train)
#DataLoader
trainloader = DataLoader(trainset,batch_size=64,shuffle=False)

  """
  


In [None]:
#defining the network
from torch import nn
from torch.nn import functional as F

class Net(nn.Module):
  def __init__(self,input_shape):
    super(Net,self).__init__()
    self.fc1 = nn.Linear(input_shape,32)
    self.fc2 = nn.Linear(32,64)
    self.fc3 = nn.Linear(64,1)  
  
  def forward(self,x): 
    # print(x.shape)
    x = torch.relu(self.fc1(x))
    x = torch.relu(self.fc2(x))
    x = torch.sigmoid(self.fc3(x))
    return x

In [None]:
#hyper parameters
learning_rate = 0.01
epochs = 200

# Model , Optimizer, Loss
be_model = Net(input_shape=768)
be_model = be_model.to(device)
optimizer = torch.optim.SGD(be_model.parameters(),lr=learning_rate)
loss_fn = nn.BCELoss()

In [None]:
# training the classifier

#forward 

for i in range(epochs):
  running_loss = 0
  last_loss = 0
  for j,(x_train,y_train) in enumerate(trainloader):
    
    #calculate output
    output = be_model(x_train)
 
    #calculate loss
    loss = loss_fn(output,y_train.reshape(-1,1))

    #backprop
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    # if j%50 == 0:
  avg_loss = running_loss/ len(trainloader)
  # print(len(trainloader))
    # losses.append(loss)
  #accur.append(acc)
  print("epoch {}\tloss : {}".format(i,avg_loss))

epoch 0	loss : 0.20751315575091583
epoch 1	loss : 0.17927496880016072
epoch 2	loss : 0.16547574068988088
epoch 3	loss : 0.15663476091836287
epoch 4	loss : 0.15017050436853704
epoch 5	loss : 0.1451404525407392
epoch 6	loss : 0.14116747736457855
epoch 7	loss : 0.1376792424584613
epoch 8	loss : 0.13461878363061994
epoch 9	loss : 0.13188903391547557
epoch 10	loss : 0.129440069366375
epoch 11	loss : 0.12724080872904053
epoch 12	loss : 0.125277282829557
epoch 13	loss : 0.12360000407649625
epoch 14	loss : 0.12181802535830756
epoch 15	loss : 0.12016650789164066
epoch 16	loss : 0.11863081135178506
epoch 17	loss : 0.11732511590804585
epoch 18	loss : 0.11602537488284431
epoch 19	loss : 0.11491116986668652
epoch 20	loss : 0.11363750083519654
epoch 21	loss : 0.11284147803541054
epoch 22	loss : 0.11182657714966919
epoch 23	loss : 0.11080008170527562
epoch 24	loss : 0.10990207674273801
epoch 25	loss : 0.10905674833429692
epoch 26	loss : 0.10815382810799937
epoch 27	loss : 0.10730067835475336
epoch 28

testing

In [None]:
x_test = torch.empty(0).to(device)
y_test = torch.empty(0).to(device)

for data2 in datasets[224:249]:
  try:
    data2.columns = ["word", "cref", "crefHead", "acrefmod", "acrefmodHead", "crefmod", "creftype", "Chainhead"]

    texts, tags, input_tokens = prepare_data(data2, 0)

    # print(len(texts), texts)
    output = get_word_vectors(texts)
    output.shape

    mentions, mention_ids = map_mentions(data2, 0)

    x, y = make_mention_pairs(mentions, mention_ids)
    print(x.shape, y.shape)

    x_test = torch.cat((x_test, x.to(device)))
    y_test = torch.cat((y_test, y.to(device)))
  except:
    continue

x_test = x_test.reshape(-1, 768)

In [None]:
testset = dataset(x_test,y_test)
testloader = DataLoader(testset,batch_size=1,shuffle=False)

  """
  


In [None]:
predicted_vals = []
actual_vals = []

for i,(x_test,y_test) in enumerate(testloader):
  
  #calculate output
  output = be_model(x_test)

  if output >= 0.1:
    predicted_vals.append(1)
  else:
    predicted_vals.append(0)
  actual_vals.append( int(y_test.item()) )

predicted_vals = np.array(predicted_vals)
actual_vals = np.array(actual_vals)

In [None]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix

# predicted_vals
# actual_vals
# accuracy: (tp + tn) / (p + n)
accuracy = accuracy_score(actual_vals, predicted_vals)
print('Accuracy: %f' % accuracy)
# precision tp / (tp + fp)
precision = precision_score(actual_vals, predicted_vals, average='weighted')
print('Precision: %f' % precision)
# recall: tp / (tp + fn)
recall = recall_score(actual_vals, predicted_vals, average='weighted')
print('Recall: %f' % recall)
# f1: 2 tp / (2 tp + fp + fn)
f1 = f1_score(actual_vals, predicted_vals, average='weighted')
print('F1 score: %f' % f1)
 
# confusion matrix
matrix = confusion_matrix(actual_vals, predicted_vals)
print(matrix)   

Accuracy: 0.770283
Precision: 0.910407
Recall: 0.770283
F1 score: 0.818875
[[36799 10794]
 [ 1101  3087]]


In [None]:
print(roc_auc_score(actual_vals, predicted_vals))

0.7551539793269784


In [None]:
torch.save(be_model, "hin_coref_2")

In [None]:
bmodel = torch.load("hin_coref_2")
bmodel.eval()

Net(
  (fc1): Linear(in_features=768, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=1, bias=True)
)

### Issues: relies on only local information to make decisions
next step: Mention ranking

## lstm
