In [None]:
import os, pickle
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList, LogitsProcessor
from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True  # 8비트 양자화 설정
)

# 모델과 토크나이저 로드
model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_name, 
                                             load_in_8bit=True, 
                                             quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 사용할 어휘 집합 정의


with open('final_target.pkl', 'rb') as f:
    final_target = pickle.load(f)

allowed_token_ids = list(final_target)
allowed_token_ids

In [None]:

class CustomLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids, scores):
        mask = torch.full_like(scores, 1)
        mask[:, allowed_token_ids] = 1
        return scores * mask 

# 텍스트 생성 함수
def generate_text(input_text, max_length=300):
    input_ids = tokenizer(input_text, return_tensors='pt').input_ids.to(device)
    logits_processor = LogitsProcessorList([CustomLogitsProcessor()])
    output = model.generate(input_ids, max_length=max_length, logits_processor=logits_processor, temperature = 0.3, )
    return tokenizer.decode(output[0], skip_special_tokens=True)


In [None]:
input_text = """
Explaination of the contents and significance of the theory of relativity : 

"""

alpha = generate_text(input_text)

print(alpha)