In [1]:
import requests

def fetch_data_from_api(api_url):
    results = []
    page = 1
    while True:
        response = requests.get(f"{api_url}?page={page}")
        data = response.json()
        if not data:
            break
        results.extend(data)
        page += 1
    return results


In [2]:
from transformers import BertTokenizer, BertModel
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from collections import defaultdict

class BertEmbedder:
    def __init__(self, model_name='bert-base-uncased'):
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.model = BertModel.from_pretrained(model_name)
        self.cache = {}

    def embed(self, texts):
        # Check cache first
        cached_texts = {text: self.cache[text] for text in texts if text in self.cache}
        new_texts = [text for text in texts if text not in self.cache]

        if not new_texts:
            return np.array([cached_texts[text] for text in texts])

        inputs = self.tokenizer(new_texts, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
        embeddings = outputs.last_hidden_state[:, 0, :].numpy()

        # Update cache
        for text, embedding in zip(new_texts, embeddings):
            self.cache[text] = embedding

        # Combine cached and new embeddings
        return np.array([self.cache[text] for text in texts])

def identify_citations_bert(response_texts, sources_list, embedder):
    citations = defaultdict(list)

    all_texts = response_texts + [source['context'] for sources in sources_list for source in sources]
    embeddings = embedder.embed(all_texts)

    response_embeddings = embeddings[:len(response_texts)]
    source_embeddings = embeddings[len(response_texts):]

    source_idx = 0
    for i, response_embedding in enumerate(response_embeddings):
        for source in sources_list[i]:
            source_embedding = source_embeddings[source_idx]
            similarity = cosine_similarity([response_embedding], [source_embedding]).flatten()[0]
            if similarity > 0.5:  # Threshold for considering a source as relevant
                citations[i].append(source)
            source_idx += 1

    return citations


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def process_data(api_url):
    data = fetch_data_from_api(api_url)
    embedder = BertEmbedder()
    response_texts = [item['response'] for item in data]
    sources_list = [item['sources'] for item in data]

    citations = identify_citations_bert(response_texts, sources_list, embedder)

    results_with_citations = []
    for i, item in enumerate(data):
        results_with_citations.append({
            'response': item['response'],
            'citations': citations[i]
        })

    return results_with_citations


In [8]:
def display_results(results):
    for item in results:
        print(f"Response: {item['response']}")
        if item['citations']:
            print("Citations:")
            for citation in item['citations']:
                print(f" - {citation['context']}")
                if 'link' in citation:
                    print(f"   Link: {citation['link']}")
        else:
            print("Citations: None")
        print("\n" + "-"*80 + "\n")

if __name__ == '__main__':
    api_url = "https://devapi.beyondchats.com/api/get_message_with_sources"
    results = process_data(api_url)
    display_results(results)


ReadTimeout: HTTPSConnectionPool(host='devapi.beyondchats.com', port=443): Read timed out. (read timeout=None)