In [1]:
from functools import partial
from tqdm import trange, tqdm

import pandas as pd

from sklearn.model_selection import KFold

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from transformers import AdamW, get_scheduler
from transformers import AutoTokenizer, AutoModel
from transformers import AutoModelForSequenceClassification



from datasets import *
from model import *
from losses import *
from metrics import *

In [2]:
def dataset_text2tensor(_path, _ckpt):
    # dataset label vectorize
    dataset = pd.read_csv(_path)
    print('dataset shape:{}'.format(dataset.shape))
    dataset, D_ids2token, B_ids2token = label_vectorize(dataset)
    n_classes = [len(D_ids2token), len(B_ids2token)]
    print('n_classes: ', n_classes)

    check_point = _ckpt
    tokenizer = AutoTokenizer.from_pretrained(check_point)
    # datset tensorize
    dataset['x_C'] = dataset['Context'].map(partial(tokenize_function, tokenizer))
    dataset['x_A'] = dataset['AST'].map(partial(tokenize_function, tokenizer))
    dataset['y'] = dataset['Dev_vec'] + dataset['Btype_vec']
    dataset['y'] = dataset['y'].map(tensor_func)

    dataset = TextCodeDataset(dataset)
    return dataset, n_classes, (D_ids2token, B_ids2token)
    
# todo: finish 'dataset_split_load' func
def dataset_split_load(dataset, trn_ids, tst_ids):
    # subsampler
    train_subsampler = torch.utils.data.SubsetRandomSampler(trn_ids)
    test_subsampler = torch.utils.data.SubsetRandomSampler(tst_ids)

    # warp dataset into dataloader
    trainloader = torch.utils.data.DataLoader(
                      dataset, 
                      batch_size=10, sampler=train_subsampler)
    testloader = torch.utils.data.DataLoader(
                      dataset,
                      batch_size=10, sampler=test_subsampler)


    return trainloader, testloader

In [3]:
def one_forward(model, loss_fn, x, y, device, train_loss, l):
    x_C = {k: v.to(device) for k, v in x[0].items()}
    x_A = {k: v.to(device) for k, v in x[1].items()}
    y = y.to(device)
                
    outputs = model(x_C, x_A)
    loss = loss_fn(outputs, y.float())
    train_loss += loss.item() / l

    return outputs, loss, train_loss

def one_backward(optimizer, loss, lr_scheduler):
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    lr_scheduler.step()

def update_metrics(y, outputs, n_classes, val_acc, val_f1, l):
    metric = metrics(y, outputs, split_pos = n_classes)
    val_acc[0] += metric['acc'][0]/l
    val_acc[1] += metric['acc'][1]/l
    val_f1[0] += metric['F1'][0]/l
    val_f1[1] += metric['F1'][1]/l
    return val_acc, val_f1

