In [1]:
# KoBERT
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

In [2]:
# Transformers
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup

In [3]:
# Setting Library
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook
import pandas as pd

In [4]:
# GPU 사용
device = torch.device("cuda:0") 

# CPU 사용 시
# device = torch.device("cpu") 

In [5]:
device

device(type='cuda', index=0)

In [6]:
# BERT 모델, Vocabulary 불러오기
bertmodel, vocab = get_pytorch_kobert_model(cachedir=".cache")

using cached model. /home/ai/ai/PAPER/.cache/kobert_v1.zip
using cached model. /home/ai/ai/PAPER/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [7]:
# [AI Hub] 감정 분류를 위한 대화 음성 데이터셋 불러오기
data = pd.read_csv("./total_df.csv")

# 2. 데이터 전처리

In [8]:
data.head()

Unnamed: 0.1,Unnamed: 0,chat,emotion
0,0,헐! 나 이벤트에 당첨 됐어.,happiness
1,1,내가 좋아하는 인플루언서가 이벤트를 하더라고. 그래서 그냥 신청 한번 해봤지.,happiness
2,2,"한 명 뽑는 거였는데, 그게 바로 내가 된 거야.",happiness
3,3,"당연히 마음에 드는 선물이니깐, 이벤트에 내가 신청 한번 해본 거지. 비싼 거야. ...",happiness
4,4,에피타이저 정말 좋아해. 그 것도 괜찮은 생각인 것 같애.,neutral


In [9]:
total_df = data[['chat','emotion']]

In [10]:
total_df.head()

Unnamed: 0,chat,emotion
0,헐! 나 이벤트에 당첨 됐어.,happiness
1,내가 좋아하는 인플루언서가 이벤트를 하더라고. 그래서 그냥 신청 한번 해봤지.,happiness
2,"한 명 뽑는 거였는데, 그게 바로 내가 된 거야.",happiness
3,"당연히 마음에 드는 선물이니깐, 이벤트에 내가 신청 한번 해본 거지. 비싼 거야. ...",happiness
4,에피타이저 정말 좋아해. 그 것도 괜찮은 생각인 것 같애.,neutral


In [11]:
total_df.shape

(207662, 2)

In [12]:
total_df['emotion'].unique()

array(['happiness', 'neutral', 'sadness', 'angry', 'surprise', 'disgust',
       'fear', '분노', '기쁨', '불안', '당황', '슬픔', '상처', '무감정'], dtype=object)

In [13]:
total_df.loc[total_df['emotion'] == '기쁨', 'emotion'] = 'happiness'

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  total_df.loc[total_df['emotion'] == '기쁨', 'emotion'] = 'happiness'


In [14]:
total_df.loc[total_df['emotion'] == '무감정', 'emotion'] = 'neutral'

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  total_df.loc[total_df['emotion'] == '무감정', 'emotion'] = 'neutral'


In [15]:
total_df.loc[total_df['emotion'] == '슬픔', 'emotion'] = 'sadness'

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  total_df.loc[total_df['emotion'] == '슬픔', 'emotion'] = 'sadness'


In [16]:
total_df.loc[total_df['emotion'] == '분노', 'emotion'] = 'angry'

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  total_df.loc[total_df['emotion'] == '분노', 'emotion'] = 'angry'


In [17]:
total_df.loc[total_df['emotion'] == '당황', 'emotion'] = 'surprise'

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  total_df.loc[total_df['emotion'] == '당황', 'emotion'] = 'surprise'


In [18]:
total_df.loc[total_df['emotion'] == '불안', 'emotion'] = 'fear'

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  total_df.loc[total_df['emotion'] == '불안', 'emotion'] = 'fear'


In [19]:
total_df.loc[total_df['emotion'] == '상처', 'emotion'] = 'disgust'

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  total_df.loc[total_df['emotion'] == '상처', 'emotion'] = 'disgust'


In [20]:
total_df['emotion'].unique()

array(['happiness', 'neutral', 'sadness', 'angry', 'surprise', 'disgust',
       'fear'], dtype=object)

In [21]:
total_df.to_csv('./total_df_convert.csv')

In [22]:
# 7개의 감정 class → 숫자
total_df.loc[(total_df['emotion'] == "fear"), 'emotion'] = 0  # fear → 0
total_df.loc[(total_df['emotion'] == "surprise"), 'emotion'] = 1  # surprise → 1
total_df.loc[(total_df['emotion'] == "angry"), 'emotion'] = 2  # angry → 2
total_df.loc[(total_df['emotion'] == "sadness"), 'emotion'] = 3  # sadness → 3
total_df.loc[(total_df['emotion'] == "neutral"), 'emotion'] = 4  # neutral → 4
total_df.loc[(total_df['emotion'] == "happiness"), 'emotion'] = 5  # happiness → 5
total_df.loc[(total_df['emotion'] == "disgust"), 'emotion'] = 6  # disgust → 6

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  total_df.loc[(total_df['emotion'] == "fear"), 'emotion'] = 0  # fear → 0
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  total_df.loc[(total_df['emotion'] == "surprise"), 'emotion'] = 1  # surprise → 1
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  total_df.loc[(total_df['emotion'] == "angry"), 'emotion'] = 2  # angry → 2
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stab

In [23]:
total_df['emotion'].unique()

array([5, 4, 3, 2, 1, 6, 0], dtype=object)

In [24]:
# [발화문, 상황] data_list 생성
data_list = []
for ques, label in zip (total_df['chat'], total_df['emotion']):
  data = []
  data.append(ques)
  data.append(str(label))

  data_list.append(data)

In [25]:
print(data)
print(data_list[:10])

['마찬가지야, 깨달음을 얻어 생사를 벗어난다 해도.', '4']
[['헐! 나 이벤트에 당첨 됐어.', '5'], ['내가 좋아하는 인플루언서가 이벤트를 하더라고. 그래서 그냥 신청 한번 해봤지.', '5'], ['한 명 뽑는 거였는데, 그게 바로 내가 된 거야.', '5'], ['당연히 마음에 드는 선물이니깐, 이벤트에 내가 신청 한번 해본 거지. 비싼 거야. 그래서 못 산 향수야.', '5'], ['에피타이저 정말 좋아해. 그 것도 괜찮은 생각인 것 같애.', '4'], ['난 부페 형식의 음식들도 정말 좋아해. 그 것도 좀 알려 줘.', '4'], ['응. 완전히 끝난 거야. 한 달 동안 주말에 쉬지도 못하고 일만 했거든.', '5'], ['신나는 음악 듣는 것도 좋고, 어디 여행 가고 싶고 이 것 저 것 다 해보고 싶어.', '5'], ['친구들도 내 연락 기다리고 있을 텐데 내가 까먹고 있었네?', '5'], ['그래. 일단은 친구들부터 만나서 여행 계획에 대해서 얘기 좀 해봐야 되겠어.', '5']]


# 2-1.Split train & test data set

In [26]:
from sklearn.model_selection import train_test_split

In [27]:
dataset_train, dataset_test = train_test_split(data_list, test_size = 0.2, shuffle = True, random_state = 32)

In [28]:
print(len(dataset_train), len(dataset_test))

166129 41533


# 2-2.입력 데이터셋 토큰화

In [29]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower = False)

using cached model. /home/ai/ai/PAPER/.cache/kobert_news_wiki_ko_cased-1087f8699e.spiece


In [30]:
# BERTDataset : 각 데이터가 BERT 모델의 입력으로 들어갈 수 있도록 tokenization, int encoding, padding하는 함수
# 출처 : https://github.com/SKTBrain/KoBERT/blob/master/scripts/NSMC/naver_review_classifications_pytorch_kobert.ipynb

class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, vocab, max_len, pad, pair):
   
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len,vocab = vocab, pad = pad, pair = pair)
        
        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))
         

    def __len__(self):
        return (len(self.labels))

# 2-3.Setting parameters

In [31]:
# parameter 값 출처 : https://github.com/SKTBrain/KoBERT/blob/master/scripts/NSMC/naver_review_classifications_pytorch_kobert.ipynb
max_len = 64
batch_size = 64
warmup_ratio = 0.1
num_epochs = 50
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

# 2-4.tokenization, int encoding, padding

In [32]:
# BERTDataset : 각 데이터가 BERT 모델의 입력으로 들어갈 수 있도록 tokenization, int encoding, padding하는 함수
data_train = BERTDataset(dataset_train, 0, 1, tok, vocab, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, vocab, max_len, True, False)

In [33]:
# torch 형식의 dataset을 만들어주면서, 입력 데이터셋의 전처리
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size = batch_size, num_workers = 5)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size = batch_size, num_workers = 5)

# 3.KoBERT 모델 구현

In [34]:
# KoBERT 오픈소스 내 예제코드 : https://github.com/SKTBrain/KoBERT/blob/master/scripts/NSMC/naver_review_classifications_pytorch_kobert.ipynb
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes = 7,   # 클래스 수 조정
                 dr_rate = None,
                 params = None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p = dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device),return_dict = False)
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [35]:
# BERT  모델 불러오기
model = BERTClassifier(bertmodel,  dr_rate = 0.5).to(device)

In [36]:
# optimizer와 schedule 설정
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

optimizer = AdamW(optimizer_grouped_parameters, lr = learning_rate)
loss_fn = nn.CrossEntropyLoss() # 다중분류를 위한 loss function

t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps = warmup_step, num_training_steps = t_total)

In [37]:
# calc_accuracy : 정확도 측정을 위한 함수
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc
    
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x70120369a910>

# 4.Train

In [38]:
# KoBERT 오픈소스 내 예제코드 : https://github.com/SKTBrain/KoBERT/blob/master/scripts/NSMC/naver_review_classifications_pytorch_kobert.ipynb
train_history = []
test_history = []
loss_history = []

