# Toy Example

In [1]:
import torch
import pandas as pd
from transformers.utils import logging
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sectors.config import INDUSTRY_DATA_DIR
from sectors.utils.trie import Trie


logging.set_verbosity_error()

model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

device = "cpu"

[2023-08-10 16:14:54,592] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
input_ids = torch.LongTensor(tokenizer.encode("Name good ingredients for a vegan salad.", return_tensors="pt")).to(device)
out = model.generate(input_ids = input_ids, max_new_tokens = 20, num_return_sequences=1)
tokenizer.decode(out[0], skip_special_tokens=True)

'Veggies, Veggies, and Veggies'

In [3]:
tokenizer.sep_token_id = 117
tokenizer.bos_token_id = 0

trie = Trie(
    bos_token_id=tokenizer.bos_token_id,
    sep_token_id=tokenizer.sep_token_id,
    eos_token_id=tokenizer.eos_token_id,
    sequences = [
        tokenizer.encode(lab) for lab in [
            "olives",
            "salmon",
            "seeds",
        ]])
trie_fn = lambda batch_id, sent: trie.get(batch_id, sent.tolist())

output = model.generate(tokenizer.encode("Name good ingredients for a vegan salad.", return_tensors="pt"), max_length = 20, prefix_allowed_tokens_fn=trie_fn)
tokenizer.decode(output[0], skip_special_tokens=True)

'olives'

# Sector Classification

### Without Trie Search

In [4]:
prompt = "SmartHealth produces fitness trackers for health conscious individuals. This company classifies into the sectors(s): "

output = model.generate(tokenizer.encode(prompt, return_tensors="pt"), max_length = 300)
tokenizer.decode(output[0], skip_special_tokens=True)

'Health-conscious'

### With Trie Search

In [5]:
TRAIN_PATH = INDUSTRY_DATA_DIR / "train_preprocessed.json"
train = pd.read_json(TRAIN_PATH, lines=True)
remove = ['id', 'legal_name', 'description', 'short_description', 'tags', 'len_des', 'tags_string', 'len_tags', 'prompt']
labels = [col for col in train.columns if col not in remove]

trie = Trie(
    bos_token_id=tokenizer.bos_token_id,
    sep_token_id=tokenizer.sep_token_id,
    eos_token_id=tokenizer.eos_token_id,
    sequences = [
        tokenizer.encode(lab) for lab in labels
])
trie_fn = lambda batch_id, sent: trie.get(batch_id, sent.tolist())

output = model.generate(tokenizer.encode(prompt, return_tensors="pt"), max_length = 300, prefix_allowed_tokens_fn=trie_fn)
tokenizer.decode(output[0], skip_special_tokens=True)

'Healthcare IT'