In [None]:
! pip install transformers

In [None]:
! pip install scikit-learn

In [None]:
import os
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import Dataset
 
from pprint import pprint
import random
import numpy as np
from typing import List, Dict

import json

import torch

from torch.utils.data.dataloader import DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np

from transformers import BertTokenizer, BertModel, BertConfig

In [None]:
print(torch.__version__)

In [None]:
path = "../resources/bio_format"
os.chdir(path)

In [None]:
torch.cuda.current_device()
torch.cuda.get_device_name(0)

In [None]:
language = "Spanish"
data_file = f"{language.lower()}/{language.lower()}.tsv"

In [None]:

data = {}
idx = 0
tokens = 0

count = 0
with open(data_file) as f:
    sentence = ""
    expression = ""
    prev_tag = "O"
    
    for line in f:
        if line == "\n":
            if expression != "":
                data[idx] = {"expression": expression.strip(),
                            "text": sentence.strip(),
                            "idiomatic": True}
            else:
                data[idx] = {"expression": expression.strip(),
                            "text": sentence.strip(),
                            "idiomatic": False}
                
            idx += 1
            sentence = ""
            expression = ""
            prev_tag = "O"

        else:
            tokens += 1
            line = line.strip().split("\t")
            if len(line)==2:
                token = line[0]
                tag = line[1]
                sentence += token
                if tag == "B-IDIOM":
                    expression += token
                    prev_tag = "B-IDIOM"
                elif tag == "I-IDIOM" and prev_tag == "B-IDIOM":
                    expression += token
                    prev_tag = "I-IDIOM"
                elif tag == "I-IDIOM" and prev_tag == "I-IDIOM":
                    expression += token
                    prev_tag = "I-IDIOM"
                elif tag=="O":
                    prev_tag = "O"
            else:
                print("count " + str(count))
                count+=1
    
print(len(data)) 
print(tokens)    

# Preprocessing

In [None]:
SEED = 2 #we set a seed for having replicability of results
 
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
model_name = 'bert-base-multilingual-cased'
 
bert_config = BertConfig.from_pretrained(model_name, output_hidden_states=True)
bert_tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name, config=bert_config)

In [None]:
class IdiomDataset(Dataset):
    def __init__(self, 
                 dataset, 
                 tokenizer: BertTokenizer,
                 device="cuda",
                ) -> None:
        
        self.encoded_data = []
    
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.device = device
        self.__init_encoded_data(self.dataset)
 
    def __init_encoded_data(self, dataset):
                
        for idx in tqdm(self.dataset):
            e = data[idx]["expression"]
            context_tmp = data[idx]["text"]
            if e!=context_tmp:
                context = context_tmp[:context_tmp.find(e)] + context_tmp[context_tmp.find(e)+len(e)-1:]
            else:
                context = context_tmp
        
            if len(context)<300:
                tokenized_e = torch.tensor(self.tokenize_mention(e, self.tokenizer, True))
                tokenized_context = torch.tensor(self.tokenize_mention(context, self.tokenizer, True))

                if e!="":
                    self.encoded_data.append((idx,
                                            e,
                                            context_tmp,
                                            tokenized_e,
                                            tokenized_context))


            
     
    def tokenize_mention(self, sent, tokenizer, special_tokens):
        encoded_sentence = tokenizer.encode(sent, add_special_tokens = special_tokens)
        return encoded_sentence
    
    def tokenize_description(self, sent, tokenizer, window):
        encoded_sentence = tokenizer.encode(sent, add_special_tokens = True)
        return encoded_sentence

    def __len__(self):
        return len(self.encoded_data)
 
    def __getitem__(self, idx: int):
        return self.encoded_data[idx]


# Models

