In [1]:
from optimization_openai import OpenAIAdam
import math
import time
from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig
import torch
import torch.nn as nn
import optimizers
from train_process import *
import os
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from DataBunch import *
from network import BigModel
from torch.utils import data
from torch.autograd import Variable
import torch.nn.functional as F
from Optim import ScheduledOptim
from tqdm import tqdm
from Config_File import Config_base
import argparse

parser = argparse.ArgumentParser(description='CNN text classificer')
parser.add_argument('--config', type=str, default='Config_base')
parser.add_argument('--batch_size', type=int, default=8, help='batch_size')
parser.add_argument('--causal_ratio', type=float,
                    default=0.1, help='batch_size')
parser.add_argument('--batch_size_test', type=int,
                    default=None, help='batch_size_test')
parser.add_argument('--epoch', type=int, default=3, help='epoch')
parser.add_argument('--dataset', type=str, default='e-snli', help='dataset')
parser.add_argument('--grad_loss_func', type=str,
                    default='argmax_loss', help='grad_loss_func')
parser.add_argument('--saliancy_method', type=str,
                    default='compute_saliancy_batch', help='saliancy_method')
parser.add_argument('--train_process', type=str,
                    default='train_cause_word', help='train_process')
parser.add_argument('--model_name_or_path', type=str,
                    default='bert-base-uncased', help='model_name_or_path')
parser.add_argument('--databunch_method', type=str,
                    default='DataBunch_e_snli_marked', help='databunch_method')
parser.add_argument('--use_custom_bert', action='store_true',
                    default=False, help='use_custom_bert')
#parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased', help='model_name_or_path')
parser.add_argument('--load_few', action='store_true',
                    default=False, help='load few')
parser.add_argument('--grad_clamp', action='store_true',
                    default=False, help='grad_clamp')
parser.add_argument('--test_mode', action='store_true',
                    default=False, help='test_mode or not')
parser.add_argument('--no_use_pre_train_parameters', action='store_true', default=False,
                    help='no_use_pre_train_parameters or not')

args = parser.parse_args(args=['--load_few'])
# Config_File.ComputeConfig(args.config)
#Config_File.Config = Config_File.ComputeConfig(args)


In [2]:
config = Config_base(args)


# from train_augment_process import *
# from train_process_MM import *

# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

torch.manual_seed(0)
torch.backends.cudnn.benchmark = True


# from Augmentation import Analogy_Auger
# from train_process_MM import BigModel_two_bert

params_trainset = {'batch_size': config.batch_size,
                   'shuffle': False,
                   'num_workers': 0}
params_testset = {'batch_size': config.batch_size_test,
                  'shuffle': False,
                  'num_workers': 0}
torch.set_printoptions(threshold=10000)


def predict(model, test_generator, outputFile):
    pred = evaluate(model, criterion, test_generator, False)
    with open(outputFile, 'w', encoding='utf-8') as writer:
        writer.write('Id,Expected\n')
        for term in pred:
            pred_ret = term[1]
            pred_label = None
            if config.dataset_train == 'RTE':
                if pred_ret == 0:
                    pred_label = 'not_entailment'
                else:
                    pred_label = 'entailment'
            if config.dataset_train == 'MSRP':
                pred_label = pred_ret
            if config.dataset_train in ['SNLI', 'mini-SNLI', 'e-snli']:
                if pred_ret == 0:
                    pred_label = 'neutral'
                elif pred_ret == 1:
                    pred_label = 'entailment'
                else:
                    pred_label = 'contradiction'

            writer.write('{},{}\n'.format(term[0], pred_label))
        writer.close()


record = {}


