In [None]:
import argparse
import random
import logging
import os

import torch
from torch.optim import Adam, AdamW
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP

from tqdm import trange
from transformers import BertTokenizer, BertModel, BertForTokenClassification, BertConfig
import pytorch_pretrained_bert

from torch.optim.lr_scheduler import LambdaLR

In [None]:
%run ./data_loader.ipynb

In [None]:
%run ./options.ipynb

In [None]:
%run ./utils.ipynb

In [None]:
%run ./model.ipynb

In [None]:
%run ./metrics.ipynb

In [None]:
def make_bert_mask(x, pad_id):
    bert_mask = (x != pad_id).float()
    return bert_mask

In [None]:
def evaluate(model, data_iterator, params, mark='Eval', verbose=False):
    """Evaluate the model on `steps` batches."""
    # set model to evaluation mode
    model.eval()

    idx2tag = params.idx2tag
    idx2word = BertTokenizer.from_pretrained(params.pretrained_model)
    #else:
    #  idx2word = pytorch_pretrained_bert.BertTokenizer.from_pretrained(params.pretrain_path)

    true_tags = []
    pred_tags = []
    word_data = []
    pt_list = []
    doc_string = ""
    # a running average object for loss
    loss_avg = RunningAverage()

    for i in range(params.eval_steps):
        # fetch the next evaluation batch
        if isinstance(model, nn.DataParallel):
          model = model.module
        batch_data, batch_tags = next(data_iterator)
        batch_masks = batch_data.gt(0)
        #x_mask = make_bert_mask(batch_data, 0)
        
        inputs = {"input_ids": batch_data, "attention_mask": batch_masks, "labels": batch_tags}
        #print(x_mask.shape)
        if params.pretrain:
          outputs = model(**inputs) 
        else:
          outputs = model(**inputs)
        #print(outputs)
        tmp_eval_loss, logits = outputs[0], outputs[1]
        #weights = outputs[2]
        if params.decoder == 'crf':
          tags = model.crf.decode(logits, batch_masks)
          batch_output = torch.squeeze(tags, 0).detach().cpu().numpy()
        elif params.decoder == 'linear':
          tags = np.argmax(logits.detach().cpu().numpy(), axis=2).tolist()
          batch_output = tags
          #print(outputs)
        # choose probability of target class
        
        #a0 = tags.detach().cpu().numpy()[0]
        #b0 = weights.detach().cpu().numpy()
        #tmp_
        #for ii in range(len(a0)):
        #  b0[ii, a0[ii]]
        #print(len(tags))
        #o = model(batch_data, token_type_ids=None, attention_mask=batch_masks, labels=batch_tags)
        #print(tag_seq)
        tmp_eval_loss = tmp_eval_loss.mean()
        
        loss_avg.update(tmp_eval_loss.item())
        
        #batch_output = model(batch_data, token_type_ids=None, attention_mask=batch_masks).logits  # shape: (batch_size, max_len, num_labels)
        
        
        batch_tags = batch_tags.detach().cpu().numpy()
        batch_data = batch_data.detach().cpu().numpy()
        #print(batch_data)
        #pred_tags.extend([idx2tag.get(idx) for indices in np.argmax(batch_output, axis=2) for idx in indices])
        word_data.extend([idx2word.convert_ids_to_tokens(int(idx)) for indices in batch_data for idx in indices])
        pred_tags.extend([idx2tag.get(idx) for indices in batch_output for idx in indices])
        true_tags.extend([idx2tag.get(idx) for indices in batch_tags for idx in indices])
        #doc_string += " ".join(pred_tags)
        #doc_string += '\n'
        #pred_list = list(chain.from_iterable(tag_seq))
        #pred_tags += pred_list
        
        #batch_lens = 16
        #true_list = list(chain.from_iterable([sublist[:batch_lens.tolist()[b]] for b, sublist in enumerate(batch_y.tolist())]))
        #print(len(pred_tags))
        #print(len(true_tags))
        #if i == 0:
        #  a1 = ", ".join(pred_tags)
        #  a2 = ", ".join(true_tags)
        #  a3 = ", ".join(word_data)
        #  a = a1 + '\n' + a2+ '\n' + a3
        #  f_p = "/dbfs/FileStore/shared_uploads/hus45338967@hustietoallas.fi/pre/pred_crf_{}_{}.txt".format(params.epoch_record, i)
        #  with open(f_p, "w") as f:
        #    f.write(a)
        

        
    #f_p = "../pred_crf_{}_{}.txt".format(params.epoch_record, i)
    #with open(f_p, "w") as f:
    #  f.write(doc_string)
    assert len(pred_tags) == len(true_tags)

    # logging loss, f1 and report
    metrics = {}
    p, r, f1 = eval_score(true_tags, pred_tags)
    metrics['loss'] = loss_avg()
    metrics['f1'] = f1
    metrics['prec'] = p
    metrics['rec'] = r
    rp = classification_report(true_tags, pred_tags)
    metrics_str = "; ".join("{}: {:05.2f}".format(k, v) for k, v in metrics.items())
    #logging.info("- {} metrics: ".format(mark) + metrics_str)
    print("- {} metrics: ".format(mark) + metrics_str)

    if verbose:
        report = classification_report(true_tags, pred_tags)
        #logging.info(report)
        print(report)
    return metrics, rp


