Most of the codes are based on the Pytorch implementation of Sentiment analysis based on KOBERT, 

which can be found in https://github.com/SKTBrain/KoBERT/blob/master/scripts/NSMC/naver_review_classifications_pytorch_kobert.ipynb.

In [None]:
import os
import time
from tqdm import tqdm
import numpy as np
import pandas as pd

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import gluonnlp as nlp
from kobert import get_pytorch_kobert_model
from kobert import get_tokenizer

from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [None]:
# Get the pre-processed train/test data, which were created from
# heuristically-classified Naver Finance data and pre-trained NSMC corpus

dataset_train = nlp.data.TSVDataset(os.path.join(os.getcwd(),'train_all.txt'), field_indices=[1,2], num_discard_samples=1)
dataset_test = nlp.data.TSVDataset(os.path.join(os.getcwd(),'test_all.txt'), field_indices=[1,2], num_discard_samples=1)

In [None]:
## CPU
# device = torch.device("cpu")

## GPU
device = torch.device("cuda:0")

In [None]:
# Set hyperparameters

max_len = 48       # Set to be the largest length in the given dataset
batch_size = 64
warmup_ratio = 0.1
num_epochs = 5
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [None]:
# Get the pre-trained KOBERT model

bertmodel, vocab = get_pytorch_kobert_model(cachedir=".cache")

/content/.cache/kobert_v1.zip[██████████████████████████████████████████████████]
/content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece[██████████████████████████████████████████████████]


In [None]:
# Get and initialize KOBERT tokenizer

tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model. /content/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [None]:
# Preprocessing class for BERT modeling, based on tokenizer

class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in tqdm(dataset)]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))

In [None]:
# Preproces train/test data into the form that model requires

data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

100%|██████████| 268999/268999 [00:16<00:00, 16106.94it/s]
100%|██████████| 89665/89665 [00:05<00:00, 16270.52it/s]


In [None]:
# Put preprocessed train/test data into the dataloaders

train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=4)



In [None]:
# Classifier network is made up with the base BERT architecture and
# one linear classifer architercture on top of it

class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        else:
            out = pooler
        return self.classifier(out)

In [None]:
# Initialize the classifier with the base pre-trained KOBERT model

model = BERTClassifier(bertmodel,  dr_rate=0.5).to(device)

In [None]:
# Prepare optimizer and schedule (linear warmup and decay)

no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [None]:
# Use Adam optimizer, and CrossEntropy loss function

optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [None]:
# Scheduler configuration to change the learning rate as the step increases

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)

In [None]:
# Accuracy metric for the classification (Simple and Intuitive)

def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [None]:
# Train the model for Naver Finance sentiment classification, 
# based on the train/validation(test) dataset, which consists of
# heuristically-classifed Naver Finance threads, and the NSMC corpus,
# by leveraging on the pre-trained KOBERT architecture & parameters

loss_list = []
train_acc_list = []
test_acc_list = []

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
            loss_list.append(loss.data.cpu().numpy())
            train_acc_list.append(train_acc / (batch_id+1))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    train_acc_list.append(train_acc / (batch_id+1))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    test_acc_list.append(test_acc / (batch_id+1))

  0%|          | 1/4204 [00:00<45:40,  1.53it/s]

epoch 1 batch id 1 loss 0.7518500089645386 train acc 0.53125


  5%|▍         | 201/4204 [01:46<35:21,  1.89it/s]

epoch 1 batch id 201 loss 0.7024027109146118 train acc 0.5032649253731343


 10%|▉         | 401/4204 [03:34<34:00,  1.86it/s]

epoch 1 batch id 401 loss 0.3296284079551697 train acc 0.6082839775561097


 14%|█▍        | 601/4204 [05:21<32:26,  1.85it/s]

epoch 1 batch id 601 loss 0.19690009951591492 train acc 0.6962354409317804


 19%|█▉        | 801/4204 [07:08<30:26,  1.86it/s]

epoch 1 batch id 801 loss 0.1867685168981552 train acc 0.746313202247191


 24%|██▍       | 1001/4204 [08:56<28:38,  1.86it/s]

