In [1]:
"""对IMDB电影评论进行情感分析， 属于文本分类中的二分类任务 positive/negative   two-class classification"""
# 50000 movie reviews, train_split:25000,test_split:25000
import torch
from torch.utils.data import Dataset, DataLoader
import os
import re
from collections import Counter
from torch import nn
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
data_base_path = './aclImdb'
UNK_TAG = 'UNK' # 表示未知词
PAD_TAG = 'PAD' # 表示填充词
reserved_tokens = [UNK_TAG, PAD_TAG]

In [3]:
# 对文本进行分词
def tokenize(text):
    text = re.sub('[^A-Za-z]+', ' ', text)
    return [i.lower().strip() for i in text.split()]

In [5]:
# 词频统计
def count_corpus():
    text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]
    total_file_path_list = []
    tokens_list = []
    for i in text_path:
        total_file_path_list += [os.path.join(i, j) for j in os.listdir(i)]
    for i in total_file_path_list:
        with open(i, 'r', encoding='utf-8') as f:
            text = tokenize(f.read())
            tokens_list += text

    return Counter(tokens_list)


['0_3.txt',
 '10000_4.txt',
 '10001_4.txt',
 '10002_1.txt',
 '10003_1.txt',
 '10004_3.txt',
 '10005_3.txt',
 '10006_4.txt',
 '10007_1.txt',
 '10008_2.txt',
 '10009_1.txt',
 '1000_4.txt',
 '10010_3.txt',
 '10011_3.txt',
 '10012_1.txt',
 '10013_1.txt',
 '10014_2.txt',
 '10015_2.txt',
 '10016_4.txt',
 '10017_4.txt',
 '10018_3.txt',
 '10019_3.txt',
 '1001_4.txt',
 '10020_3.txt',
 '10021_2.txt',
 '10022_4.txt',
 '10023_1.txt',
 '10024_3.txt',
 '10025_1.txt',
 '10026_2.txt',
 '10027_1.txt',
 '10028_2.txt',
 '10029_1.txt',
 '1002_3.txt',
 '10030_1.txt',
 '10031_2.txt',
 '10032_4.txt',
 '10033_1.txt',
 '10034_1.txt',
 '10035_1.txt',
 '10036_1.txt',
 '10037_1.txt',
 '10038_3.txt',
 '10039_1.txt',
 '1003_3.txt',
 '10040_2.txt',
 '10041_1.txt',
 '10042_1.txt',
 '10043_1.txt',
 '10044_1.txt',
 '10045_1.txt',
 '10046_1.txt',
 '10047_1.txt',
 '10048_4.txt',
 '10049_1.txt',
 '1004_4.txt',
 '10050_2.txt',
 '10051_4.txt',
 '10052_4.txt',
 '10053_4.txt',
 '10054_1.txt',
 '10055_3.txt',
 '10056_2.txt',
 

In [9]:
# 构建字典
class Vocab:
    def __init__(self, reseverd_tokens):
        temp = count_corpus()
        token_freqs = sorted(temp.items(), key=lambda x: x[1], reverse=True)
        self.unique_tokens = reserved_tokens + [token for token, freqs in token_freqs] 
        token_index = [i for i in range(len(self.unique_tokens))]
        self.unique_dict = dict(zip(self.unique_tokens, token_index))

    def __getitem__(self, idx):
        return self.unique_tokens[idx]

    def __len__(self):
        return len(self.unique_tokens)

    def token_to_index(self, token):
        return self.unique_dict[token]

    def transform(self, text, max_len=None): # 大于截短， 小于补充
        if max_len is None:
            return text
        else:
            if len(text) < max_len:
                text = text + [self.unique_tokens[1]] * (max_len - len(text))
            else:
                text = text[:max_len]

        result = []
        for token in text:
            result.append(self.token_to_index(token))

        return result


In [11]:
vocab = Vocab(reserved_tokens)
vocab_size = len(vocab)
# vocab[1]
vocab_size


73274

In [15]:
vocab[0]

'UNK'

In [16]:
# 准备dataset, neg为0， pos为1
class ImdbDataset(Dataset):
    def __init__(self, is_train=True):
        super(ImdbDataset, self).__init__()
        # 读取所有训练文件夹名称
        if is_train == True:
            text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]
        else:
            text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]

        self.total_file_path_list = []
        self.labels = []
        # 进一步获取所有文件名称
        for num, i in enumerate(text_path):
            self.total_file_path_list += [os.path.join(i, j) for j in os.listdir(i)]
            self.labels += [num]*len(os.listdir(i))

    def __getitem__(self, idx):
        with open(self.total_file_path_list[idx], "r", encoding='utf-8') as f:
            text = f.read()

        tokens = tokenize(text)
        return tokens, self.labels[idx]

    def __len__(self):
        return len(self.labels)


In [17]:
dataset = ImdbDataset()

In [None]:
batch = [
        (['hello', 'world'], 0), 
        (['hao', 'hao', 'xue', 'xi'], 1)
]

In [9]:
# 直接送入dataloader会报错，需要解决文本长度不一致的问题
def collate_fn(batch): # batch: [(tokens, lable), (tokens, lable)]
    batch = list(zip(*batch))
    labels = torch.tensor(batch[1], dtype=torch.int32)
    tokens = batch[0]
    tokens = torch.tensor([vocab.transform(i, max_len=50) for i in tokens])
    return tokens.long(), labels.long()


In [10]:
dataloader = DataLoader(dataset, batch_size=512, shuffle=True, collate_fn=collate_fn, drop_last=True)

In [11]:
# 构建网络
class IMDBModel(nn.Module):
    def __init__(self, vocab_size, embed_size, max_len):
        super(IMDBModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.net = nn.Sequential(
            nn.Linear(max_len*embed_size, 512),
            nn.ReLU(),
            nn.Linear(512, 32),
            nn.ReLU(),
            nn.Linear(32, 2)
        )

    
    def forward(self, x):
        # 输入x的shape: [batch_size, max_len]
        x = self.embedding(x) # shape:[batch_size, max_len, embed_size]
        x = x.reshape(x.shape[0], -1)
        output = self.net(x)
        return output

In [12]:
embed_size = 300
max_len = 50
lr = 0.01
model = IMDBModel(vocab_size, embed_size=embed_size, max_len=max_len).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [13]:
# 模型训练
def train(epoches, dataloader):
    start = time.time()
    for i in range(epoches):
        running_loss = 0
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = loss_fn(output, y)
            loss.backward()
            optimizer.step()
            running_loss += loss
        print(f'epoch{i}: {running_loss}')
    end = time.time()
    print(f'time:{end-start}')

In [14]:
train(10, dataloader)

epoch0: 42.92527770996094
epoch1: 33.14415740966797
epoch2: 32.64594650268555
epoch3: 28.08721351623535
epoch4: 18.209306716918945
epoch5: 11.824258804321289
epoch6: 8.542959213256836
epoch7: 6.990544319152832
epoch8: 5.935871124267578
epoch9: 5.4933648109436035
time:108.8706841468811


In [15]:
def test():
    test_loss = 0
    correct = 0
    model.eval()
    test_dataset = ImdbDataset(is_train=False)
    test_dataloader = DataLoader(dataset, batch_size=500, collate_fn=collate_fn)
    with torch.no_grad():
        for X, y in test_dataloader:
            X = X.to(device)
            y = y.to(device)
            output = model(X)
            pred = torch.max(output, dim=-1, keepdim=False)[-1]
            correct += pred.eq(y.data).sum()
        print(f'accuracy: {100 * correct / len(test_dataset)} %')


In [16]:
test()

accuracy: 97.27999877929688 %
