In [1]:
import torch
import torch.nn.functional as F
from torchtext import data
from torchtext import datasets
import time
import random
from tqdm import tqdm

torch.backends.cudnn.deterministic = True

## General Settings

In [2]:
RANDOM_SEED = 123
torch.manual_seed(RANDOM_SEED)

VOCABULARY_SIZE = 20000
LEARNING_RATE = 1e-5
BATCH_SIZE = 16
NUM_EPOCHS = 4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
OUTPUT_DIM = 1

PATH = 'best_model.pth' # PATH to save and load model

MAX_LEN=64

cuda


## Dataset

In [3]:
TEXT = data.Field(tokenize='spacy',
                  include_lengths=True,
                 fix_length=MAX_LEN)
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
train_data, valid_data = train_data.split(random_state=random.seed(RANDOM_SEED),
                                          split_ratio=0.8)

print(f'Num Train: {len(train_data)}')
print(f'Num Valid: {len(valid_data)}')
print(f'Num Test: {len(test_data)}')

Num Train: 20000
Num Valid: 5000
Num Test: 25000


In [4]:
TEXT.build_vocab(train_data,
                 max_size=VOCABULARY_SIZE,
                 vectors='glove.6B.100d',
                 unk_init=torch.Tensor.normal_)
LABEL.build_vocab(train_data)

print(f'Vocabulary size: {len(TEXT.vocab)}')
print(f'Number of classes: {len(LABEL.vocab)}')

Vocabulary size: 20002
Number of classes: 2


In [5]:
train_loader, valid_loader, test_loader = data.BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=BATCH_SIZE,
    sort_within_batch=True,
    device=DEVICE)

In [6]:
print('Train')
for batch in train_loader:
    print(f'Text matrix size: {batch.text[0].size()}')
    print(f'Target vector size: {batch.label.size()}')
    break

print('\nValid:')
for batch in valid_loader:
    print(f'Text matrix size: {batch.text[0].size()}')
    print(f'Target vector size: {batch.label.size()}')
    break

print('\nTest:')
for batch in test_loader:
    print(f'Text matrix size: {batch.text[0].size()}')
    print(f'Target vector size: {batch.label.size()}')
    break

Train
Text matrix size: torch.Size([64, 16])
Target vector size: torch.Size([16])

Valid:
Text matrix size: torch.Size([64, 16])
Target vector size: torch.Size([16])

Test:
Text matrix size: torch.Size([64, 16])
Target vector size: torch.Size([16])


## model

In [7]:
from transformers import BertForSequenceClassification, AdamW, BertConfig

# Load BertForSequenceClassification, the pretrained BERT model with a single
# linear classification layer on top.
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
    num_labels = 2, # The number of output labels--2 for binary classification.
                    # You can increase this for multi-class tasks.
    output_attentions = False, # Whether the model returns attentions weights.
    output_hidden_states = False, # Whether the model returns all hidden-states.
#     mirror='tuna',
)

# Tell pytorch to run this model on the GPU.
model.cuda()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [8]:
optimizer = AdamW(model.parameters(),
                  lr = LEARNING_RATE, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
                )

# optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

from transformers import get_linear_schedule_with_warmup

# Number of training epochs (authors recommend between 2 and 4)
epochs = 2

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_loader) * NUM_EPOCHS

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

In [9]:
def train_val(model,optimizer,criterion,train_loader, valid_loader, epochs, path):
    since = time.time()
    min_val_loss = 1e10
    for epoch in range(epochs):
        train_loss=0.0
        val_loss=0.0
        val_acc = 0.0
        
        # trainning
        model.train()
        print('training...')
        for idx, batch in enumerate(tqdm(train_loader)):
            if idx%50== 0 and not idx == 0:
                print("batch:",idx,"loss:",loss.data.item())
            optimizer.zero_grad()
            inputs, labels=batch.text, batch.label
#             print(inputs[0].size())
            # print(inputs[1].size())
            inputs=inputs[0].transpose(0,1).to(DEVICE)
            labels=labels.to(torch.int64).to(DEVICE)
            outputs = model(inputs,token_type_ids=None,labels=labels)
#             print(labels)
            loss = outputs[0]
        
            train_loss += loss.data.item()
            
#             Before=list(model.parameters())[0].clone()
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            
#             After = list(model.parameters())[0].clone()
            
