In [27]:
# !pip install datasets
# !pip install qdrant-client sentence-transformers datasets
# !pip install --upgrade huggingface_hub
# !python -c "import huggingface_hub; print(huggingface_hub.__version__)"
# !pip uninstall huggingface_hub -y
# !pip install huggingface_hub 
# !cat ~/miniconda3/envs/image_entity_extract_env/lib/python3.12/site-packages/huggingface_hub/errors.py | grep GatedRepoError

!find ~/miniconda3/envs/image_entity_extract_env/lib/python3.12/site-packages/huggingface_hub -name "*.pyc" -delete


In [47]:
import pandas as pd
import numpy as np
from datasets import load_dataset


In [48]:

ds = load_dataset("MohammadOthman/mo-customer-support-tweets-945k")

In [3]:
len(ds['train'])

945278

# Embed 

## Load all data in qdrant

In [51]:
import logging
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.http import models  # <-- ADD THIS IMPORT
from qdrant_client.http.models import VectorParams

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

QDRANT_COLLECTION = "customer_support_tweets"

def load_data(split="train", limit=None):
    """
    Load dataset from Hugging Face with optional limit for prototyping.
    """
    logger.info(f"Loading dataset split '{split}' from Hugging Face...")
    dataset = load_dataset("MohammadOthman/mo-customer-support-tweets-945k", split=split)
    if limit:
        dataset = dataset.select(range(limit))
    logger.info(f"Loaded {len(dataset)} records.")
    return dataset

def initialize_qdrant(collection_name, vector_size):
    """
    Initialize Qdrant collection. Deletes existing collection if any.
    """
    # Using a local, in-memory client for this example. 
    # Replace with QdrantClient(url="...") for a remote server.
    # Use a persistent Qdrant client (not in-memory) so it can be used by other threads/processes
    # client = QdrantClient(path="qdrant_data")  # This will persist data to ./qdrant_data directory
    client = QdrantClient(host="localhost", port=6333)
    if client.collection_exists(collection_name) == False:

        logger.info(f"Creating collection '{collection_name}' with vector size {vector_size} and cosine distance...")
        client.create_collection(
            collection_name=collection_name,
            vectors_config=VectorParams(size=vector_size, distance="Cosine")
        )
    return client

def embed_texts(model, texts, batch_size=512):
    """
    Embed a list of texts using SentenceTransformer in batches.
    """
    embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        emb = model.encode(batch, show_progress_bar=True)
        embeddings.extend(emb)
    return embeddings

# --- MODIFIED FUNCTION ---
def upsert_data(client, collection_name, inputs, outputs, embeddings, batch_size=256):
    """
    Upload vectors with payload to Qdrant in batches.
    """
    logger.info(f"Uploading {len(inputs)} vectors to Qdrant collection '{collection_name}' in batches of {batch_size}...")
    
    points_batch = []
    # Iterate through all data points
    for idx, (inp, out, vec) in enumerate(zip(inputs, outputs, embeddings)):
        # Create a PointStruct for each item
        points_batch.append(
            models.PointStruct(
                id=idx,
                vector=vec.tolist(),  # Convert numpy array to list
                payload={"input": inp, "reply": out}
            )
        )

        # When the batch is full or it's the last item, upsert it
        if len(points_batch) >= batch_size or idx == len(inputs) - 1:
            client.upsert(
                collection_name=collection_name,
                points=points_batch,
                wait=True  # Wait for the operation to be indexed
            )
            logger.info(f"Upserted batch of {len(points_batch)} points.")
            points_batch = []  # Clear the batch for the next set of points

    logger.info("All batches upserted successfully.")

def search_similar(client, collection_name, model, query, top_k=25):
    """
    Given a query string, embed and search in Qdrant.
    Returns list of matched documents with scores.
    """
    query_vec = model.encode([query])[0].tolist()
    hits = client.search(
        collection_name=collection_name,
        query_vector=query_vec,
        limit=top_k
    )
    results = []
    for hit in hits:
        results.append({
            "id": hit.id,
            "score": hit.score,
            "input": hit.payload.get("input"),
            "reply": hit.payload.get("reply")
        })
    return results

