In [None]:
%matplotlib inline

%load_ext autoreload
%autoreload 2

from imports import *
from utils import *
from constants import *
from models import *
from trains import train_model
from predicts import predict_model

torch.cuda.set_device(0)

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Preprocessing

In [None]:
# with open("/hdd/ammarinjtk/OntoBiotope_BioNLP-ST-2016.obo") as file:  
#     data = [x.split("\n") for x in file.read().split("\n\n")[2:]]
# synonym_pattern = r"synonym: \"([\w\s]+)\" EXACT"
# name_pattern = r"name: ([\w\s]+)"
# is_a_pattern = r"is_a: [\w\s\d:]+ ! ([\w ]+)"

# ontologies = []

# for x in data:
#     ontology_dict = {
#         'name': '',
#         'synonym': [],
#         'is_a': ''
#     }

#     for y in x:

#         if re.match(name_pattern, y):
#             ontology_dict['name'] = re.match(name_pattern, y).group(1)
#         elif re.match(synonym_pattern, y):
#             ontology_dict['synonym'].append(re.match(synonym_pattern, y).group(1))
#         elif re.match(is_a_pattern, y):
#             ontology_dict['is_a'] = re.match(is_a_pattern, y).group(1)
#         else:
#             continue

#     ontologies.append(ontology_dict)
# synonym_dict = {}
# for x in [x for x in ontologies if len(x['synonym']) > 0]:
#     for synonym in x['synonym']:
#         synonym_dict[synonym.lower()] = x['name'].lower()

In [None]:
# dataloaders = preprocess("/home/ammarinjtk/pytorch/Corpus_BB3/", synonym_dict)

In [None]:
# # Non-Deterministic behaviour
# for key, dataloader in dataloaders.items():
#     random.shuffle(dataloader)

In [None]:
# with open('./data/original_data/train.json', 'w') as outfile:  
#     json.dump(dataloaders['train'], outfile)
    
# with open('./data/original_data/dev.json', 'w') as outfile:  
#     json.dump(dataloaders['dev'], outfile)
    
# with open('./data/original_data/test.json', 'w') as outfile:  
#     json.dump(dataloaders['test'], outfile)

In [None]:
dataloaders = {}

with open("./data/preprocessed_BB/train.json", "r") as read_file:
    dataloaders['train'] = json.load(read_file)
    
with open("./data/preprocessed_BB/dev.json", "r") as read_file:
    dataloaders['dev'] = json.load(read_file)
    
with open("./data/preprocessed_BB/test.json", "r") as read_file:
    dataloaders['test'] = json.load(read_file)

In [None]:
len(dataloaders['train']), len(dataloaders['dev']), len(dataloaders['test'])

In [None]:
dataloaders['train'][0]

In [None]:
# Load w2v model
w2v_model = word2vec.KeyedVectors.load_word2vec_format('/hdd/ammarinjtk/wikipedia-pubmed-and-PMC-w2v.bin', binary=True)
# w2v_model = gensim.models.Word2Vec.load("/hdd/ammarinjtk/li_reimplement/models/5_epochs.model").wv

In [None]:
# Global max relative distance
max_distance = float(np.max([np.max(np.abs(input_dict['full_inputs']['full_dist1'] + input_dict['full_inputs']['full_dist2'])) for input_dict in dataloaders['train']+dataloaders['dev']+dataloaders['test']]))

word_to_ix, pos_to_ix, distance_to_ix, dependency_to_ix, char_to_ix, in_vocab_count = build_vocab(dataloaders, w2v_model)

pretrained_embedding_matrix, distance_pretrain_embedding_matrix = build_pretrain_embedding_matrix(w2v_model, 
                                                                                                  word_to_ix, 
                                                                                                  distance_to_ix, 
                                                                                                  max_distance)

In [None]:
glob_shortest_max_sentence_length = np.max([np.max(
                [len(input_dict['shortest_inputs']['shortest_token']), 
                 len(input_dict['shortest_inputs']['shortest_pos']), 
                 len(input_dict['shortest_inputs']['shortest_dep'])]) for input_dict in dataloaders['train']+dataloaders['dev']+dataloaders['test']])

In [None]:
glob_max_sentence_length = np.max([np.max(
                [len(input_dict['full_inputs']['full_token']), 
                 len(input_dict['full_inputs']['full_pos']), 
                 len(input_dict['full_inputs']['full_dep'])]) for input_dict in dataloaders['train']+dataloaders['dev']+dataloaders['test']])

In [None]:
batch_size = 4

torch.cuda.manual_seed_all(523454)
torch.manual_seed(523454)
random.seed(523454)
np.random.seed(523454)

model = Frankenstein(len(word_to_ix), len(pos_to_ix), len(distance_to_ix), len(dependency_to_ix),
                     glob_max_sentence_length, pretrained_embedding_matrix, distance_pretrain_embedding_matrix, 
                     batch_size, drop=0.5, wdrop=0.3, edrop=0.3, idrop=0.3, hidden_dim=64, 
                     window_sizes=[3, 5, 7], h=1, multihead_sizes=3)

criterion = nn.CrossEntropyLoss()

optimizer_ft = optim.Adam(filter(lambda p: p.requires_grad,model.parameters()), lr=1e-3, weight_decay=1e-7)

