In [7]:
!pip install mxnet
!pip install gluonnlp pandas tqdm
!pip install sentencepiece==0.1.85
!pip install transformers==2.1.1

Collecting transformers==2.1.1
  Downloading https://files.pythonhosted.org/packages/fd/f9/51824e40f0a23a49eab4fcaa45c1c797cbf9761adedd0b558dab7c958b34/transformers-2.1.1-py3-none-any.whl (311kB)
Installing collected packages: transformers
  Found existing installation: transformers 2.3.0
    Uninstalling transformers-2.3.0:
      Successfully uninstalled transformers-2.3.0
Successfully installed transformers-2.1.1


In [1]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint2.pt')
        self.val_loss_min = val_loss

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import gluonnlp as nlp
import numpy as np
from tqdm import tqdm, tqdm_notebook

In [3]:
from kobert.utils import get_tokenizer
from kobert.pytorch_kobert import get_pytorch_kobert_model

In [4]:
from transformers import AdamW
from transformers.optimization import WarmupLinearSchedule

In [5]:
##GPU 사용 시
device = torch.device("cuda:0")

In [6]:
bertmodel, vocab = get_pytorch_kobert_model()

using cached model
using cached model


In [7]:
dataset_train = nlp.data.TSVDataset("구 버전/train_3.txt", num_discard_samples=1)
dataset_test = nlp.data.TSVDataset("구 버전/test_3.txt", num_discard_samples=1)

In [8]:
tokenizer = get_tokenizer()
tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower=False)

using cached model


In [9]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_idx, label_idx, bert_tokenizer, max_len,
                 pad, pair):
        transform = nlp.data.BERTSentenceTransform(
            bert_tokenizer, max_seq_length=max_len, pad=pad, pair=pair)

        self.sentences = [transform([i[sent_idx]]) for i in dataset]
        self.labels = [np.int32(i[label_idx]) for i in dataset]

    def __getitem__(self, i):
        return (self.sentences[i] + (self.labels[i], ))

    def __len__(self):
        return (len(self.labels))


In [10]:
## Setting parameters
# 256 16 # 128 32
max_len = 256
batch_size = 16
warmup_ratio = 0.1
num_epochs = 10
max_grad_norm = 1
log_interval = 200
learning_rate =  5e-5

In [11]:
data_train = BERTDataset(dataset_train, 0, 1, tok, max_len, True, False)
data_test = BERTDataset(dataset_test, 0, 1, tok, max_len, True, False)

In [12]:
train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=batch_size, num_workers=0)
test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=batch_size, num_workers=0)

In [13]:
class BERTClassifier(nn.Module):
    def __init__(self,
                 bert,
                 hidden_size = 768,
                 num_classes=2,
                 dr_rate=None,
                 params=None):
        super(BERTClassifier, self).__init__()
        self.bert = bert
        self.dr_rate = dr_rate
                 
        self.classifier = nn.Linear(hidden_size , num_classes)
        if dr_rate:
            self.dropout = nn.Dropout(p=dr_rate)
    
    def gen_attention_mask(self, token_ids, valid_length):
        attention_mask = torch.zeros_like(token_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()

    def forward(self, token_ids, valid_length, segment_ids):
        attention_mask = self.gen_attention_mask(token_ids, valid_length)
        
        _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
        if self.dr_rate:
            out = self.dropout(pooler)
        return self.classifier(out)

In [14]:
model = BERTClassifier(bertmodel, dr_rate=0.5).to(device)

In [15]:
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]

In [16]:
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [17]:
t_total = len(train_dataloader) * num_epochs
warmup_step = int(t_total * warmup_ratio)

In [18]:
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_step, t_total=t_total)

In [19]:
def calc_accuracy(X,Y):
    max_vals, max_indices = torch.max(X, 1)
    train_acc = (max_indices == Y).sum().data.cpu().numpy()/max_indices.size()[0]
    return train_acc

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