In [4]:
def train_imm(_path, _logname, _loss_fn, _use_ast = True, _is_textcnn = False, _num_epochs = 50, _bsz = 8,
              _lr = 3e-5, _ckpt = 'bert-base-uncased', _k_folds = 10, device = 'cuda' if torch.cuda.is_available() else 'cpu'):
    logstr = _logname + '\n' + '-'*60 + '\n'
    logname = '../res_log/' + _logname + '.txt'
    torch.manual_seed(42)
    
    # tensorize dataset
    dataset_tensor, n_classes, ids2token = dataset_text2tensor(_path, _ckpt)
    logstr += 'dataset shape:{}\n n_classes:{}\n'.format(len(dataset_tensor), n_classes) + '-'*60 + '\n'

    # K-Flod
    results = []
    kfold = KFold(n_splits=_k_folds, shuffle=True)

    for fold, (trn_ids, tst_ids) in enumerate(kfold.split(dataset_tensor)):
        print('-'*40 + '  FOLD{}  '.format(fold) + '_'*40)

        # split train/test & wrap into dataloader According to (trn_ids, tst_ids)
        train_dataloader, val_dataloader = dataset_split_load(dataset_tensor, trn_ids, tst_ids)
        

        # model
        if _is_textcnn:
            model = MetaModel(n_classes = n_classes, use_AST=_use_ast)
        else:
            model = PretrainModel(text_ckpt=_ckpt, code_ckpt=_ckpt, n_classes=n_classes, use_AST=_use_ast)
        model = model.to(device)
        
        # loss, optimizer, lr_scheduler
        loss_fn = _loss_fn.to(device)
        # TODO: compare different optimizer
        optimizer = torch.optim.AdamW(model.parameters(), lr=_lr)
        lr_scheduler = get_scheduler(
            "linear",
            optimizer=optimizer,
            num_warmup_steps=0,
            num_training_steps=_num_epochs * len(train_dataloader),
        )

        # train process
        for epoch in trange(_num_epochs):  
            # train
            model.train()
            train_loss = 0.0
            for x, y in train_dataloader:
                # forward
                outputs, loss, train_loss = one_forward(model, loss_fn, x, y, device, train_loss, len(train_dataloader))             
                # backward
                one_backward(optimizer, loss, lr_scheduler)
            # logstr += '{}th epoch\n train_loss: {}\n'.format(epoch, train_loss)
            # print('{}th epoch\n train_loss: {}\n'.format(epoch, train_loss))

            # val
            model.eval()
            val_loss, val_acc, val_f1 = 0.0, [0.0, 0.0], [0.0, 0.0]
            for x, y in val_dataloader:
                # forward
                outputs, loss, val_loss = one_forward(model, loss_fn, x, y, device, val_loss, len(val_dataloader)) 
                # update metric
                val_acc, val_f1 = update_metrics(y.to(device), outputs, n_classes, val_acc, val_f1, len(val_dataloader))
            # logstr += '{}th epoch\n val_loss: {}\n val_acc:{}\n val_f1: {}\n'.format(epoch, val_loss, val_acc, val_f1)
            # print('{}th epoch\n val_loss: {}\n val_acc:{}\n val_f1: {}'.format(epoch, val_loss, val_acc, val_f1))

        # update n-Fold result
        results.append([val_loss, val_acc[0], val_acc[1], val_f1[0], val_f1[1]])
        tmpstr = '-'*30 + '{}TH FOLD RESULT'.format(fold) + '-'*30 + \
              '\n val_loss: {}\n val_acc:{}\n val_f1: {}\n'.format(results[-1][0], results[-1][1:3], results[-1][3:])
        logstr += tmpstr
        print(tmpstr)

    ava_result = torch.mean(torch.tensor(results),dim=0)
    tmp_str = '-'*30 + '10-FOLD AVA-RESULT' + '-'*30 + \
              '\n val_loss: {}\n val_acc:{}\n val_f1: {}\n'.format(ava_result[0], ava_result[1:3], ava_result[3:])
    logstr += tmp_str
    print(tmp_str)
        
    
    with open(logname, 'w') as f:
        f.write(logstr)

In [5]:
# 8 datasets
pathlist = [
    ('../Data/aspnet/aspnet_2.csv', 'aspnet'),
    ('../Data/efcore/efcore_2.csv', 'efcore'),
    ('../Data/elasticSearch/elasticSearch_2.csv', 'elasticSearch'),
    # ('../Data/mixedRealityToolUnity/mixedRealityToolUnity_2.csv', 'mixedRealityToolUnity'),
    # ('../Data/monoGame/monoGame_2.csv', 'monoGame'),
    ('../Data/powershell/powerShell_2.csv', 'powerShell'),
    ('../Data/realmJava/realmJava_2.csv', 'realmJava'),
    ('../Data/roslyn/roslyn_2.csv', 'roslyn'),
]
losslist = [
    (nn.BCEWithLogitsLoss(), 'BCE'),
    # (CustomizedBCELoss(), 'CBCE'),
    # (AsymmetricLossOptimized(), 'ASL'),
]

ckptlist = [
    ('bert-base-uncased', 'Multi-triage'),  # just for tokenize
    # ('bert-base-uncased', ' Bert'),
    # ('roberta-base', 'Robert'),
]

astlist = [
    'no_AST',
    'use_AST',
]

In [6]:
for path in pathlist:
    for ckpt in ckptlist:
        for loss in losslist:
            for ast in astlist:
                is_t, u_ast = (ckpt[1] == 'Multi-triage'), (ast == 'use_AST') 

                logname = ' '.join([path[1], ckpt[1], loss[1], ast])
                print('-'*100, logname, '-'*100, sep='\n')
            
                train_imm(_path = path[0], _logname = '../res_log/' + logname + '.txt', 
                      _loss_fn = loss[0], _use_ast = u_ast, _is_textcnn = is_t, _ckpt = ckpt[0])

