# COSMO

COSMO is a conversation agent with greater generalizability on both in- and out-of-domain chitchat datasets (e.g., DailyDialog, BlendedSkillTalk). It is trained on two datasets: SODA and ProsocialDialog. COSMO is especially aiming to model natural human conversations. It can accept situation descriptions as well as instructions on what role it should play in the situation.

Link: https://github.com/skywalker023/sodaverse

In [2]:
pip install matplotlib

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [3]:
pip install --pre torch torchvision torchaudio 


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


# Initialize

In [4]:
import torch
import numpy as np
import pandas as pd
import sklearn
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Set the device      
device = "cpu" 

if torch.backends.mps.is_available():
    # Initialize the device
    device = "mps"
    
# elif torch.cuda.is_available():
#     # Initialize the device
#     device = "cuda"

print(f"Using device: {device}")

PyTorch version: 2.1.0
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Using device: mps


In [5]:
# All models are downloaded from HuggingFace
device = torch.device(device)
tokenizer = AutoTokenizer.from_pretrained("allenai/cosmo-xl", legacy=False, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/cosmo-xl").to(device)

def set_input(situation_narrative, role_instruction, conversation_history):
    input_text = " <turn> ".join(conversation_history)

    if role_instruction != "":
        input_text = "{} <sep> {}".format(role_instruction, input_text)

    if situation_narrative != "":
        input_text = "{} <sep> {}".format(situation_narrative, input_text)

    return input_text

def generate(situation_narrative, role_instruction, conversation_history):
    """
    situation_narrative: the description of situation/context with the characters included (e.g., "David goes to an amusement park")
    role_instruction: the perspective/speaker instruction (e.g., "Imagine you are David and speak to his friend Sarah").
    conversation_history: the previous utterances in the conversation in a list
    """

    input_text = set_input(situation_narrative, role_instruction, conversation_history) 

    inputs = tokenizer([input_text], return_tensors="pt").to(device)
    outputs = model.generate(inputs["input_ids"], max_new_tokens=128, temperature=1.0, top_p=.95, do_sample=True)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)

    return response

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Ask a question

In [11]:
situation = "Cosmo had a really fun time participating in the EMNLP conference at Abu Dhabi."
instruction = "You are Cosmo and you are talking to a friend." # You can also leave the instruction empty

conversation = [
    "Hey, how was your trip to Abu Dhabi?"
]

response = generate(situation, instruction, conversation)
print(response)

It was great! I had a lot of fun participating in the EMNLP conference.
