In [1]:
from tqdm.notebook import tqdm
import numpy as np
from scipy import stats
import torch
import random
import json
import copy
import time

class DataHandler():
    def __init__(self, data_src):
        path = self.get_path(data_src)
        data = self.load_data(path)
        self.data = [i['sents'] for i in data]
    
    def get_path(self, data_src):
        base_dir = '/home/alta/Conversational/OET/al826/2021'
        path_dict = {'wiki':f'{base_dir}/data/unlabeled/wiki_100000.json',
                    'WSJ':f'{base_dir}/data/coherence/WSJ_train.json'}
        return path_dict[data_src]
    
    def load_data(self, path):
        with open(path) as jsonFile:
            return json.load(jsonFile)


In [2]:
#utils.py
from transformers import BertTokenizerFast, RobertaTokenizerFast

class UtilClass:
    def __init__(self, system, lim=300_000):
        if system in ['bert', 'electra']:
            self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
            self.CLS, self.SEP = [101], [102]
            self.embeddings = None
            
        elif system == 'roberta':
            self.tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
            self.CLS, self.SEP = [0], [2]
            self.embeddings = None

        elif system in ['glove', 'word2vec']:
            path = self.get_embedding_path(system)
            tok_dict, embed_matrix = self.read_embeddings(path, lim)
            self.tokenizer = FakeTokenizer(tok_dict)
            self.embeddings = torch.Tensor(embed_matrix)
            self.CLS, self.SEP = [], []
        
        else:
            raise ValueError('invalid system')
            
    def get_embedding_path(self, name):
        base_dir = '/home/alta/Conversational/OET/al826/2021'
        if name == 'glove': path = f'{base_dir}/data/embeddings/glove.840B.300d.txt'
        elif name == 'word2vec': path = f'{base_dir}/data/embeddings/word2vec.txt'
        else: raise ValueError('invalid word embedding system') 
        return path 

    def read_embeddings(self, path, limit=300_000):
        with open(path, 'r') as file:
            _ = next(file)
            tok_dict = {'[UNK]':0}
            embed_matrix = []
            for index, line in tqdm(zip(range(limit), file), total=limit):
                word, *embedding = line.split()
                if len(embedding) == 300 and word not in tok_dict:
                    embed_matrix.append([float(i) for i in embedding])
                    tok_dict[word] = len(tok_dict)
        return tok_dict, embed_matrix

#Making the tokenizer the same format as huggingface to better interface with code
class FakeTokenizer:
    def __init__(self, tok_dict):
        self.tok_dict = tok_dict
        self.reverse_dict = {v:k for k,v in self.tok_dict.items()}

    def tokenize_word(self, w):
        if w in self.tok_dict:  output = self.tok_dict[w]
        else: output = 0
        return output

    def tokenize(self, x):
        tokenized_words = [self.tokenize_word(i) for i in x.split()]
        x = type('TokenizedInput', (), {})()
        setattr(x, 'input_ids', tokenized_words)
        return x

    def decode(self, x):
        return ' '.join([self.reverse_dict[i] for i in x])

    def __call__(self, x):
        return self.tokenize(x)

In [3]:
#corrupter.py
import time

def create_corrupted_set(coherent, num_fake, schemes=[1], args=[1]):
    corrupted_set = []
    fail = 0
    while len(corrupted_set) < num_fake and fail<50:
        _, incoherent = create_incoherent(coherent, schemes, args)
        if (incoherent not in corrupted_set) and (incoherent != coherent): corrupted_set.append(incoherent)
        else: fail += 1
    return corrupted_set

def create_incoherent(coherent, schemes, args):
    r = random.choice(schemes)
    if   r == 1: incoherent = random_shuffle(coherent)
    elif r == 2: incoherent = random_swaps(coherent, args[0])
    elif r == 3: incoherent = random_neighbour_swaps(coherent, args[0])
    elif r == 4: incoherent = random_deletion(coherent, args[0])
    elif r == 5: incoherent = local_word_swaps(coherent, *args)
    else: raise Exception
    return coherent, incoherent

