The IMDB dataset comes from https://ai.stanford.edu/~amaas/data/sentiment/

In [None]:
import os
import re
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

**Download Dataset**

In [None]:
!wget https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz

--2022-04-03 22:11:21--  https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
Resolving ai.stanford.edu (ai.stanford.edu)... 171.64.68.10
Connecting to ai.stanford.edu (ai.stanford.edu)|171.64.68.10|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84125825 (80M) [application/x-gzip]
Saving to: ‘aclImdb_v1.tar.gz’


2022-04-03 22:11:30 (10.2 MB/s) - ‘aclImdb_v1.tar.gz’ saved [84125825/84125825]



In [None]:
!tar -zxvf aclImdb_v1.tar.gz

**Tokenlization**  
Fileter special characters and return a list with each word to be its element

In [None]:
def tokenlize(content):
  content = re.sub('<.*?>', ' ',content)
  filters = ['\.', ':', '\t', '\n', '\x97', '\x96', '#', '$', '%', '&']
  content = re.sub('|'.join(filters), ' ', content)
  tokens = [i.strip().lower() for i in content.split()]
  return tokens

In [None]:
class ImdbDataset(Dataset):
  def __init__(self, train_path, test_path, train=True):
    self.train_data_path = train_path
    self.test_data_path = test_path
    data_path  = self.train_data_path if train else self.test_data_path

    # add all file names into a list
    temp_data_path = [os.path.join(data_path, 'pos'), os.path.join(data_path, 'neg')]
    self.total_file_path = []  # path to all dataset
    for path in temp_data_path:
      file_name_list = os.listdir(path)
      file_path_list = [os.path.join(path, file_name) for file_name in file_name_list if file_name.endswith('.txt')]
      self.total_file_path.extend(file_path_list)



  def __getitem__(self, index):
    # get label
    file_path = self.total_file_path[index]
    label_str = file_path.split("/")[-2]
    label = 0 if label_str == 'neg' else 1
    # get content
    content = open(file_path).read()
    tokens = tokenlize(content)
    return tokens, label
  
  def __len__(self):
    return len(self.total_file_path)

In [None]:
imdb_dataset = ImdbDataset(train_path='/content/aclImdb/train', test_path='/content/aclImdb/test')
data_loader = DataLoader(imdb_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)

In [None]:
for idx, (input, target) in enumerate(data_loader):
  print(idx)
  print(input)
  print(target)
  break

0
(1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1)


**Word to Sequence**

In [None]:
class Word2Sequence():
  UNK_TAG = 'UNK'
  PAD_TAG = 'PAD'

  UNK = 0
  PAD = 1

  def __init__(self):
    self.dict = {
        self.UNK_TAG : self.UNK,
        self.PAD_TAG : self.PAD
    }
    self.count = {}
  
  def __len__(self):
    return len(self.dict)
  
  def fit(self, sentence):
    '''save sentence into dict
    :param sentence: [word1, word2, word3 ...]
    '''
    for word in sentence:
      self.count[word] = self.count.get(word, 0) + 1
  
  def build_vocab(self, min=None, max=None, max_features=None):
    '''
    Build dictionary
    :param min:
    :param max:
    :param max_features:
    :return:
    '''
    # delete words in count where frequency is less than min
    if min is not None:
      self.count = {word : value for word, value in self.count.items() if value > min}

    # delete words in count where frequency is greater than max
    if max is not None:
      self.count = {word : value for word, value in self.count.items() if value < max}
    
    # limit number of vocobulary in count
    if max_features is not None:
      temp = sorted(self.count.items(), key=lambda x : x[-1], reverse=True)[:max_features]
      self.count = dict(temp)
    
    for word in self.count:
      self.dict[word] = len(self.dict)
    
    self.inversed_dict = dict(zip(self.dict.values(), self.dict.keys()))
  
  def transform(self, sentence, max_len=None):
    '''
    Sentence 2 Sequence
    :param sentence: [word1, word2, ...]
    :param max_len: int, if add dummy / cut the sentence
    :return:
    '''
    if max_len is not None:
      if max_len > len(sentence):
        sentence += [self.PAD_TAG]*(max_len-len(sentence))
      elif max_len < len(sentence):
        sentence = sentence[:max_len]
    
    return [self.dict.get(word, self.UNK) for word in sentence]
  
  def inverse_transform(self, indices):
    '''
    Sequence 2 Sentence
    :param indices: [1, 2, 3, 4, ...]
    :return:
    '''
    return [self.inversed_dict.get(idx) for idx in indices]

