In [1]:
from transformers import BertForSequenceClassification, AutoTokenizer, pipeline, AutoModelForCausalLM
import pandas as pd
import numpy as np
import random
import re
import torch
attribute_names = [
    "health", "strength", "dexterity", "perception", 
    "intelligence", "charisma", "stamina"
]

In [2]:
#load risk-model
risk_tokenizer = AutoTokenizer.from_pretrained("samwu1/risk-model")
risk_model = BertForSequenceClassification.from_pretrained("samwu1/risk-model")

def get_risk_output(input):
    encoding = risk_tokenizer(input, return_tensors="pt", truncation=True, padding=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    risk_model.to(device)
    encoding = {k: v.to(device) for k, v in encoding.items()}
    with torch.no_grad():
        output = risk_model(**encoding)
    return output.logits[0][0]

In [3]:
#load risk-attribute
attribute_tokenizer =  AutoTokenizer.from_pretrained("samwu1/attribute-model")
attribute_model = BertForSequenceClassification.from_pretrained("samwu1/attribute-model")

def get_attribute_output(dm_prompt, user_input, dm_output):
    input = f"<DM>{dm_prompt}</DM>\n<Player>{user_input}</Player>\n<DM>{dm_output}</DM>"
    encoding = attribute_tokenizer(input, return_tensors="pt")
    encoding = {k: v.to(attribute_model.device) for k,v in encoding.items()}
    prediction = attribute_model(**encoding).logits
    return prediction.squeeze().tolist()

def update_attributes(attributes, d_attributes):
    new_attributes_stats = attributes + d_attributes
    attributes_string = "\n".join(f"{name}:{value:.2f}" for name, value in zip(attribute_names, new_attributes_stats))
    return attributes_string, new_attributes_stats


In [12]:
#load summarization pipeline (defaults to bart-large-cnn)
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
def get_summarizer_output(input):
    num_tokens = len(summarizer.tokenizer.encode(input, truncation=False))
    num_tokens = 20 if num_tokens<10 else num_tokens
    summary = summarizer(input, max_length=int(num_tokens*.8), min_length=int(num_tokens*.3), do_sample=True)
    return summary[0]['summary_text']

Device set to use cuda:0


In [None]:
#load dm-model
# Load tokenizer and model
dm_tokenizer = AutoTokenizer.from_pretrained("samwu1/dm-model")
dm_model = AutoModelForCausalLM.from_pretrained("samwu1/dm-model", torch_dtype=torch.float16)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dm_model.to(device)

def get_dm_output(dm_input, user_input):
    message = [{"role": "system","content": dm_input,},{"role": "user", "content": user_input}]
    prompt = dm_tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
    input_ids = dm_tokenizer(prompt, return_tensors="pt").input_ids.to(dm_model.device)
    with torch.inference_mode():
        generated_ids = dm_model.generate(
            input_ids,
            max_new_tokens=80,
            do_sample=True,
            top_p=0.95,
            temperature=.8,
        )
    generated_text = dm_tokenizer.decode(generated_ids[0], skip_special_tokens=False)
    return generated_text.split("<|assistant|>")[-1]

In [14]:
#init setup: generate random attributes, get inital scenario from the dataset as a introduction:
attributes_stats = np.random.rand(7)
attribute_string, attributes_stats = update_attributes(attributes_stats, 0)
scenarios = pd.read_csv("data/scenario-dataset.csv")["scenarios"].to_list()
intro_scene = random.choice(scenarios)
dm_match = re.search(r"<DM>(.*?)</DM>", intro_scene)
dm_prompt = dm_match.group(1) if dm_match else ""
history = ""

In [15]:
#MAIN METHOD!! putting it all together FINALLY!
print(attribute_string)
print(dm_prompt)
user_input = None
while user_input!="quit":
    user_input = input()
    print(user_input)
    print("\nloading risk:")
    risk_input = "\n".join([attribute_string, dm_prompt, user_input])
    #feed risk_input into risk_model
    risk = get_risk_output(risk_input)
    print(f"Risk:{risk}")

    #rolling dice
    print("Rolling Dice:")
    dice_roll = np.random.rand()
    print(dice_roll)
    roll = "<Roll: Success>" if dice_roll<risk else "<Roll: Failure>"
    print(f"{roll}")

    #generate history:
    #trying this
    prev_text = get_summarizer_output(history)
    history = prev_text+"\n"+dm_prompt+"\n"+user_input

    #feed history, attribute, dm_prompt, roll, and user_input to dm-model
    dm_input = f"history:{{{history}}}\n\n({attribute_string})\n\n<DM>{dm_prompt}</DM>\n\n{roll}"
    dm_output = get_dm_output(dm_input, user_input)

    #update attributes
    d_attributes = get_attribute_output(dm_prompt, user_input, dm_output)
    attribute_string, attributes_stats = update_attributes(attributes_stats, d_attributes)

    #reprint output
    print("\n"+attribute_string)
    print(dm_output)

    #update dm_prompt
    dm_prompt = dm_output
    

health:0.79
strength:0.98
dexterity:0.43
perception:0.69
intelligence:0.79
charisma:0.03
stamina:0.40
You stand in front of a locked wooden door, the faint sound of footsteps echoing beyond it. The door has no visible keyhole. 


Your max_length is set to 16, but your input_length is only 3. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=1)


i try to climb through the window

loading risk:
Risk:0.750372052192688
Rolling Dice:
0.06847590078297172
<Roll: Success>

health:0.76
strength:1.01
dexterity:0.44
perception:0.68
intelligence:0.81
charisma:0.03
stamina:0.43

You grip the smallest crevice between two stones and haul yourself up, heart pounding, with a swift, decisive movement. The door sways open to reveal a narrow stone hallway where whispers drift from deeper within. Pressing closer, you overhear two people arguing about a missing artifact and someone called "The Warden," with urgency about recovering


loading risk:
Risk:0.7257801294326782
Rolling Dice:
0.7155067787283218
<Roll: Success>

health:0.75
strength:1.02
dexterity:0.48
perception:0.69
intelligence:0.87
charisma:0.06
stamina:0.44

You stand your ground, your hand lifting in a defensive gesture as the conversation halts. You listen carefully, taking note of the key players and their motivations. You manage to gather crucial intelligence, enough to warn of Th