for e in range(num_epochs):
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
         
        # print(label.shape, out.shape)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
            train_history.append(train_acc / (batch_id+1))
            loss_history.append(loss.data.cpu().numpy())
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    # train_history.append(train_acc / (batch_id+1))

    # .eval() : nn.Module에서 train time과 eval time에서 수행하는 다른 작업을 수행할 수 있도록 switching 하는 함수
    # 즉, model이 Dropout이나 BatNorm2d를 사용하는 경우, train 시에는 사용하지만 evaluation을 할 때에는 사용하지 않도록 설정해주는 함수
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length = valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    test_history.append(test_acc / (batch_id+1))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 1 batch id 1 loss 1.9845285415649414 train acc 0.125
epoch 1 batch id 201 loss 1.926171898841858 train acc 0.16402363184079602
epoch 1 batch id 401 loss 1.7320576906204224 train acc 0.1925654613466334
epoch 1 batch id 601 loss 1.6617203950881958 train acc 0.21667013311148087
epoch 1 batch id 801 loss 1.552905559539795 train acc 0.24079275905118602
epoch 1 batch id 1001 loss 1.3918482065200806 train acc 0.26287774725274726
epoch 1 batch id 1201 loss 1.4500925540924072 train acc 0.2878200457951707
epoch 1 batch id 1401 loss 1.2333296537399292 train acc 0.3124219307637402
epoch 1 batch id 1601 loss 1.1572463512420654 train acc 0.336079013116802
epoch 1 batch id 1801 loss 1.2057034969329834 train acc 0.35904532204330925
epoch 1 batch id 2001 loss 1.0680456161499023 train acc 0.3812078335832084
epoch 1 batch id 2201 loss 1.0428285598754883 train acc 0.4028282598818719
epoch 1 batch id 2401 loss 0.8436154723167419 train acc 0.4235149416909621
epoch 1 train acc 0.4425999526351373


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):


  0%|          | 0/649 [00:00<?, ?it/s]

epoch 1 test acc 0.728764618959812


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 2 batch id 1 loss 0.9049066305160522 train acc 0.65625
epoch 2 batch id 201 loss 0.8277468681335449 train acc 0.7081778606965174
epoch 2 batch id 401 loss 0.6674256920814514 train acc 0.7231140897755611
epoch 2 batch id 601 loss 0.8253136873245239 train acc 0.7346609816971714
epoch 2 batch id 801 loss 0.5414807200431824 train acc 0.7439918851435705
epoch 2 batch id 1001 loss 0.337825208902359 train acc 0.7536057692307693
epoch 2 batch id 1201 loss 0.5556764602661133 train acc 0.7636084512905912
epoch 2 batch id 1401 loss 0.2520119547843933 train acc 0.771279443254818
epoch 2 batch id 1601 loss 0.3958921730518341 train acc 0.7773461898813242
epoch 2 batch id 1801 loss 0.5235512256622314 train acc 0.7838353692393115
epoch 2 batch id 2001 loss 0.3451877534389496 train acc 0.7891913418290855
epoch 2 batch id 2201 loss 0.6320160031318665 train acc 0.7936449341208541
epoch 2 batch id 2401 loss 0.46713191270828247 train acc 0.7974997396917951
epoch 2 train acc 0.8015966233373164


  0%|          | 0/649 [00:00<?, ?it/s]

epoch 2 test acc 0.8615358748642299


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 3 batch id 1 loss 0.47693026065826416 train acc 0.84375
epoch 3 batch id 201 loss 0.44402673840522766 train acc 0.8475590796019901
epoch 3 batch id 401 loss 0.353430837392807 train acc 0.8533743765586035
epoch 3 batch id 601 loss 0.34017255902290344 train acc 0.8514455074875208
epoch 3 batch id 801 loss 0.3370334506034851 train acc 0.85104556803995
epoch 3 batch id 1001 loss 0.26475974917411804 train acc 0.8511956793206793
epoch 3 batch id 1201 loss 0.4256453216075897 train acc 0.8534814737718568
epoch 3 batch id 1401 loss 0.21005721390247345 train acc 0.8538543897216274
epoch 3 batch id 1601 loss 0.3128219544887543 train acc 0.8547977826358526
epoch 3 batch id 1801 loss 0.4717724025249481 train acc 0.8561823292615214
epoch 3 batch id 2001 loss 0.3983374834060669 train acc 0.8571026986506747
epoch 3 batch id 2201 loss 0.569298267364502 train acc 0.8579338936846888
epoch 3 batch id 2401 loss 0.22427982091903687 train acc 0.8585615368596418
epoch 3 train acc 0.8596102274692619


  0%|          | 0/649 [00:00<?, ?it/s]

epoch 3 test acc 0.8753287687236353


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 4 batch id 1 loss 0.32215651869773865 train acc 0.875
epoch 4 batch id 201 loss 0.38797539472579956 train acc 0.8650497512437811
epoch 4 batch id 401 loss 0.2670033872127533 train acc 0.8703241895261845
epoch 4 batch id 601 loss 0.36659467220306396 train acc 0.8684744176372712
epoch 4 batch id 801 loss 0.29688790440559387 train acc 0.867626404494382
epoch 4 batch id 1001 loss 0.15064498782157898 train acc 0.8677572427572428
epoch 4 batch id 1201 loss 0.47881290316581726 train acc 0.869132493755204
epoch 4 batch id 1401 loss 0.17480291426181793 train acc 0.8693232512491078
epoch 4 batch id 1601 loss 0.23587755858898163 train acc 0.8696908182386008
epoch 4 batch id 1801 loss 0.4500165581703186 train acc 0.8702370210993893
epoch 4 batch id 2001 loss 0.3410252332687378 train acc 0.8707911669165417
epoch 4 batch id 2201 loss 0.4244142472743988 train acc 0.8713155951840073
epoch 4 batch id 2401 loss 0.18275907635688782 train acc 0.8715769471053728
epoch 4 train acc 0.872167443633848


  0%|          | 0/649 [00:00<?, ?it/s]

epoch 4 test acc 0.8738613472176615


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 5 batch id 1 loss 0.18522235751152039 train acc 0.90625
epoch 5 batch id 201 loss 0.29888466000556946 train acc 0.8775652985074627
epoch 5 batch id 401 loss 0.3407535254955292 train acc 0.8805720074812967
epoch 5 batch id 601 loss 0.3203809857368469 train acc 0.8797576955074875
epoch 5 batch id 801 loss 0.2915413975715637 train acc 0.8779650436953808
epoch 5 batch id 1001 loss 0.14070133864879608 train acc 0.877997002997003
epoch 5 batch id 1201 loss 0.41849470138549805 train acc 0.8790461074104913
epoch 5 batch id 1401 loss 0.09675915539264679 train acc 0.879026142041399
epoch 5 batch id 1601 loss 0.24858897924423218 train acc 0.8791868363522798
epoch 5 batch id 1801 loss 0.47516390681266785 train acc 0.8800579539144919
epoch 5 batch id 2001 loss 0.22263289988040924 train acc 0.8804425912043978
epoch 5 batch id 2201 loss 0.4410880208015442 train acc 0.880920604270786
epoch 5 batch id 2401 loss 0.16320763528347015 train acc 0.8811562890462308
epoch 5 train acc 0.8816374681613786


  0%|          | 0/649 [00:00<?, ?it/s]

epoch 5 test acc 0.8752324667205537


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 6 batch id 1 loss 0.28065362572669983 train acc 0.90625
epoch 6 batch id 201 loss 0.2953062057495117 train acc 0.8826181592039801
epoch 6 batch id 401 loss 0.27868759632110596 train acc 0.8863388403990025
epoch 6 batch id 601 loss 0.2884986400604248 train acc 0.8850353577371048
epoch 6 batch id 801 loss 0.3370751738548279 train acc 0.884246254681648
epoch 6 batch id 1001 loss 0.2243240922689438 train acc 0.8847714785214785
epoch 6 batch id 1201 loss 0.3330165445804596 train acc 0.8858633430474604
epoch 6 batch id 1401 loss 0.09140406548976898 train acc 0.8865877052105638
epoch 6 batch id 1601 loss 0.24419212341308594 train acc 0.8871213304184884
epoch 6 batch id 1801 loss 0.48390913009643555 train acc 0.8881437395891172
epoch 6 batch id 2001 loss 0.17282871901988983 train acc 0.8888212143928036
epoch 6 batch id 2201 loss 0.6473410129547119 train acc 0.8893045774647887
epoch 6 batch id 2401 loss 0.11919718980789185 train acc 0.8895967825905873
epoch 6 train acc 0.8902824154507719


  0%|          | 0/649 [00:00<?, ?it/s]

