## 3. 使用CNN进行文本分类 

<img src='img/textcnn.jfif' width=500>

reference:
- https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html
- https://github.com/649453932/Chinese-Text-Classification-Pytorch/

In [1]:
import argparse

hparams = argparse.Namespace(**{
    'batch_size': 16,
    'learning_rate': 0.004,
    # 'max_grad_norm': 1.,
    'max_length': 2000,
    'dropout': 0.2,
    'embedding_dim': 200,
    'hidden_dim': 200,
    'seed': 42,
    'num_filters': 200,
    'filter_sizes': [1, 2, 3],
    'num_train_epochs': 20,
    'model_save_path': 'data/save_model/textcnn.path',
})

hparams

Namespace(batch_size=16, dropout=0.2, embedding_dim=200, filter_sizes=[1, 2, 3], hidden_dim=200, learning_rate=0.004, max_length=2000, model_save_path='data/save_model/textcnn.path', num_filters=200, num_train_epochs=20, seed=42)

### 加载数据

In [2]:
from nltk.corpus import movie_reviews
import random
random.seed(hparams.seed)


def load_movie_reviews():
    pos_ids = movie_reviews.fileids('pos')
    neg_ids = movie_reviews.fileids('neg')

    all_reviews = []
    for pids in pos_ids:
        all_reviews.append((movie_reviews.raw(pids), 'positive'))
    
    for nids in neg_ids:
        all_reviews.append((movie_reviews.raw(nids), 'negative'))

    random.shuffle(all_reviews)
    train_reviews = all_reviews[:1600]
    test_reviews = all_reviews[1600:]

    return train_reviews, test_reviews

train_reviews, test_reviews = load_movie_reviews()
print('train:', len(train_reviews))
print('test:', len(test_reviews))

train: 1600
test: 400


### Tokenize

In [3]:
from nltk import word_tokenize


train_reviews_tokenized = []
train_labels = []

for review, label in train_reviews:
    label = 0 if label == 'negative' else 1
    tokenized = word_tokenize(review)

    train_labels.append(label)
    train_reviews_tokenized.append(tokenized)


test_reviews_tokenized = []
test_labels = []

for review, label in test_reviews:
    label = 0 if label == 'negative' else 1
    tokenized = word_tokenize(review)

    test_labels.append(label)
    test_reviews_tokenized.append(tokenized)

### 建立词表、将单词变成id

In [9]:
from collections import Counter
from torchtext.vocab import vocab


counter = Counter()
for review in train_reviews_tokenized:# + test_reviews_tokenized:
    counter.update(review)

vocab = vocab(counter, min_freq=1, specials=['<unk>', '<pad>', '<sos>', '<eos>'])
# vocab = Vocab(counter, specials=['<unk>', '<pad>', '<sos>', '<eos>'])

hparams.vocab_size = len(vocab)
hparams.pad_id = vocab['<pad>']
hparams.num_classes = 2

print(hparams.vocab_size)

42013


In [11]:
vocab.set_default_index(vocab['<unk>'])
train_reviews_ids = [vocab.lookup_indices(review) for review in train_reviews_tokenized]
test_reviews_ids = [vocab.lookup_indices(review) for review in test_reviews_tokenized]

### 将数据打包为dataloader

In [12]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch


class TextDataset(Dataset):
    def __init__(self, reviews, labels):
        self.reviews = reviews
        self.labels = labels

    def __getitem__(self, index):
        return self.reviews[index], self.labels[index]

    def __len__(self):
        return len(self.reviews)


def collate_to_max_length(batch):
    X_batch = []
    y_batch = []
    for X, y in batch:
        if len(X) >= hparams.max_length:
            X = X[:hparams.max_length]
        else:
            X = X + [hparams.pad_id] * (hparams.max_length-len(X))

        X_batch.append(X)
        y_batch.append(y)

    return torch.tensor(X_batch), torch.tensor(y_batch)


train_dataset = TextDataset(train_reviews_ids, train_labels)
test_dataset = TextDataset(test_reviews_ids, test_labels)


train_dataloader = DataLoader(
    dataset=train_dataset, 
    batch_size=hparams.batch_size, 
    collate_fn=collate_to_max_length, 
    shuffle=True)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=hparams.batch_size,
    collate_fn=collate_to_max_length,
    shuffle=False)

### 定义模型

In [13]:
from torch import nn
from torch.nn import functional as F


