In [1]:
import os
import glob
import pickle
import argparse
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler

from transformers import AutoTokenizer
from transformers import GPT2TokenizerFast
from transformers import BertTokenizerFast
from transformers import get_cosine_schedule_with_warmup
from transformers import get_linear_schedule_with_warmup

from dataset import *
from learning import *
from model import *
from utils import *

import warnings
warnings.filterwarnings(action='ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
is_full = True
is_avgpool = True

# Define project
project_name = f'NIA_119-GPT_{"full" if is_full else "end"}'
model_name = 'only_Text_GPT_{}_{}'.format("full" if is_full else "end", 'avgpool' if is_avgpool else 'last')
model_link = "kykim/gpt3-kor-small_based_on_gpt2" # 'skt/kogpt2-base-v2' #'beomi/kcbert-base'

# args
epochs = 15
batch_size = 12
lr = 1e-5

class_num = 2
speaker_num = 4
max_length = 768
padding = 'max_length'
save_term = 960

# dataset
if is_full:
    train_path = os.path.join('..', 'NIA_text_dataset', 'train_json_audio_data_decoder.csv')
    valid_path = os.path.join('..', 'NIA_text_dataset', 'valid_json_audio_data_decoder.csv')
    test_path = os.path.join('..', 'NIA_text_dataset', 'test_json_audio_data_decoder.csv')
    # train_path = os.path.join('..', 'NIA_text_dataset', 'toy_data_json_audio_data_decoder.csv')
    # valid_path = os.path.join('..', 'NIA_text_dataset', 'toy_data_json_audio_data_decoder.csv')
    # test_path = os.path.join('..', 'NIA_text_dataset', 'toy_data_json_audio_data_decoder.csv')
else:
    train_path = os.path.join('..', 'NIA_text_dataset', 'train_json_audio_end_data_decoder.csv')
    valid_path = os.path.join('..', 'NIA_text_dataset', 'valid_json_audio_end_data_decoder.csv')
    test_path = os.path.join('..', 'NIA_text_dataset', 'test_json_audio_end_data_decoder.csv')
    # train_path = os.path.join('..', 'NIA_text_dataset', 'toy_data_json_audio_end_data_decoder.csv')
    # valid_path = os.path.join('..', 'NIA_text_dataset', 'toy_data_json_audio_end_data_decoder.csv')
    # test_path = os.path.join('..', 'NIA_text_dataset', 'toy_data_json_audio_end_data_decoder.csv')
    
save_path = set_save_path(model_name, epochs, batch_size)

train_data = pd.read_csv(train_path)
valid_data = pd.read_csv(valid_path)
test_data = pd.read_csv(test_path)
valid_file_ids = valid_data.id
test_file_ids = test_data.id

test_file_ids

0        64f6b358446b19d68e338111
1        64f6bcd53d12bbf07dab09c7
2        6551f6854f4d810f409c6ba7
3        651e4a0693f80b92304ff526
4        651e50c72f06ed4a6e31e075
                   ...           
31790    651e4d9be84a30cdf982d1a9
31791    651e4f6d06957f3443ecf8f5
31792    6551f6af69cfa8bc64e69a1c
31793    651e4c5b48bae90750379aeb
31794    64ec27f62be6ffc66cc6c982
Name: id, Length: 31795, dtype: object

In [3]:
len(test_file_ids.unique()) == len(test_file_ids)

True

In [4]:
end_avgpool_path  = os.path.join('models', 'only_Text_GPT_end_avgpool_e15_bs12', 'result_file_checkpoint_4_2826.csv')
end_last_path     = os.path.join('models', 'only_Text_GPT_end_last_e15_bs12', 'result_file_checkpoint_5_4239.csv')
full_avgpool_path = os.path.join('models', 'only_Text_GPT_full_avgpool_e15_bs12', 'result_file_checkpoint_5_2826.csv')
full_last_path    = os.path.join('models', 'only_Text_GPT_full_last_e15_bs12', 'result_file_checkpoint_5_4239.csv')

end_avgpool  = pd.read_csv(end_avgpool_path)
end_last     = pd.read_csv(end_last_path)
full_avgpool = pd.read_csv(full_avgpool_path)
full_last    = pd.read_csv(full_last_path)

full_avgpool.id

0          64f6b358446b19d68e338111
1          64f6b358446b19d68e338111
2          64f6b358446b19d68e338111
3          64f6b358446b19d68e338111
4          64f6b358446b19d68e338111
                     ...           
7499641    651e545d85c12f9c1b91f581
7499642    651e545d85c12f9c1b91f581
7499643    651e545d85c12f9c1b91f581
7499644    651e545d85c12f9c1b91f581
7499645    651e545d85c12f9c1b91f581
Name: id, Length: 7499646, dtype: object

In [6]:
len(full_avgpool.id.unique())

12