epoch 6 test acc 0.8756212268559448


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 7 batch id 1 loss 0.20176619291305542 train acc 0.921875
epoch 7 batch id 201 loss 0.2998729646205902 train acc 0.8958333333333334
epoch 7 batch id 401 loss 0.2829831838607788 train acc 0.8983011221945137
epoch 7 batch id 601 loss 0.2723332345485687 train acc 0.8979825291181365
epoch 7 batch id 801 loss 0.22733643651008606 train acc 0.8969062109862672
epoch 7 batch id 1001 loss 0.10265221446752548 train acc 0.8975087412587412
epoch 7 batch id 1201 loss 0.2937442660331726 train acc 0.898834304746045
epoch 7 batch id 1401 loss 0.05987562611699104 train acc 0.8994022127052106
epoch 7 batch id 1601 loss 0.16275019943714142 train acc 0.8995744846970644
epoch 7 batch id 1801 loss 0.33919447660446167 train acc 0.9007409078289839
epoch 7 batch id 2001 loss 0.19460277259349823 train acc 0.9014711394302849
epoch 7 batch id 2201 loss 0.40188634395599365 train acc 0.9021041572012721
epoch 7 batch id 2401 loss 0.1476064920425415 train acc 0.9023193461057892
epoch 7 train acc 0.902845650490550

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 7 test acc 0.8760522177877693


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 8 batch id 1 loss 0.2385074347257614 train acc 0.953125
epoch 8 batch id 201 loss 0.27020326256752014 train acc 0.9073383084577115
epoch 8 batch id 401 loss 0.3326324224472046 train acc 0.9104192643391521
epoch 8 batch id 601 loss 0.22640489041805267 train acc 0.9099677620632279
epoch 8 batch id 801 loss 0.33565425872802734 train acc 0.9090589887640449
epoch 8 batch id 1001 loss 0.08024349808692932 train acc 0.9098245504495505
epoch 8 batch id 1201 loss 0.24121125042438507 train acc 0.9112718567860116
epoch 8 batch id 1401 loss 0.08376356214284897 train acc 0.9122390256959315
epoch 8 batch id 1601 loss 0.11954901367425919 train acc 0.9126717676452217
epoch 8 batch id 1801 loss 0.39077791571617126 train acc 0.9133380760688506
epoch 8 batch id 2001 loss 0.29813045263290405 train acc 0.9140039355322339
epoch 8 batch id 2201 loss 0.3978688418865204 train acc 0.9149037369377556
epoch 8 batch id 2401 loss 0.11631018668413162 train acc 0.9149508017492711
epoch 8 train acc 0.915481603369

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 8 test acc 0.8766071383465103


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 9 batch id 1 loss 0.1583421677350998 train acc 0.9375
epoch 9 batch id 201 loss 0.2190466821193695 train acc 0.9240516169154229
epoch 9 batch id 401 loss 0.15958094596862793 train acc 0.9242908354114713
epoch 9 batch id 601 loss 0.23831753432750702 train acc 0.9239028702163061
epoch 9 batch id 801 loss 0.26936036348342896 train acc 0.9225967540574282
epoch 9 batch id 1001 loss 0.08267020434141159 train acc 0.9233266733266733
epoch 9 batch id 1201 loss 0.1563831865787506 train acc 0.9243729184013322
epoch 9 batch id 1401 loss 0.060635991394519806 train acc 0.9254438793718772
epoch 9 batch id 1601 loss 0.09371648728847504 train acc 0.9260325577763897
epoch 9 batch id 1801 loss 0.37769708037376404 train acc 0.9265078428650749
epoch 9 batch id 2001 loss 0.14036200940608978 train acc 0.9269818215892054
epoch 9 batch id 2201 loss 0.2952404022216797 train acc 0.9277814061790095
epoch 9 batch id 2401 loss 0.0980643630027771 train acc 0.9281354123281966
epoch 9 train acc 0.928522172062199

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 9 test acc 0.8786282963702039


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 10 batch id 1 loss 0.1428861916065216 train acc 0.953125
epoch 10 batch id 201 loss 0.1722133457660675 train acc 0.9350901741293532
epoch 10 batch id 401 loss 0.16437596082687378 train acc 0.9360582917705735
epoch 10 batch id 601 loss 0.1702992022037506 train acc 0.936278078202995
epoch 10 batch id 801 loss 0.23380634188652039 train acc 0.9352957240948814
epoch 10 batch id 1001 loss 0.07823178172111511 train acc 0.936500999000999
epoch 10 batch id 1201 loss 0.14554694294929504 train acc 0.9372788301415487
epoch 10 batch id 1401 loss 0.03650921583175659 train acc 0.938459136331192
epoch 10 batch id 1601 loss 0.18783405423164368 train acc 0.9387394597126796
epoch 10 batch id 1801 loss 0.30462443828582764 train acc 0.939157065519156
epoch 10 batch id 2001 loss 0.1985362470149994 train acc 0.9395068090954523
epoch 10 batch id 2201 loss 0.2719416320323944 train acc 0.9400201612903226
epoch 10 batch id 2401 loss 0.04439377039670944 train acc 0.9404610058309038
epoch 10 train acc 0.9409

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 10 test acc 0.8758836892823765


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 11 batch id 1 loss 0.12171006947755814 train acc 0.953125
epoch 11 batch id 201 loss 0.1810160130262375 train acc 0.9459732587064676
epoch 11 batch id 401 loss 0.1313295066356659 train acc 0.9453709476309227
epoch 11 batch id 601 loss 0.1864653378725052 train acc 0.9450655158069884
epoch 11 batch id 801 loss 0.10750865936279297 train acc 0.9446200062421972
epoch 11 batch id 1001 loss 0.03594360873103142 train acc 0.9453359140859141
epoch 11 batch id 1201 loss 0.18893058598041534 train acc 0.9461126144879267
epoch 11 batch id 1401 loss 0.015030749142169952 train acc 0.946990988579586
epoch 11 batch id 1601 loss 0.13434503972530365 train acc 0.9471521705184259
epoch 11 batch id 1801 loss 0.2657940685749054 train acc 0.9478241254858412
epoch 11 batch id 2001 loss 0.0758536085486412 train acc 0.9481899675162418
epoch 11 batch id 2201 loss 0.18389707803726196 train acc 0.9485532144479782
epoch 11 batch id 2401 loss 0.032373812049627304 train acc 0.9489535610162433
epoch 11 train acc 0

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 11 test acc 0.8712876367172699


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 12 batch id 1 loss 0.05918636918067932 train acc 0.984375
epoch 12 batch id 201 loss 0.09276413172483444 train acc 0.9539023631840796
epoch 12 batch id 401 loss 0.14663034677505493 train acc 0.9523456982543641
epoch 12 batch id 601 loss 0.08646738529205322 train acc 0.953176996672213
epoch 12 batch id 801 loss 0.10344263166189194 train acc 0.9535541510611736
epoch 12 batch id 1001 loss 0.01979481615126133 train acc 0.954202047952048
epoch 12 batch id 1201 loss 0.14461219310760498 train acc 0.9547902789342215
epoch 12 batch id 1401 loss 0.009722571820020676 train acc 0.9557682012847966
epoch 12 batch id 1601 loss 0.09135570377111435 train acc 0.9557600718301061
epoch 12 batch id 1801 loss 0.29075950384140015 train acc 0.956057398667407
epoch 12 batch id 2001 loss 0.03725742921233177 train acc 0.9564124187906047
epoch 12 batch id 2201 loss 0.12171641737222672 train acc 0.9568378009995456
epoch 12 batch id 2401 loss 0.02783815935254097 train acc 0.9571077155351937
epoch 12 train acc

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 12 test acc 0.8746076877415444


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 13 batch id 1 loss 0.05443507805466652 train acc 0.984375
epoch 13 batch id 201 loss 0.0667264461517334 train acc 0.9592661691542289
epoch 13 batch id 401 loss 0.06734952330589294 train acc 0.9612687032418953
epoch 13 batch id 601 loss 0.11262334138154984 train acc 0.9612624792013311
epoch 13 batch id 801 loss 0.09245134890079498 train acc 0.9610447877652933
epoch 13 batch id 1001 loss 0.012071946635842323 train acc 0.9609296953046953
epoch 13 batch id 1201 loss 0.23423872888088226 train acc 0.9617766444629475
epoch 13 batch id 1401 loss 0.011463014408946037 train acc 0.9620025874375446
epoch 13 batch id 1601 loss 0.07884171605110168 train acc 0.9622208775765146
epoch 13 batch id 1801 loss 0.23903875052928925 train acc 0.9624080372015547
epoch 13 batch id 2001 loss 0.022441772744059563 train acc 0.9627842328835582
epoch 13 batch id 2201 loss 0.24935558438301086 train acc 0.9630210699681963
epoch 13 batch id 2401 loss 0.009819946251809597 train acc 0.9630622657226156
epoch 13 trai

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 13 test acc 0.8723915576296446


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 14 batch id 1 loss 0.04265613481402397 train acc 0.984375
epoch 14 batch id 201 loss 0.0730186402797699 train acc 0.9647077114427861
epoch 14 batch id 401 loss 0.05204812064766884 train acc 0.9652431421446384
epoch 14 batch id 601 loss 0.05750257149338722 train acc 0.9654742096505824
epoch 14 batch id 801 loss 0.12685832381248474 train acc 0.9654728464419475
epoch 14 batch id 1001 loss 0.00870201550424099 train acc 0.9657217782217782
epoch 14 batch id 1201 loss 0.05747578665614128 train acc 0.9662520815986678
epoch 14 batch id 1401 loss 0.005662761628627777 train acc 0.9668763383297645
epoch 14 batch id 1601 loss 0.06244184449315071 train acc 0.9669152092442224
epoch 14 batch id 1801 loss 0.19169308245182037 train acc 0.9668760410882843
epoch 14 batch id 2001 loss 0.051285747438669205 train acc 0.9674615817091454
epoch 14 batch id 2201 loss 0.09416540712118149 train acc 0.9677845297592004
epoch 14 batch id 2401 loss 0.015926295891404152 train acc 0.967832413577676
epoch 14 train 

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 14 test acc 0.8736711112935411


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 15 batch id 1 loss 0.10972421616315842 train acc 0.984375
epoch 15 batch id 201 loss 0.020132651552557945 train acc 0.9710043532338308
epoch 15 batch id 401 loss 0.06749652326107025 train acc 0.9719451371571073
epoch 15 batch id 601 loss 0.0763111487030983 train acc 0.9721037853577371
epoch 15 batch id 801 loss 0.029555441811680794 train acc 0.9715980024968789
epoch 15 batch id 1001 loss 0.0009010813082568347 train acc 0.971544080919081
epoch 15 batch id 1201 loss 0.026325171813368797 train acc 0.9714690882597835
epoch 15 batch id 1401 loss 0.019354745745658875 train acc 0.9720400606709493
epoch 15 batch id 1601 loss 0.10441777855157852 train acc 0.9721365552779513
epoch 15 batch id 1801 loss 0.22095754742622375 train acc 0.972220294280955
epoch 15 batch id 2001 loss 0.04982816427946091 train acc 0.9723497626186907
epoch 15 batch id 2201 loss 0.08362994343042374 train acc 0.972505395274875
epoch 15 batch id 2401 loss 0.0663485899567604 train acc 0.9725960537276135
epoch 15 train 

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 15 test acc 0.8745090176564196


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 16 batch id 1 loss 0.026198649778962135 train acc 0.984375
epoch 16 batch id 201 loss 0.04361971095204353 train acc 0.9742692786069652
epoch 16 batch id 401 loss 0.056611862033605576 train acc 0.9742440773067331
epoch 16 batch id 601 loss 0.002891833893954754 train acc 0.9742096505823628
epoch 16 batch id 801 loss 0.0809599906206131 train acc 0.9745435393258427
epoch 16 batch id 1001 loss 0.000862160581164062 train acc 0.9744630369630369
epoch 16 batch id 1201 loss 0.04842482879757881 train acc 0.9750078059950041
epoch 16 batch id 1401 loss 0.08776913583278656 train acc 0.9754862598144183
epoch 16 batch id 1601 loss 0.07232070714235306 train acc 0.9753084009993754
epoch 16 batch id 1801 loss 0.17751900851726532 train acc 0.9756472098833981
epoch 16 batch id 2001 loss 0.018041331321001053 train acc 0.9758402048975512
epoch 16 batch id 2201 loss 0.08227740973234177 train acc 0.9760549182189914
epoch 16 batch id 2401 loss 0.0710664764046669 train acc 0.9762403685964182
epoch 16 trai

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 16 test acc 0.8710879284649776


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 17 batch id 1 loss 0.07000727206468582 train acc 0.984375
epoch 17 batch id 201 loss 0.005182334221899509 train acc 0.9779228855721394
epoch 17 batch id 401 loss 0.12346728146076202 train acc 0.9776730049875312
epoch 17 batch id 601 loss 0.051872655749320984 train acc 0.9777194259567388
epoch 17 batch id 801 loss 0.03760964795947075 train acc 0.9774305555555556
epoch 17 batch id 1001 loss 0.006251759827136993 train acc 0.9778190559440559
epoch 17 batch id 1201 loss 0.035553038120269775 train acc 0.978143213988343
epoch 17 batch id 1401 loss 0.008257186971604824 train acc 0.97834136331192
epoch 17 batch id 1601 loss 0.015582817606627941 train acc 0.97865591817614
epoch 17 batch id 1801 loss 0.03775428235530853 train acc 0.9786143114936147
epoch 17 batch id 2001 loss 0.01025336142629385 train acc 0.9788308970514743
epoch 17 batch id 2201 loss 0.06496711820363998 train acc 0.978944229895502
epoch 17 batch id 2401 loss 0.0030832411721348763 train acc 0.9792013744273219
epoch 17 train

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 17 test acc 0.8737180782540603


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 18 batch id 1 loss 0.11373323202133179 train acc 0.984375
epoch 18 batch id 201 loss 0.07604016363620758 train acc 0.980488184079602
epoch 18 batch id 401 loss 0.007563331630080938 train acc 0.9821539900249376
epoch 18 batch id 601 loss 0.016751766204833984 train acc 0.980865224625624
epoch 18 batch id 801 loss 0.023660697042942047 train acc 0.9807857365792759
epoch 18 batch id 1001 loss 0.0025016185827553272 train acc 0.981018981018981
epoch 18 batch id 1201 loss 0.04343416541814804 train acc 0.9814997918401333
epoch 18 batch id 1401 loss 0.05891503393650055 train acc 0.9814641327623126
epoch 18 batch id 1601 loss 0.08995713293552399 train acc 0.9813593066833229
epoch 18 batch id 1801 loss 0.3485892415046692 train acc 0.9816594947251527
epoch 18 batch id 2001 loss 0.006176646798849106 train acc 0.981907483758121
epoch 18 batch id 2201 loss 0.18645521998405457 train acc 0.9819258291685597
epoch 18 batch id 2401 loss 0.02893836982548237 train acc 0.9819020720533111
epoch 18 train 

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 18 test acc 0.8736446677107278


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 19 batch id 1 loss 0.11512687057256699 train acc 0.984375
epoch 19 batch id 201 loss 0.03222537413239479 train acc 0.9837531094527363
epoch 19 batch id 401 loss 0.04437709599733353 train acc 0.9837515586034913
epoch 19 batch id 601 loss 0.09121517837047577 train acc 0.9835170549084858
epoch 19 batch id 801 loss 0.10238514840602875 train acc 0.9831460674157303
epoch 19 batch id 1001 loss 0.00042731367284432054 train acc 0.9834228271728271
epoch 19 batch id 1201 loss 0.09060501307249069 train acc 0.9836724604496253
epoch 19 batch id 1401 loss 0.006965440697968006 train acc 0.9840404175588865
epoch 19 batch id 1601 loss 0.04681723937392235 train acc 0.9839358213616489
epoch 19 batch id 1801 loss 0.1106170192360878 train acc 0.98379372570794
epoch 19 batch id 2001 loss 0.03606417402625084 train acc 0.9837268865567217
epoch 19 batch id 2201 loss 0.05523901432752609 train acc 0.9839987505679236
epoch 19 batch id 2401 loss 0.004869468975812197 train acc 0.9839780299875052
epoch 19 train

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 19 test acc 0.8737180782540603


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 20 batch id 1 loss 0.08222443610429764 train acc 0.96875
epoch 20 batch id 201 loss 0.018447762355208397 train acc 0.9851523631840796
epoch 20 batch id 401 loss 0.1310364454984665 train acc 0.9846477556109726
epoch 20 batch id 601 loss 0.0013434733264148235 train acc 0.984426996672213
epoch 20 batch id 801 loss 0.06685736030340195 train acc 0.9842969725343321
epoch 20 batch id 1001 loss 0.02564098685979843 train acc 0.9845467032967034
epoch 20 batch id 1201 loss 0.11886604130268097 train acc 0.9846482098251457
epoch 20 batch id 1401 loss 0.0002996843250002712 train acc 0.9847988044254105
epoch 20 batch id 1601 loss 0.003819718025624752 train acc 0.9845409119300437
epoch 20 batch id 1801 loss 0.09980011731386185 train acc 0.9847654081066074
epoch 20 batch id 2001 loss 0.02785882167518139 train acc 0.9850387306346826
epoch 20 batch id 2201 loss 0.017468370497226715 train acc 0.985290776919582
epoch 20 batch id 2401 loss 0.19612723588943481 train acc 0.9852600478967097
epoch 20 trai

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 20 test acc 0.8756188587739019


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 21 batch id 1 loss 0.0015323562547564507 train acc 1.0
epoch 21 batch id 201 loss 0.06251730769872665 train acc 0.9872512437810945
epoch 21 batch id 401 loss 0.0026550900656729937 train acc 0.9858556733167082
epoch 21 batch id 601 loss 0.08629786968231201 train acc 0.9852069467554077
epoch 21 batch id 801 loss 0.0019486017990857363 train acc 0.9853503433208489
epoch 21 batch id 1001 loss 0.0003618251357693225 train acc 0.9855144855144855
epoch 21 batch id 1201 loss 0.04716123640537262 train acc 0.9855719192339717
epoch 21 batch id 1401 loss 0.0028280019760131836 train acc 0.9854791220556746
epoch 21 batch id 1601 loss 0.0016256766393780708 train acc 0.985507104934416
epoch 21 batch id 1801 loss 0.15684525668621063 train acc 0.9859106052193226
epoch 21 batch id 2001 loss 0.0012011260259896517 train acc 0.9860148050974513
epoch 21 batch id 2201 loss 0.14977557957172394 train acc 0.9861071671967288
epoch 21 batch id 2401 loss 0.0006023261812515557 train acc 0.9861971574344023
epoch 

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 21 test acc 0.874851994872313


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 22 batch id 1 loss 0.0033655501902103424 train acc 1.0
epoch 22 batch id 201 loss 0.00289111933670938 train acc 0.9877953980099502
epoch 22 batch id 401 loss 0.0020760095212608576 train acc 0.9883494389027432
epoch 22 batch id 601 loss 0.11010799556970596 train acc 0.9876767886855241
epoch 22 batch id 801 loss 0.07885513454675674 train acc 0.9878082084893882
epoch 22 batch id 1001 loss 0.000305596535326913 train acc 0.9876841908091908
epoch 22 batch id 1201 loss 0.07902412861585617 train acc 0.987835657785179
epoch 22 batch id 1401 loss 0.0003309061867184937 train acc 0.987631602426838
epoch 22 batch id 1601 loss 0.06251966953277588 train acc 0.9877908338538414
epoch 22 batch id 1801 loss 0.06133681908249855 train acc 0.9879494031093837
epoch 22 batch id 2001 loss 0.06867989152669907 train acc 0.98804503998001
epoch 22 batch id 2201 loss 0.03464880213141441 train acc 0.9881729895502045
epoch 22 batch id 2401 loss 0.004238664638251066 train acc 0.9882275614327364
epoch 22 train ac

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 22 test acc 0.8783153148601884


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 23 batch id 1 loss 0.00087804562645033 train acc 1.0
epoch 23 batch id 201 loss 0.0849708616733551 train acc 0.9891169154228856
epoch 23 batch id 401 loss 0.06441119313240051 train acc 0.989323566084788
epoch 23 batch id 601 loss 0.0008155154064297676 train acc 0.9890806988352745
epoch 23 batch id 801 loss 0.08600211143493652 train acc 0.98895911360799
epoch 23 batch id 1001 loss 0.0005852434551343322 train acc 0.9888080669330669
epoch 23 batch id 1201 loss 0.12671995162963867 train acc 0.9891496669442131
epoch 23 batch id 1401 loss 0.002846040762960911 train acc 0.9894718058529621
epoch 23 batch id 1601 loss 0.002014716388657689 train acc 0.9895182698313554
epoch 23 batch id 1801 loss 0.10465074330568314 train acc 0.989502359800111
epoch 23 batch id 2001 loss 0.057542335242033005 train acc 0.9893256496751625
epoch 23 batch id 2201 loss 0.013110724277794361 train acc 0.9893727283053158
epoch 23 batch id 2401 loss 0.004361099563539028 train acc 0.9894184714702208
epoch 23 train ac

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 23 test acc 0.8768960443557553


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 24 batch id 1 loss 0.007545886095613241 train acc 1.0
epoch 24 batch id 201 loss 0.0005727996467612684 train acc 0.9908271144278606
epoch 24 batch id 401 loss 0.00036089675268158317 train acc 0.9910769950124688
epoch 24 batch id 601 loss 0.012253810651600361 train acc 0.990380615640599
epoch 24 batch id 801 loss 0.020276451483368874 train acc 0.9906367041198502
epoch 24 batch id 1001 loss 0.0004866990784648806 train acc 0.9908216783216783
epoch 24 batch id 1201 loss 0.11139972507953644 train acc 0.9908539758534555
epoch 24 batch id 1401 loss 0.00027365927235223353 train acc 0.9910331905781584
epoch 24 batch id 1601 loss 0.0003188106056768447 train acc 0.991040755777639
epoch 24 batch id 1801 loss 0.026966391131281853 train acc 0.9909859106052193
epoch 24 batch id 2001 loss 0.0005022427067160606 train acc 0.9911684782608695
epoch 24 batch id 2201 loss 0.012883977964520454 train acc 0.991119093593821
epoch 24 batch id 2401 loss 0.0003193554875906557 train acc 0.9911885672636401
epo

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 24 test acc 0.8759811753264795


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 25 batch id 1 loss 0.013273412361741066 train acc 1.0
epoch 25 batch id 201 loss 0.0002954571391455829 train acc 0.9933924129353234
epoch 25 batch id 401 loss 0.06106961518526077 train acc 0.9921290523690773
epoch 25 batch id 601 loss 0.0019565215334296227 train acc 0.9914205490848585
epoch 25 batch id 801 loss 0.06452019512653351 train acc 0.9912999375780275
epoch 25 batch id 1001 loss 0.00016395033162552863 train acc 0.9914928821178821
epoch 25 batch id 1201 loss 0.09785747528076172 train acc 0.9918427352206495
epoch 25 batch id 1401 loss 0.00021306662529241294 train acc 0.9920927016416845
epoch 25 batch id 1601 loss 0.000771899416577071 train acc 0.9920167083073079
epoch 25 batch id 1801 loss 0.03321068361401558 train acc 0.9919749444752916
epoch 25 batch id 2001 loss 0.0007793504046276212 train acc 0.9920977011494253
epoch 25 batch id 2201 loss 0.04143016040325165 train acc 0.9920277714675148
epoch 25 batch id 2401 loss 0.011939000338315964 train acc 0.9919955226988755
epoch 

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 25 test acc 0.8749194852105383


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 26 batch id 1 loss 0.010353908874094486 train acc 1.0
epoch 26 batch id 201 loss 0.00042368052527308464 train acc 0.9911380597014925
epoch 26 batch id 401 loss 0.0011248469818383455 train acc 0.9908042394014963
epoch 26 batch id 601 loss 0.0006341097760014236 train acc 0.9912125623960066
epoch 26 batch id 801 loss 0.0007576039643026888 train acc 0.9913974719101124
epoch 26 batch id 1001 loss 0.0003764632565435022 train acc 0.9917738511488512
epoch 26 batch id 1201 loss 0.000840947323013097 train acc 0.9920378850957535
epoch 26 batch id 1401 loss 0.00043188745621591806 train acc 0.9922042291220556
epoch 26 batch id 1601 loss 0.000448485225206241 train acc 0.992172860712055
epoch 26 batch id 1801 loss 0.012179936282336712 train acc 0.992278595224875
epoch 26 batch id 2001 loss 0.07867436110973358 train acc 0.9922226386806596
epoch 26 batch id 2201 loss 0.015031857416033745 train acc 0.9923472285324852
epoch 26 batch id 2401 loss 0.0927671417593956 train acc 0.992405508121616
epoch 

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 26 test acc 0.8767515913511328


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 27 batch id 1 loss 0.0005310570122674108 train acc 1.0
epoch 27 batch id 201 loss 0.0008775045862421393 train acc 0.9937033582089553
epoch 27 batch id 401 loss 0.0001625462027732283 train acc 0.9934928304239401
epoch 27 batch id 601 loss 0.01890593208372593 train acc 0.9933704242928453
epoch 27 batch id 801 loss 0.00034211695310659707 train acc 0.9933286516853933
epoch 27 batch id 1001 loss 0.00014611198275815696 train acc 0.9931006493506493
epoch 27 batch id 1201 loss 0.0010347855277359486 train acc 0.9932478143213989
epoch 27 batch id 1401 loss 0.0001356704015051946 train acc 0.9933641149179158
epoch 27 batch id 1601 loss 0.00018648298282641917 train acc 0.9933244846970644
epoch 27 batch id 1801 loss 0.0004647172463592142 train acc 0.9932502776235425
epoch 27 batch id 2001 loss 0.0013012764975428581 train acc 0.9932611819090454
epoch 27 batch id 2201 loss 0.012843379750847816 train acc 0.9933055997273966
epoch 27 batch id 2401 loss 0.000419817486545071 train acc 0.9934142024156

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 27 test acc 0.8774955637929728


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 28 batch id 1 loss 0.00031731335911899805 train acc 1.0
epoch 28 batch id 201 loss 0.0018152204575017095 train acc 0.9930814676616916
epoch 28 batch id 401 loss 0.00021419388940557837 train acc 0.993531795511222
epoch 28 batch id 601 loss 0.00025665294378995895 train acc 0.9937344009983361
epoch 28 batch id 801 loss 0.01876474730670452 train acc 0.9938748439450686
epoch 28 batch id 1001 loss 0.00014368019765242934 train acc 0.9936157592407593
epoch 28 batch id 1201 loss 0.021463584154844284 train acc 0.9936771440466278
epoch 28 batch id 1401 loss 0.00030725536635145545 train acc 0.9938436830835118
epoch 28 batch id 1601 loss 0.00025925226509571075 train acc 0.9939393347907558
epoch 28 batch id 1801 loss 0.22910907864570618 train acc 0.9939790394225431
epoch 28 batch id 2001 loss 0.0010438233148306608 train acc 0.9941435532233883
epoch 28 batch id 2201 loss 0.00777006009593606 train acc 0.9941148909586551
epoch 28 batch id 2401 loss 0.0002507455355953425 train acc 0.99418861932528

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 28 test acc 0.8772559928262902


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 29 batch id 1 loss 0.0002056067605735734 train acc 1.0
epoch 29 batch id 201 loss 0.007156399078667164 train acc 0.9951803482587065
epoch 29 batch id 401 loss 0.00017324619693681598 train acc 0.9950903990024937
epoch 29 batch id 601 loss 0.00017627184570301324 train acc 0.9947483361064892
epoch 29 batch id 801 loss 0.031774602830410004 train acc 0.9948306803995006
epoch 29 batch id 1001 loss 0.0008863079128786922 train acc 0.9946615884115884
epoch 29 batch id 1201 loss 0.00047021257341839373 train acc 0.994626873438801
epoch 29 batch id 1401 loss 0.00019673097995109856 train acc 0.994635528194147
epoch 29 batch id 1601 loss 0.0004030277777928859 train acc 0.9946029825109307
epoch 29 batch id 1801 loss 0.06054703891277313 train acc 0.9946817740144365
epoch 29 batch id 2001 loss 0.004976477473974228 train acc 0.9947370064967517
epoch 29 batch id 2201 loss 0.1081252321600914 train acc 0.9947963993639255
epoch 29 batch id 2401 loss 0.0003168264520354569 train acc 0.9948198667221991
e

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 29 test acc 0.8770381292783349


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 30 batch id 1 loss 0.0003363729047123343 train acc 1.0
epoch 30 batch id 201 loss 0.0002278962783748284 train acc 0.9954135572139303
epoch 30 batch id 401 loss 0.00018457973783370107 train acc 0.9948176433915212
epoch 30 batch id 601 loss 0.00044018292101100087 train acc 0.9945663477537438
epoch 30 batch id 801 loss 0.0002180104493163526 train acc 0.9945575842696629
epoch 30 batch id 1001 loss 0.033059846609830856 train acc 0.9946147602397603
epoch 30 batch id 1201 loss 0.11397315561771393 train acc 0.9947829933388843
epoch 30 batch id 1401 loss 0.00014696571452077478 train acc 0.994925499643112
epoch 30 batch id 1601 loss 0.0011910772882401943 train acc 0.9950519206745784
epoch 30 batch id 1801 loss 0.00023115900694392622 train acc 0.9951329122709606
epoch 30 batch id 2001 loss 0.000204972704523243 train acc 0.995283608195902
epoch 30 batch id 2201 loss 0.01281448733061552 train acc 0.9952720354384371
epoch 30 batch id 2401 loss 0.0001528276625322178 train acc 0.9953534985422741

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 30 test acc 0.8779541823486322


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 31 batch id 1 loss 0.0001637848763493821 train acc 1.0
epoch 31 batch id 201 loss 0.006831794511526823 train acc 0.9958799751243781
epoch 31 batch id 401 loss 0.010318439453840256 train acc 0.995791770573566
epoch 31 batch id 601 loss 0.0002686085645109415 train acc 0.9956062811980033
epoch 31 batch id 801 loss 0.00017213258252013475 train acc 0.9956499687890137
epoch 31 batch id 1001 loss 0.0012132520787417889 train acc 0.9955981518481518
epoch 31 batch id 1201 loss 0.026170894503593445 train acc 0.9956286427976686
epoch 31 batch id 1401 loss 0.06125519797205925 train acc 0.9957061920057102
epoch 31 batch id 1601 loss 0.0025183723773807287 train acc 0.9956960493441599
epoch 31 batch id 1801 loss 0.0005609108484350145 train acc 0.9956968350916158
epoch 31 batch id 2001 loss 9.64514838415198e-05 train acc 0.9957286981509246
epoch 31 batch id 2201 loss 0.005212233401834965 train acc 0.9957760676965016
epoch 31 batch id 2401 loss 0.00012395433441270143 train acc 0.9957830070803831
e

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 31 test acc 0.8768936762737124


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 32 batch id 1 loss 0.00012874309322796762 train acc 1.0
epoch 32 batch id 201 loss 0.000182808144018054 train acc 0.9961909203980099
epoch 32 batch id 401 loss 0.00012462316954042763 train acc 0.9960645261845387
epoch 32 batch id 601 loss 0.050805188715457916 train acc 0.9959442595673876
epoch 32 batch id 801 loss 0.04262164235115051 train acc 0.9962351747815231
epoch 32 batch id 1001 loss 4.935025208396837e-05 train acc 0.996082042957043
epoch 32 batch id 1201 loss 0.09549586474895477 train acc 0.9962921523730225
epoch 32 batch id 1401 loss 0.00014122969878371805 train acc 0.9964088151320485
epoch 32 batch id 1601 loss 0.0004893653676845133 train acc 0.9964280137414117
epoch 32 batch id 1801 loss 0.0335882194340229 train acc 0.9965036785119378
epoch 32 batch id 2001 loss 5.929913095314987e-05 train acc 0.9965720264867566
epoch 32 batch id 2201 loss 0.007484341971576214 train acc 0.9965569627442071
epoch 32 batch id 2401 loss 8.895833161659539e-05 train acc 0.9966355164514785
epo

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 32 test acc 0.8778807718052994


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 33 batch id 1 loss 0.00018088656361214817 train acc 1.0
epoch 33 batch id 201 loss 9.393785148859024e-05 train acc 0.9957245024875622
epoch 33 batch id 401 loss 0.0001302907185163349 train acc 0.9961034912718204
epoch 33 batch id 601 loss 9.554115240462124e-05 train acc 0.9960742512479202
epoch 33 batch id 801 loss 0.003333909437060356 train acc 0.996274188514357
epoch 33 batch id 1001 loss 6.444779137382284e-05 train acc 0.996363011988012
epoch 33 batch id 1201 loss 0.06455878168344498 train acc 0.9964482722731057
epoch 33 batch id 1401 loss 8.51244039949961e-05 train acc 0.9964311206281228
epoch 33 batch id 1601 loss 0.0005278104799799621 train acc 0.9964084946908183
epoch 33 batch id 1801 loss 0.09606914222240448 train acc 0.9964429483620211
epoch 33 batch id 2001 loss 0.00010203116835327819 train acc 0.9965251749125438
epoch 33 batch id 2201 loss 0.01515271794050932 train acc 0.9965711608359836
epoch 33 batch id 2401 loss 0.043516311794519424 train acc 0.9965379008746356
epoc

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 33 test acc 0.8767504073101114


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 34 batch id 1 loss 0.0001398363965563476 train acc 1.0
epoch 34 batch id 201 loss 0.0002794505562633276 train acc 0.9968905472636815
epoch 34 batch id 401 loss 6.258774374146014e-05 train acc 0.9968438279301746
epoch 34 batch id 601 loss 5.33274287590757e-05 train acc 0.9968801996672213
epoch 34 batch id 801 loss 6.051566015230492e-05 train acc 0.9970934769038702
epoch 34 batch id 1001 loss 3.586106686270796e-05 train acc 0.9972995754245755
epoch 34 batch id 1201 loss 0.00010924995149252936 train acc 0.997463051623647
epoch 34 batch id 1401 loss 6.275884516071528e-05 train acc 0.9975463954318344
epoch 34 batch id 1601 loss 8.20437926449813e-05 train acc 0.9974625234228607
epoch 34 batch id 1801 loss 6.620923522859812e-05 train acc 0.9974840366463076
epoch 34 batch id 2001 loss 9.85095466603525e-05 train acc 0.997532483758121
epoch 34 batch id 2201 loss 0.008738399483263493 train acc 0.9974869377555656
epoch 34 batch id 2401 loss 6.930845847819e-05 train acc 0.9974750104123282
epo

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 34 test acc 0.878293607441461


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 35 batch id 1 loss 0.00010152617323910818 train acc 1.0
epoch 35 batch id 201 loss 8.28018892207183e-05 train acc 0.9971237562189055
epoch 35 batch id 401 loss 4.7855734010227025e-05 train acc 0.99696072319202
epoch 35 batch id 601 loss 4.7986879508243874e-05 train acc 0.9970101913477537
epoch 35 batch id 801 loss 6.283995026024058e-05 train acc 0.9970934769038702
epoch 35 batch id 1001 loss 3.8431408029282466e-05 train acc 0.9971590909090909
epoch 35 batch id 1201 loss 0.002636182587593794 train acc 0.9971768318068276
epoch 35 batch id 1401 loss 0.0006331180920824409 train acc 0.9973233404710921
epoch 35 batch id 1601 loss 0.006133288145065308 train acc 0.9973454091193005
epoch 35 batch id 1801 loss 6.343977292999625e-05 train acc 0.9972844947251527
epoch 35 batch id 2001 loss 3.382373324711807e-05 train acc 0.9973606946526736
epoch 35 batch id 2201 loss 0.01188573706895113 train acc 0.9973804520672421
epoch 35 batch id 2401 loss 7.242004357976839e-05 train acc 0.997422948771345

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 35 test acc 0.8782659798176261


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 36 batch id 1 loss 0.00015209098637569696 train acc 1.0
epoch 36 batch id 201 loss 9.778283856576309e-05 train acc 0.9974347014925373
epoch 36 batch id 401 loss 0.0016031241975724697 train acc 0.9973114089775561
epoch 36 batch id 601 loss 4.508928395807743e-05 train acc 0.9971141846921797
epoch 36 batch id 801 loss 5.5781565606594086e-05 train acc 0.997249531835206
epoch 36 batch id 1001 loss 7.047876715660095e-05 train acc 0.997362012987013
epoch 36 batch id 1201 loss 0.0005785960820503533 train acc 0.997463051623647
epoch 36 batch id 1401 loss 6.023440073477104e-05 train acc 0.9975240899357601
epoch 36 batch id 1601 loss 5.746231181547046e-05 train acc 0.9974918019987508
epoch 36 batch id 1801 loss 5.607104685623199e-05 train acc 0.99749271238201
epoch 36 batch id 2001 loss 3.760101753869094e-05 train acc 0.9975481009495253
epoch 36 batch id 2201 loss 0.018924277275800705 train acc 0.997557928214448
epoch 36 batch id 2401 loss 4.00264143536333e-05 train acc 0.9975661182840483
e

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 36 test acc 0.8783382063199373


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 37 batch id 1 loss 0.0005237675504758954 train acc 1.0
epoch 37 batch id 201 loss 0.00315148220397532 train acc 0.9978233830845771
epoch 37 batch id 401 loss 3.7252386391628534e-05 train acc 0.997701059850374
epoch 37 batch id 601 loss 6.036921695340425e-05 train acc 0.9976861480865225
epoch 37 batch id 801 loss 3.073355765081942e-05 train acc 0.9979127652933832
epoch 37 batch id 1001 loss 3.6464567529037595e-05 train acc 0.9978615134865135
epoch 37 batch id 1201 loss 4.6247638238128275e-05 train acc 0.9979834512905912
epoch 37 batch id 1401 loss 3.564977669157088e-05 train acc 0.9980259635974305
epoch 37 batch id 1601 loss 5.600748772849329e-05 train acc 0.998048094940662
epoch 37 batch id 1801 loss 7.22473268979229e-05 train acc 0.9980219322598556
epoch 37 batch id 2001 loss 3.7898924347246066e-05 train acc 0.9980478510744628
epoch 37 batch id 2201 loss 0.006456555332988501 train acc 0.9980122671512949
epoch 37 batch id 2401 loss 5.250226968200877e-05 train acc 0.99802816534777

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 37 test acc 0.8793047784738185


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 38 batch id 1 loss 3.7018977309344336e-05 train acc 1.0
epoch 38 batch id 201 loss 4.505105971475132e-05 train acc 0.9982120646766169
epoch 38 batch id 401 loss 3.3267653634538874e-05 train acc 0.9977400249376559
epoch 38 batch id 601 loss 0.00013325904728844762 train acc 0.9979201331114809
epoch 38 batch id 801 loss 3.077144720009528e-05 train acc 0.9978542446941323
epoch 38 batch id 1001 loss 2.433568261039909e-05 train acc 0.9979551698301699
epoch 38 batch id 1201 loss 0.0004501363728195429 train acc 0.9980354912572856
epoch 38 batch id 1401 loss 2.8799360734410584e-05 train acc 0.9981374910778016
epoch 38 batch id 1601 loss 3.3698601328069344e-05 train acc 0.9981359306683323
epoch 38 batch id 1801 loss 0.00018924289906863123 train acc 0.9981607440310938
epoch 38 batch id 2001 loss 4.631770934793167e-05 train acc 0.9981571714142928
epoch 38 batch id 2201 loss 0.013562743552029133 train acc 0.998168446160836
epoch 38 batch id 2401 loss 3.1903520721243694e-05 train acc 0.9982364

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 38 test acc 0.8787017069135366


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 39 batch id 1 loss 4.488653939915821e-05 train acc 1.0
epoch 39 batch id 201 loss 3.03524975606706e-05 train acc 0.9984452736318408
epoch 39 batch id 401 loss 2.9784185244352557e-05 train acc 0.9982076059850374
epoch 39 batch id 601 loss 3.42141865985468e-05 train acc 0.9980761231281198
epoch 39 batch id 801 loss 7.172789628384635e-05 train acc 0.9980688202247191
epoch 39 batch id 1001 loss 2.2981859729043208e-05 train acc 0.9981737012987013
epoch 39 batch id 1201 loss 0.000177214911673218 train acc 0.9982436511240633
epoch 39 batch id 1401 loss 3.7355031963670626e-05 train acc 0.9982824768022841
epoch 39 batch id 1601 loss 2.3252132450579666e-05 train acc 0.9983116021236728
epoch 39 batch id 1801 loss 0.00012528066872619092 train acc 0.9983516102165464
epoch 39 batch id 2001 loss 2.1920754079474136e-05 train acc 0.9983992378810594
epoch 39 batch id 2201 loss 0.07110467553138733 train acc 0.998381417537483
epoch 39 batch id 2401 loss 6.216271867742762e-05 train acc 0.998418627655

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 39 test acc 0.8786776314127661


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 40 batch id 1 loss 3.503022890072316e-05 train acc 1.0
epoch 40 batch id 201 loss 3.343451317050494e-05 train acc 0.9983675373134329
epoch 40 batch id 401 loss 2.1362116967793554e-05 train acc 0.9981686408977556
epoch 40 batch id 601 loss 2.3771955966367386e-05 train acc 0.9981801164725458
epoch 40 batch id 801 loss 0.00011022415128536522 train acc 0.9983029026217228
epoch 40 batch id 1001 loss 1.8424729205435142e-05 train acc 0.9983297952047953
epoch 40 batch id 1201 loss 2.8377984563121572e-05 train acc 0.9984127810158201
epoch 40 batch id 1401 loss 2.4412418497377075e-05 train acc 0.9984497680228408
epoch 40 batch id 1601 loss 2.3386372049571946e-05 train acc 0.998428716427233
epoch 40 batch id 1801 loss 0.000601377512793988 train acc 0.9984470433092726
epoch 40 batch id 2001 loss 3.246125197620131e-05 train acc 0.9984929410294853
epoch 40 batch id 2201 loss 0.0124746048822999 train acc 0.9984311108587006
epoch 40 batch id 2401 loss 2.5204204575857148e-05 train acc 0.998379581

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 40 test acc 0.8792072924297153


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 41 batch id 1 loss 6.209644925547764e-05 train acc 1.0
epoch 41 batch id 201 loss 3.072846448048949e-05 train acc 0.9989116915422885
epoch 41 batch id 401 loss 3.173333243466914e-05 train acc 0.9985193266832918
epoch 41 batch id 601 loss 2.6665869881981052e-05 train acc 0.9984141014975042
epoch 41 batch id 801 loss 5.07211298099719e-05 train acc 0.9985369850187266
epoch 41 batch id 1001 loss 2.0819785277126357e-05 train acc 0.998563936063936
epoch 41 batch id 1201 loss 0.0007484892848879099 train acc 0.9986209408825978
epoch 41 batch id 1401 loss 2.63094170804834e-05 train acc 0.9986616702355461
epoch 41 batch id 1601 loss 0.00014803609519731253 train acc 0.9986727045596502
epoch 41 batch id 1801 loss 2.5306862880825065e-05 train acc 0.9986899639089395
epoch 41 batch id 2001 loss 1.9832908947137184e-05 train acc 0.9987193903048476
epoch 41 batch id 2201 loss 0.01318568829447031 train acc 0.9986795774647887
epoch 41 batch id 2401 loss 0.0001611233747098595 train acc 0.998704966680

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 41 test acc 0.8784838433655814


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 42 batch id 1 loss 0.0047777919098734856 train acc 1.0
epoch 42 batch id 201 loss 1.9484496078803204e-05 train acc 0.9986007462686567
epoch 42 batch id 401 loss 1.9467484889901243e-05 train acc 0.99848036159601
epoch 42 batch id 601 loss 2.6291989343008026e-05 train acc 0.9984920965058236
epoch 42 batch id 801 loss 1.8925833501270972e-05 train acc 0.9985759987515606
epoch 42 batch id 1001 loss 1.8392991478322074e-05 train acc 0.9985795454545454
epoch 42 batch id 1201 loss 2.5691846531117335e-05 train acc 0.9986990008326395
epoch 42 batch id 1401 loss 1.7029848095262423e-05 train acc 0.9987508922198429
epoch 42 batch id 1601 loss 2.1266905605443753e-05 train acc 0.9987507807620237
epoch 42 batch id 1801 loss 2.1032465156167746e-05 train acc 0.9988027484730705
epoch 42 batch id 2001 loss 1.9492037608870305e-05 train acc 0.998852136431784
epoch 42 batch id 2201 loss 0.009271142072975636 train acc 0.9988002612448886
epoch 42 batch id 2401 loss 2.2192498363438062e-05 train acc 0.99880

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 42 test acc 0.8800740104574503


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 43 batch id 1 loss 2.1046833353466354e-05 train acc 1.0
epoch 43 batch id 201 loss 2.0345140001154505e-05 train acc 0.9986784825870647
epoch 43 batch id 401 loss 2.2690797777613625e-05 train acc 0.9985193266832918
epoch 43 batch id 601 loss 1.9922219507861882e-05 train acc 0.9983881031613977
epoch 43 batch id 801 loss 1.630528458917979e-05 train acc 0.9984589575530587
epoch 43 batch id 1001 loss 2.1495739929378033e-05 train acc 0.9984234515484516
epoch 43 batch id 1201 loss 5.129926648805849e-05 train acc 0.9984778309741882
epoch 43 batch id 1401 loss 2.6107272788067348e-05 train acc 0.998561295503212
epoch 43 batch id 1601 loss 1.7031561583280563e-05 train acc 0.9984970331043098
epoch 43 batch id 1801 loss 2.2421703761210665e-05 train acc 0.9985511521377013
epoch 43 batch id 2001 loss 1.690131648501847e-05 train acc 0.9986256871564217
epoch 43 batch id 2201 loss 0.007021885830909014 train acc 0.9986369831894594
epoch 43 batch id 2401 loss 2.005089663725812e-05 train acc 0.998704

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 43 test acc 0.8801703124605319


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 44 batch id 1 loss 2.271744961035438e-05 train acc 1.0
epoch 44 batch id 201 loss 0.0010809069499373436 train acc 0.9986784825870647
epoch 44 batch id 401 loss 1.9823481125058606e-05 train acc 0.9985193266832918
epoch 44 batch id 601 loss 2.0873869289061986e-05 train acc 0.9985700915141431
epoch 44 batch id 801 loss 1.622330819373019e-05 train acc 0.9986345193508115
epoch 44 batch id 1001 loss 1.7098425814765505e-05 train acc 0.9987668581418582
epoch 44 batch id 1201 loss 4.357178113423288e-05 train acc 0.9988030807660283
epoch 44 batch id 1401 loss 1.9100016288575716e-05 train acc 0.9988624197002142
epoch 44 batch id 1601 loss 1.791077920643147e-05 train acc 0.9988288569643973
epoch 44 batch id 1801 loss 2.9016529879299924e-05 train acc 0.9988721543586896
epoch 44 batch id 2001 loss 1.6614429114270024e-05 train acc 0.9989224137931034
epoch 44 batch id 2201 loss 0.009883986786007881 train acc 0.9988925488414357
epoch 44 batch id 2401 loss 1.9657727534649894e-05 train acc 0.998919

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 44 test acc 0.8801932039202809


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 45 batch id 1 loss 2.3777212845743634e-05 train acc 1.0
epoch 45 batch id 201 loss 2.2322738004731946e-05 train acc 0.9987562189054726
epoch 45 batch id 401 loss 1.4383127563633025e-05 train acc 0.998675187032419
epoch 45 batch id 601 loss 1.7003718312480487e-05 train acc 0.998778078202995
epoch 45 batch id 801 loss 2.6926460122922435e-05 train acc 0.9988295880149812
epoch 45 batch id 1001 loss 1.404215618094895e-05 train acc 0.9988605144855145
epoch 45 batch id 1201 loss 2.0479114027693868e-05 train acc 0.9988941507077436
epoch 45 batch id 1401 loss 1.5936446288833395e-05 train acc 0.9988958779443254
epoch 45 batch id 1601 loss 1.4630816622229759e-05 train acc 0.9989166926920675
epoch 45 batch id 1801 loss 1.8696826373343356e-05 train acc 0.9989415602443087
epoch 45 batch id 2001 loss 1.882302967715077e-05 train acc 0.9989458395802099
epoch 45 batch id 2201 loss 0.006939645856618881 train acc 0.9989209450249886
epoch 45 batch id 2401 loss 3.8194852095330134e-05 train acc 0.99894

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 45 test acc 0.8804098834272147


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 46 batch id 1 loss 1.6789579603937455e-05 train acc 1.0
epoch 46 batch id 201 loss 1.907684054458514e-05 train acc 0.9989116915422885
epoch 46 batch id 401 loss 1.3666010090673808e-05 train acc 0.9989479426433915
epoch 46 batch id 601 loss 1.6688973119016737e-05 train acc 0.9989860648918469
epoch 46 batch id 801 loss 1.6547412087675184e-05 train acc 0.999005149812734
epoch 46 batch id 1001 loss 1.3699510418518912e-05 train acc 0.9989853896103896
epoch 46 batch id 1201 loss 2.2010146494721994e-05 train acc 0.999011240632806
epoch 46 batch id 1401 loss 1.7545695300213993e-05 train acc 0.9990854746609564
epoch 46 batch id 1601 loss 1.42116732604336e-05 train acc 0.9990630855715178
epoch 46 batch id 1801 loss 2.348833550058771e-05 train acc 0.9990890477512493
epoch 46 batch id 2001 loss 1.7115318769356236e-05 train acc 0.9990785857071465
epoch 46 batch id 2201 loss 0.009251575917005539 train acc 0.9990416288050886
epoch 46 batch id 2401 loss 2.3794114895281382e-05 train acc 0.9990628

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 46 test acc 0.8804327748869636


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 47 batch id 1 loss 1.9775268810917623e-05 train acc 1.0
epoch 47 batch id 201 loss 1.6674070138833486e-05 train acc 0.9993781094527363
epoch 47 batch id 401 loss 1.5705469195381738e-05 train acc 0.9990258728179551
epoch 47 batch id 601 loss 2.201940878876485e-05 train acc 0.998778078202995
epoch 47 batch id 801 loss 1.3379174561123364e-05 train acc 0.9988295880149812
epoch 47 batch id 1001 loss 1.2447865628928412e-05 train acc 0.9988605144855145
epoch 47 batch id 1201 loss 2.4243783627753146e-05 train acc 0.9988681307243963
epoch 47 batch id 1401 loss 2.0881172531517223e-05 train acc 0.9989181834403997
epoch 47 batch id 1601 loss 1.6348009012290277e-05 train acc 0.9989557307932542
epoch 47 batch id 1801 loss 2.1212897991063073e-05 train acc 0.9989849389228207
epoch 47 batch id 2001 loss 1.4722133528266568e-05 train acc 0.9990083083458271
epoch 47 batch id 2201 loss 0.008896791376173496 train acc 0.9990132326215356
epoch 47 batch id 2401 loss 2.3423437596648e-05 train acc 0.999062

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 47 test acc 0.8806253788931269


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 48 batch id 1 loss 1.3634376045956742e-05 train acc 1.0
epoch 48 batch id 201 loss 1.9294486264698207e-05 train acc 0.9990671641791045
epoch 48 batch id 401 loss 2.9281391107360832e-05 train acc 0.9987141521197007
epoch 48 batch id 601 loss 1.8529082808527164e-05 train acc 0.9988560732113144
epoch 48 batch id 801 loss 1.6489675545017235e-05 train acc 0.9988100811485643
epoch 48 batch id 1001 loss 1.4684765119454823e-05 train acc 0.9989385614385614
epoch 48 batch id 1201 loss 1.792017610569019e-05 train acc 0.999011240632806
epoch 48 batch id 1401 loss 1.8434006051393226e-05 train acc 0.9990854746609564
epoch 48 batch id 1601 loss 1.7804306480684318e-05 train acc 0.9990435665209244
epoch 48 batch id 1801 loss 1.8331740648136474e-05 train acc 0.9990630205441421
epoch 48 batch id 2001 loss 1.7102438505389728e-05 train acc 0.9991020114942529
epoch 48 batch id 2201 loss 0.006751569453626871 train acc 0.999084223080418
epoch 48 batch id 2401 loss 1.8702268789638765e-05 train acc 0.9990

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 48 test acc 0.8803846238854227


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 49 batch id 1 loss 3.4599062928464264e-05 train acc 1.0
epoch 49 batch id 201 loss 1.9808763681794517e-05 train acc 0.9988339552238806
epoch 49 batch id 401 loss 1.3986362318973988e-05 train acc 0.9987531172069826
epoch 49 batch id 601 loss 1.583570883667562e-05 train acc 0.9988040765391015
epoch 49 batch id 801 loss 2.832842801581137e-05 train acc 0.9988686017478152
epoch 49 batch id 1001 loss 1.602195879968349e-05 train acc 0.9989073426573427
epoch 49 batch id 1201 loss 1.9231254555052146e-05 train acc 0.9988941507077436
epoch 49 batch id 1401 loss 1.6761381630203687e-05 train acc 0.9989181834403997
epoch 49 batch id 1601 loss 1.5858107872190885e-05 train acc 0.9989459712679575
epoch 49 batch id 1801 loss 1.7487847799202427e-05 train acc 0.9989762631871183
epoch 49 batch id 2001 loss 1.3518884770746808e-05 train acc 0.9989848825587206
epoch 49 batch id 2201 loss 0.009238813072443008 train acc 0.998991935483871
epoch 49 batch id 2401 loss 1.9674622308230028e-05 train acc 0.99903

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 49 test acc 0.8803364728838818


  0%|          | 0/2596 [00:00<?, ?it/s]

