In [1]:
import pandas as pd
from functools import partial

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_scheduler
from transformers import AutoTokenizer, AutoModel
from transformers import AutoModelForSequenceClassification
from tqdm import trange, tqdm

from model import *

In [2]:
# dataset utils
def get_map(labels):
    l_set = set()
    for label in labels:
        l_set.update(label)
    ids2token = list(l_set)
    token2ids = {ids2token[i] : i for i in range(len(ids2token))}
    return ids2token, token2ids

def onehot(labels, token2ids):
    vec = [0 for i in token2ids]
    for label in labels.split('| '):
        vec[token2ids[label]] = 1
    return vec

def lab(labels, token2ids):
    return [token2ids[label] for label in labels.split('| ')]

def label_vectorize(data):
    data = data.rename(columns={'Title_Description' : 'Context', 'AST' : 'AST', 'FixedByID' : 'Dev', 'Name' : 'Btype'})
    data = data[['Context', 'AST', 'Dev', 'Btype']]
    # avoid NaN in dataset
    data['Context'].fillna('[UNK]', inplace=True)
    data['AST'].fillna('[UNK]', inplace=True)
    data['Dev'].fillna('unknown', inplace=True)
    data['Btype'].fillna('unknown', inplace=True)
    
    D_labels = [label.split('| ') for label in data['Dev']]
    _D_ids2token, D_token2ids = get_map(D_labels)
    data['Dev_l'] = data['Dev'].map(partial(lab, token2ids = D_token2ids))
    data['Dev_vec'] = data['Dev'].map(partial(onehot, token2ids = D_token2ids))
    
    B_labels = [label.split('| ') for label in data['Btype']]
    _B_ids2token, B_token2ids = get_map(B_labels)
    data['Btype_l'] = data['Btype'].map(partial(lab, token2ids = B_token2ids))
    data['Btype_vec'] = data['Btype'].map(partial(onehot, token2ids = B_token2ids))
    
    return data, _D_ids2token, _B_ids2token

def tokenize_function(_tokenizer, example, max_seq_len = 512):
    example = example if type(example) == str else _tokenizer.unk_token
    return _tokenizer(example, padding='max_length',
                                truncation=True, max_length=max_seq_len, return_tensors="pt")

def tensor_func(example):
    return torch.tensor(example)

class TextCodeDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        
    def __getitem__(self, item):
        return (self.data['x_C'][item], self.data['x_A'][item]), self.data['y'][item]
    
    def __len__(self):
        return len(self.data)