epoch 1 batch id 1001 loss 0.2463875114917755 train acc 0.7778471528471529


 29%|██▊       | 1201/4204 [10:43<26:51,  1.86it/s]

epoch 1 batch id 1201 loss 0.33933523297309875 train acc 0.8004917776852623


 33%|███▎      | 1401/4204 [12:30<25:02,  1.87it/s]

epoch 1 batch id 1401 loss 0.2703348696231842 train acc 0.8160242683797287


 38%|███▊      | 1601/4204 [14:17<23:17,  1.86it/s]

epoch 1 batch id 1601 loss 0.16353772580623627 train acc 0.8281933166770769


 43%|████▎     | 1801/4204 [16:04<21:26,  1.87it/s]

epoch 1 batch id 1801 loss 0.08808691799640656 train acc 0.8380760688506386


 48%|████▊     | 2001/4204 [17:51<19:41,  1.86it/s]

epoch 1 batch id 2001 loss 0.23847873508930206 train acc 0.8460769615192404


 52%|█████▏    | 2201/4204 [19:39<17:54,  1.86it/s]

epoch 1 batch id 2201 loss 0.1727137714624405 train acc 0.8529503634711495


 57%|█████▋    | 2401/4204 [21:26<16:03,  1.87it/s]

epoch 1 batch id 2401 loss 0.2230997234582901 train acc 0.8585485214493961


 62%|██████▏   | 2601/4204 [23:13<14:20,  1.86it/s]

epoch 1 batch id 2601 loss 0.2420775592327118 train acc 0.8634599673202614


 67%|██████▋   | 2801/4204 [25:00<12:33,  1.86it/s]

epoch 1 batch id 2801 loss 0.1946702003479004 train acc 0.8676477151017494


 71%|███████▏  | 3001/4204 [26:47<10:43,  1.87it/s]

epoch 1 batch id 3001 loss 0.10933242738246918 train acc 0.8713605881372876


 76%|███████▌  | 3201/4204 [28:34<08:57,  1.87it/s]

epoch 1 batch id 3201 loss 0.0666903406381607 train acc 0.8746729537644486


 81%|████████  | 3401/4204 [30:21<07:10,  1.86it/s]

epoch 1 batch id 3401 loss 0.3035421073436737 train acc 0.8775589900029404


 86%|████████▌ | 3601/4204 [32:08<05:23,  1.86it/s]

epoch 1 batch id 3601 loss 0.1190512403845787 train acc 0.8802936684254373


 90%|█████████ | 3801/4204 [33:56<03:35,  1.87it/s]

epoch 1 batch id 3801 loss 0.16725970804691315 train acc 0.8828268876611418


 95%|█████████▌| 4001/4204 [35:43<01:48,  1.87it/s]

epoch 1 batch id 4001 loss 0.1133481115102768 train acc 0.8850014058985254


100%|█████████▉| 4201/4204 [37:30<00:01,  1.88it/s]

epoch 1 batch id 4201 loss 0.22325681149959564 train acc 0.8873370923589622


100%|██████████| 4204/4204 [37:31<00:00,  1.87it/s]

epoch 1 train acc 0.8873914724072313



100%|██████████| 1402/1402 [04:10<00:00,  5.59it/s]

epoch 1 test acc 0.9305456490727532



  0%|          | 1/4204 [00:00<36:28,  1.92it/s]

epoch 2 batch id 1 loss 0.35756170749664307 train acc 0.859375


  5%|▍         | 201/4204 [01:47<35:50,  1.86it/s]

epoch 2 batch id 201 loss 0.34262320399284363 train acc 0.9251399253731343


 10%|▉         | 401/4204 [03:34<33:59,  1.86it/s]

epoch 2 batch id 401 loss 0.06666381657123566 train acc 0.9269404613466334


 14%|█▍        | 601/4204 [05:22<32:10,  1.87it/s]

epoch 2 batch id 601 loss 0.10525047034025192 train acc 0.9287645590682196


 19%|█▉        | 801/4204 [07:09<30:21,  1.87it/s]