epoch 50 batch id 1 loss 1.7689015294308774e-05 train acc 1.0
epoch 50 batch id 201 loss 1.926300501509104e-05 train acc 0.9989116915422885
epoch 50 batch id 401 loss 2.396802301518619e-05 train acc 0.9987920822942643
epoch 50 batch id 601 loss 1.570741915202234e-05 train acc 0.9989080698835274
epoch 50 batch id 801 loss 1.595698995515704e-05 train acc 0.999083177278402
epoch 50 batch id 1001 loss 0.0006871878285892308 train acc 0.9991258741258742
epoch 50 batch id 1201 loss 2.458513881720137e-05 train acc 0.99920639050791
epoch 50 batch id 1401 loss 1.8553035260993056e-05 train acc 0.9992193076374019
epoch 50 batch id 1601 loss 1.8247765183332376e-05 train acc 0.9991704403497814
epoch 50 batch id 1801 loss 2.0086275981157087e-05 train acc 0.9991844808439756
epoch 50 batch id 2001 loss 1.832777343224734e-05 train acc 0.9991957146426786
epoch 50 batch id 2201 loss 0.00597362220287323 train acc 0.9991907087687415
epoch 50 batch id 2401 loss 2.4723243768676184e-05 train acc 0.999186536859

  0%|          | 0/649 [00:00<?, ?it/s]