def random_shuffle(conversation):
    incoherent = conversation.copy()
    random.shuffle(incoherent)
    return incoherent

def random_swaps(conversation, num_swaps=1):
    incoherent = conversation.copy()
    indices = random.sample(range(0, len(incoherent)), min(2*num_swaps, 2*int(len(incoherent)/2)))

    for i in range(0, len(indices), 2):
        ind_1, ind_2 = indices[i], indices[i+1]
        incoherent[ind_1], incoherent[ind_2] = incoherent[ind_2], incoherent[ind_1]
    return incoherent

def random_neighbour_swaps(conversation, num_swaps=1):
    incoherent = conversation.copy()
    indices = random.sample(range(1, len(incoherent)), num_swaps)
    for i in indices:
        incoherent[i], incoherent[i-1] = incoherent[i-1], incoherent[i]
    return incoherent

def random_deletion(conversation, num_delete=1):
    incoherent = conversation.copy()
    _ = [conversation.pop(-1) for i in range(num_delete)]
    indices = random.sample(range(1, len(incoherent)-1), num_delete)
    indices.sort(reverse=True)
    for i in indices:
        incoherent.pop(i)
    return incoherent

def local_word_swaps(conversation, num_sents=1, num_word_swaps=1):
    incoherent = conversation.copy()
    indices = random.sample(range(0, len(incoherent)), num_sents)
    for i in indices:
        words = incoherent[i].split()
        positions = random.sample(range(0, len(words)), min(2*num_word_swaps, 2*(len(words)//2)))
        for j in range(0, len(positions), 2):
            ind_1, ind_2 = positions[j], positions[j+1]
            words[ind_1], words[ind_2] = words[ind_2], words[ind_1]
        sentence = ' '.join(words)
        incoherent[i] = sentence
    return incoherent


In [26]:
#batcher.py
from collections import namedtuple
from tqdm.notebook import tqdm

flatten = lambda doc: [word for sent in doc for word in sent]
Batch = namedtuple('Batch', ['ids', 'mask'])

class Batcher:
    def __init__(self, U):
        self.bsz = 8
        self.schemes = [2]
        self.args = [1]
        
        self.tokenizer = U.tokenizer
        self.CLS = U.CLS
        self.SEP = U.SEP
        
    def make_batches(self, documents, c_num, hier=False):
        coherent = documents.copy()
        random.shuffle(coherent)
        coherent = [self.tokenize_doc(doc) for doc in tqdm(coherent)]
        cor_pairs = self.corupt_pairs(coherent, c_num)
        batches = [cor_pairs[i:i+self.bsz] for i in range(0,len(cor_pairs), self.bsz)]
        batches = [self.prep_batch(batch, hier) for batch in batches]
        return batches        
    
    def corupt_pairs(self, coherent, c_num):
        incoherent = [create_corrupted_set(doc, c_num, self.schemes, self.args) for doc in coherent]
        examples = []
        for pos, neg_set in zip(coherent, incoherent):
            for neg in neg_set:
                examples.append([pos, neg])
        return examples

    def prep_batch(self, pairs, hier=False):
        if hier == False:
            coherent, incoherent = zip(*pairs)
            coherent = [self.flatten_doc(doc) for doc in coherent]
            incoherent = [self.flatten_doc(doc) for doc in incoherent]

            pos_batch = self.batchify(coherent)
            neg_batch = self.batchify(incoherent)
            return pos_batch, neg_batch
        
        elif hier == True:
            return [[self.batchify(coh), self.batchift(inc)] for coh, inc in pairs]

    def tokenize_doc(self, document):
        return [self.tokenizer(sent).input_ids for sent in document]
    
    def flatten_doc(self, document):
        ids = self.CLS + flatten([sent[1:-1] for sent in document]) + self.SEP
        return ids
    
    def batchify(self, batch):
        max_len = max([len(i) for i in batch])
        ids = [doc + [0]*(max_len-len(doc)) for doc in batch]
        mask = [[1]*len(doc) + [0]*(max_len-len(doc)) for doc in batch]
        ids = torch.LongTensor(ids) #.to(self.device)
        mask = torch.FloatTensor(mask) #.to(self.device)
        return Batch(ids, mask)
    

U = UtilClass('bert')
D = DataHandler('wiki')
B = Batcher(U)

In [None]:
#models.py

import torch
import torch.nn as nn
from transformers import BertConfig, BertModel, RobertaModel, ElectraModel

class DocumentClassifier(nn.Module):
    def __init__(self, class_number=1, system=None, hier=None, embeds=None):  
        super().__init__()
        
        if system in ['bert','roberta','electra']: self.sent_encoder = TransEncoder(system) 
        elif system in ['glove', 'word2vec']:      self.sent_encoder = BilstmEncoder(embeds) 

        if hier in ['bert','roberta','electra']: self.doc_encoder = TransEncoder(hier)
        elif hier in ['glove', 'wiki']: self.doc_encoder = BilstmEncoder(hier)

    def forward(self, x, mask):
        y = self.sent_encoder(x, mask)
        if self.hier:
            y = self.doc_encoder(y)
        y = self.classifier(y)
        return y

class TransEncoder(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.transformer = self.get_transformer(name)

    def forward(self, x, mask):
        hidden_vectors = self.transformer(input_ids=x, attention_mask=mask).last_hidden_state[:,0]
        return hidden_vectors 

    def get_transformer(self, name):
        if name == 'bert':      transformer = BertModel.from_pretrained('bert-base-uncased', return_dict=True)
        elif name == 'roberta': transformer = RobertaModel.from_pretrained('roberta-base', return_dict=True)
        elif name == 'electra': transformer = ElectraModel.from_pretrained('google/electra-base-discriminator')
        elif name == 'hier': transformer = BertModel()
        else: raise Exception
        return transformer

class BilstmEncoder(nn.Module):
    def __init__(self, embeddings):
        super().__init__()
        self.bilstm = nn.LSTM(input_size=300, hidden_size=150, num_layers=1, 
                                    bias=True, batch_first=True, dropout=0, bidirectional=True)
        self.embeddings = nn.Embedding.from_pretrained(embeddings)
        
    def forward(self, x, mask):  
        embeds = self.embeddings(x)
        mask_lens = torch.sum(mask, dim=-1)
        x_padded = torch.nn.utils.rnn.pack_padded_sequence(embeds, mask_lens, batch_first=True, enforce_sorted=False)
        output, _ = self.bilstm(x_padded) 
        h1, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(output)
        return h1

class HierTransEncoder(nn.Module):
    def __init__(self, name, hsz=300):
        super().__init__()
        config = BertConfig(hidden_size=hsz, num_hidden_layers=12, num_attention_heads=12, intermediate_size=4*hsz)
        self.transformer = BertModel(config)
    
    def forward(self, x):
        hidden_vectors = self.transformer(inputs_embeds=x).last_hidden_state[:,0]
        return hidden_vectors 

class HierBilstmEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.bilstm = nn.LSTM(input_size=300, hidden_size=150, num_layers=1, bias=True, 
                              batch_first=True, dropout=0, bidirectional=True)
        
    def forward(self, x):  
        output, _ = self.bilstm(x) 
        return output

"""
U = UtilClass('glove', 5000)
model = DocumentClassifier(1, system='glove', hier=False, embeds=U.embeddings)

D = DataHandler()
d = D.pair_corupt('wiki_50000', [1], [1])

for batch in U.prep_batches(d, 2):
    y = [model(i.ids, i.mask) for i in batch]
"""
pass


In [5]:
a = [[1,4,7],[2,5,8],[3,6,9]]

flatten = lambda doc: [word for sent in doc for word in sent]

print(flatten(zip(*a)))

[1, 2, 3, 4, 5, 6, 7, 8, 9]