In [3]:
# loss & metrics
class CustomizedBCELoss(nn.Module):
    """
    a flexible version of BCE,
    which enable the loss to focus more on the performance of positive samples' prediction
    """

    def __init__(self, weight_pos=0.8, weight_neg=0.2, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weight_pos = weight_pos
        self.weight_neg = weight_neg

    def forward(self, x, y):
        x = nn.Sigmoid()(x)
        loss_pos = y * torch.log(x)
        loss_neg = (1 - y) * torch.log(1 - x)
        # loss = 0.8*loss_pos + 0.2*loss_neg
        loss = self.weight_pos * loss_pos + self.weight_neg * loss_neg
        return -torch.sum(loss)

class AsymmetricLossOptimized(nn.Module):
    """
    AsymmetricLoss from https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py

    Notice - optimized version, minimizes memory allocation and gpu uploading,
    favors inplace operations
    """

    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
        super(AsymmetricLossOptimized, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None

    def forward(self, x, y):
        """
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        self.targets = y
        self.anti_targets = 1 - y

        # Calculating Probabilities
        self.xs_pos = torch.sigmoid(x)
        self.xs_neg = 1.0 - self.xs_pos

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            self.xs_neg.add_(self.clip).clamp_(max=1)

        # Basic CE calculation
        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
            self.xs_pos = self.xs_pos * self.targets
            self.xs_neg = self.xs_neg * self.anti_targets
            self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
                                          self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            self.loss *= self.asymmetric_w

        return -self.loss.sum()

def metrics(y: torch.Tensor, pred: torch.Tensor, split_pos: list, threshold: float = 0.5, from_logits=True):
    if from_logits:
        pred = nn.Sigmoid()(pred)
    pred = torch.where(pred > threshold, 1, 0)

    y_d, y_b = torch.split(y, split_pos, dim=1)
    pred_d, pred_b = torch.split(pred, split_pos, dim=1)

    TPd, TPb = torch.sum(y_d * pred_d, dim=1), torch.sum(y_b * pred_b, dim=1)
    TNd, TNb = torch.sum((1 - y_d) * (1 - pred_d), dim=1), torch.sum((1 - y_b) * (1 - pred_b), dim=1)
    FPd, FPb = torch.sum((1 - y_d) * pred_d, dim=1), torch.sum((1 - y_b) * pred_b, dim=1)
    FNd, FNb = torch.sum(y_d * (1 - pred_d), dim=1), torch.sum(y_b * (1 - pred_b), dim=1)

    acc = torch.mean((TPd + TNd) / (TPd + TNd + FPd + FNd + 1e-6)).item(), torch.mean(
        (TPb + TNb) / (TPb + TNb + FPb + FNb + 1e-6)).item()
    recall = torch.mean(TPd / (TPd + FNd + 1e-6)).item(), torch.mean(TPb / (TPb + FNb + 1e-6)).item()
    precision = torch.mean(TPd / (TPd + FPd + 1e-6)).item(), torch.mean(TPb / (TPb + FPb + 1e-6)).item()
    F1 = 2 * recall[0] * precision[0] / (recall[0] + precision[0] + 1e-6), 2 * recall[1] * precision[1] / (
            recall[1] + precision[1] + 1e-6)

    return {
        'acc': acc,
        'precision': precision,
        'recall': recall,
        'F1': F1
    }

In [4]:
def train_imm(_path, _logname, _loss_fn, _use_ast = True, _is_textcnn = False, _num_epochs = 100, _bsz = 8,
              _lr = 3e-5, _ckpt = 'bert-base-uncased', device = 'cuda' if torch.cuda.is_available() else 'cpu'):
    logname = '../res_log/' + _logname + '.txt'
    logstr = _logname + '\n' + '-'*60 + '\n'
    
    # dataset label vectorize
    dataset = pd.read_csv(_path)
    logstr += 'dataset shape:{}\n'.format(dataset.shape)
    print('dataset shape:{}'.format(dataset.shape))
    dataset, D_ids2token, B_ids2token = label_vectorize(dataset)
    n_classes = [len(D_ids2token), len(B_ids2token)]
    logstr += 'n_classes:{}\n'.format(n_classes) + '-'*60 + '\n'
    print('n_classes: ', n_classes)

    check_point = _ckpt
    tokenizer = AutoTokenizer.from_pretrained(check_point)
    # datset tensorize
    dataset['x_C'] = dataset['Context'].map(partial(tokenize_function, tokenizer))
    dataset['x_A'] = dataset['AST'].map(partial(tokenize_function, tokenizer))
    dataset['y'] = dataset['Dev_vec'] + dataset['Btype_vec']
    dataset['y'] = dataset['y'].map(tensor_func)

    # split datset
    t_dataset = dataset[:int(0.8*len(dataset))].reset_index(drop=True)
    train_dataset = t_dataset.sample(frac=0.8,random_state=0,axis=0).reset_index(drop=True)
    val_dataset = t_dataset[~t_dataset.index.isin(train_dataset.index)].reset_index(drop=True)
    test_dataset = dataset[int(0.8*len(dataset)):].reset_index(drop=True)

    # wrap dataset & dataloader
    train_dataset = TextCodeDataset(train_dataset)
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=_bsz, drop_last=True)
    val_dataset = TextCodeDataset(val_dataset)
    val_dataloader = DataLoader(val_dataset, shuffle=True, batch_size=_bsz, drop_last=True)
    test_dataset = TextCodeDataset(test_dataset)
    test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=_bsz, drop_last=True)

    # load model
    if _is_textcnn:
        model = MetaModel(n_classes = n_classes, use_AST=_use_ast)
    else:
        # TODO: seperate ckpt
        model = PretrainModel(text_ckpt=_ckpt, code_ckpt=_ckpt, n_classes=n_classes, use_AST=_use_ast)
    model = model.to(device)

    # loss
    loss_fn = _loss_fn.to(device)

    # optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=_lr)

    # lr_scheduler
    num_epochs = _num_epochs
    num_training_steps = num_epochs * len(train_dataloader)
    lr_scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )

    # train process
    val_loss_min, val_down = 100000.0, [1, 1, 1, 1, 1]
    for epoch in trange(num_epochs):
        # train
        model.train()
        train_loss = 0.0
        for x, y in train_dataloader:
            x_C = {k: v.to(device) for k, v in x[0].items()}
            x_A = {k: v.to(device) for k, v in x[1].items()}
            y = y.to(device)
            
            outputs = model(x_C, x_A)

            loss = loss_fn(outputs, y.float())
            train_loss += loss.item()/len(train_dataloader)
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()
        logstr += '{}th epoch\n train_loss: {}\n'.format(epoch, train_loss)
        # print('{}th epoch\n train_loss: {}\n'.format(epoch, train_loss))
    
        # val
        model.eval()
        val_loss, val_acc, val_f1 = 0.0, [0.0, 0.0], [0.0, 0.0]
        for x, y in val_dataloader:
            x_C = {k: v.to(device) for k, v in x[0].items()}
            x_A = {k: v.to(device) for k, v in x[1].items()}
            y = y.to(device)

            outputs = model(x_C, x_A)
        
            loss = loss_fn(outputs, y.float())
            val_loss += loss.item()/len(val_dataloader)
            metric = metrics(y, outputs, split_pos = n_classes)
            val_acc[0] += metric['acc'][0]/len(val_dataloader)
            val_acc[1] += metric['acc'][1]/len(val_dataloader)
            val_f1[0] += metric['F1'][0]/len(val_dataloader)
            val_f1[1] += metric['F1'][1]/len(val_dataloader)
        logstr += '{}th epoch\n val_loss: {}\n val_acc:{}\n val_f1: {}\n'.format(epoch, val_loss, val_acc, val_f1)
        # print('{}th epoch\n val_loss: {}\n val_acc:{}\n val_f1: {}'.format(epoch, val_loss, val_acc, val_f1))

        val_down.append(1 if val_loss_min - val_loss > 1e-10 else 0)
        val_loss_min = min(val_loss_min, val_loss)
        if val_down[-1] + val_down[-2] + val_down[-3] + val_down[-4] + val_down[-5] == 0:
            break

    # test
    model.eval()
    test_loss, test_acc, test_f1 = 0.0, [0.0, 0.0], [0.0, 0.0]
    for x, y in tqdm(test_dataloader):
        x_C = {k: v.to(device) for k, v in x[0].items()}
        x_A = {k: v.to(device) for k, v in x[1].items()}
        y = y.to(device)
            
        outputs = model(x_C, x_A)
                
        loss = loss_fn(outputs, y.float())
        test_loss += loss.item()/len(test_dataloader)
        metric = metrics(y, outputs, split_pos = n_classes)
        test_acc[0] += metric['acc'][0]/len(test_dataloader)
        test_acc[1] += metric['acc'][1]/len(test_dataloader)
        test_f1[0] += metric['F1'][0]/len(test_dataloader)
        test_f1[1] += metric['F1'][1]/len(test_dataloader)
    logstr += '-' * 60 + '\ntest_loss: {}\n test_acc:{}\n test_f1: {}'.format(test_loss, test_acc, test_f1)
    print('test_loss: {}\n test_acc:{}\n test_f1: {}'.format(test_loss, test_acc, test_f1))

    with open(logname, 'w') as f:
        f.write(logstr)

In [5]:
# 8 datasets
pathlist = [
    ('../Data/aspnet/aspnet_2.csv', 'aspnet'),
    ('../Data/efcore/efcore_2.csv', 'efcore'),
    ('../Data/elasticSearch/elasticSearch_2.csv', 'elasticSearch'),
    # ('../Data/mixedRealityToolUnity/mixedRealityToolUnity_2.csv', 'mixedRealityToolUnity'),
    # ('../Data/monoGame/monoGame_2.csv', 'monoGame'),
    ('../Data/powershell/powerShell_2.csv', 'powerShell'),
    ('../Data/realmJava/realmJava_2.csv', 'realmJava'),
    ('../Data/roslyn/roslyn_2.csv', 'roslyn'),
]
losslist = [
    # (nn.BCEWithLogitsLoss(), 'BCE'),
    # (CustomizedBCELoss(), 'CBCE'),
    (AsymmetricLossOptimized(), 'ASL'),
]

ckptlist = [
    ('bert-base-uncased', 'Multi-triage'),  # just for tokenize
    # ('bert-base-uncased', ' Bert'),
]

In [6]:
for path in pathlist:
    for ckpt in ckptlist:
        for loss in losslist:
            is_t = (ckpt[1] == 'Multi-triage')       
            logname = ' '.join([path[1], ckpt[1], loss[1], 'no_AST'])
            print('-'*100, logname, '-'*100, sep='\n')
            train_imm(_path = path[0], _logname = logname, _loss_fn = loss[0], 
                      _use_ast = False, _is_textcnn = is_t, _ckpt = ckpt[0])
            
            logname = ' '.join([path[1], ckpt[1], loss[1], 'use_AST'])
            print('-'*100, logname, '-'*100, sep='\n')
            train_imm(_path = path[0], _logname = logname, _loss_fn = loss[0], _is_textcnn = is_t, _ckpt = ckpt[0])

----------------------------------------------------------------------------------------------------
aspnet Multi-triage ASL no_AST
----------------------------------------------------------------------------------------------------
dataset shape:(1006, 7)
n_classes:  [32, 88]


  return F.conv1d(input, weight, bias, self.stride,
100%|██████████| 100/100 [00:30<00:00,  3.24it/s]
100%|██████████| 25/25 [00:00<00:00, 393.79it/s]


test_loss: 14.834178695678714
 test_acc:[0.9817187500000001, 0.9507386350631711]
 test_f1: [0.7364400908384727, 0.2955426034800301]
----------------------------------------------------------------------------------------------------
aspnet Multi-triage ASL use_AST
----------------------------------------------------------------------------------------------------
dataset shape:(1006, 7)
n_classes:  [32, 88]


 98%|█████████▊| 98/100 [00:51<00:01,  1.89it/s]
100%|██████████| 25/25 [00:00<00:00, 345.64it/s]


test_loss: 14.474770202636716
 test_acc:[0.9801562500000001, 0.9521590852737424]
 test_f1: [0.7231755649094918, 0.30623835983731007]
----------------------------------------------------------------------------------------------------
efcore Multi-triage ASL no_AST
----------------------------------------------------------------------------------------------------
dataset shape:(1740, 7)
n_classes:  [15, 49]


 59%|█████▉    | 59/100 [00:31<00:21,  1.89it/s]
100%|██████████| 43/43 [00:00<00:00, 478.10it/s]


test_loss: 10.691873993984489
 test_acc:[0.9281007190083348, 0.9319530462109764]
 test_f1: [0.6238718662398647, 0.38134451645918155]
----------------------------------------------------------------------------------------------------
efcore Multi-triage ASL use_AST
----------------------------------------------------------------------------------------------------
dataset shape:(1740, 7)
n_classes:  [15, 49]


 80%|████████  | 80/100 [01:09<00:17,  1.16it/s]
100%|██████████| 43/43 [00:00<00:00, 346.76it/s]


test_loss: 10.386687378550684
 test_acc:[0.9255813429521959, 0.9287494379420611]
 test_f1: [0.6177459288151886, 0.3993598219173758]
----------------------------------------------------------------------------------------------------
elasticSearch Multi-triage ASL no_AST
----------------------------------------------------------------------------------------------------
dataset shape:(1634, 7)
n_classes:  [86, 188]


 82%|████████▏ | 82/100 [00:42<00:09,  1.91it/s]
100%|██████████| 40/40 [00:00<00:00, 399.59it/s]


test_loss: 17.80180823802948
 test_acc:[0.9868822619318962, 0.9843915998935694]
 test_f1: [0.44955986896692823, 0.3866683318388676]
----------------------------------------------------------------------------------------------------
elasticSearch Multi-triage ASL use_AST
----------------------------------------------------------------------------------------------------
dataset shape:(1634, 7)
n_classes:  [86, 188]


 81%|████████  | 81/100 [01:13<00:17,  1.10it/s]
100%|██████████| 40/40 [00:00<00:00, 322.57it/s]


test_loss: 18.001556611061098
 test_acc:[0.9870276063680653, 0.9843749806284899]
 test_f1: [0.4381392580493889, 0.3586912095279595]
----------------------------------------------------------------------------------------------------
powerShell Multi-triage ASL no_AST
----------------------------------------------------------------------------------------------------
dataset shape:(312, 7)
n_classes:  [144, 90]


 85%|████████▌ | 85/100 [00:08<00:01, 10.26it/s]
100%|██████████| 7/7 [00:00<00:00, 357.09it/s]


test_loss: 36.66150883265904
 test_acc:[0.865823405129569, 0.8331349066325597]
 test_f1: [0.019181429644476063, 0.11382812210539411]
----------------------------------------------------------------------------------------------------
powerShell Multi-triage ASL use_AST
----------------------------------------------------------------------------------------------------
dataset shape:(312, 7)
n_classes:  [144, 90]


  6%|▌         | 6/100 [00:01<00:15,  5.89it/s]
100%|██████████| 7/7 [00:00<00:00, 347.15it/s]


test_loss: 62.55499485560826
 test_acc:[0.5307539531162806, 0.5418650848524912]
 test_f1: [0.021118453230275367, 0.04941333283908968]
----------------------------------------------------------------------------------------------------
realmJava Multi-triage ASL no_AST
----------------------------------------------------------------------------------------------------
dataset shape:(340, 7)
n_classes:  [11, 18]


 81%|████████  | 81/100 [00:09<00:02,  8.87it/s]
100%|██████████| 8/8 [00:00<00:00, 381.52it/s]


test_loss: 8.333059787750244
 test_acc:[0.9176135584712029, 0.760416567325592]
 test_f1: [0.7455773750184688, 0.2459116555090667]
----------------------------------------------------------------------------------------------------
realmJava Multi-triage ASL use_AST
----------------------------------------------------------------------------------------------------
dataset shape:(340, 7)
n_classes:  [11, 18]


 73%|███████▎  | 73/100 [00:13<00:05,  5.31it/s]
100%|██████████| 8/8 [00:00<00:00, 317.84it/s]


test_loss: 8.750559389591217
 test_acc:[0.8778408244252205, 0.762152686715126]
 test_f1: [0.6599005121643212, 0.2394443346822738]
----------------------------------------------------------------------------------------------------
roslyn Multi-triage ASL no_AST
----------------------------------------------------------------------------------------------------
dataset shape:(1316, 7)
n_classes:  [58, 97]


 93%|█████████▎| 93/100 [00:38<00:02,  2.44it/s]
100%|██████████| 33/33 [00:00<00:00, 460.95it/s]


test_loss: 18.020436460321598
 test_acc:[0.9799503720167914, 0.9527881849895822]
 test_f1: [0.4839983090578555, 0.39450672330491543]
----------------------------------------------------------------------------------------------------
roslyn Multi-triage ASL use_AST
----------------------------------------------------------------------------------------------------
dataset shape:(1316, 7)
n_classes:  [58, 97]


 75%|███████▌  | 75/100 [00:50<00:16,  1.48it/s]
100%|██████████| 33/33 [00:00<00:00, 335.80it/s]


test_loss: 18.356628475767195
 test_acc:[0.9793625835216407, 0.9517728866952839]
 test_f1: [0.4942357072722464, 0.37503145194570287]
