In [20]:
# Import torch and other required modules
import glob
import torch
import re
import os
import nltk
import numpy as np
from torch.utils.data import Dataset, DataLoader, RandomSampler
from sklearn.datasets import load_svmlight_file
from nltk.corpus import stopwords
from torch.nn.utils.rnn import pad_sequence
nltk.download('stopwords') 
nltk.download('punkt')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -zxf aclImdb_v1.tar.gz

--2021-03-25 07:28:22--  http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10
Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84125825 (80M) [application/x-gzip]
Saving to: ‘aclImdb_v1.tar.gz’


2021-03-25 07:28:27 (17.7 MB/s) - ‘aclImdb_v1.tar.gz’ saved [84125825/84125825]



In [6]:
with open ("/content/aclImdb/imdb.vocab",encoding = "utf-8") as f:
  vocab = [line.strip() for line in f.readlines()]
print(f"vocab length before removing stopwords: {len(vocab)}")
en_stopwords = set(stopwords.words('english'))
vocab = [word for word in vocab if word not in en_stopwords]
print(f"vocab length before removing stopwords: {len(vocab)}")
vocab_dic = {word: idx for idx, word in enumerate(vocab)}

vocab length before removing stopwords: 89527
vocab length before removing stopwords: 89356


In [8]:
review_pairs = []
for folder, label in [('pos',1),('neg',0)]:
  filepaths = glob.glob(os.path.join('aclImdb', 'train', folder,'*'))
  for filepath in filepaths:
    review_pairs.append((filepath,label))

print(review_pairs[:2])
print(f"Total reviews:{len(review_pairs)}")

[('aclImdb/train/pos/6217_8.txt', 1), ('aclImdb/train/pos/6231_10.txt', 1)]
Total reviews:25000


In [27]:
def load_review(review_path):
    with open(review_path, encoding='utf-8') as f:
        review = f.read()

    # 移除non-alphabet符號、贅字與tokenize
    review = re.sub(r'\W', ' ', review)
    review = nltk.word_tokenize(review)
    
    return review

def generate_vec(review, vocab_dic):
    idx_vec = [vocab_dic[word] for word in review if vocab_dic.get(word)]

    return idx_vec

In [28]:
class dataset(Dataset):
    '''custom dataset to load reviews and labels
    Parameters
    ----------
    data_pairs: list
        directory of all review-label pairs
    vocab: list
        list of vocabularies
    '''
    def __init__(self, data_dirs, vocab):
        self.data_dirs = data_dirs
        self.vocab = vocab

    def __len__(self):
        return len(self.data_dirs)

    def __getitem__(self, idx):
        review_path, label = self.data_dirs[idx]
        review = load_review(review_path)
        idx_vector = generate_vec(review, self.vocab)

        return idx_vector, label
    

# 建立客製化collate_fn，將長度不一的文本pad 0變成相同長度
def collate_fn(batch):
    reviews, labels = zip(*batch)
    lengths = torch.LongTensor([len(review) for review in reviews])
    labels = torch.LongTensor(labels)
    reviews = pad_sequence([
        torch.LongTensor(review) for review in reviews
    ], batch_first=True, padding_value=0)

    return reviews, labels, lengths

In [29]:
custom_dataset = dataset(review_pairs, vocab_dic)
custom_dataloader = DataLoader(custom_dataset, 
                               batch_size=4, 
                               sampler=RandomSampler(custom_dataset), 
                               collate_fn=collate_fn
)
next(iter(custom_dataloader))

(tensor([[ 2736,    78,   261,    78, 11250,   981,  6606,   247,  1167,  1283,
           3662,  3120,  1130,  2196,   338,   981,    15, 25205,    78,    11,
             65,    51,  1662,  1777,  7138,  4059,  1681,    78,    84,   704,
           1167,  3207, 19752, 17304,    78,   263,  7337,  2186,  1900,  5511,
          15029,    78,   100,    65,   506,   297,  2223,  6303,  4100,   625,
            293,    98,    75,    49,     4,    17,  5398,    65,   734,  4862,
            572,  6827,    98,  4411, 21732,    49,   306, 75154, 75154,    27,
           1033,   352,     2,   733,   590,   628,   215,  2040,   432,    49,
            514,  1389,  1413,   686, 75154, 75154,    18,   145,   568,   352,
            476,  4543,  1606,  1606,  1606],
         [   42,   265,  4186,  2996,  4303,  2895, 13254,  1105,    14,    15,
             15,  7806,   208,    71,   457,  8055,   166,    77,    42,   422,
             89,   242,     4,    98,   218,  1213,   852,   321,    59,  