epoch 2 batch id 801 loss 0.18274588882923126 train acc 0.9302239388264669


 24%|██▍       | 1001/4204 [08:56<28:40,  1.86it/s]

epoch 2 batch id 1001 loss 0.15293416380882263 train acc 0.9317401348651349


 29%|██▊       | 1201/4204 [10:44<26:50,  1.86it/s]

epoch 2 batch id 1201 loss 0.20816847681999207 train acc 0.9331416527893422


 33%|███▎      | 1401/4204 [12:31<25:01,  1.87it/s]

epoch 2 batch id 1401 loss 0.23310773074626923 train acc 0.933953426124197


 38%|███▊      | 1601/4204 [14:17<23:12,  1.87it/s]

epoch 2 batch id 1601 loss 0.05080941319465637 train acc 0.9347478138663335


 43%|████▎     | 1801/4204 [16:04<21:23,  1.87it/s]

epoch 2 batch id 1801 loss 0.03408567234873772 train acc 0.9357041227096058


 48%|████▊     | 2001/4204 [17:51<19:40,  1.87it/s]

epoch 2 batch id 2001 loss 0.1424965113401413 train acc 0.93655515992004


 52%|█████▏    | 2201/4204 [19:38<17:50,  1.87it/s]

epoch 2 batch id 2201 loss 0.16334424912929535 train acc 0.9373580190822354


 57%|█████▋    | 2401/4204 [21:25<16:07,  1.86it/s]

epoch 2 batch id 2401 loss 0.13382838666439056 train acc 0.9382353706788839


 62%|██████▏   | 2601/4204 [23:12<14:21,  1.86it/s]

epoch 2 batch id 2601 loss 0.25854936242103577 train acc 0.9391580161476355


 67%|██████▋   | 2801/4204 [24:59<12:32,  1.87it/s]

epoch 2 batch id 2801 loss 0.15013687312602997 train acc 0.9398763834344876


 71%|███████▏  | 3001/4204 [26:46<10:42,  1.87it/s]

epoch 2 batch id 3001 loss 0.10279393196105957 train acc 0.9404104881706098


 76%|███████▌  | 3201/4204 [28:33<08:56,  1.87it/s]

epoch 2 batch id 3201 loss 0.06728874891996384 train acc 0.9409461886910341


 81%|████████  | 3401/4204 [30:20<07:09,  1.87it/s]

epoch 2 batch id 3401 loss 0.21516314148902893 train acc 0.9414694207586004


 86%|████████▌ | 3601/4204 [32:07<05:22,  1.87it/s]

epoch 2 batch id 3601 loss 0.0707777664065361 train acc 0.942025652596501


 90%|█████████ | 3801/4204 [33:54<03:35,  1.87it/s]

epoch 2 batch id 3801 loss 0.11862091720104218 train acc 0.9425808997632202


 95%|█████████▌| 4001/4204 [35:42<01:48,  1.87it/s]

epoch 2 batch id 4001 loss 0.04799315705895424 train acc 0.9431001624593851


100%|█████████▉| 4201/4204 [37:29<00:01,  1.88it/s]

epoch 2 batch id 4201 loss 0.18729785084724426 train acc 0.943819179957153


100%|██████████| 4204/4204 [37:30<00:00,  1.87it/s]

epoch 2 train acc 0.9438406874405328



100%|██████████| 1402/1402 [04:10<00:00,  5.61it/s]

epoch 2 test acc 0.9365749821683309



  0%|          | 1/4204 [00:00<36:27,  1.92it/s]

epoch 3 batch id 1 loss 0.21786531805992126 train acc 0.90625


  5%|▍         | 201/4204 [01:47<35:40,  1.87it/s]

epoch 3 batch id 201 loss 0.20705074071884155 train acc 0.9497823383084577


 10%|▉         | 401/4204 [03:34<33:55,  1.87it/s]

epoch 3 batch id 401 loss 0.039436206221580505 train acc 0.9528522443890274


 14%|█▍        | 601/4204 [05:21<32:11,  1.87it/s]