epoch 50 test acc 0.8803123973831115


In [41]:
PATH = './models/'
torch.save(model.state_dict(), PATH + 'kobert_new_nlp_epoch_50.pt')

# 5.Test

In [42]:
# predict : 학습 모델을 활용하여 다중 분류된 클래스를 출력해주는 함수
# 코드 출처 : https://hoit1302.tistory.com/159

def predict(predict_sentence): # input = 감정분류하고자 하는 sentence

    data = [predict_sentence, '0']
    dataset_another = [data]

    another_test = BERTDataset(dataset_another, 0, 1, tok, vocab, max_len, True, False) # 토큰화한 문장
    test_dataloader = torch.utils.data.DataLoader(another_test, batch_size = batch_size, num_workers = 5) # torch 형식 변환
    
    model.eval() 

    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)

        valid_length = valid_length
        label = label.long().to(device)

        out = model(token_ids, valid_length, segment_ids)

        # print('out = ', out)

        logit = out[0]
        softmax_logit = torch.softmax(logit, dim=-1)
        softmax_logit = softmax_logit.squeeze()

        max_index = torch.argmax(softmax_logit).item()
        max_index_value = softmax_logit[torch.argmax(softmax_logit)].item()

        # print( 'max_index_value = ', max_index_value )
        
        test_eval = []
        for i in out: # out = model(token_ids, valid_length, segment_ids)
            # print('i = ', i)
            logits = i
            logits = logits.detach().cpu().numpy()

            if np.argmax(logits) == 0:
                test_eval.append("공포가")
            elif np.argmax(logits) == 1:
                test_eval.append("놀람이")
            elif np.argmax(logits) == 2:
                test_eval.append("분노가")
            elif np.argmax(logits) == 3:
                test_eval.append("슬픔이")
            elif np.argmax(logits) == 4:
                test_eval.append("중립이")
            elif np.argmax(logits) == 5:
                test_eval.append("행복이")
            elif np.argmax(logits) == 6:
                test_eval.append("혐오가")

        print(">> 입력하신 내용에서 " + test_eval[0] + "(" + str(max_index_value) + ") 느껴집니다.")

