### 作業目的: 熟練自定義collate_fn與sampler進行資料讀取

本此作業主要會使用[IMDB](http://ai.stanford.edu/~amaas/data/sentiment/)資料集利用Pytorch的Dataset與DataLoader進行
客製化資料讀取。
下載後的資料有分成train與test，因為這份作業目的在讀取資料，所以我們取用train部分來進行練習。
(請同學先行至IMDB下載資料)

### 載入套件

In [1]:
# Import torch and other required modules
import glob
import torch
import re
import nltk
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import load_svmlight_file
from nltk.corpus import stopwords

nltk.download('stopwords') #下載stopwords
nltk.download('punkt') #下載word_tokenize需要的corpus

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\user\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\user\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

### 探索資料與資料前處理
這份作業我們使用test資料中的pos與neg


In [8]:
# 讀取字典，這份字典為review內所有出現的字詞
###<your code>###
with open('.\\aclImdb\\imdb.vocab','r',encoding = 'utf-8') as f:
    vocab = f.read()

vocab = vocab.split('\n')

# 以nltk stopwords移除贅字，過多的贅字無法提供有用的訊息，也可能影響模型的訓練
print(f"vocab length before removing stopwords: {len(vocab)}")
vocab = list(set(vocab).difference(stopwords.words('english')))
print(f"vocab length after removing stopwords: {len(vocab)}")

# 將字典轉換成dictionary
### <your code> ###
vocab_dict = dict(zip(vocab,range(len(vocab))))

vocab length before removing stopwords: 89527
vocab length after removing stopwords: 89356


{'second-to-last': 0,
 'unguarded': 1,
 'cheap-o': 2,
 'nimh': 3,
 'exeter': 4,
 'over-powerful': 5,
 'documentary-not': 6,
 'melato': 7,
 'tries': 8,
 'louie': 9,
 'drama-less': 10,
 'clone': 11,
 'corpsified': 12,
 'zering': 13,
 'alike': 14,
 'intellectual-wannabe': 15,
 'structuralists': 16,
 'may-june': 17,
 'mafia': 18,
 'presaging': 19,
 'okw': 20,
 'bodyguards': 21,
 'fail': 22,
 'rectify': 23,
 'straightening': 24,
 'rockabilly': 25,
 'half-assed': 26,
 'lamps': 27,
 'crazy-hippie-cult-killing-spree': 28,
 'honestly': 29,
 'doy': 30,
 'branson': 31,
 'investigator': 32,
 'archival': 33,
 'tried-and-true': 34,
 'pfeh': 35,
 'drivel': 36,
 'conor': 37,
 'counterespionage': 38,
 'ivana': 39,
 'wll': 40,
 'clean-shaven': 41,
 'dango': 42,
 'snow-capped': 43,
 'nonchalantly': 44,
 'peevishness': 45,
 'splendiferously': 46,
 'cortez': 47,
 'mentirosos': 48,
 'do-do': 49,
 'toyoko': 50,
 'chirin': 51,
 'diry': 52,
 'cute-but-obnoxious': 53,
 'correct': 54,
 'observers': 55,
 'rachell

In [9]:
# 將資料打包成(x, y)配對，其中x為review的檔案路徑，y為正評(1)或負評(0)
# 這裡將x以檔案路徑代表的原因是讓同學練習不一次將資料全讀取進來，若電腦記憶體夠大(所有資料檔案沒有很大)
# 可以將資料全一次讀取，可以減少在訓練時I/O時間，增加訓練速度

### <your code> ###
pos = glob.glob(".\\aclImdb\\train\\pos\\*.txt")
neg = glob.glob("./aclImdb/test/neg/*.txt")
review = pos + neg
y =[1]*len(pos)+[0]*len(neg)

review_pairs = list(zip(review,y))

print(review_pairs[:2])
print(f"Total reviews: {len(review_pairs)}")

[('.\\aclImdb\\train\\pos\\0_9.txt', 1), ('.\\aclImdb\\train\\pos\\10000_8.txt', 1)]
Total reviews: 25000


### 建立Dataset, DataLoader, Sampler與Collate_fn讀取資料
這裡我們會需要兩個helper functions，其中一個是讀取資料與清洗資料的函式(load_review)，另外一個是生成詞向量函式
(generate_vec)，注意這裡我們用來產生詞向量的方法是單純將文字tokenize(為了使產生的文本長度不同，而不使用BoW)

In [33]:
def load_review(review_path):
    
    ###<your code>###
    
    with open(review_path,"r",encoding = "utf-8") as f:
        review = f.read()
    
    #移除non-alphabet符號、贅字與tokenize
    ###<your code>###
    review = re.sub('[^A-Za-z]',' ',review)
    review = nltk.word_tokenize(review)
    review = set(review).difference(set(stopwords.words('english')))
    
    return review    
    

def generate_vec(review, vocab_dict):
    ### <your code> ###
    bag_vector = []
    for word in review:
        if vocab_dict.get(word):
            bag_vector.append(vocab_dict.get(word))
            
    return torch.tensor(bag_vector)    

In [44]:
#建立客製化dataset

class dataset(Dataset):
    '''custom dataset to load reviews and labels
    Parameters
    ----------
    data_pairs: list
        directory of all review-label pairs
    vocab: list
        list of vocabularies
    '''
    ### <your code> ###
    def __init__(self, data_dirs, vocab):
        ###<your code>###
        self.data_dirs = data_dirs
        self.vocab = vocab
        
    def __len__(self):
        ###<your code>###
        return len(self.data_dirs)

    def __getitem__(self, idx):
        ###<your code>###
        pairs = self.data_dirs[idx]
        review = pairs[0]
        review = load_review(review)
        review = generate_vec(review,self.vocab)
        return review, pairs[1]    

#建立客製化collate_fn，將長度不一的文本pad 0 變成相同長度
def collate_fn(batch):
    ### <your code> ###
    corpus, labels = zip(*batch)
    lengths = [len(x) for x in corpus]
    max_len = max(lengths)
    
    batch_corpus = []
    
    for i in range(len(corpus)):
        tmp_pads = torch.zeros(max_len)
        tmp_pads[:lengths[i]] = corpus[i]
#         tmp_pads.view(-1,1)
        batch_corpus.append(tmp_pads.view(1,-1))
    
    return torch.cat(batch_corpus,dim=0), torch.tensor(labels), torch.tensor(lengths)

In [45]:
# 使用Pytorch的RandomSampler來進行indice讀取並建立dataloader
### <your code> ###
custom_dst = dataset(review_pairs,vocab_dict)
custom_dataloader = DataLoader(custom_dst,collate_fn=collate_fn,batch_size=2,shuffle = True)
next(iter(custom_dataloader))

(tensor([[7.6628e+04, 1.0326e+04, 8.3439e+04, 7.2104e+04, 2.4067e+04, 5.9449e+04,
          8.0000e+00, 1.7700e+02, 2.9560e+03, 4.8474e+04, 1.8296e+04, 6.7040e+04,
          1.5480e+04, 4.4871e+04, 8.5534e+04, 3.5854e+04, 7.2307e+04, 7.7381e+04,
          1.3140e+04, 1.2077e+04, 3.6031e+04, 3.9798e+04, 8.6811e+04, 1.1585e+04,
          6.2504e+04, 4.1248e+04, 4.8188e+04, 5.7546e+04, 1.3167e+04, 6.8749e+04,
          6.4633e+04, 5.4835e+04, 7.9638e+04, 7.8491e+04, 5.6400e+02, 3.3133e+04,
          5.5049e+04, 2.4451e+04, 1.1313e+04, 7.2182e+04, 6.0592e+04, 1.0099e+04,
          5.1566e+04, 6.8276e+04, 5.6151e+04, 6.8113e+04, 1.3933e+04, 3.6468e+04,
          6.5205e+04, 7.8519e+04, 6.7115e+04, 5.3448e+04, 5.2419e+04, 7.8979e+04,
          3.3685e+04, 6.3281e+04, 6.9963e+04, 8.6260e+03, 2.4008e+04, 7.7483e+04,
          9.4850e+03, 5.5811e+04, 7.2925e+04, 1.2180e+04, 4.0607e+04, 7.6430e+04,
          6.2360e+03, 4.0976e+04, 7.4320e+03, 6.0979e+04, 2.3853e+04, 6.2415e+04,
          8.4623