#             print("update:",torch.sum(After-Before))
            
            # print(batch.text[0].size(0))
        train_loss /= len(train_loader)
        # print(len(train_loader))
        print(f"epoch: {epoch+1}: train loss:{train_loss}")
        
        print("valing")
        model.eval()
        with torch.no_grad():
            for val_idx, val_batch in enumerate(tqdm(valid_loader)):
                val_inputs, val_labels=val_batch.text, val_batch.label
                val_inputs = val_inputs[0].transpose(0,1).to(DEVICE)
                val_labels = val_labels.to(torch.int64).to(DEVICE)
                val_outputs =  model(val_inputs,token_type_ids=None)[0]
                val_outputs_acc = torch.argmax(val_outputs,axis=1).view(-1)
                val_acc+=torch.sum(val_outputs_acc==val_labels)
                loss = criterion(val_outputs,val_labels)
                val_loss += loss.data.item()
        val_loss /= len(valid_loader)
        val_acc /= len(valid_data)
        print(f"epoch: {epoch+1}: val loss:{val_loss} val acc:{val_acc}")

#         if val_loss<min_val_loss:
#             min_val_loss = val_loss
#             torch.save(model.state_dict(),path)
#             print('model saved!')

        
    time_elapsed = time.time() - since
    print('\nTraining complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

In [10]:
import torch.nn as nn
criterion = nn.CrossEntropyLoss()
train_val(model,optimizer,criterion,train_loader, valid_loader, NUM_EPOCHS, PATH)


training...


  4%|███▏                                                                            | 50/1250 [01:38<11:49,  1.69it/s]

batch: 50 loss: 0.6383523941040039


  8%|██████▍                                                                        | 101/1250 [02:09<11:33,  1.66it/s]

batch: 100 loss: 0.6873492002487183


 12%|█████████▍                                                                     | 150/1250 [02:38<11:03,  1.66it/s]

batch: 150 loss: 0.7341209650039673


 16%|████████████▋                                                                  | 200/1250 [03:08<10:30,  1.67it/s]

batch: 200 loss: 0.6647279858589172


 20%|███████████████▊                                                               | 251/1250 [03:39<10:02,  1.66it/s]

batch: 250 loss: 0.7270231246948242


 24%|██████████████████▉                                                            | 300/1250 [04:09<09:34,  1.65it/s]

batch: 300 loss: 0.6823735237121582


 28%|██████████████████████                                                         | 350/1250 [04:39<08:59,  1.67it/s]

batch: 350 loss: 0.7665712237358093


 32%|█████████████████████████▎                                                     | 400/1250 [05:09<08:31,  1.66it/s]

batch: 400 loss: 0.7034977078437805


 36%|████████████████████████████▍                                                  | 450/1250 [05:39<08:02,  1.66it/s]

batch: 450 loss: 0.6697101593017578


 40%|███████████████████████████████▌                                               | 500/1250 [06:09<07:33,  1.65it/s]

batch: 500 loss: 0.691288411617279


 44%|██████████████████████████████████▊                                            | 550/1250 [06:40<07:02,  1.66it/s]

batch: 550 loss: 0.6686034798622131


 48%|█████████████████████████████████████▉                                         | 600/1250 [07:10<06:32,  1.66it/s]

batch: 600 loss: 0.6973676085472107


 52%|█████████████████████████████████████████                                      | 650/1250 [07:40<06:03,  1.65it/s]

batch: 650 loss: 0.7103231549263


 56%|████████████████████████████████████████████▏                                  | 700/1250 [08:10<05:29,  1.67it/s]

batch: 700 loss: 0.6918174624443054


 60%|███████████████████████████████████████████████▍                               | 750/1250 [08:40<05:03,  1.65it/s]

batch: 750 loss: 0.6254110336303711


 64%|██████████████████████████████████████████████████▌                            | 801/1250 [09:11<04:30,  1.66it/s]

batch: 800 loss: 0.6177152395248413


 68%|█████████████████████████████████████████████████████▋                         | 850/1250 [09:41<04:02,  1.65it/s]

batch: 850 loss: 0.6991162300109863


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [10:11<03:31,  1.66it/s]

batch: 900 loss: 0.6769292950630188


 76%|████████████████████████████████████████████████████████████                   | 950/1250 [10:41<03:00,  1.66it/s]

batch: 950 loss: 0.8073201179504395


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [11:11<02:31,  1.65it/s]

batch: 1000 loss: 0.6525235176086426


 84%|█████████████████████████████████████████████████████████████████▌            | 1050/1250 [11:42<02:01,  1.65it/s]

batch: 1050 loss: 0.6871134638786316


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [12:12<01:30,  1.67it/s]

batch: 1100 loss: 0.6411629319190979


 92%|███████████████████████████████████████████████████████████████████████▊      | 1150/1250 [12:42<01:00,  1.67it/s]

batch: 1150 loss: 0.6024967432022095


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [13:13<00:30,  1.66it/s]

batch: 1200 loss: 0.6682289242744446


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [13:43<00:00,  1.52it/s]


epoch: 1: train loss:0.6953458051204682
valing


100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:52<00:00,  5.91it/s]


