In [1]:
from torch import nn, optim
from torch.utils.data import DataLoader
import torch

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import numpy as np

from tqdm import tqdm

from IMDBDataset import IMDBDataset
from IMDBTokenizers import IMDBTokenizer
from IMDBmodels import IMDBLstm

# dataLoader
# tokenizer
# Model
# train
# showLossAndError


In [2]:
# 常量层
torch.manual_seed(1)


word_dim=100
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

In [3]:
imdbTokenizer = IMDBTokenizer(vocab_path='aclImdb/imdb.vocab', glove_path='../../glove.6B.100d.txt', word_dim=word_dim, special_tokens=['<PAD>', '<UNK>'])

400001it [00:10, 36618.47it/s]
100%|██████████| 62596/62596 [00:00<00:00, 389919.23it/s]


In [4]:
trainIMDBDataset = IMDBDataset('./aclImdb/train')
trainIMDBDataLoader = DataLoader(trainIMDBDataset, batch_size=16, shuffle=True)
trainSmallIMDBDataLoader = DataLoader(trainIMDBDataset, batch_size=1, shuffle=True)

testIMDBDataset = IMDBDataset('./aclImdb/test')
testIMDBDataLoader = DataLoader(testIMDBDataset, batch_size=8, shuffle=True)

共有pos数据 12500 条, neg数据 12500 条，共 25000 条
共有pos数据 12500 条, neg数据 12500 条，共 25000 条


In [5]:
imdbLSTM = IMDBLstm(imdbTokenizer, 100, word_dim, 256, 2, True, num_labels=2, device=device)

In [6]:
imdbLSTM = imdbLSTM.to(device)

# 查看模型的各个层，设置初始化策略，设置参数
base_lr = 0.1

embedding_parameters = list(map(id, imdbLSTM.embedding.parameters()))

base_params = filter(lambda p: id(p) not in embedding_parameters, imdbLSTM.parameters())

optimizer = optim.Adam([{'params': base_params}, 
                        {'params': imdbLSTM.embedding.parameters(), 'lr': base_lr * 0.1}], lr=base_lr,betas=(0.9,0.999))


criterion = nn.NLLLoss()

In [None]:
epochs = 10
skip = 100
import copy
model = imdbLSTM

for epoch in range(epochs):
    model = model.to(device)
    print("epoch {}/{}".format(epoch, epochs))
    
    epoch_total_loss = 0.0
    epoch_acc_num = 0
    epoch_total_num = 0
    
    best_model = copy.deepcopy(model.state_dict())
    best_model_acc = 0.0
    
    for i, (inputs, labels) in enumerate(trainIMDBDataLoader):
        labels = labels.to(device)
        
        inputs = [str(i) for i in inputs]
        imdbLSTM.train()
        optimizer.zero_grad()
        outputs = imdbLSTM(inputs)
        loss = criterion(outputs, labels)
        
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        acc_num = torch.sum(preds==labels)
        
        epoch_total_loss += loss.item() * len(outputs)
        epoch_total_num += len(outputs)
        epoch_acc_num += acc_num
                    
        
        if i % skip == 0:
            print('Mini epoch: {}/{} loss {}, acc {}/{}={}'.format(i, len(trainIMDBDataLoader), loss.item(), acc_num, len(labels), acc_num.double()/len(labels)))
        else:
            
    epoch_acc = epoch_acc_num.double()/epoch_total_num
    print('Epoch: {}/{} loss:{}, acc {}/{}={}'.format(epoch, epochs, epoch_total_loss/epoch_total_num, epoch_acc_num, epoch_total_num, epoch_acc))
    
    if epoch_acc > best_model_acc:
        best_model = copy.deepcopy(model.state_dict())


epoch 0/10
Mini epoch: 0/1563 loss 0.6902734041213989, acc 9/16=0.5625
Mini epoch: 100/1563 loss 2.6618688106536865, acc 9/16=0.5625
Mini epoch: 200/1563 loss 1.402869462966919, acc 10/16=0.625
Mini epoch: 300/1563 loss 2.0885932445526123, acc 11/16=0.6875
Mini epoch: 400/1563 loss 6.414031505584717, acc 8/16=0.5
Mini epoch: 500/1563 loss 1.8173062801361084, acc 7/16=0.4375
Mini epoch: 600/1563 loss 0.8748462200164795, acc 6/16=0.375
Mini epoch: 700/1563 loss 0.7348121404647827, acc 7/16=0.4375
Mini epoch: 800/1563 loss 1.4061627388000488, acc 4/16=0.25
Mini epoch: 900/1563 loss 3.287309169769287, acc 9/16=0.5625
Mini epoch: 1000/1563 loss 0.6530912518501282, acc 11/16=0.6875
Mini epoch: 1100/1563 loss 2.831125497817993, acc 4/16=0.25
Mini epoch: 1200/1563 loss 0.8249785900115967, acc 9/16=0.5625
Mini epoch: 1300/1563 loss 2.341109275817871, acc 10/16=0.625
Mini epoch: 1400/1563 loss 0.9394770264625549, acc 8/16=0.5
Mini epoch: 1500/1563 loss 0.7378884553909302, acc 10/16=0.625
Epoch: 