def ReadDataset(args):
    global trainset, testset, devset
    db_class = globals()[config.databunch_method]
    dataset = config.dataset
    if config.do_train:
        trainset[dataset] = db_class(config, config.train_file_dict[dataset], config.sent_token_dict[dataset], config.label_token_dict[dataset],
                                     config.tokenizer, config.sent2_token_dict[dataset],
                                     dataset=dataset, id_token=config.id_token_dict[dataset], load_few=args.load_few)
        devset[dataset] = db_class(config, config.dev_file_dict[dataset], config.sent_token_dict[dataset],
                                   config.label_token_dict[dataset], config.tokenizer, config.sent2_token_dict[dataset],
                                   dataset=dataset, id_token=config.id_token_dict[dataset])
    if config.do_test:
        testset[dataset] = db_class(config, config.test_file_dict[dataset], config.sent_token_dict[dataset],
                                    None, config.tokenizer, config.sent2_token_dict[dataset], dataset=dataset, id_token=config.id_token_dict[dataset])



In [3]:
trainset = {}
testset = {}
devset = {}
ReadDataset(config)

mode_result = []
# print(config.__dict__)

# model = network.__dict__[Config.big_model](w2id[Config.dataset_train]).to(Config.device)
model = globals()[config.big_model](config)
if config.multiple_gpu:
    model = nn.DataParallel(model)
    model = model.cuda()
else:
    model = model.to(config.device)
if config.continue_train:
    model.load_state_dict(torch.load(config.model_save_path))
    print('continue the training model in {}'.format(config.model_save_path))

criterion = nn.CrossEntropyLoss()

start reading file../data/e-snli/esnli_train_1.tsv
Loaded Id with size (9998,)
Loaded cause_mask with size (9998, 80)
Loaded txt with size (9998,)
Loaded x_mask with size (9998, 80)
Loaded x_pos with size (9998, 80)
Loaded x_sent with size (9998, 80)
Loaded x_typeid with size (9998, 80)
Loaded y with size (9998, 1)
Loaded h5 from ../data/e-snli/esnli_train_1.tsv_Pretrain_DataBunch_e_snli_markedbert-base-uncased_loadfew.h5
../data/e-snli/esnli_train_1.tsv loaded from h5
start reading file../data/e-snli/esnli_dev.tsv
Loaded Id with size (9842,)
Loaded cause_mask with size (9842, 80)
Loaded txt with size (9842,)
Loaded x_mask with size (9842, 80)
Loaded x_pos with size (9842, 80)
Loaded x_sent with size (9842, 80)
Loaded x_typeid with size (9842, 80)
Loaded y with size (9842, 1)
Loaded h5 from ../data/e-snli/esnli_dev.tsv_Pretrain_DataBunch_e_snli_markedbert-base-uncased_loadfew.h5
../data/e-snli/esnli_dev.tsv loaded from h5


In [4]:
config

<Config_File.Config_base at 0x7f5257749790>

In [5]:
train_generator = data.DataLoader(
    trainset[config.dataset_train], **params_trainset)
dev_generator = data.DataLoader(
    devset[config.dataset_train], **params_trainset)

optimizer = optimizers.__dict__[config.optimizer](model, int(
    len(trainset[config.dataset_train]) / params_trainset['batch_size']) * config.epoch)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0,
                                            num_training_steps=int(len(train_generator) * config.epoch))

In [6]:
globals()[args.train_process]

<function train_process_causality.train_cause_word(args, model, optimizer, scheduler, criterion, train_generator, test_generator)>

In [7]:
args = config

In [8]:
best_accu = 0
best_f1 = 0
accu_min_train_loss = 0
last_update_epoch = 0

train_ratios_log, eval_ratios_log = [], []

In [9]:
from train_process_causality import *
test_generator = dev_generator