epoch: 1: val loss:0.6749234770814451 val acc:0.5853999853134155
training...


  4%|███▎                                                                            | 51/1250 [00:30<12:05,  1.65it/s]

batch: 50 loss: 0.7065174579620361


  8%|██████▎                                                                        | 100/1250 [00:59<11:28,  1.67it/s]

batch: 100 loss: 0.6664404273033142


 12%|█████████▍                                                                     | 150/1250 [01:30<11:07,  1.65it/s]

batch: 150 loss: 0.6892347931861877


 16%|████████████▋                                                                  | 201/1250 [02:00<10:30,  1.66it/s]

batch: 200 loss: 0.673011839389801


 20%|███████████████▊                                                               | 250/1250 [02:30<10:00,  1.67it/s]

batch: 250 loss: 0.6029416918754578


 24%|██████████████████▉                                                            | 300/1250 [03:00<09:41,  1.63it/s]

batch: 300 loss: 0.6327592134475708


 28%|██████████████████████                                                         | 350/1250 [03:30<08:58,  1.67it/s]

batch: 350 loss: 0.6148663759231567


 32%|█████████████████████████▎                                                     | 400/1250 [04:00<08:32,  1.66it/s]

batch: 400 loss: 0.8347190618515015


 36%|████████████████████████████▍                                                  | 450/1250 [04:31<08:02,  1.66it/s]

batch: 450 loss: 0.7777694463729858


 40%|███████████████████████████████▌                                               | 500/1250 [05:01<07:32,  1.66it/s]

batch: 500 loss: 0.6324566602706909


 44%|██████████████████████████████████▊                                            | 550/1250 [05:31<07:03,  1.65it/s]

batch: 550 loss: 0.8103207945823669


 48%|█████████████████████████████████████▉                                         | 600/1250 [06:01<06:31,  1.66it/s]

batch: 600 loss: 0.6263577342033386


 52%|█████████████████████████████████████████                                      | 650/1250 [06:32<06:01,  1.66it/s]

batch: 650 loss: 0.5702188611030579


 56%|████████████████████████████████████████████▏                                  | 700/1250 [07:02<05:35,  1.64it/s]

batch: 700 loss: 0.6350577473640442


 60%|███████████████████████████████████████████████▍                               | 750/1250 [07:32<05:03,  1.64it/s]

batch: 750 loss: 0.6593537926673889


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [08:03<04:34,  1.64it/s]

batch: 800 loss: 0.6269738674163818


 68%|█████████████████████████████████████████████████████▋                         | 850/1250 [08:33<04:02,  1.65it/s]

batch: 850 loss: 0.6406251192092896


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [09:04<03:32,  1.65it/s]

batch: 900 loss: 0.4666021168231964


 76%|████████████████████████████████████████████████████████████                   | 950/1250 [09:34<03:00,  1.66it/s]

batch: 950 loss: 0.5642251968383789


 80%|██████████████████████████████████████████████████████████████▍               | 1001/1250 [10:05<02:31,  1.65it/s]

batch: 1000 loss: 0.5271121263504028


 84%|█████████████████████████████████████████████████████████████████▌            | 1050/1250 [10:35<02:01,  1.64it/s]

batch: 1050 loss: 0.516744077205658


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [11:05<01:30,  1.65it/s]

batch: 1100 loss: 0.6457004547119141


 92%|███████████████████████████████████████████████████████████████████████▊      | 1150/1250 [11:35<01:01,  1.64it/s]

batch: 1150 loss: 0.5530329942703247


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [12:06<00:30,  1.66it/s]