class TextCNN(nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams    

        self.embedding = nn.Embedding(
            hparams.vocab_size, 
            hparams.embedding_dim, 
            padding_idx=hparams.pad_id)

        self.convs = nn.ModuleList([
            nn.Conv2d(1, hparams.num_filters, (k, hparams.embedding_dim))
            for k in hparams.filter_sizes
        ])
        self.dropout = nn.Dropout(hparams.dropout)

        hidden_size = hparams.num_filters * len(hparams.filter_sizes)
        self.classifier = nn.Linear(hidden_size, hparams.num_classes)
            
        # self.init_weights()

    def init_weights(self):
        for name, w in self.named_parameters():
            if 'weight' in name:
                # w.data.xavier_normal_()
                nn.init.xavier_normal_(w)
            elif 'bias' in name:
                w.data.zero_()

    def forward(self, x):
        # [B, L, embedding_dim]
        embed = self.embedding(x)
        # [B, 1, L, embedding_dim]
        embed = embed.unsqueeze(1)
        
        # [(B, num_filters), ...] => [(B, num_filters*len(filter_sizes))]
        hidden = torch.cat([self.conv_and_pool(embed, conv) for conv in self.convs], dim=1)
        hidden = self.dropout(hidden)
        logits = self.classifier(hidden)

        return logits

    def conv_and_pool(self, x, conv):
        # (B, 1, L, embedding_dim) => (B, 1, L, 1, num_filters)
        # (B, 1, L, 1, num_filters) => (B, 1, L, num_filters)
        x = F.relu(conv(x).squeeze(3))
        # (B, 1, L, num_filters) => (B, 1, num_filters)
        # (B, 1, num_filters) => (B, num_filters)
        x = F.max_pool1d(x, x.size(2)).squeeze(2)
        return x

In [46]:
# vocab.load_vectors('glove.6B.200d')

In [14]:
model = TextCNN(hparams)

# model.embedding.weight.data.copy_(vocab.vectors)
# model.embedding.weight.requires_grad = False
if torch.cuda.is_available():
    model.cuda()

In [15]:
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=hparams.learning_rate, momentum=0.9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.95**epoch)

In [16]:
from tqdm import tqdm

def train(model, dataloader, loss_func, optimizer, epoch_idx, hparams):
    model.train()
    
    pbar = tqdm(dataloader)
    pbar.set_description(f'Epoch {epoch_idx}')

    for X, y in pbar:
        if torch.cuda.is_available():
            X = X.cuda()
            y = y.cuda()
        
        optimizer.zero_grad()
        output = model(X)  # (B, 2)
        loss = loss_func(output, y)
        loss.backward()
        optimizer.step()

        pbar.set_postfix(loss=loss.item())

In [17]:
def evaluate(model, dataloader, loss_func):
    model.eval()
    size = len(dataloader.dataset)
    total_loss = 0.
    correct_num = 0

    with torch.no_grad():
        pbar = tqdm(dataloader)
        pbar.set_description('Valid')
        for X, y in pbar:
            if torch.cuda.is_available():
                X = X.cuda()
                y = y.cuda()
            output = model(X)
            
            loss = loss_func(output, y)
            total_loss += loss.item()
            
            correct_num = correct_num + (output.argmax(1) == y).float().sum().item()

        avg_loss = total_loss / len(dataloader)
        accuracy = correct_num / len(dataloader.dataset)

    return avg_loss, accuracy

In [18]:
best_val_loss = None
accuracy_at_lowest_loss = 0
best_accuracy = 0

for epoch_idx in range(hparams.num_train_epochs):
    train(model, train_dataloader, loss_func, optimizer, epoch_idx+1, hparams)
    scheduler.step()
    val_loss, accuracy = evaluate(model, test_dataloader, loss_func)
    best_accuracy = max(best_accuracy, accuracy)
    print(f'\r[Validation] loss: {val_loss:.4f}, accuracy: {accuracy:.4f}, LR: {scheduler.get_last_lr()}     ')

    if not best_val_loss or val_loss < best_val_loss:
        torch.save(model.state_dict(), hparams.model_save_path)
        print(f'\rsave model to {hparams.model_save_path}\n\n')
        best_val_loss = val_loss
        accuracy_at_lowest_loss = accuracy

print(f'accuracy_at_lowest_loss: {accuracy_at_lowest_loss}, best_accuracy: {best_accuracy}')


Epoch 1: 100%|████████████████████████████████████████████████████████████| 100/100 [00:15<00:00,  6.52it/s, loss=2.31]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 95.67it/s]


