In [None]:
import os
import sys
import json
import copy
import random

sys.path.append("/home/tejas/projects/teach/src/")

from teach.dataset.definitions import Definitions
from teach.dataset.dataset import Dataset
from teach.dataset.actions import Action_Keyboard, Action_ObjectInteraction

# Edit data directory if changed when using `teach_download`
data_dir = "/tmp/teach-dataset"
data_dir = "/home/shared/teach/"

definitions = Definitions(version="2.0")

images_dir = os.path.join(data_dir, "images/train/")
game_ids = os.listdir(images_dir)


In [None]:
def get_interactions_with_dialog_acts(game_id, do_print=False):
    f = os.path.join(data_dir, f"games/train/{game_id}.game.json")
    with open(f) as h:
        game_dict = json.load(h)
    interactions = game_dict["tasks"][0]["episodes"][0]["interactions"]
    previous_context = ""
    da_interactions = []
    for interaction in interactions:
        if "utterance" in interaction:
            role = definitions.map_agents_id2info[interaction["agent_id"]]["agent_name"]
            role_actual = "User" if role == "Commander" else "Assistant"
            utterance = interaction["utterance"]
            if do_print:
                print(f"{role}: {utterance}")
            for idx in range(len(interaction["da_metadata"]["das"])):
                # interaction["da_metadata"]["text_segments"] and interaction["da_metadata"]["das"] are lists of length 3
                # If an utterance has fewer than 3 DAs then the extra segments and DAs are empty
                # No utterance has more than 3 DAs
                utt_segment = interaction["da_metadata"]["text_segments"][idx]
                da = interaction["da_metadata"]["das"][idx].strip()
                if da == "" or previous_context == "":
                    continue
                if role_actual == "Assistant":
                    da_interactions.append({
                        "context": previous_context,
                        "response": f"{role_actual}: {utt_segment}",
                        "dialog_act": da
                    })
            previous_context += f"{role_actual}: {utterance} \n"
    return da_interactions


game_id = random.choice(game_ids)
get_interactions_with_dialog_acts(game_id, do_print=True)

In [None]:
from collections import defaultdict
from tqdm import tqdm

da2instances = defaultdict(list)
for g in tqdm(game_ids):
    game_interactions = get_interactions_with_dialog_acts(g)
    for interaction in game_interactions:
        da2instances[interaction["dialog_act"]].append(interaction)

In [None]:
all_selected_instances = []
random.seed(0)
for da, instances in da2instances.items():
    if len(instances) < 50:
        continue
    print(da, len(instances))
    filtered_instances = random.choices(instances, k=min(50, len(instances)))
    all_selected_instances.extend(filtered_instances)

print(f"Total instances: {len(all_selected_instances)}")


In [None]:
class Friction:

    with open('system_new.txt', 'r') as txt:
        system = '\n'.join(txt.readlines())

    def prompt(x):
        p = f'The example dialogue is provided next.\n\n{x["context"]}\n\n\n'
        try:
            p += f'The response is provided below.\n\nResponse: {x["response"]}\n\n\n'
        except KeyError:
            p += f'The response is provided below.\n\nResponse: \n\n\n'
        p += 'What friction category if any?'
        return p

f = Friction()

In [None]:
import random
from openai_utils import openai_caller

idx = random.choice(range(len(all_selected_instances)))
d = all_selected_instances[idx]
print(f"Context: {d['context']}")
print(f"Response: {d['response']}")
print(f"Dialog acts: {d['dialog_act']}")

prompt = Friction.system + Friction.prompt(d)
messages = [{'role': 'system', 'content': prompt}]

gpt_response = openai_caller(messages, max_new_tokens=256, model='gpt4o')

In [None]:
gpt_response

In [None]:
import openai, logging

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpcore").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)


openai_caller.reset_tokens_used()
annotated_data = []