epoch 3 batch id 601 loss 0.05108603462576866 train acc 0.9545549084858569


 19%|█▉        | 801/4204 [07:08<30:26,  1.86it/s]

epoch 3 batch id 801 loss 0.07231220602989197 train acc 0.9561095505617978


 24%|██▍       | 1001/4204 [08:55<28:33,  1.87it/s]

epoch 3 batch id 1001 loss 0.07741715759038925 train acc 0.9569493006993007


 29%|██▊       | 1201/4204 [10:43<26:45,  1.87it/s]

epoch 3 batch id 1201 loss 0.20922128856182098 train acc 0.9577175270607827


 33%|███▎      | 1401/4204 [12:30<25:08,  1.86it/s]

epoch 3 batch id 1401 loss 0.186459481716156 train acc 0.9582552640970735


 38%|███▊      | 1601/4204 [14:17<23:04,  1.88it/s]

epoch 3 batch id 1601 loss 0.040096353739500046 train acc 0.9589221580262336


 43%|████▎     | 1801/4204 [16:04<21:28,  1.86it/s]

epoch 3 batch id 1801 loss 0.0372525118291378 train acc 0.9597359106052193


 48%|████▊     | 2001/4204 [17:51<19:35,  1.87it/s]

epoch 3 batch id 2001 loss 0.14670135080814362 train acc 0.9603245252373813


 52%|█████▏    | 2201/4204 [19:38<17:49,  1.87it/s]

epoch 3 batch id 2201 loss 0.11553248018026352 train acc 0.9608984552476147


 57%|█████▋    | 2401/4204 [21:25<16:05,  1.87it/s]

epoch 3 batch id 2401 loss 0.14017708599567413 train acc 0.9614743856726364


 62%|██████▏   | 2601/4204 [23:12<14:19,  1.87it/s]

epoch 3 batch id 2601 loss 0.14788664877414703 train acc 0.9620218185313341


 67%|██████▋   | 2801/4204 [24:59<12:31,  1.87it/s]

epoch 3 batch id 2801 loss 0.19709154963493347 train acc 0.9625635933595145


 71%|███████▏  | 3001/4204 [26:46<10:44,  1.87it/s]

epoch 3 batch id 3001 loss 0.03631721809506416 train acc 0.9629862962345884


 76%|███████▌  | 3201/4204 [28:34<08:57,  1.87it/s]

epoch 3 batch id 3201 loss 0.005295498296618462 train acc 0.9634342783505154


 81%|████████  | 3401/4204 [30:21<07:10,  1.87it/s]

epoch 3 batch id 3401 loss 0.2150292992591858 train acc 0.963769847103793


 86%|████████▌ | 3601/4204 [32:08<05:23,  1.86it/s]

epoch 3 batch id 3601 loss 0.10755099356174469 train acc 0.9643371632879756


 90%|█████████ | 3801/4204 [33:55<03:36,  1.86it/s]

epoch 3 batch id 3801 loss 0.09058941155672073 train acc 0.9647748947645356


 95%|█████████▌| 4001/4204 [35:42<01:48,  1.87it/s]

epoch 3 batch id 4001 loss 0.024729078635573387 train acc 0.9651688640339915


100%|█████████▉| 4201/4204 [37:29<00:01,  1.87it/s]

epoch 3 batch id 4201 loss 0.08839485049247742 train acc 0.9656480599857177


100%|██████████| 4204/4204 [37:31<00:00,  1.87it/s]

epoch 3 train acc 0.9656614236441484



100%|██████████| 1402/1402 [04:11<00:00,  5.58it/s]

epoch 3 test acc 0.9392385877318117



  0%|          | 1/4204 [00:00<36:14,  1.93it/s]

epoch 4 batch id 1 loss 0.06578122824430466 train acc 0.96875


  5%|▍         | 201/4204 [01:47<35:49,  1.86it/s]

epoch 4 batch id 201 loss 0.1422354280948639 train acc 0.9684390547263682


 10%|▉         | 401/4204 [03:34<33:50,  1.87it/s]