[Validation] loss: 2.6607, accuracy: 0.5050, LR: [0.0038]     
save model to data/save_model/textcnn.path




Epoch 2: 100%|███████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 58.40it/s, loss=0.765]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 95.67it/s]


[Validation] loss: 0.8244, accuracy: 0.6925, LR: [0.00361]     
save model to data/save_model/textcnn.path




Epoch 3: 100%|███████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.93it/s, loss=0.442]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.03it/s]


[Validation] loss: 0.5675, accuracy: 0.7850, LR: [0.0034295]     
save model to data/save_model/textcnn.path




Epoch 4: 100%|███████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.51it/s, loss=0.345]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.03it/s]


[Validation] loss: 0.8155, accuracy: 0.7300, LR: [0.0032580249999999995]     


Epoch 5: 100%|███████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.93it/s, loss=0.701]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.03it/s]


[Validation] loss: 0.5656, accuracy: 0.7875, LR: [0.003095123749999999]     
save model to data/save_model/textcnn.path




Epoch 6: 100%|██████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.68it/s, loss=0.0219]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.03it/s]


[Validation] loss: 0.9408, accuracy: 0.7500, LR: [0.0029403675624999994]     


Epoch 7: 100%|███████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.93it/s, loss=0.333]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.41it/s]


[Validation] loss: 0.6598, accuracy: 0.7850, LR: [0.002793349184374999]     


Epoch 8: 100%|███████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.86it/s, loss=0.471]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.04it/s]


[Validation] loss: 0.7276, accuracy: 0.8000, LR: [0.002653681725156249]     


Epoch 9: 100%|███████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.90it/s, loss=0.989]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.04it/s]


[Validation] loss: 1.4216, accuracy: 0.7000, LR: [0.0025209976388984364]     


Epoch 10: 100%|█████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.97it/s, loss=0.0247]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.40it/s]


[Validation] loss: 0.5436, accuracy: 0.8350, LR: [0.0023949477569535148]     
save model to data/save_model/textcnn.path




Epoch 11: 100%|████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.97it/s, loss=0.00733]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.41it/s]


[Validation] loss: 0.6729, accuracy: 0.8025, LR: [0.0022752003691058386]     


Epoch 12: 100%|█████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.90it/s, loss=0.0124]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 97.16it/s]


[Validation] loss: 0.5416, accuracy: 0.8325, LR: [0.0021614403506505465]     
save model to data/save_model/textcnn.path




Epoch 13: 100%|████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.90it/s, loss=0.00684]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.04it/s]


[Validation] loss: 0.5546, accuracy: 0.8400, LR: [0.002053368333118019]     


Epoch 14: 100%|██████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.86it/s, loss=0.118]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 94.59it/s]


[Validation] loss: 0.5583, accuracy: 0.8350, LR: [0.0019506999164621182]     


Epoch 15: 100%|█████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 60.08it/s, loss=0.0949]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 91.82it/s]


[Validation] loss: 0.6302, accuracy: 0.8350, LR: [0.001853164920639012]     


Epoch 16: 100%|██████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.72it/s, loss=0.277]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 91.14it/s]


[Validation] loss: 1.0945, accuracy: 0.7650, LR: [0.0017605066746070614]     


Epoch 17: 100%|█████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.83it/s, loss=0.0481]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.78it/s]


[Validation] loss: 0.5800, accuracy: 0.8350, LR: [0.0016724813408767083]     


Epoch 18: 100%|█████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.86it/s, loss=0.0172]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 97.15it/s]


[Validation] loss: 0.5524, accuracy: 0.8525, LR: [0.0015888572738328728]     


Epoch 19: 100%|█████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.82it/s, loss=0.0102]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 95.30it/s]


[Validation] loss: 0.5377, accuracy: 0.8500, LR: [0.001509414410141229]     
save model to data/save_model/textcnn.path




Epoch 20: 100%|███████████████████████████████████████████████████████| 100/100 [00:01<00:00, 59.65it/s, loss=0.000122]
Valid: 100%|███████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 96.03it/s]

[Validation] loss: 0.6067, accuracy: 0.8250, LR: [0.0014339436896341675]     
accuracy_at_lowest_loss: 0.85, best_accuracy: 0.8525





In [None]:
拓展

- 如何理解textcnn中的卷积核和pooling层
- 如何确定卷积核的大小，调参？
    - RCNN[1]

[1] Lai, Siwei, et al. "Recurrent convolutional neural networks for text classification." AAAI2015.