In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class TextCNN(nn.Module):
    def __init__(self, **args):
        super(TextCNN, self).__init__()
        self.args = args
        
        V = args['embed_num']
        D = args['embed_dim']
        embedding = torch.from_numpy(args['embed_mat'])
        C = args['class_num']
        Ci = 1
        Co = args['kernel_num']
        Ks = args['kernel_sizes']
        dropout = args['dropout']
        
        self.embed = nn.Embedding(V, D)
        self.embed.weight.data.copy_(embedding)
        self.convs = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D)) for K in Ks])
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(len(Ks)*Co, C)
        
    def forward(self, x):
        x = self.embed(x)
        if self.args['static']:
            x = x.detach()
        x = x.unsqueeze(1)
        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
        x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
        x = torch.cat(x, 1)
        x = self.dropout(x)
        logit = self.fc1(x)
        return logit

In [5]:
import numpy as np
word_embed = np.load('../../data_preprocess/embed/word_embed_mat.npy')

In [14]:
textCNN = TextCNN(embed_num=411722, embed_dim=256, embed_mat=word_embed, class_num=1999, kernel_num=100, kernel_sizes=[1,2,3,4,5], dropout=0.5, static=True)

In [15]:
from torch.autograd import Variable
input = torch.from_numpy(np.ones((64, 71)).astype('int32')).long()
input = Variable(input)
output = textCNN(input)
print output.size()

torch.Size([64, 1999])