epoch 4 batch id 401 loss 0.0192982479929924 train acc 0.9712827306733167


 14%|█▍        | 601/4204 [05:22<32:27,  1.85it/s]

epoch 4 batch id 601 loss 0.048108797520399094 train acc 0.9732217138103162


 19%|█▉        | 801/4204 [07:09<30:29,  1.86it/s]

epoch 4 batch id 801 loss 0.13177327811717987 train acc 0.974211922596754


 24%|██▍       | 1001/4204 [08:56<28:43,  1.86it/s]

epoch 4 batch id 1001 loss 0.1322590857744217 train acc 0.9747908341658341


 29%|██▊       | 1201/4204 [10:43<26:45,  1.87it/s]

epoch 4 batch id 1201 loss 0.2027164250612259 train acc 0.975176935886761


 33%|███▎      | 1401/4204 [12:31<25:13,  1.85it/s]

epoch 4 batch id 1401 loss 0.17166078090667725 train acc 0.9757873840114204


 38%|███▊      | 1601/4204 [14:18<23:12,  1.87it/s]

epoch 4 batch id 1601 loss 0.007427994627505541 train acc 0.976235555902561


 43%|████▎     | 1801/4204 [16:05<21:25,  1.87it/s]

epoch 4 batch id 1801 loss 0.01476898230612278 train acc 0.9768184342032205


 48%|████▊     | 2001/4204 [17:52<19:35,  1.87it/s]

epoch 4 batch id 2001 loss 0.07528883963823318 train acc 0.9772379435282359


 52%|█████▏    | 2201/4204 [19:39<17:52,  1.87it/s]

epoch 4 batch id 2201 loss 0.0529000423848629 train acc 0.9777302930486143


 57%|█████▋    | 2401/4204 [21:26<16:04,  1.87it/s]

epoch 4 batch id 2401 loss 0.062436170876026154 train acc 0.9780755414410662


 62%|██████▏   | 2601/4204 [23:13<14:23,  1.86it/s]

epoch 4 batch id 2601 loss 0.10317745059728622 train acc 0.9784578046905037


 67%|██████▋   | 2801/4204 [25:01<12:30,  1.87it/s]

epoch 4 batch id 2801 loss 0.059405311942100525 train acc 0.9788747322384862


 71%|███████▏  | 3001/4204 [26:48<10:43,  1.87it/s]

epoch 4 batch id 3001 loss 0.01971041038632393 train acc 0.9790850966344552


 76%|███████▌  | 3201/4204 [28:35<08:57,  1.87it/s]

epoch 4 batch id 3201 loss 0.00317357387393713 train acc 0.9794449000312402


 81%|████████  | 3401/4204 [30:22<07:09,  1.87it/s]

epoch 4 batch id 3401 loss 0.1611330360174179 train acc 0.9796750955601293


 86%|████████▌ | 3601/4204 [32:09<05:23,  1.87it/s]

epoch 4 batch id 3601 loss 0.029594168066978455 train acc 0.9799534851430158


 90%|█████████ | 3801/4204 [33:56<03:35,  1.87it/s]

epoch 4 batch id 3801 loss 0.08613819628953934 train acc 0.9802107997895291


 95%|█████████▌| 4001/4204 [35:43<01:48,  1.87it/s]

epoch 4 batch id 4001 loss 0.0122053362429142 train acc 0.9804423894026494


100%|█████████▉| 4201/4204 [37:31<00:01,  1.87it/s]

epoch 4 batch id 4201 loss 0.03261991962790489 train acc 0.9806928409902405


100%|██████████| 4204/4204 [37:32<00:00,  1.87it/s]

epoch 4 train acc 0.9807029019980971



100%|██████████| 1402/1402 [04:10<00:00,  5.59it/s]

epoch 4 test acc 0.9398849857346647



  0%|          | 1/4204 [00:00<36:39,  1.91it/s]

epoch 5 batch id 1 loss 0.06526624411344528 train acc 0.96875


  5%|▍         | 201/4204 [01:47<35:49,  1.86it/s]

