### 作業目的: 熟練自定義collate_fn與sampler進行資料讀取

本此作業主要會使用[IMDB](http://ai.stanford.edu/~amaas/data/sentiment/)資料集利用Pytorch的Dataset與DataLoader進行
客製化資料讀取。
下載後的資料有分成train與test，因為這份作業目的在讀取資料，所以我們取用train部分來進行練習。
(請同學先行至IMDB下載資料)

### 載入套件

In [44]:
# Import torch and other required modules
import glob
import torch
import re
import nltk
import numpy as np
from torch.utils.data import Dataset, DataLoader, RandomSampler
from sklearn.datasets import load_svmlight_file
from nltk.corpus import stopwords

nltk.download('stopwords') #下載stopwords
nltk.download('punkt') #下載word_tokenize需要的corpus

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\USER\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\USER\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

### 探索資料與資料前處理
這份作業我們使用test資料中的pos與neg


In [2]:
# 讀取字典，這份字典為review內所有出現的字詞
with open('D:/aclImdb_v1/imdb.vocab', 'r',encoding='utf-8') as f:
    vocab = [line.strip() for line in f.readlines()]
    
# 以nltk stopwords移除贅字，過多的贅字無法提供有用的訊息，也可能影響模型的訓練
print(f"vocab length before removing stopwords: {len(vocab)}")
stop_w = set(stopwords.words('english'))
vocab = [i for i in vocab if i not in stop_w]
print(f"vocab length after removing stopwords: {len(vocab)}")

# 將字典轉換成dictionary
vocab_dic = {word:idx for idx, word in enumerate(vocab)}

vocab length before removing stopwords: 89527
vocab length after removing stopwords: 89356


In [56]:
# 將資料打包成(x, y)配對，其中x為review的檔案路徑，y為正評(1)或負評(0)
# 這裡將x以檔案路徑代表的原因是讓同學練習不一次將資料全讀取進來，若電腦記憶體夠大(所有資料檔案沒有很大)
# 可以將資料全一次讀取，可以減少在訓練時I/O時間，增加訓練速度
import os
review_pairs = []
for folder,label in [('pos',1),('neg',0)]:
    filepaths = glob.glob(os.path.join('D:','aclImdb_v1','test',folder,'*'))
    for i in filepaths:
        review_pairs.append((i,label))
print(review_pairs[:2])
print(f"Total reviews: {len(review_pairs)}")

[('D:aclImdb_v1\\test\\pos\\0_10.txt', 1), ('D:aclImdb_v1\\test\\pos\\10000_7.txt', 1)]
Total reviews: 25000


### 建立Dataset, DataLoader, Sampler與Collate_fn讀取資料
這裡我們會需要兩個helper functions，其中一個是讀取資料與清洗資料的函式(load_review)，另外一個是生成詞向量函式
(generate_vec)，注意這裡我們用來產生詞向量的方法是單純將文字tokenize(為了使產生的文本長度不同，而不使用BoW)

In [84]:
def load_review(review_path):
    with open(review_path, 'r',encoding='utf-8') as f:
        review = f.read()
 
    #移除non-alphabet符號、贅字與tokenize
    review = nltk.word_tokenize(re.sub(r'\W', ' ', review ))
    review = [word for word in review if word not in stop_w]
    return review

def generate_vec(review, vocab_dic):
    vec = [vocab_dic[i] for i in review if vocab_dic.get(i)]
    return np.asarray(vec)

In [86]:
#建立客製化dataset

class dataset(Dataset):
    '''custom dataset to load reviews and labels
    Parameters
    ----------
    data_pairs: list
        directory of all review-label pairs
    vocab: list
        list of vocabularies
    '''
    def __init__(self, data_dirs, vocab):
        self.data_dirs = data_dirs
        self.vocab = vocab

    def __len__(self):
        return len(self.data_dirs)
    
    def __getitem__(self, idx):
        review_path, label = self.data_dirs[idx]
        gec_vec = generate_vec(load_review(review_path),self.vocab)
        return gec_vec,label
    

#建立客製化collate_fn，將長度不一的文本pad 0 變成相同長度
def collate_fn(batch):
    corpus, labels = zip(*batch) 
    ### create pads for corpus ###
    lengths = [len(x) for x in corpus]
    max_length = max(lengths)
    
    batch_corpus = []
    
    for i in range(len(corpus)):
        # pad corpus
        tmp_pads = torch.zeros(max_length)
        tmp_pads[:lengths[i]] = torch.from_numpy(corpus[i])
        tmp_pads.view(-1, 1)
        batch_corpus.append(tmp_pads.view(1,-1))

    return torch.cat(batch_corpus,dim=0), torch.tensor(labels) , torch.tensor(lengths)

In [87]:
# 使用Pytorch的RandomSampler來進行indice讀取並建立dataloader
custom_dst = dataset(review_pairs, vocab_dic)
custom_dataloader = DataLoader(dataset=custom_dst, batch_size=4, sampler=RandomSampler(custom_dst), collate_fn=collate_fn)
next(iter(custom_dataloader))

(tensor([[1.2800e+03, 1.2600e+02, 1.3570e+03, 1.8000e+01, 8.2760e+03, 6.3284e+04,
          4.2200e+02, 1.5400e+02, 1.4166e+04, 4.3880e+03, 1.6660e+03, 2.2900e+02,
          9.6350e+03, 7.5154e+04, 7.5154e+04, 1.8600e+02, 8.3100e+02, 5.7050e+03,
          9.6700e+02, 5.6700e+03, 7.3770e+03, 4.5800e+03, 3.8000e+01, 4.1000e+02,
          1.6200e+02, 2.5030e+03, 9.0100e+02, 2.9170e+03, 1.3300e+02, 1.3170e+03,
          1.6000e+01, 3.9310e+03, 1.8600e+02, 1.0701e+04, 8.5000e+01, 1.6200e+02,
          1.9080e+03, 9.0100e+02, 2.5100e+02, 3.3300e+02, 3.8020e+03, 1.9550e+03,
          8.2000e+02, 1.3000e+01, 1.8390e+03, 4.3040e+03, 6.3800e+02, 2.9150e+03,
          7.5970e+03, 1.2610e+03, 4.3040e+03, 2.8930e+03, 3.5700e+02, 1.5000e+02,
          3.3740e+03, 2.2600e+02, 2.7000e+01, 3.4200e+02, 3.0310e+03, 5.3000e+02,
          1.6400e+02, 1.7850e+03, 1.6200e+02, 7.5300e+02, 1.2402e+04, 7.5154e+04,
          7.5154e+04, 7.2230e+03, 6.8700e+02, 7.4000e+01, 3.4300e+02, 2.3200e+02,
          3.3500