In [1]:
import json
import copy
import time
import numpy as np

class DataHandler():
    def __init__(self, data_src):
        path = self.get_path(data_src)
        data = self.load_data(path)
        self.data = [i['sents'] for i in data]
    
    def get_path(self, data_src):
        base_dir = '/home/alta/Conversational/OET/al826/2021'
        path_dict = {'wiki':f'{base_dir}/data/unlabeled/wiki_100000.json',
                    'WSJ':f'{base_dir}/data/coherence/WSJ_train.json'}
        return path_dict[data_src]
    
    def load_data(self, path):
        with open(path) as jsonFile:
            return json.load(jsonFile)


In [2]:
#utils.py
from transformers import BertTokenizerFast, RobertaTokenizerFast
from tqdm.notebook import tqdm

class UtilClass:
    def __init__(self, system, lim=300_000):
        if system in ['bert', 'electra']:
            self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
            self.CLS, self.SEP = [101], [102]
            self.embeddings = None
            
        elif system == 'roberta':
            self.tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
            self.CLS, self.SEP = [0], [2]
            self.embeddings = None

        elif system in ['glove', 'word2vec']:
            path = self.get_embedding_path(system)
            tok_dict, embed_matrix = self.read_embeddings(path, lim)
            self.tokenizer = FakeTokenizer(tok_dict)
            self.embeddings = torch.Tensor(embed_matrix)
            self.CLS, self.SEP = [], []

        else:
            raise ValueError('invalid system')
            
    def get_embedding_path(self, name):
        base_dir = '/home/alta/Conversational/OET/al826/2021'
        if name == 'glove': path = f'{base_dir}/data/embeddings/glove.840B.300d.txt'
        elif name == 'word2vec': path = f'{base_dir}/data/embeddings/word2vec.txt'
        else: raise ValueError('invalid word embedding system') 
        return path 

    def read_embeddings(self, path, limit=300_000):
        with open(path, 'r') as file:
            _ = next(file)
            tok_dict = {'[UNK]':0}
            embed_matrix = []
            for index, line in tqdm(zip(range(limit), file), total=limit):
                word, *embedding = line.split()
                if len(embedding) == 300 and word not in tok_dict:
                    embed_matrix.append([float(i) for i in embedding])
                    tok_dict[word] = len(tok_dict)
        return tok_dict, embed_matrix

#Making the tokenizer the same format as huggingface to better interface with code
class FakeTokenizer:
    def __init__(self, tok_dict):
        self.tok_dict = tok_dict
        self.reverse_dict = {v:k for k,v in self.tok_dict.items()}

    def tokenize_word(self, w):
        if w in self.tok_dict:  output = self.tok_dict[w]
        else: output = 0
        return output

    def tokenize(self, x):
        tokenized_words = [self.tokenize_word(i) for i in x.split()]
        x = type('TokenizedInput', (), {})()
        setattr(x, 'input_ids', tokenized_words)
        return x

    def decode(self, x):
        return ' '.join([self.reverse_dict[i] for i in x])

    def __call__(self, x):
        return self.tokenize(x)

In [3]:
#corrupter.py
import time
import random

def create_corrupted_set(coherent, num_fake, schemes=[1], args=[1]):
    corrupted_set = []
    fail = 0
    while len(corrupted_set) < num_fake and fail<50:
        _, incoherent = create_incoherent(coherent, schemes, args)
        if (incoherent not in corrupted_set) and (incoherent != coherent): corrupted_set.append(incoherent)
        else: fail += 1
    return corrupted_set

def create_incoherent(coherent, schemes, args):
    r = random.choice(schemes)
    if   r == 1: incoherent = random_shuffle(coherent)
    elif r == 2: incoherent = random_swaps(coherent, args[0])
    elif r == 3: incoherent = random_neighbour_swaps(coherent, args[0])
    elif r == 4: incoherent = random_deletion(coherent, args[0])
    elif r == 5: incoherent = local_word_swaps(coherent, *args)
    else: raise Exception
    return coherent, incoherent

def random_shuffle(conversation):
    incoherent = conversation.copy()
    random.shuffle(incoherent)
    return incoherent

def random_swaps(conversation, num_swaps=1):
    incoherent = conversation.copy()
    indices = random.sample(range(0, len(incoherent)), min(2*num_swaps, 2*int(len(incoherent)/2)))

    for i in range(0, len(indices), 2):
        ind_1, ind_2 = indices[i], indices[i+1]
        incoherent[ind_1], incoherent[ind_2] = incoherent[ind_2], incoherent[ind_1]
    return incoherent

def random_neighbour_swaps(conversation, num_swaps=1):
    incoherent = conversation.copy()
    indices = random.sample(range(1, len(incoherent)), num_swaps)
    for i in indices:
        incoherent[i], incoherent[i-1] = incoherent[i-1], incoherent[i]
    return incoherent

