In [11]:
import os

DATA_DIR = "data" # This may need to be changed on different machines

# Make sure we're in the correct directory and make sure the data directory exists
if not os.path.exists(DATA_DIR):
    os.chdir("../..") # Move up two directories because we're in src/nb and the data directory/path should be in/start at the root directory 
    assert os.path.exists(DATA_DIR), f"ERROR: DATA_DIR={DATA_DIR} not found"  # If we still can't see the data directory something is wrong

import torch

from transformers import BertForSequenceClassification, BertTokenizer, BertConfig, BertModel

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "models/gpt2_large"
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model = BertForSequenceClassification(BertConfig.from_pretrained(
    "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
    num_labels = 11, # The number of output labels.   
    output_attentions = False, # Whether the model returns attentions weights.
    output_hidden_states = True, # Whether the model returns all hidden-states.))
))
model.load_state_dict(torch.load("models/style_classifier.pth"))
model = model.to(device)

In [5]:
def get_style_vector(text):
    input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
    output = model(input_ids)
    hidden_states = output.hidden_states
    # style vector is the hidden state for the [CLS] token of the last layer
    style_vector = hidden_states[-1][:, 0, :]
    return style_vector


In [9]:
get_style_vector("I like to eat pizza").shape

torch.Size([1, 768])