epoch 5 batch id 201 loss 0.04494037479162216 train acc 0.9849191542288557


 10%|▉         | 401/4204 [03:34<33:56,  1.87it/s]

epoch 5 batch id 401 loss 0.004987768363207579 train acc 0.9848425810473815


 14%|█▍        | 601/4204 [05:21<32:12,  1.86it/s]

epoch 5 batch id 601 loss 0.11397985368967056 train acc 0.985648918469218


 19%|█▉        | 801/4204 [07:08<30:18,  1.87it/s]

epoch 5 batch id 801 loss 0.015250646509230137 train acc 0.9858770287141073


 24%|██▍       | 1001/4204 [08:55<28:41,  1.86it/s]

epoch 5 batch id 1001 loss 0.0535762794315815 train acc 0.9862481268731269


 29%|██▊       | 1201/4204 [10:42<26:44,  1.87it/s]

epoch 5 batch id 1201 loss 0.17787984013557434 train acc 0.9865346586178185


 33%|███▎      | 1401/4204 [12:29<25:06,  1.86it/s]

epoch 5 batch id 1401 loss 0.15884895622730255 train acc 0.9868286045681656


 38%|███▊      | 1601/4204 [14:16<23:08,  1.88it/s]

epoch 5 batch id 1601 loss 0.006339280866086483 train acc 0.9869710337289195


 43%|████▎     | 1801/4204 [16:03<21:23,  1.87it/s]

epoch 5 batch id 1801 loss 0.004466739017516375 train acc 0.9873160744031094


 48%|████▊     | 2001/4204 [17:50<19:35,  1.87it/s]

epoch 5 batch id 2001 loss 0.010880510322749615 train acc 0.9873891179410295


 52%|█████▏    | 2201/4204 [19:37<17:50,  1.87it/s]

epoch 5 batch id 2201 loss 0.005512502044439316 train acc 0.9876831553839164


 57%|█████▋    | 2401/4204 [21:24<16:05,  1.87it/s]

epoch 5 batch id 2401 loss 0.020474033430218697 train acc 0.9877915451895044


 62%|██████▏   | 2601/4204 [23:11<14:17,  1.87it/s]

epoch 5 batch id 2601 loss 0.10626467317342758 train acc 0.987943339100346


 67%|██████▋   | 2801/4204 [24:58<12:34,  1.86it/s]

epoch 5 batch id 2801 loss 0.05251137167215347 train acc 0.98814597465191


 71%|███████▏  | 3001/4204 [26:46<10:43,  1.87it/s]

epoch 5 batch id 3001 loss 0.004179703537374735 train acc 0.9882122625791403


 76%|███████▌  | 3201/4204 [28:33<08:57,  1.86it/s]

epoch 5 batch id 3201 loss 0.001932527287863195 train acc 0.9883288425492034


 81%|████████  | 3401/4204 [30:20<07:07,  1.88it/s]

epoch 5 batch id 3401 loss 0.16232435405254364 train acc 0.9883168553366657


 86%|████████▌ | 3601/4204 [32:07<05:24,  1.86it/s]

epoch 5 batch id 3601 loss 0.023717235773801804 train acc 0.9884146764787559


 90%|█████████ | 3801/4204 [33:53<03:36,  1.87it/s]

epoch 5 batch id 3801 loss 0.008132265880703926 train acc 0.9885556432517758


 95%|█████████▌| 4001/4204 [35:41<01:48,  1.87it/s]

epoch 5 batch id 4001 loss 0.016083626076579094 train acc 0.9886356535866033


100%|█████████▉| 4201/4204 [37:28<00:01,  1.87it/s]

epoch 5 batch id 4201 loss 0.04420106112957001 train acc 0.9886857295881933


100%|██████████| 4204/4204 [37:29<00:00,  1.87it/s]

epoch 5 train acc 0.9886900868220743



100%|██████████| 1402/1402 [04:11<00:00,  5.58it/s]

epoch 5 test acc 0.9408991619115549





In [None]:
# Save parameters of the trained model

torch.save(model,os.path.join(os.getcwd(),'trained_model.pt'))