In [1]:
pwd

'/home/yjkim/new'

In [2]:

%%capture

# Package Import and Installation
import torch
import torch.nn as nn
import pandas as pd
import os

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from tqdm.notebook import tqdm


!pip3 install kobert-transformers
!pip3 install git+https://git@github.com/SKTBrain/KoBERT.git@master


In [3]:
from kobert_transformers import get_distilkobert_model, get_tokenizer

# Import Data


# CUDA
print("CUDA AVAILABLE: ", torch.cuda.is_available())
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

tokenizer = get_tokenizer()
KoBERT    = get_distilkobert_model()

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


CUDA AVAILABLE:  True


In [4]:
# Need PAD_ID for data preprocessing
PAD_ID = tokenizer.pad_token_id
SEP_ID = tokenizer.sep_token_id
CLS_ID = tokenizer.cls_token_id

# Hyperparameters
hyperparams = {'num_epochs': 3,
               'batch_size': 32,
               'seq_len'   : 200,
               'learn_rate': 2e-5}

In [5]:
# Demonstration of KoBERT Tokenizer
ex_sentence = "본 모델은 영화 리뷰를 입력으로 하여 평점을 예측합니다."
ex_tokenize = tokenizer.tokenize(ex_sentence)
tokenized   = tokenizer.convert_tokens_to_ids(ex_tokenize)
untokenized = tokenizer.convert_ids_to_tokens(tokenized)
print(tokenized, '\n', untokenized)

[2408, 2046, 7086, 3394, 1900, 6431, 6116, 3840, 7078, 517, 7815, 4841, 7223, 3415, 7843, 54] 
 ['▁본', '▁모델', '은', '▁영화', '▁리', '뷰', '를', '▁입력', '으로', '▁', '하여', '▁평', '점을', '▁예측', '합니다', '.']


In [6]:
tokenizer.vocab_size

8002

In [7]:
class MovieDataset(Dataset):
    def __init__(self, tokenizer, input_file, label_file, seq_len=64):
        super().__init__()
        
        self.label_file = label_file
        self.seq_len = seq_len
        self.inputs  = open(input_file, 'r').read().splitlines()
        if label_file is not None:
            self.labels = open(label_file, 'r').read().splitlines()

        

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        # Convert to tokens then augment tags
        tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(self.inputs[idx]))
        padded = [CLS_ID] + tokens[:self.seq_len-2] + [SEP_ID] + [PAD_ID] * (self.seq_len - len(tokens) - 2)
        length = len(tokens[:self.seq_len-2])
        if self.label_file is not None:
            labels = int(self.labels[idx])
            return {'inputs': torch.tensor(padded),
                    'labels': torch.tensor(labels),
                    'length': torch.tensor(length)}
        else:
            return {'inputs': torch.tensor(padded),
                    'length': torch.tensor(length)}

In [8]:
train_ds = MovieDataset(tokenizer, 'train_data', 'train_label', seq_len=hyperparams['seq_len'])
valid_ds = MovieDataset(tokenizer, 'valid_data', 'valid_label', seq_len=hyperparams['seq_len'])
test_ds  = MovieDataset(tokenizer, 'test_data', None, seq_len=hyperparams['seq_len'])

train_loader = DataLoader(train_ds, batch_size=hyperparams['batch_size'], shuffle=True)
valid_loader = DataLoader(valid_ds, batch_size=128)
test_loader  = DataLoader(test_ds,  batch_size=128)

In [8]:
train_ds.inputs[1]

'불멸의 명작. 영화인을 꿈꾸는 사람이라면 반드시 봐야할 영화'

In [9]:
train_ds.labels[0]

'10'

In [10]:
train_ds[0]