In [43]:
# 질문에 0 입력 시 종료
end = 1
while end == 1 :
    sentence = input("하고싶은 말을 입력해주세요 : ")
    if sentence == "0" :
        break
    predict(sentence)
    print("\n")

하고싶은 말을 입력해주세요 :  안녕!


>> 입력하신 내용에서 행복이(0.9999961853027344) 느껴집니다.




하고싶은 말을 입력해주세요 :  앗 이거 너무 우낀데...


>> 입력하신 내용에서 행복이(0.999981164932251) 느껴집니다.




하고싶은 말을 입력해주세요 :  앗 ㅋㅋ 너무 웃겨


>> 입력하신 내용에서 행복이(0.9999963045120239) 느껴집니다.




하고싶은 말을 입력해주세요 :  배고프다


>> 입력하신 내용에서 슬픔이(0.9998705387115479) 느껴집니다.




하고싶은 말을 입력해주세요 :  배고파


>> 입력하신 내용에서 혐오가(0.9985101819038391) 느껴집니다.




하고싶은 말을 입력해주세요 :  헐


>> 입력하신 내용에서 행복이(0.9999955892562866) 느껴집니다.




하고싶은 말을 입력해주세요 :  생각보다 성능이 좋지 못함


>> 입력하신 내용에서 공포가(0.9998078942298889) 느껴집니다.