def random_deletion(conversation, num_delete=1):
    incoherent = conversation.copy()
    _ = [conversation.pop(-1) for i in range(num_delete)]
    indices = random.sample(range(1, len(incoherent)-1), num_delete)
    indices.sort(reverse=True)
    for i in indices:
        incoherent.pop(i)
    return incoherent

def local_word_swaps(conversation, num_sents=1, num_word_swaps=1):
    incoherent = conversation.copy()
    indices = random.sample(range(0, len(incoherent)), num_sents)
    for i in indices:
        words = incoherent[i].split()
        positions = random.sample(range(0, len(words)), min(2*num_word_swaps, 2*(len(words)//2)))
        for j in range(0, len(positions), 2):
            ind_1, ind_2 = positions[j], positions[j+1]
            words[ind_1], words[ind_2] = words[ind_2], words[ind_1]
        sentence = ' '.join(words)
        incoherent[i] = sentence
    return incoherent


In [4]:
#batcher.py
from collections import namedtuple
from tqdm.notebook import tqdm
import torch

flatten = lambda doc: [word for sent in doc for word in sent]
Batch = namedtuple('Batch', ['ids', 'mask'])

class Batcher:
    def __init__(self, bsz=8, schemes=[1], args=None, max_len=512, U=None):
        self.bsz = bsz
        self.schemes = schemes
        self.args = args
        self.max_len = max_len
        
        self.tokenizer = U.tokenizer
        self.CLS = U.CLS
        self.SEP = U.SEP
        
        self.device = torch.device('cpu')
        
    def make_batches(self, documents, c_num, hier=False):
        coherent = documents.copy()
        random.shuffle(coherent)
        coherent = [self.tokenize_doc(doc) for doc in tqdm(coherent)]
        
        if not hier: coherent = [doc for doc in coherent if len(self.flatten_doc(doc)) < self.max_len] 
        else:  coherent = [doc for doc in coherent if max([len(i) for i in doc]) < self.max_len] 
                    
        cor_pairs = self.corupt_pairs(coherent, c_num)
        batches = [cor_pairs[i:i+self.bsz] for i in range(0,len(cor_pairs), self.bsz)]
        batches = [self.prep_batch(batch, hier) for batch in batches]
        return batches        
    
    def corupt_pairs(self, coherent, c_num):
        incoherent = [create_corrupted_set(doc, c_num, self.schemes, self.args) for doc in coherent]
        examples = []
        for pos, neg_set in zip(coherent, incoherent):
            for neg in neg_set:
                examples.append([pos, neg])
        return examples

    def prep_batch(self, pairs, hier=False):
        if hier == False:
            coherent, incoherent = zip(*pairs)
            coherent = [self.flatten_doc(doc) for doc in coherent]
            incoherent = [self.flatten_doc(doc) for doc in incoherent]

            pos_batch = self.batchify(coherent)
            neg_batch = self.batchify(incoherent)
            return pos_batch, neg_batch
        
        else:
            return [[self.batchify(coh), self.batchify(inc)] for coh, inc in pairs]
        
    def tokenize_doc(self, document):
        return [self.tokenizer(sent).input_ids for sent in document]
    
    def flatten_doc(self, document):
        ids = self.CLS + flatten([sent[1:-1] for sent in document]) + self.SEP
        return ids
    
    def batchify(self, batch):
        max_len = max([len(i) for i in batch])
        ids = [doc + [0]*(max_len-len(doc)) for doc in batch]
        mask = [[1]*len(doc) + [0]*(max_len-len(doc)) for doc in batch]
        ids = torch.LongTensor(ids).to(self.device)
        mask = torch.FloatTensor(mask).to(self.device)
        return Batch(ids, mask)

    def to(self, device):
        self.device = device

In [5]:
#models.py

import torch
import torch.nn as nn
from transformers import BertConfig, BertModel, RobertaModel, ElectraModel

class DocumentClassifier(nn.Module):
    def __init__(self, class_num=1, system=None, hier=None, embeds=None):  
        super().__init__()
        self.classifier = nn.Linear(300, class_num)
        
        if system in ['bert','roberta','electra']: self.sent_encoder = TransEncoder(system) 
        elif system in ['glove', 'word2vec']:      self.sent_encoder = BilstmEncoder(embeds) 

        self.hier = hier
        if hier == 'transformer': self.doc_encoder = HierTransEncoder(hier)
        elif hier == 'bilstm': self.doc_encoder = HierTransEncoder(hier)

    def forward(self, x, mask):
        y = self.sent_encoder(x, mask)
        if self.hier:
            y = y.unsqueeze(0)
            y = self.doc_encoder(y)
        y = self.classifier(y)
        return y

class TransEncoder(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.transformer = self.get_transformer(name)
        self.linear = nn.Linear(768, 300)
        
    def forward(self, x, mask):
        h1 = self.transformer(input_ids=x, attention_mask=mask).last_hidden_state[:,0]
        h1 = self.linear(h1)
        return h1 

    def get_transformer(self, name):
        if name == 'bert':      transformer = BertModel.from_pretrained('bert-base-uncased', return_dict=True)
        elif name == 'roberta': transformer = RobertaModel.from_pretrained('roberta-base', return_dict=True)
        elif name == 'electra': transformer = ElectraModel.from_pretrained('google/electra-base-discriminator')
        elif name == 'hier': transformer = BertModel()
        else: raise Exception
        return transformer

class BilstmEncoder(nn.Module):
    def __init__(self, embeddings):
        super().__init__()
        self.bilstm = nn.LSTM(input_size=300, hidden_size=150, num_layers=2, 
                                    bias=True, batch_first=True, dropout=0, bidirectional=True)
        self.embeddings = nn.Embedding.from_pretrained(embeddings)
        
    def forward(self, x, mask):  
        x = self.embeddings(x)
        mask_lens = torch.sum(mask, dim=-1).cpu()
        x_padded = torch.nn.utils.rnn.pack_padded_sequence(x, mask_lens, batch_first=True, enforce_sorted=False)
        output, _ = self.bilstm(x_padded) 
        h1, unpacked_len = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
        h1 = torch.mean(h1, dim=1)
        return h1

class HierTransEncoder(nn.Module):
    def __init__(self, name, hsz=300):
        super().__init__()
        config = BertConfig(hidden_size=hsz, num_hidden_layers=12, num_attention_heads=12, intermediate_size=4*hsz)
        self.transformer = BertModel(config)
    
    def forward(self, x):
        hidden_vectors = self.transformer(inputs_embeds=x, return_dict=True).last_hidden_state[:,0]
        return hidden_vectors 

class HierBilstmEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.bilstm = nn.LSTM(input_size=300, hidden_size=150, num_layers=1, bias=True, 
                              batch_first=True, dropout=0, bidirectional=True)
        
    def forward(self, x):  
        h1, _ = self.bilstm(x) 
        h1 = torch.mean(h1, dim=1)
        return output

    


In [6]:
"""
D = DataHandler('wiki')
U = UtilClass('bert', 100_000)
model = DocumentClassifier(1, 'bert', embeds=U.embeddings)
B = Batcher(4, [1], [1], 200, U)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
B.to(device)

for k, batch in enumerate(B.make_batches(D.data[:1000], c_num=1, hier='transformer')):
    for pos, neg in batch:
        print(pos.ids.shape)
        y = model(pos.ids, pos.mask)
        print(y.shape)
    time.sleep(5)
    print('--'*20)

"""

"\nD = DataHandler('wiki')\nU = UtilClass('bert', 100_000)\nmodel = DocumentClassifier(1, 'bert', embeds=U.embeddings)\nB = Batcher(4, [1], [1], 200, U)\n\ndevice = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\nmodel.to(device)\nB.to(device)\n\nfor k, batch in enumerate(B.make_batches(D.data[:1000], c_num=1, hier='transformer')):\n    for pos, neg in batch:\n        print(pos.ids.shape)\n        y = model(pos.ids, pos.mask)\n        print(y.shape)\n    time.sleep(5)\n    print('--'*20)\n\n"

In [7]:
#TEMP
"""
model = DocumentClassifier(1, 'bert', 'transformer')

#TEMP
D = DataHandler('wiki')
U = UtilClass('bert', 100_000)
B = Batcher(4, [1], [1], 200, U)

for k, batch in enumerate(B.make_batches(D.data[:1000], c_num=1, hier='bert')):
    for pos, neg in batch:
        y = model(pos.ids, pos.mask)
        print(y.shape)
    time.sleep(5)
    print('--'*20)

"""


"\nmodel = DocumentClassifier(1, 'bert', 'transformer')\n\n#TEMP\nD = DataHandler('wiki')\nU = UtilClass('bert', 100_000)\nB = Batcher(4, [1], [1], 200, U)\n\nfor k, batch in enumerate(B.make_batches(D.data[:1000], c_num=1, hier='bert')):\n    for pos, neg in batch:\n        y = model(pos.ids, pos.mask)\n        print(y.shape)\n    time.sleep(5)\n    print('--'*20)\n\n"

In [8]:
from torch.optim.lr_scheduler import LambdaLR

class log_sigmoid_loss(nn.Module):
        def __init__(self):
            super().__init__()
            self.log_sigmoid = nn.LogSigmoid()

        def forward(self, inputs):
            log_likelihood = self.log_sigmoid(inputs)
            loss =  -1 * torch.mean(log_likelihood)
            return loss

class ExperimentHandler:
    def __init__(self):
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.cross_loss = nn.CrossEntropyLoss()
        self.log_sigmoid_loss = log_sigmoid_loss()

    def train(self, config):
        D = DataHandler(config.data_src)
        U = UtilClass(config.system, config.embed_lim)
        B = Batcher(config.bsz, config.schemes, config.args, config.max_len, U)
        
        model = DocumentClassifier(1, config.system, config.hier, embeds=U.embeddings)
        self.model = model
        
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
        if config.scheduling:
              SGD_steps = (len(train_data)*config.epochs)/self.bsz
              lambda1 = lambda i: 10*i/SGD_steps if i <= SGD_steps/10 else 1 - ((i - 0.1*SGD_steps)/(0.9*SGD_steps))
              scheduler = LambdaLR(optimizer, lr_lambda=lambda1)
        
        model.to(self.device)
        B.to(self.device) 

        for epoch in range(config.epochs):
            analysis_loss, acc = 0, np.zeros(2)
            
            for k, batch in enumerate(B.make_batches(D.data[:10000], config.c_num, config.hier)):
                if config.hier in ['transformer', 'bilstm']:
                    loss = 0
                    for pos, neg in batch:
                        y_pos = model(pos.ids, pos.mask)
                        y_neg = model(neg.ids, neg.mask)
                        loss += self.log_sigmoid_loss(y_pos - y_neg)/len(batch)
                        acc += [(y_pos>y_neg).item(), 1]
                else:
                    pos, neg = batch
                    y_pos = model(pos.ids, pos.mask)
                    y_neg = model(neg.ids, neg.mask)
                    loss = self.log_sigmoid_loss(y_pos - y_neg)
                    acc += [sum(y_pos - y_neg > 0).item(), len(y_pos)]
            
                analysis_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if config.scheduling: scheduler.step()

                if k%config.debug_sz==0 and k!=0:
                    print(f'{k:<5} {analysis_loss/config.debug_sz:.3f}   {acc[0]/acc[1]:.3f}')
                    analysis_loss, acc = 0, np.zeros(2)


In [None]:
config_dict = {'bsz':6, 'lr':1e-5, 'epochs':5, 'scheduling':False, 
               'system':'bert', 'embed_lim':100000, 'hier':'bilstm', 
               'data_src':'wiki', 'c_num':1, 'schemes':[1], 'args':None, 'max_len':512,
               'debug_sz':100}

ConfigTruple = namedtuple('Config', config_dict)
config = ConfigTruple(**config_dict)

E = ExperimentHandler()
E.train(config)

  0%|          | 0/10000 [00:00<?, ?it/s]

100   0.497   0.744
200   0.400   0.818
300   0.453   0.782
400   0.383   0.813
500   0.395   0.783
600   0.380   0.823
700   0.359   0.813
800   0.327   0.848
900   0.391   0.810
1000  0.376   0.813
1100  0.359   0.810
1200  0.355   0.827
1300  0.365   0.813
1400  0.369   0.810
1500  0.378   0.812
1600  0.385   0.810


  0%|          | 0/10000 [00:00<?, ?it/s]

100   0.259   0.866
200   0.292   0.852
300   0.292   0.845
400   0.299   0.862
500   0.295   0.823
600   0.254   0.870
700   0.260   0.865
800   0.300   0.847
900   0.317   0.843
1000  0.291   0.832
1100  0.279   0.863
1200  0.269   0.852
1300  0.309   0.855
1400  0.286   0.857
1500  0.307   0.847
1600  0.287   0.848


  0%|          | 0/10000 [00:00<?, ?it/s]

100   0.246   0.893
200   0.203   0.895
300   0.226   0.880
400   0.217   0.882
500   0.245   0.872
600   0.239   0.895
700   0.233   0.870
800   0.241   0.883
900   0.251   0.877
1000  0.214   0.882
1100  0.228   0.882
1200  0.208   0.885
1300  0.233   0.867
1400  0.253   0.887
1500  0.180   0.905
1600  0.250   0.867


  0%|          | 0/10000 [00:00<?, ?it/s]

100   0.181   0.896
200   0.177   0.905


In [None]:
model = DocumentClassifier(c_num=1, system='glove', hier=False, embeds=U.embeddings)

In [None]:
ids = torch.ones([8,32], dtype=torch.long)

print(model(ids, ids).shape)

In [None]:
U = UtilClass('glove')
D = DataHandler('wiki')
B = Batcher(U)
model = DocumentClassifier(1, system='bert', hier=False)
