In [None]:
from torch.utils.data import Dataset, DataLoader
#from text_manipulation import word_model
#from text_manipulation import extract_sentence_words
from nltk.tokenize import RegexpTokenizer
from pathlib2 import Path
import re
import os
import math
import gensim
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [None]:
import transformers
from transformers import AdamW
from transformers import BertModel, BertTokenizer

bert = BertModel.from_pretrained('bert-base-uncased')
# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
MAX_TOKENS = 200



In [None]:
#logger = utils.setup_logger(__name__, 'train.log')
missing_stop_words = set(['of', 'a', 'and', 'to'])
section_delimiter = "========"
segment_seperator = "========"

def get_list_token():
    return "***LIST***"

def get_formula_token():
    return "***formula***"

def get_codesnipet_token():
    return "***codice***"

def get_special_tokens():
    special_tokens = []
    special_tokens.append(get_list_token())
    special_tokens.append(get_formula_token())
    special_tokens.append(get_codesnipet_token())
    return special_tokens

def get_seperator_foramt(levels = None):
    level_format = '\d' if levels == None else '['+ str(levels[0]) + '-' + str(levels[1]) + ']'
    seperator_fromat = segment_seperator + ',' + level_format + ",.*?\."
    return seperator_fromat

words_tokenizer = None
def get_words_tokenizer():
    global words_tokenizer

    if words_tokenizer:
        return words_tokenizer

    words_tokenizer = RegexpTokenizer(r'\w+')
    return words_tokenizer

def extract_sentence_words(sentence, remove_missing_emb_words = False, remove_special_tokens = False):
    if (remove_special_tokens):
        for token in get_special_tokens():
            # Can't do on sentence words because tokenizer delete '***' of tokens.
            sentence = sentence.replace(token, "")
    tokenizer = get_words_tokenizer()
    sentence_words = tokenizer.tokenize(sentence)
    if remove_missing_emb_words:
        sentence_words = [w for w in sentence_words if w not in missing_stop_words]

    return sentence_words


def word_model(word, model):http://127.0.0.1:8990/?token=bf09ebab079e0fa44f92e9781373743116328722bca9e067
    if model is None:
        return np.random.randn(1, 300)
    else:
        if word in model:
            return model[word].reshape(1, 300)
        else:
            #print ('Word missing w2v: ' + word)
            return model['UNK'].reshape(1, 300)

def get_files(path):
    all_objects = Path(path).glob('**/*')
    files = [str(p) for p in all_objects if p.is_file()]
    return files


def get_cache_path(wiki_folder):
    cache_file_path = wiki_folder / 'paths_cache'
    return cache_file_path


def cache_wiki_filenames(wiki_folder):
    files = Path(wiki_folder).glob('*/*/*/*')
    cache_file_path = get_cache_path(wiki_folder)

    with cache_file_path.open('w+') as f:
        for file in files:
            f.write(file + u'\n')


def clean_section(section):
    cleaned_section = section.strip('\n')
    return cleaned_section


def get_scections_from_text(txt, high_granularity=True):
    sections_to_keep_pattern = get_seperator_foramt() if high_granularity else get_seperator_foramt(
        (1, 2))
    if not high_granularity:
        # if low granularity required we should flatten segments within segemnt level 2
        pattern_to_ommit = get_seperator_foramt((3, 999))
        txt = re.sub(pattern_to_ommit, "", txt)

        #delete empty lines after re.sub()
        sentences = [s for s in txt.strip().split("\n") if len(s) > 0 and s != "\n"]
        txt = '\n'.join(sentences).strip('\n')


    all_sections = re.split(sections_to_keep_pattern, txt)
    non_empty_sections = [s for s in all_sections if len(s) > 0]

    return non_empty_sections


def get_sections(path, high_granularity=True):
    file = open(str(path), "r")
    raw_content = file.read()
    file.close()

    clean_txt = raw_content.strip()

    sections = [clean_section(s) for s in get_scections_from_text(clean_txt, high_granularity)]

    return sections

def read_wiki_file(path, n_context_sent = 1, remove_preface_segment=True, high_granularity=True):
    data = []
    targets = []
    all_sections = get_sections(path, high_granularity)
    required_sections = all_sections[1:] if remove_preface_segment and len(all_sections) > 0 else all_sections
    required_non_empty_sections = [section for section in required_sections if len(section) > 0 and section != "\n"]

    list_sentence = get_list_token() + "."
    final_sentences = []
    label = []
    for section_ind in range(len(required_non_empty_sections)):
        sentences_ = required_non_empty_sections[section_ind].split('\n')
        sentences = [x for x in sentences_ if x != list_sentence]
        if sentences:
            for sentence in sentences[:-1]:
                final_sentences.append(sentence)
                label.append(0)
            final_sentences.append(sentences[-1])
            label.append(1)
    
    if len(final_sentences)>n_context_sent:
        for sent_ind in range(n_context_sent,len(final_sentences)):
            prev_context = final_sentences[sent_ind-n_context_sent:sent_ind]
            after_context = final_sentences[sent_ind: min(len(final_sentences),sent_ind+n_context_sent)]
            
            prev_context = " ".join(prev_context)
            after_context = " ".join(after_context)
            data.append([prev_context, after_context])
            targets.append(label[sent_ind-1])

    return data, targets, path


