In [1]:
import pickle
import torch
from torch import nn

In [2]:
models = ['glove-wiki-gigaword-200', 'word2vec-google-news-300']
model_type = models[1]
wv_path = ['./data/train_data/train', './data/test_data/test', './data/dev_data/dev']

with open(wv_path[1] + f'_{model_type}_label_IMDB.pkl', 'rb') as f:
    y_test = pickle.load(f)

X_test_tensor = torch.load(wv_path[1] + f'_{model_type}_tensor_IMDB.pt')

In [3]:
wv_num = X_test_tensor.shape[1]
max_length = X_test_tensor.shape[2]

In [4]:
print(f"Shape of testing data {[X_test_tensor.shape[i] for i in range(3)]}")

Shape of testing data [1000, 300, 1506]


In [5]:
# model
class CNN(nn.Module):
    def __init__(self, dim_in, dim_conv, dim_out, dropout_rate=0.5):
        super(CNN, self).__init__()
        self.conv1_3 = nn.Conv1d(dim_in, dim_conv, 3, padding=5)  # 33
        self.conv1_4 = nn.Conv1d(dim_in, dim_conv, 4, padding=5)  # 32
        self.conv1_5 = nn.Conv1d(dim_in, dim_conv, 5, padding=5)  # 31
        self.bn = nn.BatchNorm1d(dim_conv * 3)
        self.ReLU = nn.ReLU()
        self.maxpool_1 = nn.MaxPool1d(kernel_size=max_length+8)
        self.maxpool_2 = nn.MaxPool1d(kernel_size=max_length+7)
        self.maxpool_3 = nn.MaxPool1d(kernel_size=max_length+6)
        self.fc = nn.Linear(dim_conv * 3, dim_out)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.softmax = nn.Softmax()


    def forward(self, x):
        x1 = self.conv1_3(x)
        x1 = self.ReLU(x1)
        x1 = self.maxpool_1(x1)


        x2 = self.conv1_4(x)
        x2 = self.ReLU(x2)
        x2 = self.maxpool_2(x2)

        x3 = self.conv1_5(x)
        x3 = self.ReLU(x3)
        x3 = self.maxpool_3(x3)

        x = torch.cat((x1, x2, x3), dim=1)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.dropout(x)
        x = self.softmax(x)
        return x


In [6]:
dim_in = wv_num
dim_conv = 100
dim_out = 2

model = CNN(dim_in, dim_conv, dim_out)
model.load_state_dict(torch.load('./model/CNN_vector_model.pth'))
model.eval()

pred = model(X_test_tensor.float())

  return self._call_impl(*args, **kwargs)


In [7]:
correct_num = (torch.eq(torch.argmax(pred, dim=1), torch.tensor(y_test))).type(torch.float).sum().item()
test_acc = correct_num / len(y_test)
print(f"test accuracy on IMDB is {test_acc}")

test accuracy on IMDB is 0.817