In [None]:
early_stopping = EarlyStopping(patience=5, verbose=True)
for e in range(num_epochs):
    value_losses = 0
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
        loss = loss_fn(out, label)
        value_losses += float(loss.item()) 
        
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    valid_loss = value_losses / len(test_dataloader)
    print('validation loss : {}'.format(valid_loss))
    early_stopping(valid_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break
model.load_state_dict(torch.load('checkpoint.pt'))

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint1.pt')
        self.val_loss_min = val_loss

In [45]:
early_stopping = EarlyStopping(patience=5, verbose=True)
for e in range(num_epochs):
    value_losses = 0
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
        loss = loss_fn(out, label)
        value_losses += float(loss.item()) 
        
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    valid_loss = value_losses / len(test_dataloader)
    print('validation loss : {}'.format(valid_loss))
    early_stopping(valid_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break
model.load_state_dict(torch.load('checkpoint1.pt'))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 1 batch id 1 loss 0.683671236038208 train acc 0.625
epoch 1 batch id 201 loss 0.5163155198097229 train acc 0.6147388059701493
epoch 1 batch id 401 loss 0.6582837700843811 train acc 0.6472880299251871
epoch 1 batch id 601 loss 0.6173977851867676 train acc 0.6549500831946755
epoch 1 batch id 801 loss 0.7966580390930176 train acc 0.6586298377028714
epoch 1 batch id 1001 loss 0.6101324558258057 train acc 0.6611513486513486
epoch 1 batch id 1201 loss 0.593669056892395 train acc 0.6638218151540383
epoch 1 batch id 1401 loss 0.637549638748169 train acc 0.666131334760885

epoch 1 train acc 0.6662520729684909


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 1 test acc 0.6734330484330484
validation loss : 0.6317418986066454
Validation loss decreased (inf --> 0.631742).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 2 batch id 1 loss 0.5770298838615417 train acc 0.8125
epoch 2 batch id 201 loss 0.5602574944496155 train acc 0.6570273631840796
epoch 2 batch id 401 loss 0.6311184167861938 train acc 0.6687967581047382
epoch 2 batch id 601 loss 0.6350998282432556 train acc 0.6696131447587355
epoch 2 batch id 801 loss 0.7522932291030884 train acc 0.6692415730337079
epoch 2 batch id 1001 loss 0.5602304339408875 train acc 0.6695804195804196
epoch 2 batch id 1201 loss 0.6017241477966309 train acc 0.670951290591174
epoch 2 batch id 1401 loss 0.6549298167228699 train acc 0.6722876516773733

epoch 2 train acc 0.6723821369343758


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 2 test acc 0.6734330484330484
validation loss : 0.6318848115256709
EarlyStopping counter: 1 out of 5


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 3 batch id 1 loss 0.5519777536392212 train acc 0.8125
epoch 3 batch id 201 loss 0.5481972694396973 train acc 0.6551616915422885
epoch 3 batch id 401 loss 0.620429277420044 train acc 0.6678615960099751
epoch 3 batch id 601 loss 0.6096762418746948 train acc 0.6691971713810316
epoch 3 batch id 801 loss 0.780620276927948 train acc 0.6692415730337079
epoch 3 batch id 1001 loss 0.5656611323356628 train acc 0.6695804195804196
epoch 3 batch id 1201 loss 0.5356923937797546 train acc 0.6710033305578684
epoch 3 batch id 1401 loss 0.657416582107544 train acc 0.6723322626695217

epoch 3 train acc 0.6724265576877517


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 3 test acc 0.6734330484330484
validation loss : 0.6317371024705066
Validation loss decreased (0.631742 --> 0.631737).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 4 batch id 1 loss 0.5506930947303772 train acc 0.8125
epoch 4 batch id 201 loss 0.5487828254699707 train acc 0.6570273631840796
epoch 4 batch id 401 loss 0.6064774990081787 train acc 0.6687967581047382
epoch 4 batch id 601 loss 0.648193359375 train acc 0.6698211314475874
epoch 4 batch id 801 loss 0.7299001812934875 train acc 0.6697097378277154
epoch 4 batch id 1001 loss 0.5925991535186768 train acc 0.669955044955045
epoch 4 batch id 1201 loss 0.5565282702445984 train acc 0.671315570358035
epoch 4 batch id 1401 loss 0.6661913990974426 train acc 0.6725999286224126

epoch 4 train acc 0.6726930822080076


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 4 test acc 0.6734330484330484
validation loss : 0.6317328645972444
Validation loss decreased (0.631737 --> 0.631733).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 5 batch id 1 loss 0.5457884073257446 train acc 0.8125
epoch 5 batch id 201 loss 0.49243974685668945 train acc 0.6570273631840796
epoch 5 batch id 401 loss 0.6185859441757202 train acc 0.6687967581047382
epoch 5 batch id 601 loss 0.6208499073982239 train acc 0.6698211314475874
epoch 5 batch id 801 loss 0.7933619618415833 train acc 0.6697097378277154
epoch 5 batch id 1001 loss 0.5840980410575867 train acc 0.669955044955045
epoch 5 batch id 1201 loss 0.5373005270957947 train acc 0.671315570358035
epoch 5 batch id 1401 loss 0.6582377552986145 train acc 0.6725999286224126

epoch 5 train acc 0.6726930822080076


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 5 test acc 0.6734330484330484
validation loss : 0.6317239150702104
Validation loss decreased (0.631733 --> 0.631724).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 6 batch id 1 loss 0.531432032585144 train acc 0.8125
epoch 6 batch id 201 loss 0.5527076721191406 train acc 0.6573383084577115
epoch 6 batch id 401 loss 0.6178285479545593 train acc 0.6689526184538653
epoch 6 batch id 601 loss 0.6276085376739502 train acc 0.6699251247920133
epoch 6 batch id 801 loss 0.7677193284034729 train acc 0.6697877652933832
epoch 6 batch id 1001 loss 0.5489223599433899 train acc 0.6700174825174825
epoch 6 batch id 1201 loss 0.5465782284736633 train acc 0.6713676103247294
epoch 6 batch id 1401 loss 0.6772615313529968 train acc 0.672644539614561

epoch 6 train acc 0.6727375029613836


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 6 test acc 0.6734330484330484
validation loss : 0.6319347530009061
EarlyStopping counter: 1 out of 5


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 7 batch id 1 loss 0.5112602114677429 train acc 0.8125
epoch 7 batch id 201 loss 0.5401614904403687 train acc 0.6573383084577115
epoch 7 batch id 401 loss 0.6089612245559692 train acc 0.6689526184538653
epoch 7 batch id 601 loss 0.6424015760421753 train acc 0.6699251247920133
epoch 7 batch id 801 loss 0.7502754330635071 train acc 0.6697877652933832
epoch 7 batch id 1001 loss 0.5633200407028198 train acc 0.6700174825174825
epoch 7 batch id 1201 loss 0.5295676589012146 train acc 0.6713676103247294
epoch 7 batch id 1401 loss 0.6777652502059937 train acc 0.672644539614561

epoch 7 train acc 0.6727375029613836


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 7 test acc 0.6734330484330484
validation loss : 0.6321480563732973
EarlyStopping counter: 2 out of 5


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 8 batch id 1 loss 0.5268442034721375 train acc 0.8125
epoch 8 batch id 201 loss 0.527959406375885 train acc 0.6573383084577115
epoch 8 batch id 401 loss 0.6615091562271118 train acc 0.6689526184538653
epoch 8 batch id 601 loss 0.6205856800079346 train acc 0.6699251247920133
epoch 8 batch id 801 loss 0.7765102386474609 train acc 0.6697877652933832
epoch 8 batch id 1001 loss 0.5660990476608276 train acc 0.6700174825174825
epoch 8 batch id 1201 loss 0.5279337167739868 train acc 0.6713676103247294
epoch 8 batch id 1401 loss 0.6735869646072388 train acc 0.672644539614561

epoch 8 train acc 0.6727375029613836


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 8 test acc 0.6734330484330484
validation loss : 0.6322710125194995
EarlyStopping counter: 3 out of 5


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 9 batch id 1 loss 0.5506366491317749 train acc 0.8125
epoch 9 batch id 201 loss 0.5324851870536804 train acc 0.6573383084577115
epoch 9 batch id 401 loss 0.6287277340888977 train acc 0.6689526184538653
epoch 9 batch id 601 loss 0.6294515132904053 train acc 0.6699251247920133
epoch 9 batch id 801 loss 0.7513949871063232 train acc 0.6697877652933832
epoch 9 batch id 1001 loss 0.5805321931838989 train acc 0.6700174825174825
epoch 9 batch id 1201 loss 0.531319797039032 train acc 0.6713676103247294
epoch 9 batch id 1401 loss 0.6738660335540771 train acc 0.672644539614561

epoch 9 train acc 0.6727375029613836


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 9 test acc 0.6734330484330484
validation loss : 0.6326436622020526
EarlyStopping counter: 4 out of 5


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 10 batch id 1 loss 0.5221284627914429 train acc 0.8125
epoch 10 batch id 201 loss 0.5398997664451599 train acc 0.6573383084577115
epoch 10 batch id 401 loss 0.6392308473587036 train acc 0.6689526184538653
epoch 10 batch id 601 loss 0.6347752809524536 train acc 0.6699251247920133
epoch 10 batch id 801 loss 0.7530544400215149 train acc 0.6697877652933832
epoch 10 batch id 1001 loss 0.5624060034751892 train acc 0.6700174825174825
epoch 10 batch id 1201 loss 0.5238558053970337 train acc 0.6713676103247294
epoch 10 batch id 1401 loss 0.6670626997947693 train acc 0.672644539614561

epoch 10 train acc 0.6727375029613836


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 10 test acc 0.6734330484330484
validation loss : 0.632304512367629
EarlyStopping counter: 5 out of 5
Early stopping


<All keys matched successfully>

In [20]:
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

Let's use 2 GPUs!


In [21]:
torch.cuda.empty_cache()

In [None]:
early_stopping = EarlyStopping(patience=5, verbose=True)
for e in range(num_epochs):
    value_losses = 0
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
        loss = loss_fn(out, label)
        value_losses += float(loss.item()) 
        
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    valid_loss = value_losses / len(test_dataloader)
    print('validation loss : {}'.format(valid_loss))
    early_stopping(valid_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break
model.load_state_dict(torch.load('checkpoint2.pt'))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))



epoch 1 batch id 1 loss 0.7247582674026489 train acc 0.4375
epoch 1 batch id 201 loss 0.5726904273033142 train acc 0.6110074626865671
epoch 1 batch id 401 loss 0.681823194026947 train acc 0.6443266832917706
epoch 1 batch id 601 loss 0.6738764643669128 train acc 0.6520382695507487
epoch 1 batch id 801 loss 0.8680327534675598 train acc 0.6552746566791511
epoch 1 batch id 1001 loss 0.5611388087272644 train acc 0.6577797202797203
epoch 1 batch id 1201 loss 0.6194086670875549 train acc 0.6605432972522898
epoch 1 batch id 1401 loss 0.646765410900116 train acc 0.6629193433261956

epoch 1 train acc 0.6630537787254205


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 1 test acc 0.6734330484330484
validation loss : 0.6334708693020704
Validation loss decreased (inf --> 0.633471).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 2 batch id 1 loss 0.5729288458824158 train acc 0.8125
epoch 2 batch id 201 loss 0.5127244591712952 train acc 0.6545398009950248
epoch 2 batch id 401 loss 0.6507270336151123 train acc 0.6673940149625935
epoch 2 batch id 601 loss 0.6278367042541504 train acc 0.6690931780366056
epoch 2 batch id 801 loss 0.7625033855438232 train acc 0.6686953807740325
epoch 2 batch id 1001 loss 0.5872196555137634 train acc 0.6690184815184815
epoch 2 batch id 1201 loss 0.5866000056266785 train acc 0.6703788509575354
epoch 2 batch id 1401 loss 0.6693872213363647 train acc 0.6717969307637401

epoch 2 train acc 0.67189350864724


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 2 test acc 0.6734330484330484
validation loss : 0.6317200771084538
Validation loss decreased (0.633471 --> 0.631720).  Saving model ...


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 3 batch id 1 loss 0.5674735307693481 train acc 0.8125
epoch 3 batch id 201 loss 0.541002094745636 train acc 0.6573383084577115
epoch 3 batch id 401 loss 0.6204591989517212 train acc 0.6689526184538653
epoch 3 batch id 601 loss 0.6019932627677917 train acc 0.6691971713810316
epoch 3 batch id 801 loss 0.7365024089813232 train acc 0.6693976279650437
epoch 3 batch id 1001 loss 0.5384328365325928 train acc 0.6697052947052947
epoch 3 batch id 1201 loss 0.5401071310043335 train acc 0.6711594504579517
epoch 3 batch id 1401 loss 0.7434086799621582 train acc 0.6724214846538187

epoch 3 train acc 0.6725153991945036


HBox(children=(FloatProgress(value=0.0, max=351.0), HTML(value='')))


epoch 3 test acc 0.6734330484330484
validation loss : 0.6317925742718569
EarlyStopping counter: 1 out of 5


HBox(children=(FloatProgress(value=0.0, max=1407.0), HTML(value='')))

epoch 4 batch id 1 loss 0.5113425254821777 train acc 0.8125
epoch 4 batch id 201 loss 0.513543426990509 train acc 0.6548507462686567
epoch 4 batch id 401 loss 0.639189600944519 train acc 0.6677057356608479
epoch 4 batch id 601 loss 0.6095196604728699 train acc 0.6690931780366056
epoch 4 batch id 801 loss 0.7925604581832886 train acc 0.669085518102372
epoch 4 batch id 1001 loss 0.5642462372779846 train acc 0.6693931068931069


In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint3.pt')
        self.val_loss_min = val_loss

In [None]:
early_stopping = EarlyStopping(patience=5, verbose=True)
for e in range(num_epochs):
    value_losses = 0
    train_acc = 0.0
    test_acc = 0.0
    model.train()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(train_dataloader)):
        optimizer.zero_grad()
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        loss = loss_fn(out, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        scheduler.step()  # Update learning rate schedule
        train_acc += calc_accuracy(out, label)
        if batch_id % log_interval == 0:
            print("epoch {} batch id {} loss {} train acc {}".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))
    print("epoch {} train acc {}".format(e+1, train_acc / (batch_id+1)))
    model.eval()
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(tqdm_notebook(test_dataloader)):
        token_ids = token_ids.long().to(device)
        segment_ids = segment_ids.long().to(device)
        valid_length= valid_length
        label = label.long().to(device)
        out = model(token_ids, valid_length, segment_ids)
        test_acc += calc_accuracy(out, label)
        loss = loss_fn(out, label)
        value_losses += float(loss.item()) 
        
    print("epoch {} test acc {}".format(e+1, test_acc / (batch_id+1)))
    valid_loss = value_losses / len(test_dataloader)
    print('validation loss : {}'.format(valid_loss))
    early_stopping(valid_loss, model)
    if early_stopping.early_stop:
        print("Early stopping")
        break
model.load_state_dict(torch.load('checkpoint3.pt'))