In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from transformers.optimization import AdamW
from transformers import BertTokenizerFast, BertModel
from torch.utils.data import DataLoader

from tqdm import tqdm
import collections 

import model
import inference_utils
import dataset_utils

import spacy
import re

In [2]:
MAX_SEQ_LEN = 80
LABEL_LIST = ["O", "B-ASP", "I-ASP", "[CLS]", "[SEP]"]
THRESHOLD_SRD = 5
LCF = 'fusion'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", do_lower_case=True)

In [4]:
bert_base_model = BertModel.from_pretrained("bert-base-uncased")
bert_base_model.to(device)
ATEPCNet = model.LCF_ATEPC(bert_base_model, True, 0.1, MAX_SEQ_LEN, LCF, device)
ATEPCNet.load_state_dict(torch.load('./model_restaraunts.pt', map_location=device))
ATEPCNet.to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


LCF_ATEPC(
  (bert_for_global_context): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementw

In [5]:
examples = ['Staff was so nice to us .',
            'Staff was so horrible to us .']

In [6]:
dataset_ate = dataset_utils.ATEPCDataset('infer_ate',
                                         'syntax_tree',
                                         [x.split() for x in examples],
                                         None,
                                         None,
                                         tokenizer,
                                         MAX_SEQ_LEN,
                                         THRESHOLD_SRD,
                                         'en_core_web_sm')

In [7]:
aspects_res = inference_utils.extract_aspects(dataset_ate, ATEPCNet, device)

In [8]:
dataset_apc = dataset_utils.ATEPCDataset('infer_apc',
                                         'syntax_tree',
                                         [sample[0] for sample in aspects_res],
                                         [sample[1] for sample in aspects_res],
                                         [sample[2] for sample in aspects_res],
                                         tokenizer,
                                         MAX_SEQ_LEN,
                                         THRESHOLD_SRD,
                                         'en_core_web_sm')

In [9]:
result = inference_utils.classify_polarity(dataset_apc, ATEPCNet, device)

In [10]:
[x[:3] for x in result]

[(['Staff', 'was', 'so', 'nice', 'to', 'us', '.'], ['Staff'], 'Positive'),
 (['Staff', 'was', 'so', 'horrible', 'to', 'us', '.'], ['Staff'], 'Negative')]