# CAsT Session

In [1]:
import protobuf

ModuleNotFoundError: No module named 'protobuf'

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [3]:
from elasticsearch import Elasticsearch
from typing import Dict, List, Optional
import json

import ipytest
import pytest

ipytest.autoconfig()

In [4]:
INDEX_NAME = "cast_base"
es = Elasticsearch()

## T5 testing

In [5]:
tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")

model = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")

Downloading: 100%|██████████| 850M/850M [00:43<00:00, 20.5MB/s]


In [17]:
input_ids = tokenizer('Jafar is funny. <sep> Is he funny?', return_tensors='pt').input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Is Jafar funny?


## Framework

In [126]:
class CAsT():
    def __init__(self) -> None:
        self.INDEX_NAME = "cast_base"
        self.es = Elasticsearch()
        self.queries = []
        self.responses = []

        self.tokenizer = AutoTokenizer.from_pretrained(
            "castorini/t5-base-canard")
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            "castorini/t5-base-canard")

    def clear_context(self, clear_queries: bool, clear_responses: bool):
        if clear_queries:
            self.queries = []
        if clear_responses:
            self.responses = []

    def query(self, q: str, context_queries: int = 0, context_responses: int = 0) -> str:
        """ 
            returns: passage_id NOTE: for now complete hit is returned
        """
        sep = " <sep>"
        qs = []
        if context_queries > 0 or context_responses > 0:
            for i in range(1, max(context_queries, context_responses)+1):
                if i <= context_queries:
                    print("Entered 'i <= context_queries:'")
                    if len(self.queries) >= i:
                        print("Entered 'len(self.queries) >= i:'")
                        qs.append(self.queries[-i])

                if i <= context_responses:
                    if len(self.responses) >= i:
                        qs.append(self.responses[-i])
        qs.append(q)

        input_ids = tokenizer(sep.join(qs), return_tensors='pt').input_ids
        outputs = model.generate(input_ids)

        query = tokenizer.decode(outputs[0], skip_special_tokens=True)
        self.queries.append(query)  # * Adding reformated query to context

        hits = es.search(
            index=self.INDEX_NAME, q=query, _source=True, size=100
        ).get("hits", {}).get("hits")

        if hits is not None:
            print("Query: " + query)
            self.responses.append(hits[0].get("_source").get("passage"))
            return hits[0]
        else:
            return None


### Framework tests

In [127]:
test = CAsT()

In [128]:
test.query("Tell me about Oslo?", context_queries=1)

Entered 'i <= context_queries:'
Query: Tell me about Oslo?


{'_index': 'cast_base',
 '_type': '_doc',
 '_id': '8841272',
 '_score': 9.102427,
 '_source': {'passage': "Tell a friend about us, add a link to this page, or visit the webmaster's page for free fun content. Link to this page: <a href=http://acronyms.thefreedictionary.com/South+African+Board+for+Personnel+Practice>SABPP</a>. Facebook."}}

In [129]:
test.query("Where is it?", context_queries=1)

Entered 'i <= context_queries:'
Entered 'len(self.queries) >= i:'
Query: Where is Oslo located?


{'_index': 'cast_base',
 '_type': '_doc',
 '_id': '8841012',
 '_score': 4.786257,
 '_source': {'passage': 'Refraction of Sound. Refraction is the bending of waves when they enter a medium where their speed is different. Refraction is not so important a phenomenon with sound as it is with light where it is responsible for image formation by lenses, the eye, cameras, etc.But bending of sound waves does occur and is an interesting phenomena in sound.efraction of Sound. Refraction is the bending of waves when they enter a medium where their speed is different. Refraction is not so important a phenomenon with sound as it is with light where it is responsible for image formation by lenses, the eye, cameras, etc.'}}