Demo Word2Sequence

In [None]:
max_len = 20
batch_size = 128

In [None]:
ws = Word2Sequence()
str = 'I like machine learning'
tok = tokenlize(str)
ws.fit(tok)

str = 'Today\'s weather is good'
tok = tokenlize(str)
ws.fit(tok)
ws.build_vocab()
print(ws.dict)

{'UNK': 0, 'PAD': 1, 'i': 2, 'like': 3, 'machine': 4, 'learning': 5, "today's": 6, 'weather': 7, 'is': 8, 'good': 9}


In [None]:
str = 'I like today\'s beautiful weather'
tok = tokenlize(str)
print(tok)
ret = ws.transform(tok, max_len=10)
print(ret)

['i', 'like', "today's", 'beautiful', 'weather']
[2, 3, 6, 0, 7, 1, 1, 1, 1, 1]


In [None]:
ret = ws.inverse_transform(ret)
print(ret)

['i', 'like', "today's", 'UNK', 'weather', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD']


In [None]:
del ws

**Apply Word2Sequence on IMDB**

In [None]:
ws = Word2Sequence()

In [None]:
path = '/content/aclImdb/train'
temp_data_path = [os.path.join(path, 'pos'), os.path.join(path, 'neg')]

In [None]:
for data_path in temp_data_path:
  file_paths = [os.path.join(data_path, file_name) for file_name in os.listdir(data_path) if file_name.endswith('txt')]
  for file_path in file_paths:
    sentence = tokenlize(open(file_path).read())
  ws.fit(sentence)

In [None]:
ws.build_vocab(min=10, max_features=10000)

In [None]:
if not os.path.exists('./model'):
  os.mkdir('model')
pickle.dump(ws, open('./model/ws.pkl', 'wb'))

**Model**

In [None]:
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def collate_fn(batch):
  '''
  :param batch: (retVal of getitem [tokens,label], retVal of getitem ...)
  :return:
  '''
  content, label = list(zip(*batch))
  content = [ws.transform(i, max_len=20) for i in content]
  content = torch.LongTensor(content)
  label = torch.LongTensor(label)
  return content, label

In [None]:
del data_loader

In [None]:
data_loader = DataLoader(imdb_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [None]:
class MyModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.embedding = nn.Embedding(num_embeddings=len(ws), embedding_dim=100)
    self.fc = nn.Linear(max_len*100, 2)
  
  def forward(self, input):
    '''
    :param input: [barch_size, max_len]
    :return:
    '''
    x = self.embedding(input) # [barch_size, max_len, 100]
    x = x.view([-1, max_len * 100])
    out = self.fc(x)
    return F.log_softmax(out, dim=-1)

**Train Model**

In [None]:
from torch.optim import Adam

In [None]:
model = MyModel()

In [None]:
optimizer = Adam(model.parameters(), 0.001)

In [None]:
for idx, (input, target) in enumerate(data_loader):
  optimizer.zero_grad()
  output = model(input)
  loss = F.nll_loss(output, target)
  loss.backward()
  optimizer.step()
  print(loss.item())

0.7185078859329224
1.2328746318817139
0.796586275100708
0.7791052460670471
0.9869107604026794
0.8765200972557068
0.7295302152633667
0.7161595225334167
0.8665590286254883
0.9150087237358093
0.6838130354881287
0.7077015042304993
0.7546396255493164
0.7585675120353699
0.6838388442993164
0.7083402872085571
0.7199370861053467
0.7295805215835571
0.7117388248443604
0.6869719624519348
0.6886753439903259
0.7148804068565369
0.7039846777915955
0.6960114240646362
0.7123207449913025
0.7073391079902649
0.7007230520248413
0.6972960233688354
0.7043099999427795
0.694233775138855
0.6860768795013428
0.6889455914497375
0.697074294090271
0.6978955268859863
0.7163321375846863
0.6948528289794922
0.69910728931427
0.7097565531730652
0.6877046823501587
0.6988693475723267
0.6951066255569458
0.6874724626541138
0.6869089603424072
0.6892951130867004
0.7093034386634827
0.6888047456741333
0.7051665782928467
0.7027527093887329
0.7002254724502563
0.7002366781234741
0.6896883845329285
0.705477237701416
0.7168882489204407