In [1]:
# !pip install -q langchain pypdf2 tiktoken textract faiss-cpu huggingface_hub pypdfium2 InstructorEmbedding sentence-transformers python-docx contractions -q

In [None]:
# !pip install qdrant-client

In [7]:
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import CollectionStatus

In [26]:
from transformers import AutoModel, AutoTokenizer
import torch
from torch import Tensor
import os
import re
import deepl
import json
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOllama
from langchain.prompts import ChatPromptTemplate
from langchain.chains import LLMChain

In [9]:
tokenizer_de = AutoTokenizer.from_pretrained("thenlper/gte-base")
model_de = AutoModel.from_pretrained("thenlper/gte-base")
tokenizer_en = AutoTokenizer.from_pretrained("thenlper/gte-small")
model_en = AutoModel.from_pretrained("thenlper/gte-small")

Downloading model.safetensors: 100%|███████| 66.7M/66.7M [00:04<00:00, 14.9MB/s]


In [10]:
def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
  return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

In [11]:
def embedding(text,lang):
    if lang == 'en':
        batch_dict = tokenizer_en(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
        outputs = model_en(**batch_dict)
    else:
        batch_dict = tokenizer_de(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
        outputs = model_de(**batch_dict)
    embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
    embeddings = embeddings.detach().numpy()
    embeddings = embeddings[0]
    return embeddings

In [23]:
def doc_to_rscope(docs):
    rscope = ""
    links = []
    for doc in docs:
        content = doc.payload['content']
        links.append(doc.payload['link'])
        rscope += content + "\n\n"
    return rscope, links
        

In [13]:
def get_top_results(hits1,hits2):
    combined_lst = hits1+hits2
    sorted_Lst = sorted(combined_lst, key=lambda x:x.score,reverse = True)
    return sorted_Lst[:3]

In [14]:
key_file = json.load(open("deepl.json"))
key_deepl = key_file['key'][0]
translator = deepl.Translator(key_deepl)

In [35]:
def query_embedding(query,lang,kanton):
    if kanton == 'all':
        if lang != 'en' and lang != 'de':
            print(lang)
            query = translator.translate_text(query,target_lang='EN-GB').text
            query_embedding = embedding(query,'en')
        else:
            query_embedding = embedding(query,lang)
    else:
        if lang != 'de':
            query = translator.translate_text(query,target_lang='DE').text
        query_embedding = embedding(query,'de')
    return query_embedding
    

In [21]:
def search_context(query,lang,kanton):
    query_vecs = query_embedding(query,lang,kanton)
    client = QdrantClient(host="localhost",port=6333)
    if kanton != 'all':
        collection1 = 'swiss-'+kanton
        hits1=client.search(
            collection_name=collection1,
            query_vector = query_vecs,
            limit=3
        )
        hits2=client.search(
            collection_name = 'swiss-de',
            query_vector = query_vecs,
            limit = 3
        )
        hits = get_top_results(hits1,hits2)
    else:
        if lang!='de':
            collection = 'swiss-or'
        else:
            collection = 'swiss-de'
        hits=client.search(
            collection_name=collection,
            query_vector = query_vecs,
            limit=3
        )
    return hits

In [28]:
def qa_chatbot(query, lang, kanton):
  context_raw = search_context(query, lang, kanton)
  context,links = doc_to_rscope(context_raw)
  memory = ConversationBufferMemory(k=10,memory_key='chat_history')
  chat_text = """
  You are a legal assistant expert on the Swiss Code of Obligations.
  Answer questions related to contract law, employment regulations,
  or corporate obligations.
  Base your answers exclusively on the provided top 3 articles from the Swiss Code of Obligations.
  Please provide a summary of the relevant article(s), along with the source link(s) for reference.
  The souce link(s) should be from the following collection {source_links}, if none of the links works, just don't provide the information.
  If an answer is not explicitly covered in the provided context, please indicate so by saying 'Whoopsie! It seems we took a detour from the legal zone. Let's hop back to law talk. Ask me anything about contracts, family law, or legal advice!'
  Context: {context}
  Question: {query}
  """
  prompt_template = ChatPromptTemplate.from_template(chat_text)
  ollama_llm_chain = LLMChain(prompt=prompt_template, llm=ollama)
  answer = ollama_llm_chain.run(context=chat_text,
                            query=query,source_links=links)
  
  if lang != ('en' or 'de'):
      detector = translator.translate_text(query, target_lang="DE")
      answer = translator.translate_text(answer, target_lang=detector.detected_source_lang).text
    
  return answer

In [38]:
ollama = ChatOllama(base_url='http://localhost:11434', model="mistral", temperature=0.1)

In [39]:
%%time
qa_chatbot('产假有多久','cn','all')

cn
CPU times: user 907 ms, sys: 68.2 ms, total: 975 ms
Wall time: 4min 8s


'\n根据《瑞士义务法典》第 329 条，雇员有权享受每年至少四周的带薪休假。假期可分为两部分，每部分至少两周。如果合理且不损害雇员的利益，雇主和雇员可以商定不同的休假时间。\n\n资料来源<https://www.fedlex.admin.ch/eli/cc/27/317_321_377/en#art_329_f>'