In [ ]:
# Installs Unsloth, Xformers (Flash Attention) and all other packages!
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes

In [0]:
from unsloth import FastLanguageModel
import torch

model1, tokenizer1 = FastLanguageModel.from_pretrained(
    model_name = "ThePwo/PwoAI",
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = True,
)

model2, tokenizer2 = FastLanguageModel.from_pretrained(
    model_name = "ThePwo/FishAI",
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = True,
)

In [ ]:
from unsloth.chat_templates import get_chat_template

tokenizer1 = get_chat_template(
    tokenizer1,
    chat_template = "llama-3",
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt"},
)

tokenizer2 = get_chat_template(
    tokenizer2,
    chat_template = "llama-3",
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt"},
)

def get_last_message(output):
    parts = output.rsplit('\n\n', 1)
    if len(parts) > 1:
        return parts[-1].strip().replace('<|eot_id|>', '')
    return None

FastLanguageModel.for_inference(model1)
FastLanguageModel.for_inference(model2)

In [ ]:
start_message = "hi"
messages1, messages2 = [{"from": "gpt", "value": start_message}], [{"from": "human", "value": start_message}]
turn = 1

for step in range(40):
    tokenizer = tokenizer2 if turn else tokenizer1
    model = model2 if turn else model1
    messages = messages2 if turn else messages1
    other_messages = messages1 if turn else messages2

    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize = True,
        add_generation_prompt = True,
        return_tensors = "pt",
    ).to("cuda")

    outputs = model.generate(input_ids = inputs, max_new_tokens = 64, use_cache=True, temperature=0.8)
    response = get_last_message(tokenizer.batch_decode(outputs)[0])
    messages.append({"from": "gpt", "value": response})
    other_messages.append({"from": "human", "value": response})
    messages, other_messages = messages[-10:], other_messages[-10:]
    print(f"{'FishAI' if turn else 'PwoAI'}: {response}")
    turn = not turn