def top_p_filtering(results, p=0.9, score_key="score"):
    """
    Given a list of results (dicts with a score), return the smallest set of results
    whose normalized cumulative score >= p.
    Assumes higher score is better.
    """
    sorted_results = sorted(results, key=lambda x: x[score_key], reverse=True)
    scores = [r[score_key] for r in sorted_results]
    total = sum(scores)
    if total == 0:
        return []
    normalized_scores = [s / total for s in scores]
    cumulative = 0.0
    filtered = []
    for r, ns in zip(sorted_results, normalized_scores):
        filtered.append(r)
        cumulative += ns
        if cumulative >= p:
            break
    return filtered

def top_p_filtering_with_temperature(results, p=0.9, temperature=0.1, score_key="score"):
    """
    Apply temperature to scores, then do top-p (nucleus) filtering.
    """
    if not results:
        return []
    scores = np.array([r[score_key] for r in results])
    # Apply temperature to scores
    if temperature != 1.0:
        scores = scores / temperature
    # Softmax to get probabilities
    exp_scores = np.exp(scores - np.max(scores))  # for numerical stability
    probs = exp_scores / np.sum(exp_scores)
    # Sort results by probability descending
    sorted_indices = np.argsort(probs)[::-1]
    sorted_results = [results[i] for i in sorted_indices]
    sorted_probs = probs[sorted_indices]
    # Top-p filtering
    cumulative = 0.0
    filtered = []
    for r, prob in zip(sorted_results, sorted_probs):
        filtered.append(r)
        cumulative += prob
        if cumulative >= p:
            break
    return filtered

def main():
    # Load data (limit to 10k for prototyping, remove limit for full dataset)
    global ds
    dataset = ds['train'][:1000]
    inputs = dataset["input"]
    outputs = dataset["output"]

    # Load embedding model
    logger.info("Loading embedding model 'all-MiniLM-L6-v2'...")
    model = SentenceTransformer("all-MiniLM-L6-v2")

    # Initialize Qdrant collection
    vector_size = model.get_sentence_embedding_dimension()
    client = initialize_qdrant(QDRANT_COLLECTION, vector_size=vector_size)

    # Embed inputs
    embeddings = embed_texts(model, inputs)

    # Upload vectors and payloads using the new batching function
    upsert_data(client, QDRANT_COLLECTION, inputs, outputs, embeddings)

    # Example query
    user_query = "I am having trouble logging into my account"
    logger.info(f"\nSearching for documents matching query: '{user_query}'")
    results = search_similar(client, QDRANT_COLLECTION, model, user_query)

    # logger.info("Search results:")
    # for res in results:
    #     logger.info(f"Score: {res['score']:.4f} | Input: {res['input']} | Reply: {res['reply']}")


    logger.info(f"Search results: {len(results)}")
    results_top_p = top_p_filtering_with_temperature(results , score_key="score")
    logger.info(f"Search results: {len(results_top_p)}")

    for res in results_top_p:
        logger.info(f"Score: {res['score']:.4f} | Input: {res['input']} | Reply: {res['reply']}")


if __name__ == '__main__':
    main()

INFO:__main__:Loading embedding model 'all-MiniLM-L6-v2'...
INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2
INFO:httpx:HTTP Request: GET http://localhost:6333 "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: GET http://localhost:6333/collections/customer_support_tweets/exists "HTTP/1.1 200 OK"


Batches:   0%|          | 0/16 [00:00<?, ?it/s]

Batches:   0%|          | 0/16 [00:00<?, ?it/s]