In [None]:
def train_and_evaluate_and_test(model, test_data, optimizer, scheduler, args, model_dir, restore_file=None):
    """Train the model and evaluate every epoch."""
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth')
        #logging.info("Restoring parameters from {}".format(restore_path))
        print("Restoring parameters from {}".format(restore_path))
        load_checkpoint(restore_path, model, optimizer)
        
    best_val_f1 = 0.0
    patience_counter = 0

    for epoch in range(1, args.n_epochs + 1):
        # Run one epoch
        args.epoch_record = epoch
        #logging.info("Epoch {}/{}".format(epoch, args.n_epochs))
        print("Epoch {}/{}".format(epoch, args.n_epochs))
        args.test_steps = args.test_size // args.batch_size

        test_data_iterator = data_loader.data_iterator(test_data, shuffle=False)

        args.eval_steps = args.test_steps
        test_metrics, test_rp = evaluate(model, test_data_iterator, args, mark='Test')
        

        # Save weights of the network
        #model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        optimizer_to_save = optimizer.optimizer if args.fp16 else optimizer
        print(test_metrics)
        print(test_rp)

In [None]:
if __name__ == '__main__':
  args = args_parser()
  
  args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  args.gpus = torch.cuda.device_count()
  args.n_epochs = 1
  random.seed(args.seed)
  torch.manual_seed(args.seed)
  
  if args.gpus > 0:
    torch.cuda.manual_seed_all(args.seed)  # set random seed for all GPUs
  args.seed = args.seed
  #args.batch_size = 1
  # Set the logger
  #logging.info("device: {}, gpus: {}".format(args.device, args.gpus))
  print("device: {}, gpus: {}".format(args.device, args.gpus))

  # Create the input data pipeline
  #logging.info("Loading the datasets...")
  print("Loading the datasets...")
  
  # Initialize the DataLoader
  data_loader = DataLoader(args.data_dir, args.bert_model_dir, args, token_pad_idx=0)
  
  # Load training data and test data
  test_data = data_loader.load_data('test')
  
  args.test_size = test_data['size']
  
  #config = BertConfig.from_json_file('../bert-base-finnish-uncased-v1/tokenizer.json')
  #model = BertForTokenClassification.from_pretrained(args.pretrained_model, num_labels=len(args.tag2idx))
  config = BertConfig.from_pretrained(args.pretrained_model, num_labels=len(args.tag2idx))
  if args.pretrain == True:
    config.pretrain = args.pretrain
    config.pretrain_path = args.pretrain_path
  config.loss_type = args.loss_type
  if args.decoder == 'linear':
    if args.pretrain == False:
      model = BERT_Linear(config)
    else:
      model = BERT_Linear_pre(config)
  elif args.decoder == 'crf':
    if args.pretrain == False:
      model = BERT_CRF(config)
    else:
      model = BERT_CRF_pre(config)
  elif args.decoder == 'lan':
    config.drop_rate = 0
    config.head_num = 1
    model = BERT_LAN(config)
  else:
    raise RuntimeError("wrong model name")
  model.to(args.device)
  #model._init_weights()
  if args.gpus > 1:
    model = torch.nn.DataParallel(model)
  #checkpoint = torch.load('/../best.pth.tar')
  model = load_model(model, '../linear_no_pretrain_lsr_kir_oper/best.pth.tar')
  #model.load_state_dict(checkpoint['state_dict'])
  #optimizer.load_state_dict(checkpoint['optimizer'])
    
  optimizer = None
  scheduler = None#LambdaLR(optimizer, lr_lambda=lambda epoch: 1/(1 + 0.05*epoch))
  
  # Train and evaluate the model
  #logging.info("Starting training for {} epoch(s)".format(args.n_epochs))
  print("Starting training for {} epoch(s)".format(args.n_epochs))
  train_and_evaluate_and_test(model, test_data, optimizer, scheduler, args, args.model_dir, args.restore_file)