## prepare data

### step1: trokenize

In [16]:
import torch
import os
from torch.utils.data import Dataset, DataLoader
import re
import word2seq
import pickle
import torch.nn as nn
import torch.nn.functional as F 
from torch.optim import Adam

ws = pickle.load(open('./model/ws.pkl', 'rb'))

def tokenize(content):
    content = re.sub("<.*?>", " ", content)
    content = re.sub("'s", " is", content)
    content = re.sub("'m", " am", content)
    filters = [':','\t','\n','\x97','\x96','#','$','%','&','\.']
    content = re.sub("|".join(filters), " ", content)
    tokens = [i.strip().lower() for i in content.split()]
    return tokens

class IMDBDataset(Dataset):
    def __init__(self, train=True):
        super(IMDBDataset, self).__init__()
        self.train_path = './data/aclImdb/train'
        self.test_path = './data/aclImdb/test'
        data_path = self.train_path if train else self.test_path
        temp_data_path = [os.path.join(data_path,'pos'), os.path.join(data_path,'neg')]
        self.total_file_path = [] # all comment file path
        for path in temp_data_path:
            file_name_list = os.listdir(path)
            file_path_list = [os.path.join(path, i) for i in file_name_list if i.endswith('.txt')]
            self.total_file_path.extend(file_path_list)
            
    def __getitem__(self, index):
        file_path = self.total_file_path[index]
        label = 0 if file_path.split('\\')[-2] == 'neg' else 1
        with open(file_path,'r',encoding='UTF-8') as data:
            tokens = tokenize(data.read())
        return tokens, label
    
    def __len__(self):
        return len(self.total_file_path)

In [17]:
max_len = 20

def collate_fn(batch):
    content, label = zip(*batch)
    content = [ws.transform(i, max_len=max_len) for i in content]
    content = torch.LongTensor(content)
    label = torch.LongTensor(label)
    return content, label

def get_dataloader(train=True):
    imdb = IMDBDataset(train)
    data_loader = DataLoader(imdb, batch_size=128, shuffle=True, collate_fn = collate_fn)
    return data_loader

In [18]:
for idx, (Input, target) in enumerate(get_dataloader()):
    print(idx, Input, target)
    break

0 tensor([[  259,   328,     3,  ...,     3,    13,  1319],
        [  729,   618,   858,  ...,  2037,  3692, 12959],
        [14146, 14147,     3,  ...,     0,   798,  8487],
        ...,
        [13028, 18525,  1195,  ...,  4744,    24,    10],
        [   10,   373,  1334,  ...,  1976,   504,   127],
        [ 1778,  1373,  1092,  ...,  2363,  4139,    24]]) tensor([1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1,
        1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1,
        1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1,
        1, 0, 0, 1, 0, 0, 1, 0])


### create model

In [19]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(len(ws), 100)
        self.fc = nn.Linear(max_len*100, 2)
    def forward(self, Input):
        x = self.embedding(Input)
        x = x.view([-1, max_len*100])
        out = self.fc(x)
        return F.log_softmax(out, dim=-1)

In [24]:
model = Model()
optimizer = Adam(model.parameters(), 0.001)
def train(epoch):
    for idx, (Input, target) in enumerate(get_dataloader(train=True)):
        predict = model(Input)
        optimizer.zero_grad()
        loss = F.nll_loss(predict, target)
        loss.backward()
        optimizer.step()
        if idx % 10 == 0:
            print(loss.item())
        
for i in range(1):
    train(i)

0.7546595931053162
0.7508937120437622
0.7551367878913879
0.7034208178520203
0.7682862877845764
0.7443813681602478
0.7299940586090088
0.6973925232887268
0.6996548175811768
0.7131869196891785
0.685352623462677
0.7099332213401794
0.6893857717514038
0.7115482687950134
0.7044959664344788
0.6678099036216736
0.6535273790359497
0.6993693113327026
0.713651716709137
0.6516751050949097