----------------------------------------------------------------------------------------------------
aspnet Multi-triage BCE no_AST
----------------------------------------------------------------------------------------------------
dataset shape:(1006, 7)
n_classes:  [32, 88]
----------------------------------------  FOLD0  ________________________________________


  return F.conv1d(input, weight, bias, self.stride,
100%|██████████| 50/50 [00:13<00:00,  3.59it/s]


------------------------------0TH FOLD RESULT------------------------------
 val_loss: 0.2643410915678198
 val_acc:[0.9869318387725138, 0.9790289185263894]
 val_f1: [0.790907836639731, 0.006060561984218199]

----------------------------------------  FOLD1  ________________________________________


100%|██████████| 50/50 [00:12<00:00,  3.89it/s]


------------------------------1TH FOLD RESULT------------------------------
 val_loss: 0.2595035135746002
 val_acc:[0.9877841039137408, 0.9776859716935591]
 val_f1: [0.8090896374820052, 0.021211980891983243]

----------------------------------------  FOLD2  ________________________________________


100%|██████████| 50/50 [00:13<00:00,  3.77it/s]


------------------------------2TH FOLD RESULT------------------------------
 val_loss: 0.26049606095660816
 val_acc:[0.9832386482845653, 0.9776859662749551]
 val_f1: [0.7225623402825693, 0.0]

----------------------------------------  FOLD3  ________________________________________


100%|██████████| 50/50 [00:13<00:00,  3.76it/s]


------------------------------3TH FOLD RESULT------------------------------
 val_loss: 0.25390764732252463
 val_acc:[0.9781250140883706, 0.978719023140994]
 val_f1: [0.6545443757771479, 0.004545418405879658]

----------------------------------------  FOLD4  ________________________________________


100%|██████████| 50/50 [00:12<00:00,  3.94it/s]


------------------------------4TH FOLD RESULT------------------------------
 val_loss: 0.26582675088535657
 val_acc:[0.9747159264304421, 0.9739669669758191]
 val_f1: [0.59760663416215, 0.008181751651088625]

----------------------------------------  FOLD5  ________________________________________


100%|██████████| 50/50 [00:13<00:00,  3.82it/s]


------------------------------5TH FOLD RESULT------------------------------
 val_loss: 0.30496192384849896
 val_acc:[0.9880681850693445, 0.9776859879493716]
 val_f1: [0.8090896374820074, 0.0036363332452089667]

----------------------------------------  FOLD6  ________________________________________


100%|██████████| 50/50 [00:13<00:00,  3.80it/s]


------------------------------6TH FOLD RESULT------------------------------
 val_loss: 0.2951791480183601
 val_acc:[0.9868750214576721, 0.9753409087657928]
 val_f1: [0.787418107125055, 0.013999907033236697]

----------------------------------------  FOLD7  ________________________________________


100%|██████████| 50/50 [00:12<00:00,  4.02it/s]


------------------------------7TH FOLD RESULT------------------------------
 val_loss: 0.291102945804596
 val_acc:[0.9800000250339507, 0.9765909314155578]
 val_f1: [0.6799988515018538, 0.0]

----------------------------------------  FOLD8  ________________________________________


100%|██████████| 50/50 [00:13<00:00,  3.72it/s]


------------------------------8TH FOLD RESULT------------------------------
 val_loss: 0.26519374400377277
 val_acc:[0.986250013113022, 0.9760227620601652]
 val_f1: [0.7874273236722782, 0.0]

----------------------------------------  FOLD9  ________________________________________


100%|██████████| 50/50 [00:13<00:00,  3.80it/s]


------------------------------9TH FOLD RESULT------------------------------
 val_loss: 0.2668348804116249
 val_acc:[0.9809375107288362, 0.9777272820472717]
 val_f1: [0.7148375344097367, 0.0]

------------------------------10-FOLD AVA-RESULT------------------------------
 val_loss: 0.27273476123809814
 val_acc:tensor([0.9833, 0.9770])
 val_f1: tensor([0.7353, 0.0058])

----------------------------------------------------------------------------------------------------
aspnet Multi-triage BCE use_AST
----------------------------------------------------------------------------------------------------
dataset shape:(1006, 7)
n_classes:  [32, 88]
----------------------------------------  FOLD0  ________________________________________


100%|██████████| 50/50 [00:24<00:00,  2.02it/s]


------------------------------0TH FOLD RESULT------------------------------
 val_loss: 0.26347632028839807
 val_acc:[0.9838068214329808, 0.9745867848396301]
 val_f1: [0.7454533345340496, 0.013636273372704052]

----------------------------------------  FOLD1  ________________________________________


100%|██████████| 50/50 [00:23<00:00,  2.10it/s]


------------------------------1TH FOLD RESULT------------------------------
 val_loss: 0.26419164782220667
 val_acc:[0.9792613712224094, 0.9779958833347668]
 val_f1: [0.6430819391516377, 0.009090854966824395]

----------------------------------------  FOLD2  ________________________________________


 26%|██▌       | 13/50 [00:06<00:19,  1.92it/s]


KeyboardInterrupt: 