하고싶은 말을 입력해주세요 :  오늘은 어떤 커피를 마시면 좋을까 ?


>> 입력하신 내용에서 분노가(0.9996351003646851) 느껴집니다.




하고싶은 말을 입력해주세요 :  오늘 날씨가 좋다


>> 입력하신 내용에서 공포가(0.9534702301025391) 느껴집니다.




하고싶은 말을 입력해주세요 :  날씨가 좋다


>> 입력하신 내용에서 공포가(0.9824321866035461) 느껴집니다.




하고싶은 말을 입력해주세요 :  좋다


>> 입력하신 내용에서 행복이(0.9998262524604797) 느껴집니다.




하고싶은 말을 입력해주세요 :  날씨가


>> 입력하신 내용에서 중립이(0.9960732460021973) 느껴집니다.




하고싶은 말을 입력해주세요 :  온르


>> 입력하신 내용에서 슬픔이(0.9999898672103882) 느껴집니다.




하고싶은 말을 입력해주세요 :  오늘


>> 입력하신 내용에서 분노가(0.4211781322956085) 느껴집니다.




하고싶은 말을 입력해주세요 :  커피는 어떤것을 마시는게 좋을까?


>> 입력하신 내용에서 분노가(0.840933084487915) 느껴집니다.




하고싶은 말을 입력해주세요 :  오늘은 학교가는 날이야


>> 입력하신 내용에서 공포가(0.9997609257698059) 느껴집니다.




하고싶은 말을 입력해주세요 :  학교가는 날이군


>> 입력하신 내용에서 공포가(0.6179559230804443) 느껴집니다.




하고싶은 말을 입력해주세요 :  학교


>> 입력하신 내용에서 중립이(0.6105743050575256) 느껴집니다.




하고싶은 말을 입력해주세요 :  날이군


>> 입력하신 내용에서 슬픔이(0.9999918937683105) 느껴집니다.




KeyboardInterrupt: Interrupted by user