INFO:__main__:Uploading 1000 vectors to Qdrant collection 'customer_support_tweets' in batches of 256...
INFO:httpx:HTTP Request: PUT http://localhost:6333/collections/customer_support_tweets/points?wait=true "HTTP/1.1 200 OK"
INFO:__main__:Upserted batch of 256 points.
INFO:httpx:HTTP Request: PUT http://localhost:6333/collections/customer_support_tweets/points?wait=true "HTTP/1.1 200 OK"
INFO:__main__:Upserted batch of 256 points.
INFO:httpx:HTTP Request: PUT http://localhost:6333/collections/customer_support_tweets/points?wait=true "HTTP/1.1 200 OK"
INFO:__main__:Upserted batch of 256 points.
INFO:httpx:HTTP Request: PUT http://localhost:6333/collections/customer_support_tweets/points?wait=true "HTTP/1.1 200 OK"
INFO:__main__:Upserted batch of 232 points.
INFO:__main__:All batches upserted successfully.
INFO:__main__:
Searching for documents matching query: 'I am having trouble logging into my account'


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  hits = client.search(
INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/search "HTTP/1.1 200 OK"
INFO:__main__:Search results: 25
INFO:__main__:Search results: 16
INFO:__main__:Score: 0.7478 | Input: I need help with my Account, It gets disabled every time I try to login and it is getting really frustrating! | Reply: It seems a member of our team has reached out to you. Please check your email and follow up with them there.
INFO:__main__:Score: 0.6445 | Input: I can’t log in because it’s locked. | Reply: The link TN provided will allow you to contact us without signing in! Select skip log in and you will be in touch in no time!
INFO:__main__:Score: 0.5934 | Input: You have to log in. I do not remember the email address I used. Can you look up the account by my bank information? | Reply: It would be best if you get in touch with us by phone or chat.use this link to do so
INFO:__main__:Score: 0.5816 | Input: um, my account was locked as soon

In [52]:
client = QdrantClient(host="localhost", port=6333)
print(client.count(collection_name=QDRANT_COLLECTION, exact=True))

INFO:httpx:HTTP Request: GET http://localhost:6333 "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/count "HTTP/1.1 200 OK"


count=1000


## test 

In [12]:
model = SentenceTransformer("all-MiniLM-L6-v2")
vector_size = model.get_sentence_embedding_dimension()
client = initialize_qdrant(QDRANT_COLLECTION, vector_size=vector_size)
user_query = "Input: is the worst customer service | Reply: I would love the chance to review the account and provide assistance	"
logger.info(f"\nSearching for documents matching query: '{user_query}'")
results = search_similar(client, QDRANT_COLLECTION, model, user_query)

results_top_p = top_p_filtering(results , score_key="score")
logger.info(f"Search results: {len(results)=} , {len(results_top_p)=}")

for res in results_top_p:
    logger.info(f"Score: {res['score']:.4f} | Input: {res['input']} | Reply: {res['reply']}")

INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2
INFO:httpx:HTTP Request: GET http://localhost:6333 "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: GET http://localhost:6333/collections/customer_support_tweets/exists "HTTP/1.1 200 OK"
INFO:__main__:
Searching for documents matching query: 'Input: is the worst customer service | Reply: I would love the chance to review the account and provide assistance	'


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

  hits = client.search(
INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/search "HTTP/1.1 200 OK"
INFO:__main__:Search results: len(results)=25 , len(results_top_p)=23
INFO:__main__:Score: 0.7390 | Input: undoubtedly the worst customer service. Informed you several times I am overseas due to passing of a close family member. Nevertheless you threaten to close my account bc I cannot send paperwork you already have. Very upsetting. | Reply: I appreciate the tweet amp would like to see how I can help. I am sorry to hear about the passing of your family member. Pls click link to send me your account type. For security, do not include any account numbers or PINs when responding. I hope to hear back soon.
INFO:__main__:Score: 0.7194 | Input: Your customer service is terrible. Having to talk to multiple different people, all giving me nonanswers and passing me off to the next is terrible. Highly considering closing my account. I will tell all frie

# Improvements which can be done

- 1. identify duplicate data and remove -  quantify percentage of duplicate data and delete
- 2. cluster similar questions together - and their answers together in one document row

## Identify duplicate data and remove 

In [None]:
!python -m ai.retrieval.load.flag_duplicate_data

INFO:httpx:HTTP Request: GET http://localhost:6333 "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: GET http://localhost:6333/collections/customer_support_tweets "HTTP/1.1 200 OK"


Attempting to flag duplicate sentences from existing Qdrant data (test mode, seeding from first 100)...
Connecting to Qdrant at localhost:6333
Collection 'customer_support_tweets' found. Total points: 945278
Retrieving all 945278 points from Qdrant...


INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/scroll "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/scroll "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/scroll "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/scroll "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/scroll "HTTP/1.1 200 OK"
Flagging duplicate sentences:   1%|          | 1/100 [3:36:52<357:51:02, 13012.75s/it]
INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/scroll "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/scroll "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_

In [23]:
import pandas as pd
flagged_df = pd.read_csv('flagged_df_1000000.csv')
flagged_df.info()
flagged_df.head()

flagged_df['group'].describe()

count    945278.000000
mean     452125.246029
std      269005.786724
min           1.000000
25%      217569.250000
50%      450482.500000
75%      685189.750000
max      920799.000000
Name: group, dtype: float64

In [27]:
# 534289
flagged_df[flagged_df['group'] == 1]
# ['grouped_with']

Unnamed: 0,qdrant_id,input_sentence,reply,group,grouped_with
0,0,is the worst customer service,I would love the chance to review the account ...,1,"[{'neighbor_id': 534289, 'score': 0.9574863}, ..."
80733,80733,customer service is the absolute worst,We do not want you to have problems. what is t...,1,"[{'neighbor_id': 784954, 'score': 0.9735185}, ..."
160301,160301,customer service is the worst.,we are here for you. what is going on?,1,"[{'neighbor_id': 784954, 'score': 0.97044206},..."
264317,264317,customer service is the worst,"Hello, I apologize about your eerience with us...",1,"[{'neighbor_id': 0, 'score': 0.9556223}, {'nei..."
340705,340705,customer service is the worst.,We see your message and will be following up v...,1,"[{'neighbor_id': 784954, 'score': 0.97044206},..."
534289,534289,is THE WORST IN CUSTOMER SERVICES.,Hey Asim! I am sorry for the frustration. Plea...,1,"[{'neighbor_id': 0, 'score': 0.9574863}]"
577984,577984,customer service is the absolute worst,"Hey, Brian. I would like to offer my help to t...",1,"[{'neighbor_id': 784954, 'score': 0.9735185}, ..."
784954,784954,customer service is the worst,we are here for you! what is going on?,1,"[{'neighbor_id': 0, 'score': 0.9556223}, {'nei..."


### delete duplicate rows

In [32]:
flagged_df[flagged_df['group'].duplicated()]


Unnamed: 0,qdrant_id,input_sentence,reply,group,grouped_with
355,355,,Thanks. It looks like that domain is not regis...,210,"[{'neighbor_id': 209, 'score': 1.0}, {'neighbo..."
414,414,,"Hi James, would you be able to let me know how...",210,"[{'neighbor_id': 209, 'score': 1.0}, {'neighbo..."
500,500,,Oh my goodness – I love it! Nicely done. Becky,210,"[{'neighbor_id': 209, 'score': 1.0}, {'neighbo..."
567,567,,"Hi Carrie, awwww, they are awesome. Happy Hall...",210,"[{'neighbor_id': 209, 'score': 1.0}, {'neighbo..."
730,730,,Got it. Are you getting a specific error messa...,210,"[{'neighbor_id': 209, 'score': 1.0}, {'neighbo..."
...,...,...,...,...,...
945128,945128,I need help As Soon As Possible,How may we assist?,5188,"[{'neighbor_id': 5242, 'score': 0.9999999}, {'..."
945145,945145,mobileCareXI Why is not my internet working?,"Hi, I have responded to your DM. Please see an...",61677,"[{'neighbor_id': 168760, 'score': 0.9536183}, ..."
945210,945210,mobileCare,Are you having a billing or service issue that...,3005,"[{'neighbor_id': 3029, 'score': 1.0000001}, {'..."
945257,945257,架空請求きたよww しかとショートメールでAmazon相談係。,ご承知のとおり、残念ながら悪質な詐欺が増加しているようですのでお気を付けください。 を装った...,861714,"[{'neighbor_id': 884396, 'score': 0.99392235}]"


In [44]:
to_delete = flagged_df[flagged_df['group'].duplicated()]['qdrant_id'].to_list()
result = client.delete(
    collection_name="customer_support_tweets",
    points_selector=to_delete,
    wait=True  # Set to True to wait for the operation to complete
)

INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/delete?wait=true "HTTP/1.1 200 OK"


In [46]:
result = client.query_points(
    collection_name="customer_support_tweets",
    query=to_delete[0])
result

INFO:httpx:HTTP Request: POST http://localhost:6333/collections/customer_support_tweets/points/query "HTTP/1.1 404 Not Found"


UnexpectedResponse: Unexpected Response: 404 (Not Found)
Raw response content:
b'{"status":{"error":"Not found: No point with id 355 found"},"time":0.000348834}'