for d in tqdm(all_selected_instances):
        prompt = Friction.system + Friction.prompt(d)
        messages = [{'role': 'system', 'content': prompt}]
        gpt_response = openai_caller(messages, max_new_tokens=256, model='gpt4o')
        friction_anno = gpt_response.split('ANSWER = ')[-1]
        new_d = d.copy()
        new_d['friction_anno'] = friction_anno
        new_d['gpt_response'] = gpt_response
        annotated_data.append(new_d)
print(f"Annotation cost: ${openai_caller.compute_cost():.4f}")

In [None]:
from collections import Counter
annotations = [x['friction_anno'] for x in annotated_data]
print(Counter(annotations))

import json
json.dump(annotated_data, open(f'data_new_prompt/teach_anno-{len(annotated_data)}instances.json', 'w'), indent=2)

In [None]:
acttype_frictionanno_counter = Counter()
for d in annotated_data:
    d['friction_anno'] = d['friction_anno'].strip('.')
    acttype_frictionanno_counter[(d['dialog_act'], d['friction_anno'])] += 1

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

rows = set([x[0] for x in acttype_frictionanno_counter.keys()])
cols = set([x[1] for x in acttype_frictionanno_counter.keys()])
data = [[acttype_frictionanno_counter[(r, c)]/sum([acttype_frictionanno_counter[(r, c2)] for c2 in cols]) for c in cols] for r in rows]
df = pd.DataFrame(data)#, index=rows, columns=cols)
plt.figure(figsize=(8, 8), dpi=150)
sns.heatmap(df, annot=True, fmt='.2f', cmap='viridis')
plt.xlabel('Friction Category', fontsize=14)
plt.ylabel('Dialog Act type', fontsize=14)
plt.xticks(ticks=[x+0.5 for x in range(len(cols))], labels=cols, rotation=45, fontsize=12)
plt.yticks(ticks=[x+0.5 for x in range(len(rows))], labels=rows, rotation=0, fontsize=12)
plt.title('Fraction of instances in each TEACh Dialog Act type \nclassified under different Friction categories', fontsize=14)
plt.show()

In [None]:
annotated_data[0]

In [2]:
import json

annotations = json.load(open('data_new_prompt/teach_anno-550instances.json'))

In [3]:
from collections import defaultdict

das = set([x['dialog_act'] for x in annotations])
da2annos = {da: [x for x in annotations if x['dialog_act'] == da] for da in das}
fcs = set([x['friction_anno'] for x in annotations])
fc2annos = {fc: [x for x in annotations if x['friction_anno'] == fc] for fc in fcs}
dafc2annos = {(da, fc): [x for x in annotations if x['dialog_act'] == da and x['friction_anno'] == fc] for da in das for fc in fcs}

In [13]:
import random
#x = random.choice(da2annos["MiscOther"])
x = random.choice(fc2annos["Reflective Pause"])

print(x['context'])

print("RESPONSE TO ANNOTATE:", x['response'])

print("\nANNOTATION:", x['friction_anno'])
print("GPT REASONING:", x['gpt_response'])


Assistant: what can i do today 
User: make a slice of tomato 
User: tomato is on the chair 
User: knife is on the left side to the oven 
Assistant: done 
User: potato is inside the wash basin 
Assistant: what can i do next 
User: slice it 
User: and cook it in the microwave 
Assistant: am i to make a slice of tomatoe or potatoe? 
User: both 
User: tomato slicing done 
User: now potato 
User: potato is inside wash basin 
User: turn off the tap to find it 
User: left side basin 
Assistant: i have sliced the potatoe and tomatoe 
User: cook a slice of potato in the microwave 
Assistant: cant seem to be able to put the knife down 
User: put it on the right side of the wash basin 
User: enough area there 
Assistant: its still not working 
Assistant: now its working 
User: put it on the table 
User: ok 
User: remove extra items from the oven 
User: to place the slice inside directly 
User: just the slice 
User: not with plate 
Assistant: done 
User: now place the slices in this order on the p