In [None]:
class BERT(nn.Module):
    def __init__(self, hparams):
        super(BERT, self).__init__()
        pprint(params)
 
        self.hparams = hparams
 
        self.expression_encoder = bert_model #BertModel.from_pretrained(model_name, config=bert_config)
        self.context_encoder = bert_model #BertModel.from_pretrained(model_name, config=bert_config)

        self.cosine_similarity = nn.CosineSimilarity(dim=-1, eps=1e-6)
        
        self.dropout = nn.Dropout(hparams.dropout)
        
        #self.pooling = nn.AvgPool1d(3, stride=3)
                
        for param in self.expression_encoder.parameters():
            param.requires_grad = False
            
        for param in self.context_encoder.parameters():
            param.requires_grad = False
  

    def forward(self, expression, context, mask1, mask2):

        embedding_expression = self.expression_encoder.forward(expression.cuda(), mask1.cuda())[0]
        embedding_expression = torch.sum(embedding_expression, 1)
        #embedding_expression = embedding_expression[:,0,:].squeeze(1)
            
        embedding_context = self.context_encoder.forward(context.cuda(), mask2.cuda())[0] #320x64x768
        #embedding_context = torch.mean(embedding_context, 1)
        embedding_context = embedding_context[:,0,:].squeeze(1) #320x768
            
        similarities = self.cosine_similarity(embedding_expression, embedding_context) 
                        
        return similarities

# Trainer

In [None]:
#trainer with cross entropy

import math 

class Predict():
    def __init__(self,
                model:nn.Module, 
                tokenizer):
        
        self.model = model
        self.tokenizer = tokenizer
 
    def padding_mask(self, batch):
        padding = torch.ones_like(batch)
        padding[batch == 0] = 0
        padding = padding.type(torch.int64)
        return padding
    
    def normalize(self, m):
        row_min, _ = m.min(dim=1, keepdim=True)
        row_max, _ = m.max(dim=1, keepdim=True)
        return (m - row_min) / (row_max - row_min)
 
    def predict(self,
            dataset:Dataset):
        
        print("\nPredicting...")
                 
        
        self.model.eval()
        
        not_idiomatic = []
        not_idiomatic = []
    
        for ids, e, text, expressions, contexts in tqdm(dataset):
            mask1 = self.padding_mask(expressions)
            mask2 = self.padding_mask(contexts)
            
            with torch.no_grad():
                similarities = self.model(expressions, contexts, mask1, mask2)

            if language not in ["Chinese", "Japanese"] and similarities.item()>0.4 and len(text[0].split())>4:
                not_idiomatic.append(ids)
                print(e, "\n", text)
                print(similarities.item())
                print("\n\n\n")
            elif language in ["Chinese", "Japanese"] and similarities.item()>0.55:
                not_idiomatic.append(ids)
                print(e, "\n", text)
                print(similarities.item())
                print("\n\n\n")
        
        return not_idiomatic

# Index Datasets and DataLoader

In [None]:
dataset = IdiomDataset(data, bert_tokenizer)
print(len(dataset))

In [None]:
def collate(elems: tuple) -> tuple:
    ids, e, texts, expressions, contexts = list(zip(*elems))
    
    pad_expressions = pad_sequence(expressions, batch_first=True, padding_value=0)
    pad_contexts = pad_sequence(contexts, batch_first=True, padding_value=0)
 
    return ids, e, texts, pad_expressions.to(torch.int64), pad_contexts.to(torch.int64)


dataloader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate)

print(len(dataloader))

# Hyperparameter, Training and Model Selection

In [None]:
class HParams():
    dropout = 0.25
    
params = HParams()

In [None]:
el_model = BERT(params).cuda()
el_model

In [None]:
predictor = Predict(model = el_model,
                    tokenizer = bert_tokenizer)

In [None]:
not_idiomatic = predictor.predict(dataloader)

In [None]:
for idx in not_idiomatic:
    if language not in ["Chinese", "Japanese"] and len(data[idx[0]]["text"].split(" "))>8:
        data[idx[0]]["idiomatic"] = False
    elif language in ["Chinese", "Japanese"]:
        data[idx[0]]["idiomatic"] = False

print(len(data))

with open(f"../json_format/{language.lower()}/{language}_preannotations_dual.json", "w") as f:
    json.dump(data, f, ensure_ascii=False)