In [3]:
import torch
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaTokenizer, LlamaModel, LlamaForCausalLM

class CurseEncoder():
    def __init__(self, enc_id, hf_token, device='cuda:0'):
        self.tokenizer = AutoTokenizer.from_pretrained(enc_id, token=hf_token)
        self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
        self.tokenizer.padding_side = "right"
        torch_dtype = 'auto' if torch.cuda.is_available() else torch.float32
        self.llm = AutoModelForCausalLM.from_pretrained(enc_id, token=hf_token, torch_dtype=torch_dtype).to(device)
        self.llm.resize_token_embeddings(len(self.tokenizer))
        self.llm.config.pad_token_id = self.tokenizer.pad_token_id
        for name, tensor in self.llm.named_parameters():
            tensor.requires_grad = False
        self.device = device

    def encode(self, curse):
        inputs = self.tokenizer(curse, padding=True, return_tensors='pt').to(self.device)
        sequence_lengths = (torch.eq(inputs.input_ids, self.llm.config.pad_token_id).long().argmax(-1)-1).to(self.device)
        transformer_outputs = self.llm(**inputs)
        hidden_states = transformer_outputs[0]
        logits = hidden_states[torch.arange(1, device=self.device), sequence_lengths]
        
        return logits
    
    def decode(self, logits):
        predicted_ids = logits.argmax(-1)
        return self.tokenizer.decode(predicted_ids, skip_special_tokens=True)
    
# model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model_name = "meta-llama/Meta-Llama-3.1-8B"
# model_name = "meta-llama/Llama-2-7b-hf"
token = "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"
encoder = CurseEncoder(model_name, token)

Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 14.14it/s]


In [4]:
input_text = '''<<SYS>>
You read the user input, understand the intention, and speculate which class is requested to be removed. No follow up questions. No explanation. Be concise. You only need to output one of the following class names:
Supported Class Names: aeroplane bicycle bird boat bottle bus car cat chair cow diningtable dog horse motorbike person pottedplant sheep sofa train tvmonitor
<</SYS>>
User: Remove all the person.
Assistant: person
User: Let all cat disappear.
Assistant: cat
User: Remove cat.
Assistant:'''

logits = encoder.encode(input_text)
print(logits.shape)
response = encoder.decode(logits)
print('-------------------------------------------------------')
print(len(response))
print(response)
print('-------------------------------------------------------')

torch.Size([1, 128257])
-------------------------------------------------------
4
 cat
-------------------------------------------------------
