In [0]:
import torch
from torchtext import data
from torchtext import datasets
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import spacy
nlp = spacy.load('en')

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
TEXT = data.Field(tokenize = 'spacy', batch_first = True)
LABEL = data.LabelField(dtype = torch.float)

train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
train_data, valid_data = train_data.split(random_state = random.seed(1234))

downloading aclImdb_v1.tar.gz


aclImdb_v1.tar.gz: 100%|██████████| 84.1M/84.1M [00:07<00:00, 11.3MB/s]


In [4]:
TEXT.build_vocab(train_data, max_size = 10000,vectors = "glove.6B.300d", unk_init = torch.Tensor.normal_)


.vector_cache/glove.6B.zip: 862MB [06:30, 2.21MB/s]                           
100%|█████████▉| 399721/400000 [00:52<00:00, 7629.10it/s]

In [0]:
class sentimentCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, dropout, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)
        self.convs = nn.ModuleList([nn.Conv1d(in_channels = embedding_dim, out_channels = 150, kernel_size = fs)for fs in [4,5]])
        self.conv2 = nn.Conv1d(in_channels = 1, out_channels = 100, kernel_size = 3)              
        self.fc1 = nn.Linear(200, 100) 
        #self.fc1_bn=nn.BatchNorm1d(100)
        self.fc2=nn.Linear(100,1)
        #self.dropout = nn.Dropout(dropout)
        
    def forward(self, text):

        embedded = self.embedding(text)        
        embedded = embedded.permute(0, 2, 1)

        sent_len=embedded.size(2)
        padding=3000-sent_len
        batch_size=embedded.size(0)
        torch_padding=torch.zeros(batch_size,300,padding,dtype = embedded.dtype,device = embedded.device)
        lz=[embedded,torch_padding]
        zcat = torch.cat(lz, dim = 2)

        conved = [F.relu(conv(zcat)) for conv in self.convs]

        pooled=[]
        for c in conved:
          pooled.append(F.max_pool1d(c,c.shape[2]))

        pooled = [f.permute(0,2,1) for f in pooled]

        pooled2 = [F.max_pool1d(p, 2) for p in pooled]

        pooled3 = [F.relu(self.conv2(p1)) for p1 in pooled2]

        pooled4=[]
        for c in pooled3:
            pooled4.append(F.max_pool1d(c,c.shape[2]))

        final = torch.cat(pooled4,dim = 1)
        final = final.reshape(batch_size,200)
        full1 = self.fc1(final)
        full2= self.fc2(full1)
        return full2

In [6]:
sentimentModel = sentimentCNN(10002, 300, 0.5, 1)
sentimentModel = sentimentModel.to(device)

100%|█████████▉| 399721/400000 [01:10<00:00, 7629.10it/s]

In [7]:
sentimentModel.load_state_dict(torch.load('/content/drive/My Drive/sentimentModel.pt'))

<All keys matched successfully>

In [0]:
sentimentActivation = {}
def get_sentiactivation(name):
    def hook(model, input, output):
        sentimentActivation[name] = output.detach()
    return hook

for name, layer in sentimentModel.named_modules():
    layer.register_forward_hook(get_sentiactivation(name))

In [0]:
def extract_sentimentfeatures(model,sentence,min_len=5):
  model.eval()
  tokenized = [tok.text for tok in nlp.tokenizer(sentence)]
  if len(tokenized)>3000:
    tokenized=tokenized[:3000]
  if len(tokenized) < min_len:
      tokenized += ['<pad>'] * (min_len - len(tokenized))
  indexed = [TEXT.vocab.stoi[t] for t in tokenized]
  tensor = torch.LongTensor(indexed).to(device)
  tensor = tensor.unsqueeze(0)
  model(tensor)
  return sentimentActivation['fc1']



In [18]:
extract_sentimentfeatures(sentimentModel, "Good at being terrible ")

tensor([[5.3510e-01, 3.2975e-04, 5.1191e-01, 9.9680e-01, 6.6998e-01, 3.1177e-01,
         6.4213e-01, 8.1708e-01, 2.9019e-01, 3.9390e-02, 3.9507e-01, 6.9788e-02,
         2.6327e-01, 3.4447e-02, 3.4543e-01, 6.9755e-01, 9.9717e-01, 7.9936e-01,
         3.1308e-01, 8.8863e-01, 6.3043e-01, 7.7977e-01, 9.0573e-01, 2.4440e-02,
         9.9921e-01, 2.8196e-01, 2.7151e-01, 1.0833e-01, 7.1744e-01, 4.7538e-01,
         3.9037e-01, 9.9995e-01, 6.8818e-01, 3.3306e-01, 2.1636e-01, 2.4784e-01,
         4.0579e-01, 1.4494e-04, 8.7693e-01, 9.9367e-01, 2.1836e-01, 2.8644e-03,
         9.1185e-01, 9.8155e-01, 1.5622e-03, 7.5711e-01, 6.1195e-01, 5.5578e-01,
         9.9837e-01, 6.9460e-01, 8.8170e-01, 6.0920e-01, 3.5408e-02, 6.8476e-01,
         8.0683e-01, 5.8673e-01, 9.9394e-01, 6.4912e-01, 4.7899e-02, 6.8219e-01,
         9.8514e-01, 9.9936e-01, 1.0143e-01, 2.3946e-01, 8.8059e-01, 1.1258e-01,
         3.5249e-02, 2.4357e-01, 2.6061e-04, 7.4591e-01, 6.6398e-01, 2.7523e-02,
         1.9432e-01, 4.3737e