In [1]:
import torch
from transformers import AlbertForTokenClassification, AlbertTokenizer
import numpy as np
import inference_params
import json
import torch.nn as nn

In [2]:
label_encoder_path = inference_params.LABEL_ENCODER_PATH
tag_values = ['none ', 'none', 'exclamation', 'comma', 'viram', 'PAD']
with open(label_encoder_path) as label_encoder:
    train_encoder = json.load(label_encoder)

In [3]:
tokenizer = AlbertTokenizer.from_pretrained('ai4bharat/indic-bert')

In [23]:
test_sentence = "विकिपीडिया सभी विषयों पर प्रामाणिक और उपयोग परिवर्तन व पुनर्वितरण के लिए स्वतन्त्र ज्ञानकोश बनाने का एक बहुभाषीय प्रकल्प है"
#test_sentence = "अमेरिका समेत अन्य देशों से जो मदद भारत पहुंची उसका क्या हुआ"
tokenized_sentence = tokenizer.encode(test_sentence)
input_ids = torch.tensor([tokenized_sentence]).cuda()
input_ids

tensor([[     2, 171759,    516,  27926,     37,  59251, 126136,   1825,     31,
           3557,  17137,    555,  41281,    130, 175206,   7571,  71080,     10,
             66,   1301,  90529,   9327,   4765,  73061,  11867, 125491,   1342,
             29,     45,  13015,  52974,   2266,     37,  17132,   1551,     16,
              3]], device='cuda:0')

In [5]:
model = AlbertForTokenClassification.from_pretrained('ai4bharat/indic-bert',
                                                     num_labels=len(train_encoder),
                                                     output_attentions=False,
                                                     output_hidden_states=False)
#if torch.cuda.device_count() > 1:
#    print("Using ", torch.cuda.device_count(), "GPUs")
model = nn.DataParallel(model)

Some weights of the model checkpoint at ai4bharat/indic-bert were not used when initializing AlbertForTokenClassification: ['predictions.bias', 'predictions.LayerNorm.weight', 'predictions.LayerNorm.bias', 'predictions.dense.weight', 'predictions.dense.bias', 'predictions.decoder.weight', 'predictions.decoder.bias', 'sop_classifier.classifier.weight', 'sop_classifier.classifier.bias']
- This IS expected if you are initializing AlbertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of AlbertForTokenClassification were not initialized from the model checkpoint at ai4bharat/indic-bert and a

In [6]:
checkpoint = torch.load(inference_params.CHECKPOINT_PATH)
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [7]:
model.eval()
model.cuda()

DataParallel(
  (module): AlbertForTokenClassification(
    (albert): AlbertModel(
      (embeddings): AlbertEmbeddings(
        (word_embeddings): Embedding(200000, 128, padding_idx=0)
        (position_embeddings): Embedding(512, 128)
        (token_type_embeddings): Embedding(2, 128)
        (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0, inplace=False)
      )
      (encoder): AlbertTransformer(
        (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
        (albert_layer_groups): ModuleList(
          (0): AlbertLayerGroup(
            (albert_layers): ModuleList(
              (0): AlbertLayer(
                (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (attention): AlbertAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
      

In [24]:
with torch.no_grad():
    output = model(input_ids)
label_indices = np.argmax(output[0].to('cpu').numpy(), axis=2)

In [25]:
label_indices

array([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4]])

In [26]:
train_encoder

{'hyp': 0, 'qm': 1, 'comma': 2, 'end': 3, 'blank': 4, 'ex': 5, 'PAD': 6}

In [27]:
tokens = tokenizer.convert_ids_to_tokens(input_ids.to('cpu').numpy()[0])
tokens

['[CLS]',
 '▁विकिपीडिया',
 '▁सभी',
 '▁विषयों',
 '▁पर',
 '▁परा',
 'माण',
 'िक',
 '▁और',
 '▁उपयोग',
 '▁परि',
 'वर',
 'तन',
 '▁व',
 '▁पुनर',
 'वि',
 'तरण',
 '▁के',
 '▁लिए',
 '▁स',
 'वतन',
 'तर',
 '▁ज',
 'ञ',
 'ान',
 'कोश',
 '▁बनाने',
 '▁का',
 '▁एक',
 '▁बहु',
 'भाष',
 'ीय',
 '▁पर',
 'कल',
 'प',
 '▁है',
 '[SEP]']

In [28]:
print(len(tokens))
print(len(label_indices[0]))

37
37


In [29]:
new_tokens, new_labels = [], []
for token, label_idx in zip(tokens, label_indices[0]):
    if token.startswith("▁"):
        new_tokens[-1] = new_tokens[-1] + token[1:]
    else:
        new_labels.append(list(train_encoder.keys())[list(train_encoder.values()).index(label_idx)])
        new_tokens.append(token)

In [30]:
for token, label in zip(new_tokens, new_labels):
    print("{}\t{}".format(label, token))

blank	[CLS]विकिपीडियासभीविषयोंपरपरा
blank	माण
blank	िकऔरउपयोगपरि
blank	वर
blank	तनवपुनर
blank	वि
blank	तरणकेलिएस
blank	वतन
blank	तरज
blank	ञ
blank	ान
blank	कोशबनानेकाएकबहु
blank	भाष
blank	ीयपर
blank	कल
blank	पहै
blank	[SEP]


In [31]:
new_tokens = []
new_labels = []
for i in range(1, len(tokens)-1):
    if tokens[i].startswith("▁"):
        current_word = tokens[i][1:]
        new_labels.append(list(train_encoder.keys())[list(train_encoder.values()).index(label_indices[0][i])])
        for j in range(i+1, len(tokens)-1):
            if not tokens[j].startswith("▁"):
                current_word = current_word + tokens[j]
            if tokens[j].startswith("▁"):
                break
        new_tokens.append(current_word)
print(new_tokens)
print(new_labels)

['विकिपीडिया', 'सभी', 'विषयों', 'पर', 'परामाणिक', 'और', 'उपयोग', 'परिवरतन', 'व', 'पुनरवितरण', 'के', 'लिए', 'सवतनतर', 'जञानकोश', 'बनाने', 'का', 'एक', 'बहुभाषीय', 'परकलप', 'है']
['blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'blank', 'end']


In [32]:
sentence = test_sentence.split(' ')

In [33]:
assert len(sentence)==len(new_tokens)==len(new_labels)

In [34]:
print(len(sentence))
print(len(new_tokens))
print(len(new_tokens))

20
20
20


In [35]:
punctuation_dict = {'hyp': '-', 'qm': '? ', 'comma': ', ', 'end': '। ', 'blank': ' ', 'ex': '! '}

In [36]:
s = ''
for word, punctuation in zip(sentence, new_labels):
    s = s + word + punctuation_dict[punctuation]
s

'विकिपीडिया सभी विषयों पर प्रामाणिक और उपयोग परिवर्तन व पुनर्वितरण के लिए स्वतन्त्र ज्ञानकोश बनाने का एक बहुभाषीय प्रकल्प है। '