class WikipediaDataSet(Dataset):
    def __init__(self, root,n_context_sent = 1, train=True, manifesto=False, folder=False, high_granularity=False):

        if (manifesto):
            self.textfiles = list(Path(root).glob('*'))
        else:
            if (folder):
                self.textfiles = get_files(root)
            else:
                root_path = Path(root)
                print(root_path)
                cache_path = get_cache_path(root_path)
                print(cache_path)
                if not cache_path.exists():
                    print("not_exist")
                    cache_wiki_filenames(root_path)
                self.textfiles = cache_path.read_text().splitlines()

        if len(self.textfiles) == 0:
            raise RuntimeError('Found 0 images in subfolders of: {}'.format(root))
        self.train = train
        self.root = root
        self.high_granularity = high_granularity
        self.n_context_sent = n_context_sent

    def __getitem__(self, index):
        path = self.textfiles[index]

        return read_wiki_file(Path(path),n_context_sent = 2,high_granularity=self.high_granularity)

    def __len__(self):
        return len(self.textfiles)


In [None]:

def collate_fn(batch):
    batched_data = []
    batched_targets = []
    batched_paths = []

    window_size = 1
    before_sentence_count = int(math.ceil(float(window_size - 1) /2))
    after_sentence_count = window_size - before_sentence_count - 1
    max_tokens = 100
    for data, targets, path in batch:
        try:
            for i in range(len(data)):
                temp = len(data[i][0].split())+len(data[i][1].split())
                if max_tokens < temp:
                    max_tokens = temp
                batched_data.append(data[i])
                batched_targets.append(targets[i])
                batched_paths.append(path)
        except Exception as e:
            logger.info('Exception "%s" in file: "%s"', e, path)
            logger.debug('Exception!', exc_info=True)
            continue
    
    max_tokens = min(MAX_TOKENS, max_tokens)
    tokens = tokenizer(
                    batched_data,
                    padding = True,
                    max_length = max_tokens,
                    truncation=True)
    '''seq = torch.tensor(tokens['input_ids'])
    mask = torch.tensor(tokens['attention_mask'])
    y = torch.tensor(batched_targets)'''
        
    return tokens['input_ids'], tokens['attention_mask'], batched_targets, batched_paths


dataset_path = "/home/aakash/amagi/data/wiki/wiki_727"
dataset = WikipediaDataSet(dataset_path+'/dev', high_granularity=False)
dl = DataLoader(dataset, batch_size=12, collate_fn = collate_fn, shuffle=True)

In [None]:
print(len(dataset))

In [None]:
sent,target,path = dataset.__getitem__(1)
print(sent[0][0])
print(len(sent),len(target))

In [None]:
print(len(sent[0]))
print(sent[1])
print(sent[2])

In [None]:
tokens = tokenizer(
                    sent,
                    padding = True,
                    max_length = 200,
                    truncation=True)
print(tokenizer.decode(tokens['input_ids'][2]))

In [None]:
print(len(dl))
_,batch = next(enumerate(dl))
print(len(batch))
input_, mask, targets, paths = batch
print(len(input_),len(mask),len(targets))

In [None]:
for param in bert.parameters():
    param.requires_grad = False

In [None]:
class Encoder_Classifier(nn.Module):
    def __init__(self, bert, n_classes):
        super(Encoder_Classifier, self).__init__()
        self.bert = bert

        # dropout layer
        self.dropout = nn.Dropout(0.1)

        # relu activation function
        self.relu =  nn.ReLU()

        # dense layer 1
        self.fc1 = nn.Linear(768,n_classes)
        
        #softmax activation function
        self.softmax = nn.LogSoftmax(dim=1)

    #define the forward pass
    def forward(self, sent_id, mask):
        #pass the inputs to the model  
        _, cls_hs = self.bert(sent_id, attention_mask=mask, return_dict=False)


        x = self.fc1(cls_hs) 

        # apply softmax activation
        x = self.softmax(x)

        return x

In [None]:
def train(train_dataloader):
    model.train()

    total_loss, total_accuracy = 0, 0
    '''
    # empty list to save model predictions
    total_preds=[]'''
  
    # iterate over batches
    for step,batch in enumerate(train_dataloader):

        # push the batch to gpu
        #batch = [r.to(device) for r in batch]

        ###### for  labeled data, computing cross entropy   #########
        sent_id, mask, labels = torch.tensor(batch[0]).to(device),torch.tensor(batch[1]).to(device),torch.tensor(batch[2]).to(device)

        model.zero_grad()        
        preds = model(sent_id, mask)
        loss = CELoss(preds, labels)

        # backward pass to calculate the gradients
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # update parameters
        optimizer.step()
        
        #torch.cuda.empty_cache()
        # add on to the total loss
        loss_item = loss.item()
        total_loss += loss_item

        # progress update after every 100 batches.
        if step % 100 == 0 and not step == 0:
            print('  Batch {:>5,}  of  {:>5,}.'.format(step, len(train_dataloader)))
            print("loss",loss_item)

         

        '''
        # model predictions are stored on GPU. So, push it to CPU
        preds=preds.detach().cpu().numpy()

        # append the model predictions
        total_preds.append(preds)'''

    # compute the training loss of the epoch
    avg_loss = total_loss / len(train_dataloader)
    # predictions are in the form of (no. of batches, size of batch, no. of classes).
    # reshape the predictions in form of (number of samples, no. of classes)
    #total_preds  = np.concatenate(total_preds, axis=0)

    #returns the loss and predictions
    return avg_loss