batch: 1200 loss: 0.5930404663085938


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [12:36<00:00,  1.65it/s]


epoch: 2: train loss:0.6383023795366287
valing


100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:53<00:00,  5.86it/s]


epoch: 2: val loss:0.6378705627240312 val acc:0.6539999842643738
training...


  4%|███▏                                                                            | 50/1250 [00:30<12:15,  1.63it/s]

batch: 50 loss: 0.6980108022689819


  8%|██████▎                                                                        | 100/1250 [01:00<11:40,  1.64it/s]

batch: 100 loss: 0.44636860489845276


 12%|█████████▍                                                                     | 150/1250 [01:30<11:11,  1.64it/s]

batch: 150 loss: 0.5512357950210571


 16%|████████████▋                                                                  | 200/1250 [02:01<10:37,  1.65it/s]

batch: 200 loss: 0.7590818405151367


 20%|███████████████▊                                                               | 250/1250 [02:31<10:10,  1.64it/s]

batch: 250 loss: 0.5131892561912537


 24%|██████████████████▉                                                            | 300/1250 [03:02<09:44,  1.63it/s]

batch: 300 loss: 0.5294904112815857


 28%|██████████████████████▏                                                        | 351/1250 [03:33<09:05,  1.65it/s]

batch: 350 loss: 0.8582481741905212


 32%|█████████████████████████▎                                                     | 400/1250 [04:03<08:40,  1.63it/s]

batch: 400 loss: 0.47246745228767395


 36%|████████████████████████████▌                                                  | 451/1250 [04:34<08:09,  1.63it/s]

batch: 450 loss: 0.46520280838012695


 40%|███████████████████████████████▌                                               | 500/1250 [05:04<07:40,  1.63it/s]

batch: 500 loss: 0.7468345165252686


 44%|██████████████████████████████████▊                                            | 551/1250 [05:35<07:04,  1.65it/s]

batch: 550 loss: 0.5817596912384033


 48%|█████████████████████████████████████▉                                         | 601/1250 [06:06<06:34,  1.65it/s]

batch: 600 loss: 0.5700541138648987


 52%|█████████████████████████████████████████                                      | 650/1250 [06:35<06:06,  1.64it/s]

batch: 650 loss: 0.5327909588813782


 56%|████████████████████████████████████████████▏                                  | 700/1250 [07:06<05:35,  1.64it/s]

batch: 700 loss: 0.6202393770217896


 60%|███████████████████████████████████████████████▍                               | 750/1250 [07:36<05:04,  1.64it/s]

batch: 750 loss: 0.8207709789276123


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [08:07<04:34,  1.64it/s]

batch: 800 loss: 0.5097620487213135


 68%|█████████████████████████████████████████████████████▋                         | 850/1250 [08:37<04:04,  1.63it/s]

batch: 850 loss: 0.40602797269821167


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [09:08<03:31,  1.65it/s]

batch: 900 loss: 0.5602698922157288


 76%|████████████████████████████████████████████████████████████                   | 950/1250 [09:38<03:03,  1.63it/s]

batch: 950 loss: 0.37067362666130066


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [10:08<02:33,  1.63it/s]

batch: 1000 loss: 0.3593118190765381


 84%|█████████████████████████████████████████████████████████████████▌            | 1050/1250 [10:38<02:01,  1.65it/s]

batch: 1050 loss: 0.477824330329895


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [11:09<01:30,  1.66it/s]

batch: 1100 loss: 0.6273440718650818


 92%|███████████████████████████████████████████████████████████████████████▊      | 1151/1250 [11:40<00:58,  1.69it/s]

batch: 1150 loss: 0.3836294710636139


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [12:09<00:30,  1.63it/s]

batch: 1200 loss: 0.5584975481033325


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [12:39<00:00,  1.64it/s]


epoch: 3: train loss:0.5633237728595734
valing


100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:52<00:00,  5.91it/s]


epoch: 3: val loss:0.5576662848742244 val acc:0.7217999696731567
training...


  4%|███▏                                                                            | 50/1250 [00:29<12:08,  1.65it/s]

batch: 50 loss: 0.49496397376060486


  8%|██████▎                                                                        | 100/1250 [01:00<11:30,  1.67it/s]

batch: 100 loss: 0.5001287460327148


 12%|█████████▍                                                                     | 150/1250 [01:30<11:11,  1.64it/s]