In [11]:
for epoch in range(args.epoch):
    if (args.early_stop is not None) and (epoch - last_update_epoch > args.early_stop):
        break
    train_loss = 0
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    ori_losses = AverageMeter()
    top1 = AverageMeter()
    prec5 = AverageMeter()
    grad_loss = AverageMeter()
    grad0_loss = AverageMeter()
    priores = AverageMeter()
    varies = AverageMeter()
    print('')
    bar = Bar('Processing', max=len(train_generator))
    end = time.time()
    with torch.autograd.set_detect_anomaly(True):
        for iter, batch_data0 in enumerate(train_generator):
            batch_data = {}
            for k in batch_data0:
                v = batch_data0[k]
                if isinstance(v, torch.LongTensor) or isinstance(v, torch.FloatTensor):
                    batch_data[k] = v.cuda()
                else:
                    batch_data[k] = v
            if 'cause_mask' in batch_data:
                cause_mask = batch_data['cause_mask']
            else:
                cause_mask = torch.zeros(
                    batch_data['x_sent'].size()).cuda().int()

            data_time.update(time.time() - end)
            model.train()
            optimizer.zero_grad()
            model.zero_grad()
            loss = 0.0

            if True:
                loss_gradspred = compute_saliancy(
                    args, model, batch_data, retain_graph=True)
                #loss_gradspred = nn.LayerNorm(loss_gradspred.size()[1:]).cuda()(loss_gradspred)
                '''if iter%max(len(train_generator)//5,1)==0:
                    prec5_this = visualize(args, epoch, iter, batch_data, loss_gradspred, write_label='w' if (epoch+iter==0) else 'a')
                    prec5.update(prec5_this,batch_data['x_sent'].size(0))'''
                loss_g0 = torch.sum(loss_gradspred * (1 - cause_mask) * batch_data['x_mask']) / torch.sum(
                    (1 - cause_mask) * batch_data['x_mask'])
                if torch.sum(cause_mask) == 0:
                    loss_g = loss_g0 * 0.0
                else:
                    loss_g = torch.sum(
                        loss_gradspred * cause_mask) / torch.sum(cause_mask)
                grad_loss.update(
                    loss_g.item(), batch_data['x_sent'].size(0))
                grad0_loss.update(
                    loss_g0.item(), batch_data['x_sent'].size(0))
                loss_g *= args.causal_ratio
                loss_g0 *= args.causal_ratio
                if args.grad_clamp:
                    loss = -torch.sum(torch.clamp(loss_gradspred, min=1.0) * cause_mask) / torch.sum(cause_mask)
                else:
                    loss = -loss_g + loss_g0
                    #loss = (-torch.sum(loss_gradspred.pow(0.5) * cause_mask) + torch.sum(loss_gradspred.pow(2) * (1-cause_mask))) *args.causal_ratio
                #loss = - torch.sum(torch.clamp(torch.abs(loss_gradspred*cause_mask/loss_gradspred_old),max=2.0)) * args.causal_ratio*0.02

            local_labels = batch_data['y'].to(args.device).squeeze()

            pred_y, deep_repre, seq_repre = model(batch_data)
            ce_loss = criterion(
                pred_y.reshape(-1, pred_y.shape[-1]), local_labels.reshape(-1))
            loss += ce_loss

            loss.backward()

            losses.update(loss.item(), batch_data['x_sent'].size(0))
            ori_losses.update(ce_loss.item(), batch_data['x_sent'].size(0))

            nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            train_loss += loss.item()
            optimizer.step()
            scheduler.step()

            if iter % 10 == 0 or iter == len(train_generator)-1:
                top1.update(get_acc(args, pred_y, local_labels),
                            batch_data['y'].size(0))

                train_accuracy, f1, train_loss, train_ratio = evaluate_causal_word(args, model, criterion, train_generator,
                                                                                   count_limit=3000)
                val_accuracy, f1, val_loss, eval_ratio = evaluate_causal_word(args, model, criterion, test_generator,
                                                                              count_limit=3000)
                train_ratios_log.append(
                    (train_ratio, train_accuracy, train_loss))
                eval_ratios_log.append(
                    (eval_ratio, val_accuracy, val_loss))

            batch_time.update(time.time() - end)
            end = time.time()

            bar.suffix = '({batch}/{size}) Batch:{bt:.3f}s|Total:{total:}|ETA:{eta:}|Loss:{loss:.4f}|Loss_ce:{loss_ce:.4f}|Grad:{grad_loss:.4f}|Grad0:{grad0_loss:.4f}|top1:{accu:.4f}|grad_ratio:{ratio:.4f}'.format(
                batch=iter + 1, size=len(train_generator), bt=batch_time.avg,
                total=bar.elapsed_td, eta=bar.eta_td, loss=losses.avg, grad_loss=grad_loss.avg,
                grad0_loss=grad0_loss.avg, accu=top1.avg, ratio=grad_loss.avg/grad0_loss.avg, loss_ce=ori_losses.avg)
            bar.next()
            print(bar.suffix, flush=True)

            # evaluate_causal_word(args, model, criterion, test_generator, True)

        bar.finish()

    print("epoch:{} train_loss:{}".format(
        epoch, train_loss / len(train_generator)))

    accuracy, f1, val_loss, eval_ratio = evaluate_causal_word(
        args, model, criterion, test_generator, True)

    if best_accu < accuracy:
        best_accu = accuracy
        last_update_epoch = epoch
        best_f1 = f1
        torch.save(model.state_dict(), args.model_save_path)
        print('update new model at {}'.format(args.model_save_path))

    resultStr = "mode:{} epoch:{} val_loss:{:.4f} accu:{:.4f}, f1:{}, best accu:{:.4f}, minaccu:{:.4f}".format(
        args.config_name, epoch, val_loss, accuracy, f1, best_accu, accu_min_train_loss)
    print(resultStr)