{'inputs': tensor([   2, 2288, 7089,  517,  422,  453,  389,  708,  423,  389, 7095,  791,
         7206, 2734, 7069, 5468, 3376, 6493, 5561, 6903, 1390, 1597, 7543, 7828,
         4694, 3960, 5938, 2355, 6197, 4737, 7628, 7095, 5152, 6527, 6896,  517,
         7233, 7086, 5943, 7828, 3194, 7172, 7828, 4883, 4299, 5859,  517, 7145,
         7088, 6630, 3273,   54,    3,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    

In [9]:
# Model
class MovieReviewModel(nn.Module):
    def __init__(self):
        super().__init__()
        # nn.CrossEntropyLoss in torch has built-in softmax. No need for activation.
        self.pretrained = KoBERT.to(DEVICE)
        self.linear     = nn.Linear(768, 10).to(DEVICE)
        
    def forward(self, x):
        # Pretrained Model Output -> (Embedding, Pooled)
        x, = self.pretrained(x)
        x  = self.linear(x[:,0])

        return x

model = MovieReviewModel()

# Loss
loss_func = nn.CrossEntropyLoss(ignore_index=PAD_ID)

# Optimizer
optimizer = torch.optim.Adam(params=model.parameters(), lr=hyperparams['learn_rate'])

In [10]:
# Validating
def validate(model, valid_loader):
    model.eval()
    s, t = 0, 0
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            x = batch['inputs'].to(DEVICE)
            y = batch['labels'].cpu()
            q = batch['length']

            if type(model) == MovieReviewModel:
                z = model(x).cpu()
            else:
                z = model(x, q).cpu()
            p = torch.argmax(z, dim=-1).cpu() + 1

            s += torch.sum(y==p, [-1,0]).item()
            t += len(x)

        print('VALIDATION ACCURACY: %.2f %%' % (s / t * 100))

# Finetuning
def finetune(model, train_loader):
    model.train()
    t = 0
    for i, batch in enumerate(tqdm(train_loader)):
        x = batch['inputs'].to(DEVICE)
        y = batch['labels'].cpu() - 1
        z = model(x).cpu()

        l = loss_func(z, y)
        t += l.item()
        
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
        
        # Print train loss, save model, and validate every 1000 steps
        if not (i % 1000):
            print('BATCH %d of %d TRAIN LOSS: %.3f' % (i, len(train_loader), l))
            print('SAVING MODEL ...')
            torch.save(model.state_dict(), 'checkpoint.pt')
            validate(model, valid_loader)
            model.train()

    print('TRAIN LOSS: %.3f' % t)



In [None]:
# if os.path.exists('checkpoint.pt'):
#     model.load_state_dict(torch.load('checkpoint.pt'))
#     model.to(DEVICE)
# else:
 

for e in range(hyperparams['num_epochs']):
    print('******* EPOCH %d / %d ********' % (e+1, hyperparams['num_epochs']))
    finetune(model, train_loader)

******* EPOCH 1 / 3 ********


HBox(children=(FloatProgress(value=0.0, max=281250.0), HTML(value='')))

BATCH 0 of 281250 TRAIN LOSS: 2.370
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 60.52 %
BATCH 1000 of 281250 TRAIN LOSS: 1.094
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 76.94 %
BATCH 2000 of 281250 TRAIN LOSS: 1.585
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 78.15 %
BATCH 3000 of 281250 TRAIN LOSS: 1.323
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 79.13 %
BATCH 4000 of 281250 TRAIN LOSS: 1.289
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 79.47 %
BATCH 5000 of 281250 TRAIN LOSS: 1.093
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 79.94 %
BATCH 6000 of 281250 TRAIN LOSS: 1.140
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.31 %
BATCH 7000 of 281250 TRAIN LOSS: 1.198
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.10 %
BATCH 8000 of 281250 TRAIN LOSS: 1.496
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.39 %
BATCH 9000 of 281250 TRAIN LOSS: 1.263
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.17 %
BATCH 10000 of 281250 TRAIN LOSS: 1.046
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.16 %
BATCH 11000 of 281250 TRAIN LOSS: 1.382
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.94 %
BATCH 12000 of 281250 TRAIN LOSS: 1.126
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.42 %
BATCH 13000 of 281250 TRAIN LOSS: 1.280
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.65 %
BATCH 14000 of 281250 TRAIN LOSS: 1.065
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.81 %
BATCH 15000 of 281250 TRAIN LOSS: 0.854
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.56 %
BATCH 16000 of 281250 TRAIN LOSS: 0.847
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.06 %
BATCH 17000 of 281250 TRAIN LOSS: 1.047
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.73 %
BATCH 18000 of 281250 TRAIN LOSS: 1.269
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 80.97 %
BATCH 19000 of 281250 TRAIN LOSS: 1.438
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.20 %
BATCH 20000 of 281250 TRAIN LOSS: 1.325
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.43 %
BATCH 21000 of 281250 TRAIN LOSS: 1.335
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.09 %
BATCH 22000 of 281250 TRAIN LOSS: 1.444
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.67 %
BATCH 23000 of 281250 TRAIN LOSS: 0.847
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.72 %
BATCH 24000 of 281250 TRAIN LOSS: 1.140
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.53 %
BATCH 25000 of 281250 TRAIN LOSS: 1.303
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.65 %
BATCH 26000 of 281250 TRAIN LOSS: 0.893
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.38 %
BATCH 27000 of 281250 TRAIN LOSS: 1.302
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.83 %
BATCH 28000 of 281250 TRAIN LOSS: 1.328
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.44 %
BATCH 29000 of 281250 TRAIN LOSS: 1.025
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.81 %
BATCH 30000 of 281250 TRAIN LOSS: 1.244
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.90 %
BATCH 31000 of 281250 TRAIN LOSS: 1.037
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.64 %
BATCH 32000 of 281250 TRAIN LOSS: 1.124
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.54 %
BATCH 33000 of 281250 TRAIN LOSS: 0.951
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.82 %
BATCH 34000 of 281250 TRAIN LOSS: 1.419
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.77 %
BATCH 35000 of 281250 TRAIN LOSS: 1.108
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.78 %
BATCH 36000 of 281250 TRAIN LOSS: 0.893
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.01 %
BATCH 37000 of 281250 TRAIN LOSS: 1.207
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.92 %
BATCH 38000 of 281250 TRAIN LOSS: 1.432
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.97 %
BATCH 39000 of 281250 TRAIN LOSS: 1.120
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.92 %
BATCH 40000 of 281250 TRAIN LOSS: 0.994
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.13 %
BATCH 41000 of 281250 TRAIN LOSS: 1.061
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.26 %
BATCH 42000 of 281250 TRAIN LOSS: 1.100
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.73 %
BATCH 43000 of 281250 TRAIN LOSS: 1.177
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.09 %
BATCH 44000 of 281250 TRAIN LOSS: 1.218
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.04 %
BATCH 45000 of 281250 TRAIN LOSS: 0.915
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.22 %
BATCH 46000 of 281250 TRAIN LOSS: 1.279
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 81.92 %
BATCH 47000 of 281250 TRAIN LOSS: 1.392
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.22 %
BATCH 48000 of 281250 TRAIN LOSS: 0.809
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.10 %
BATCH 49000 of 281250 TRAIN LOSS: 1.153
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.15 %
BATCH 50000 of 281250 TRAIN LOSS: 1.507
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.17 %
BATCH 51000 of 281250 TRAIN LOSS: 1.408
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.21 %
BATCH 52000 of 281250 TRAIN LOSS: 1.479
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.09 %
BATCH 53000 of 281250 TRAIN LOSS: 1.069
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.44 %
BATCH 54000 of 281250 TRAIN LOSS: 1.394
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.37 %
BATCH 55000 of 281250 TRAIN LOSS: 1.056
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.31 %
BATCH 56000 of 281250 TRAIN LOSS: 0.862
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.28 %
BATCH 57000 of 281250 TRAIN LOSS: 1.028
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.38 %
BATCH 58000 of 281250 TRAIN LOSS: 0.930
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.38 %
BATCH 59000 of 281250 TRAIN LOSS: 0.746
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.03 %
BATCH 60000 of 281250 TRAIN LOSS: 1.759
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.35 %
BATCH 61000 of 281250 TRAIN LOSS: 0.916
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.08 %
BATCH 62000 of 281250 TRAIN LOSS: 1.074
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.47 %
BATCH 63000 of 281250 TRAIN LOSS: 1.135
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.38 %
BATCH 64000 of 281250 TRAIN LOSS: 1.153
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.17 %
BATCH 65000 of 281250 TRAIN LOSS: 1.149
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.30 %
BATCH 66000 of 281250 TRAIN LOSS: 1.449
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.42 %
BATCH 67000 of 281250 TRAIN LOSS: 0.978
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.34 %
BATCH 68000 of 281250 TRAIN LOSS: 1.172
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.50 %
BATCH 69000 of 281250 TRAIN LOSS: 0.919
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.26 %
BATCH 70000 of 281250 TRAIN LOSS: 1.256
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.63 %
BATCH 71000 of 281250 TRAIN LOSS: 0.966
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.39 %
BATCH 72000 of 281250 TRAIN LOSS: 1.050
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.55 %
BATCH 73000 of 281250 TRAIN LOSS: 1.002
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.61 %
BATCH 74000 of 281250 TRAIN LOSS: 0.989
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.39 %
BATCH 75000 of 281250 TRAIN LOSS: 0.928
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.57 %
BATCH 76000 of 281250 TRAIN LOSS: 0.992
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.50 %
BATCH 77000 of 281250 TRAIN LOSS: 1.188
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.76 %
BATCH 78000 of 281250 TRAIN LOSS: 1.022
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.49 %
BATCH 79000 of 281250 TRAIN LOSS: 0.684
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.57 %
BATCH 80000 of 281250 TRAIN LOSS: 1.094
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.71 %
BATCH 81000 of 281250 TRAIN LOSS: 0.934
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.74 %
BATCH 82000 of 281250 TRAIN LOSS: 0.998
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.71 %
BATCH 83000 of 281250 TRAIN LOSS: 1.090
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.54 %
BATCH 84000 of 281250 TRAIN LOSS: 1.222
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.36 %
BATCH 85000 of 281250 TRAIN LOSS: 1.281
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.58 %
BATCH 86000 of 281250 TRAIN LOSS: 1.227
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.71 %
BATCH 87000 of 281250 TRAIN LOSS: 0.917
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.75 %
BATCH 88000 of 281250 TRAIN LOSS: 0.874
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.57 %
BATCH 89000 of 281250 TRAIN LOSS: 1.266
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.52 %
BATCH 90000 of 281250 TRAIN LOSS: 1.242
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.49 %
BATCH 91000 of 281250 TRAIN LOSS: 0.827
SAVING MODEL ...


HBox(children=(FloatProgress(value=0.0, max=79.0), HTML(value='')))


VALIDATION ACCURACY: 82.69 %


In [11]:
def test(model, test_loader):
    model.eval()
    output = []
    for i, batch in enumerate(tqdm(test_loader)):
        x = batch['inputs'].to(DEVICE)
        z = model(x)
        p = torch.argmax(z, dim=-1).cpu() + 1
        l = [i.item() for i in p.flatten()]
        output.extend(l)

    return output

In [14]:
# Load model and run test
model.load_state_dict(torch.load('checkpoint.pt'))
model.to(DEVICE)
l = test(model, test_loader)

dic = {'ID': [i for i in range(len(l))], 'Prediction': l}
df = pd.DataFrame(dic)
df.to_csv('submission.csv', index=False)

HBox(children=(FloatProgress(value=0.0, max=200000.0), HTML(value='')))




# 테스트 배치사이즈 안맞을시

In [13]:
test_loader  = DataLoader(test_ds,  batch_size=2)