batch: 150 loss: 0.5843247771263123


 16%|████████████▋                                                                  | 201/1250 [02:01<10:34,  1.65it/s]

batch: 200 loss: 0.35362082719802856


 20%|███████████████▊                                                               | 250/1250 [02:30<10:00,  1.67it/s]

batch: 250 loss: 0.6555014848709106


 24%|██████████████████▉                                                            | 300/1250 [03:01<09:34,  1.65it/s]

batch: 300 loss: 0.7150256037712097


 28%|██████████████████████                                                         | 350/1250 [03:31<09:03,  1.66it/s]

batch: 350 loss: 0.6478661298751831


 32%|█████████████████████████▎                                                     | 400/1250 [04:01<08:35,  1.65it/s]

batch: 400 loss: 0.6792751550674438


 36%|████████████████████████████▍                                                  | 450/1250 [04:31<08:01,  1.66it/s]

batch: 450 loss: 0.6323157548904419


 40%|███████████████████████████████▌                                               | 500/1250 [05:02<07:34,  1.65it/s]

batch: 500 loss: 0.31502461433410645


 44%|██████████████████████████████████▊                                            | 550/1250 [05:32<07:13,  1.61it/s]

batch: 550 loss: 0.5604267120361328


 48%|█████████████████████████████████████▉                                         | 601/1250 [06:03<06:31,  1.66it/s]

batch: 600 loss: 0.16701672971248627


 52%|█████████████████████████████████████████                                      | 650/1250 [06:33<06:04,  1.65it/s]

batch: 650 loss: 0.47688427567481995


 56%|████████████████████████████████████████████▏                                  | 700/1250 [07:03<05:34,  1.64it/s]

batch: 700 loss: 0.4945117235183716


 60%|███████████████████████████████████████████████▍                               | 750/1250 [07:33<05:00,  1.66it/s]

batch: 750 loss: 0.5775353312492371


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [08:03<04:29,  1.67it/s]

batch: 800 loss: 0.5904060006141663


 68%|█████████████████████████████████████████████████████▋                         | 850/1250 [08:33<04:00,  1.66it/s]

batch: 850 loss: 0.48135992884635925


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [09:03<03:31,  1.66it/s]

batch: 900 loss: 0.5441907644271851


 76%|████████████████████████████████████████████████████████████                   | 950/1250 [09:34<02:58,  1.68it/s]

batch: 950 loss: 0.37045133113861084


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [10:04<02:30,  1.66it/s]

batch: 1000 loss: 0.4378727972507477


 84%|█████████████████████████████████████████████████████████████████▌            | 1050/1250 [10:34<02:00,  1.67it/s]

batch: 1050 loss: 0.5358057618141174


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [11:04<01:29,  1.67it/s]

batch: 1100 loss: 0.3715573251247406


 92%|███████████████████████████████████████████████████████████████████████▊      | 1150/1250 [11:34<01:00,  1.65it/s]

batch: 1150 loss: 0.38805827498435974


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1201/1250 [12:05<00:29,  1.65it/s]

batch: 1200 loss: 0.565887987613678


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [12:34<00:00,  1.66it/s]


epoch: 4: train loss:0.5124421606302262
valing


100%|████████████████████████████████████████████████████████████████████████████████| 313/313 [00:53<00:00,  5.90it/s]

epoch: 4: val loss:0.5488730688064624 val acc:0.7319999933242798

Training complete in 55m 7s





In [11]:
def test(model, test_loader):
    acc = 0.0
    test_loss = 0.0
    model.eval()
    with torch.no_grad():
        for idx, batch in enumerate(tqdm(test_loader)):
            inputs, labels=batch.text, batch.label
            inputs = inputs[0].transpose(0,1).to(DEVICE)
            labels = labels.to(torch.int64).to(DEVICE)
            outputs = model(inputs,token_type_ids=None)[0]
            loss = criterion(outputs,labels)
            test_loss += loss.data.item()
            outputs = torch.argmax(outputs,axis=1).view(-1)
            acc+=torch.sum(outputs==labels)
    acc /= len(test_data)
    test_loss /= len(test_loader)
    print(f"test loss: {test_loss}: test acc:{acc}")

In [12]:
test(model, test_loader)

100%|██████████████████████████████████████████████████████████████████████████████| 1563/1563 [04:46<00:00,  5.46it/s]

test loss: 0.5694077742534498: test acc:0.7275999784469604



