In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

In [2]:
import transformers
from transformers import BertTokenizer

In [3]:
from src.data_processing.process_labels import *
from src.data_processing.process_reviews import *
from src.data_processing.train_val_test import train_val_test
from src.models.model_evalaute import *

## Data processing
Get BERT encodings for the train and test set.

In [4]:
### DATA PROCESSING ###
# Read data
df = pd.read_csv('data/raw_reviews/reviews_v1.csv')
# Separate reviews and labels
X = df.text # review text
food_labels = df.food
service_labels = df.service
y = label_generator(food_labels=food_labels.values, 
                    service_labels=service_labels.values).trim_and_fetch_labels()

In [5]:
X_train, X_test, _ = train_val_test(data=X, train_frac=0.8, val_frac=0.2, test_frac=0)
y_train, y_test, _ = train_val_test(data=y, train_frac=0.8, val_frac=0.2, test_frac=0)

In [6]:
# Get Bert encodings
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # Load Bert tokenizer
train_encodings = tokenizer(list(X_train), truncation=True, padding=True, return_tensors='pt')
test_encodings = tokenizer(list(X_test), truncation=True, padding=True, return_tensors='pt')

## Load in saved finetuned (weighted) BERT model

In [7]:
from src.models.model_zoo import *
from src.models.model_train import *

# instatiate model
bert_model_wt = BERTClass()

In [8]:
# load saved parameters
bert_model_wt.eval()
bert_model_wt.load_state_dict(torch.load('src/models/saved_models/bert_fine_tuned_weighted.pt'))

<All keys matched successfully>

## Sanity check: try made up reviews

In [15]:
# Test review
input = "I asked for the check and they overcharged"
input_list = [input]

tokenized_input = tokenizer(input_list,truncation=True, padding=True, return_tensors='pt')

ids, mask, token_type_ids = (tokenized_input['input_ids'], tokenized_input['attention_mask'], 
                             tokenized_input['token_type_ids'])

# Get BERT prediction output
bert_model_wt(ids, mask, token_type_ids)

tensor([[0.0715, 0.0067, 0.9134, 0.0084]], grad_fn=<SoftmaxBackward0>)