# CAsT Session

In [1]:
import protobuf

ModuleNotFoundError: No module named 'protobuf'

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [3]:
from elasticsearch import Elasticsearch
from typing import Dict, List, Optional
import json

import ipytest
import pytest

ipytest.autoconfig()

In [4]:
INDEX_NAME = "cast_base"
es = Elasticsearch()

## T5 testing

In [5]:
tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")

model = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")

Downloading: 100%|██████████| 850M/850M [00:43<00:00, 20.5MB/s]


In [17]:
input_ids = tokenizer('Jafar is funny. <sep> Is he funny?', return_tensors='pt').input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Is Jafar funny?


## Framework

In [177]:
class CAsT():
    def __init__(self, context_queries: int = 0, context_responses: int = 0) -> None:
        self.INDEX_NAME = "cast_base"
        self.es = Elasticsearch()
        self.queries = []
        self.responses = []
        self.context_queries = context_queries
        self.context_responses = context_responses

        self.tokenizer = AutoTokenizer.from_pretrained(
            "castorini/t5-base-canard")
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            "castorini/t5-base-canard")

    def clear_context(self, clear_queries: bool = True, clear_responses: bool = True):
        if clear_queries:
            self.queries = []
        if clear_responses:
            self.responses = []

    def query(self, q: str) -> str:
        """ 
            returns: passage_id NOTE: for now complete hit is returned
        """
        sep = " <sep>"
        qs = []
        if self.context_queries > 0 or self.context_responses > 0:
            for i in range(1, max(self.context_queries, self.context_responses)+1):
                if i <= self.context_queries:
                    if len(self.queries) >= i:
                        qs.insert(0,self.queries[-i])

                if i <= self.context_responses:
                    if len(self.responses) >= i:
                        qs.insert(0,self.responses[-i])
        qs.append(q)

        input_ids = tokenizer(sep.join(qs), return_tensors='pt').input_ids
        outputs = model.generate(input_ids)

        query = tokenizer.decode(outputs[0], skip_special_tokens=True)
        self.queries.append(query)  # * Adding reformated query to context

        hits = es.search(
            index=self.INDEX_NAME, q=query, _source=True, size=100
        ).get("hits", {}).get("hits")

        if len(hits) > 0:
            print("Query: " + query)
            self.responses.append(hits[0].get("_source").get("passage"))
            return hits[:500]
        else:
            return []


### Framework tests

In [164]:
A = []
A.insert(0,1)
A.insert(0,2)
A.insert(0,3)
A

[3, 2, 1]

In [140]:
test = CAsT()

In [141]:
test.query("Tell me about Oslo?", context_queries=1)

Entered 'i <= context_queries:'
Query: Tell me about Oslo?




{'_index': 'cast_base',
 '_type': '_doc',
 '_id': '8841272',
 '_score': 9.102427,
 '_source': {'passage': "Tell a friend about us, add a link to this page, or visit the webmaster's page for free fun content. Link to this page: <a href=http://acronyms.thefreedictionary.com/South+African+Board+for+Personnel+Practice>SABPP</a>. Facebook."}}

In [142]:
test.query("Where is it?", context_queries=1)

Entered 'i <= context_queries:'
Entered 'len(self.queries) >= i:'
Query: Where is Oslo located?


{'_index': 'cast_base',
 '_type': '_doc',
 '_id': '8841012',
 '_score': 4.786257,
 '_source': {'passage': 'Refraction of Sound. Refraction is the bending of waves when they enter a medium where their speed is different. Refraction is not so important a phenomenon with sound as it is with light where it is responsible for image formation by lenses, the eye, cameras, etc.But bending of sound waves does occur and is an interesting phenomena in sound.efraction of Sound. Refraction is the bending of waves when they enter a medium where their speed is different. Refraction is not so important a phenomenon with sound as it is with light where it is responsible for image formation by lenses, the eye, cameras, etc.'}}

## Run Queries

In [172]:
path = "../eval/2020_automatic_evaluation_topics_v1.0.json"
key = "raw_utterance"
queries = json.load(open(path))

In [178]:
def run_queries(query_file: str, key: str, CAsT: object, run_id: str):

    queries = json.load(open(query_file))
    if queries[0].get("turn", {})[0].get(key) is None:
        raise KeyError("Provided key: " + key +
                       "is not a valid key for queryfile")

    f = open(run_id + ".txt", "w")

    for topic in queries:
        CAsT.clear_context()
        topic_id = topic.get("number")
        for turn in topic.get("turn"):
            turn_id = turn.get("number")
            hits = CAsT.query(turn.get(key))
            for j, hit in enumerate(hits):
                hit_id = "MARCO_" + hit.get("_id") if hit.get("_source").get(
                    "origin") == "msmarco" else "CAR_" + hit.get("_id")
                f.write(str(topic_id) + "_" + str(turn_id) + "\t" + "Q0" + "\t" + str(hit_id) +
                        "\t" + str(j) + "\t" + str(hit.get("_score")) + "\t" + str(run_id) + "\n")
    f.close()


In [179]:
test_obj = CAsT(context_queries=3)

In [180]:
run_queries(path, key=key, CAsT=test_obj, run_id="Test02")



Query: How do you know when your garage door opener is going bad?
Query: Now the garage door opener stopped working. Why?
Query: How much does it cost for someone to fix a garage door opener that stopped working?
Query: How much does it cost for someone to fix a garage door opener that stopped working?
Query: How do I choose a new garage door opener?
Query: What does a smart garage door opener do?
Query: What's important for me to know about a smart garage door opener's safety
Query: How could a smart garage door opener be hacked?
Query: I would like to learn about GMO Food labeling.
Query: What are the pros and cons of GMO Food labeling?
Query: What are the pros and cons of GMO Food labeling?
Query: What are the EU rules for GMO Food labeling?
Query: Tell me more about traceability tools.
Query: What is the role of Co-Extra in the EU?
Query: How is testing done for contamination of GMO Food in the EU?
Query: What's the difference between the European and US approaches for testing for 