# Using ELMo and BERT models

In [None]:
import json_lines

options_file = "/hdd/ammarinjtk/ELMO_model/revised_bacteria_pubmed/options.json"
weight_file = "/hdd/ammarinjtk/ELMO_model/revised_bacteria_pubmed/headentity_finetune_weights.hdf5"

elmo_model = Elmo(options_file, weight_file, 1, dropout=0)

finetune_berts = []
with open('/hdd/ammarinjtk/BERT_features/synonym_revised_headentity/finetune_bert.jsonl', 'rb') as f: # opening file in binary(rb) mode    
    for item in json_lines.reader(f):
        finetune_berts.append(item)

with open('/hdd/ammarinjtk/BERT_features/synonym_revised_headentity/tkn.txt', 'r') as f: 
    full_bert_tkns = f.read()
    
dataloader_count = 0
for idx, full_bert_tkn in enumerate(full_bert_tkns.split('\n')):
    
    for dataloader in dataloaders['train']+dataloaders['dev']+dataloaders['test']:
        if " ".join(dataloader['full_inputs']['full_token']) == full_bert_tkn:
            dataloader_count += 1
            dataloader['bert_features'] = np.sum([np.array(layer['values']) for layer in finetune_berts[0]['features'][0]['layers']], axis=0)

# Training

In [None]:
train_out = train_model(model, elmo_model, dataloaders['train'], dataloaders['dev'], word_to_ix, pos_to_ix, 
                        distance_to_ix, dependency_to_ix, criterion, optimizer_ft, lr_scheduler=None, 
                        num_epochs=5, early_stopped_patience=1, batch_size=batch_size)

(model, train_f1, val_f1, history) = train_out

# Prediction

In [None]:
y_true, y_pred, predictions, self_attn_scores, ent_attn_scores, multihead_attn_scores = predict_model(model, elmo_model, 
                                  dataloaders['test'], word_to_ix, pos_to_ix, distance_to_ix,
                                  dependency_to_ix, char_to_ix, batch_size, optimizer_ft)

In [None]:
# len(y_pred), len(dataloaders['test'])

# Prediction (.a2) file generation

In [None]:
import os

model_dir_name = "test_prediction"
os.mkdir(f'/hdd/ammarinjtk/{model_dir_name}')

In [None]:
test_data = dataloaders['test']

In [None]:
file = minidom.parse("/home/ammarinjtk/pytorch/Corpus_BB3/BioNLP-ST-2016_BB-event_{}.xml".format('test'))
docs = file.getElementsByTagName("document")
all_test_files = []
for doc in docs:
    all_test_files.append(doc.getAttribute("origId"))

In [None]:
write_dict = {}
relation_idx_dict = {}
pred_test_files = set()
for idx, input_dict in enumerate(test_data):
    inputs = input_dict['shortest_inputs']
    entity_tag = input_dict['entity_pair']
    label = input_dict['label']
    entity_idx_to_type = input_dict['entity_idx_to_type']

    if y_pred[idx] == 1:
        document_idx = input_dict['document_id']
        pred_test_files.add(document_idx)
        entity_idx_to_origId = input_dict['entity_idx_to_origId']
        first_match = re.match(r'(BB-event-\d+).(T\d+)', entity_idx_to_origId[entity_tag[0]])
        second_match = re.match(r'(BB-event-\d+).(T\d+)', entity_idx_to_origId[entity_tag[1]])
        
        first_entity = first_match.group(2).upper()
        second_entity = second_match.group(2).upper()
        first_doc = first_match.group(1)
        second_doc = second_match.group(1)
        
        try:
            relation_idx_dict[document_idx] += 1
        except KeyError:
            relation_idx_dict[document_idx] = 1
        
        try:
            write_dict["/hdd/ammarinjtk/{}/{}.a2".format(model_dir_name, document_idx)]
        except KeyError:
            write_dict["/hdd/ammarinjtk/{}/{}.a2".format(model_dir_name, document_idx)] = set()
        write_dict["/hdd/ammarinjtk/{}/{}.a2".format(model_dir_name, document_idx)].add("R{}\tLives_In Bacteria:{} Location:{}\n".format(relation_idx_dict[document_idx], first_entity, second_entity))

In [None]:
for key, value in write_dict.items():
    
    write_str = "".join([i[1] for i in sorted([(int(d.split('\t')[0][1:]), d) for d in list(value)], key=lambda tup: tup[0])])
    
    f = open(f"{key}", "w")
    f.write(write_str)
    f.close()

In [None]:
len(all_test_files), len(pred_test_files)

In [None]:
pred_test_files = list(pred_test_files)
for test_file in all_test_files:
    if not test_file in pred_test_files:
        print(test_file)
        f = open("/hdd/ammarinjtk/{}/{}.a2".format(model_dir_name, test_file), "a+")
        f.write("")
        f.close()

In [None]:
import glob
len(glob.glob(f"/hdd/ammarinjtk/{model_dir_name}/*.a2"))

In [None]:
import shutil
shutil.make_archive(f"/hdd/ammarinjtk/{model_dir_name}",
                    'zip',
                    f"/hdd/ammarinjtk/{model_dir_name}")