In [None]:
from sklearn.metrics import f1_score
def evaluate(dev_dataloader):
  
    print("\nEvaluating...")
  
    # deactivate dropout layers
    model.eval()
    total_loss, total_accuracy = 0, 0
    total_preds = [[None,None]]
    count=0
    curr_examples = 0

    # iterate over batches
    for step,batch in enumerate(dev_dataloader):
    
        # Progress update every 50 batches.
        if step % 1 == 0 and not step == 0:
      
            # Report progress.
            print('  Batch {:>5,}  of  {:>5,} accuracy {}.'.format(step, len(dev_dataloader), count/curr_examples))
            temp = np.delete(total_preds,0,0)
            print("F1 score {}".format(f1_score(list(temp[:,0]),list(temp[:,1]),average="macro")))
        # push the batch to gpu
        #batch = [t.to(device) for t in batch]
    

        sent_id, mask, labels = torch.tensor(batch[0]).to(device),torch.tensor(batch[1]).to(device),torch.tensor(batch[2]).to(device)
        curr_examples += len(sent_id)
        # deactivate autograd
        with torch.no_grad():
      
        # model predictions
            preds = model(sent_id, mask)

            # compute the validation loss between actual and predicted values
            loss = CELoss(preds,labels)

            total_loss = total_loss + loss.item()
            preds = preds.detach().cpu().numpy()

            true_class = np.expand_dims(np.argmax(preds,axis=1),axis=0)
            labels = np.expand_dims(labels.detach().cpu().numpy(), axis=0)
            temp = np.concatenate((labels,true_class),axis=0).T
            
            total_preds = np.concatenate((total_preds,temp),axis=0)

            
            for myvar in range(len(labels[0])):
                if labels[0][myvar]== true_class[0][myvar]:
                    count+=1

    # compute the validation loss of the epoch
    avg_loss = total_loss / len(dev_dataloader) 

    # reshape the predictions in form of (number of samples, no. of classes)
    total_preds  = np.concatenate(total_preds, axis=0)
    print("Validation accuracy",count/len(dev_data))

    return avg_loss, count/len(dev_data)

In [None]:
best_valid_loss = float('inf')
train_losses = []
valid_losses = []
accuracy_list = []
learning_rate = .0005
freezed_epochs = 1
unfreezed_epochs = 1
n_classes = 2
n_layer_unfreeze = 2
weight = .3
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
CELoss = nn.CrossEntropyLoss()
model = Encoder_Classifier(bert, n_classes).to(device)
optimizer = AdamW(model.parameters(),lr = learning_rate)
print(device)

dataset_path = "/home/aakash/amagi/data/wiki/wiki_727"
dataset_train = WikipediaDataSet(dataset_path+'/dev', high_granularity=False)
train_dataloader = DataLoader(dataset_train, batch_size=2, collate_fn = collate_fn, shuffle=True)

dataset_dev = WikipediaDataSet(dataset_path+'/test', high_granularity=False)
dev_dataloader = DataLoader(dataset_dev, batch_size=2, collate_fn = collate_fn, shuffle=True)

In [None]:
for epoch in range(freezed_epochs):
     
    print('\n Epoch {:} / {:}'.format(epoch + 1, freezed_epochs))
    
    #train model
    train_loss = train(train_dataloader)
    
    #evaluate model
    valid_loss, accuracy = evaluate(dev_dataloader)
    
    #save the best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'saved_weights.pt')
    
    # append training and validation loss
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    accuracy_list.append(accuracy)

    print(f'\nTraining Loss: {train_loss[0]:.3f}')
    print(f'Validation Loss: {valid_loss:.3f}')

In [None]:
torch.save(model.state_dict(), '/home/aakash/amagi/data/wiki/wiki_727/saved_weights.pt')

In [None]:
##### unfreezing layers #######
for iter in range(n_layer_unfreeze):
    print(str(iter+1)+" unfreeze")
    for param in model.bert.encoder.layer._modules[str(11-iter)].parameters():
        param.requires_grad=True
    for epoch in range(unfreezed_epochs):
     
        print('\n Epoch {:} / {:}'.format(epoch+1 , ))
        
        #train model
        train_loss = train(train_dataloader)
    
        #evaluate model
        valid_loss, accuracy = evaluate(dev_dataloader)
        
        #save the best model
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'saved_weights.pt')
        
        # append training and validation loss
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        accuracy_list.append(accuracy)
        
        print(f'\nTraining Loss: {train_loss[0]:.3f}')
        print(f'Validation Loss: {valid_loss:.3f}')