with open('ratio{}_log.txt'.format(args.causal_ratio), 'w', encoding='utf-8') as writer:
    for i in range(len(train_ratios_log)):
        writer.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(
            train_ratios_log[i][0], train_ratios_log[i][1], train_ratios_log[i][2], eval_ratios_log[i][0], eval_ratios_log[i][1], eval_ratios_log[i][2]))
    writer.close()

# if args.class_num == 2:
#     return best_accu, best_f1
# else:
#     return best_accu



(1/1250) Batch:428.699s|Total:0:07:08|ETA:0:00:00|Loss:0.9181|Loss_ce:1.1076|Grad:7.9031|Grad0:6.0083|top1:0.2500|grad_ratio:1.3154
(2/1250) Batch:214.853s|Total:0:07:09|ETA:6 days, 4:44:05|Loss:0.7917|Loss_ce:1.1105|Grad:8.1347|Grad0:4.9461|top1:0.2500|grad_ratio:1.6446
(3/1250) Batch:143.549s|Total:0:07:10|ETA:3 days, 2:28:57|Loss:0.8885|Loss_ce:1.1148|Grad:8.2555|Grad0:5.9925|top1:0.2500|grad_ratio:1.3776
(4/1250) Batch:107.897s|Total:0:07:11|ETA:2 days, 1:43:27|Loss:3.5734|Loss_ce:1.1137|Grad:28.3100|Grad0:52.9069|top1:0.2500|grad_ratio:0.5351
(5/1250) Batch:86.506s|Total:0:07:12|ETA:1 day, 13:20:41|Loss:3.0354|Loss_ce:1.1335|Grad:24.3626|Grad0:43.3812|top1:0.2500|grad_ratio:0.5616
(6/1250) Batch:72.246s|Total:0:07:13|ETA:1 day, 5:55:00|Loss:2.6187|Loss_ce:1.1179|Grad:22.3881|Grad0:37.3960|top1:0.2500|grad_ratio:0.5987
(7/1250) Batch:62.059s|Total:0:07:14|ETA:1 day, 0:57:54|Loss:2.2929|Loss_ce:1.1293|Grad:21.2709|Grad0:32.9068|top1:0.2500|grad_ratio:0.6464
(8/1250) Batch:54.420s|T

KeyboardInterrupt: 

In [None]:
mode_result.append(
    config.config_name + '\t' + str(globals()[args.train_process](config, model, optimizer, scheduler, criterion, train_generator, dev_generator)))
for mode in mode_result:
    print(mode)

In [None]:
config.do_test

In [None]:
if config.do_test:
    model.load_state_dict(torch.load(config.model_save_path))
    test_generator = data.DataLoader(
        testset[config.dataset_test], **params_testset)
    predict(model, test_generator, config.output_test_file)