In [2]:
import config
from model import Albert
from dataset import CustomDataset, process_data, train_test_split

import os
import datetime
from time import time 
from shutil import copyfile

import pandas as pd
import torch
from torch.utils.data import DataLoader
from transformers import AlbertForSequenceClassification
from transformers import get_polynomial_decay_schedule_with_warmup

from torch.utils.tensorboard import SummaryWriter

os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES      # specify GPU usage    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()


def multi_acc(y_pred, y_test):
    y_pred_softmax = torch.log_softmax(y_pred, dim = 1)
    _, y_pred_tags = torch.max(y_pred_softmax, dim = 1)    
    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)
    return acc

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']



In [3]:
    
    ## Loading data
    print('Loading dataset...')
    if config.ALREADY_SPLIT:
        train_df = pd.read_csv(config.TRAIN_FILE) 
        val_df = pd.read_csv(config.VALIDATION_FILE)    
        print('Training set shape: '+ str(train_df.shape))
        print('Validaiton set shape: '+ str(val_df.shape))
        print('Loading finished.')
    else:
        data_df = process_data(config.INPUT_FILE, config.CLS2IDX, True)     # DataFrame, only used labeled data
        train_df, test_df = train_test_split(
            data_df, 
            test_size=config.TEST_SIZE, 
            shuffle=True, 
            random_state=config.RANDOM_STATE)
        train_df, val_df = train_test_split(
            train_df, 
            test_size=config.VALIDATION_SIZE, 
            shuffle=True, 
            random_state=config.RANDOM_STATE)  
        print('Training set shape: '+ str(train_df.shape))
        print('Validaiton set shape: '+ str(val_df.shape))
        print('Test set shape: '+ str(test_df.shape))
        print('Loading finished.')
        print('Saving training set & validation set & test set to local...')
        train_df.to_csv(config.TRAIN_FILE, index=False)
        val_df.to_csv(config.VALIDATION_FILE, index=False)
        test_df.to_csv(config.TEST_FILE, index=False)
        print('Saving finished.')
    

    ## Processing data
    print('Processing dataset...')
    train_set = CustomDataset(
        sentences=train_df[config.CONTENT_FIELD].values.astype("str"),
        labels=train_df[config.LABEL_FIELD]
    )
    val_set = CustomDataset(
        sentences=val_df[config.CONTENT_FIELD].values.astype("str"),
        labels=val_df[config.LABEL_FIELD]
    )
    train_dataloader = DataLoader(
        dataset=train_set, 
        batch_size=config.BATCH_SIZE, 
        shuffle=True, 
        num_workers=config.NUM_WORKERS,
        pin_memory=True
    )
    val_dataloader = DataLoader(
        dataset=val_set, 
        batch_size=config.BATCH_SIZE, 
        shuffle=False, 
        num_workers=config.NUM_WORKERS,
        pin_memory=True
    )
    print('Processing finished.')




    

Loading dataset...
Training set shape: (11391, 3)
Validaiton set shape: (2847, 3)
Loading finished.
Processing dataset...
Processing finished.


In [4]:
train_set[0]

{'token_ids': tensor([ 101, 5307, 5317, 3221,  679, 2100, 1762, 4638, 8024, 3187, 6389,  872,
         4500, 3227, 2544, 7262, 4692, 8024, 6820, 3221,  166, 1045, 4212, 8024,
         6963, 4692,  679, 6224,  511, 5307, 5317,  510, 4954,  855, 3221,  704,
         1744, 1367,  782, 1762,  679, 5543, 1059, 7481, 6371, 6399,  782,  860,
         4638, 2658, 1105,  678, 5621, 2682, 1139, 3341, 4638,  691, 6205, 8024,
         5445, 6821, 3315,  841, 4906, 2110, 4638,  741, 6820, 1762, 1920, 1297,
         8024, 1938, 1214, 1920, 2157, 6206,  676, 2590,  511,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0, 