In [1]:
import html
from bs4 import BeautifulSoup
import pandas as pd
import numpy as np
import re
import json
from tqdm import tqdm
import nltk 
from nltk.tokenize import sent_tokenize
nltk.download('punkt_tab')

import os
from dotenv import load_dotenv

load_dotenv()  
openai_api_key = os.getenv("OPENAI_API_KEY")
print("✅ Key loaded:", openai_api_key[:5] + "..." if openai_api_key else "❌ NOT FOUND")



✅ Key loaded: sk-pr...


[nltk_data] Downloading package punkt_tab to /home/ken/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


In [2]:
df = pd.read_csv('data/leafly_strain_data_project.csv')
df.shape,df.columns[:6] 
### The first 6 columns are as  below, but the rest are reported affects.

((4762, 64),
 Index(['name', 'img_url', 'type', 'thc_level', 'most_common_terpene',
        'description'],
       dtype='object'))

In [3]:
df.describe()

Unnamed: 0,name,img_url,type,thc_level,most_common_terpene,description,relaxed,happy,euphoric,uplifted,...,fibromyalgia,crohn's_disease,phantom_limb_pain,epilepsy,multiple_sclerosis,parkinson's,tourette's_syndrome,alzheimer's,hiv/aids,tinnitus
count,4762,98,4107,2735,2447,4727,4762,4762,4762,4762,...,4762,4762,4762,4762,4762,4762,4762,4762,4762,4762
unique,4762,98,3,29,8,4727,81,79,76,79,...,13,6,3,7,6,2,3,4,3,2
top,Blueberry Waltz,https://images.leafly.com/flower-images/gg-4.jpg,Hybrid,18%,Myrcene,Blueberry Waltz is an indica-dominant strain t...,0%,0%,0%,0%,...,0%,0%,0%,0%,0%,0%,0%,0%,0%,0%
freq,1,1,2772,366,1195,1,2158,1991,2446,2306,...,4748,4755,4759,4753,4755,4760,4760,4758,4760,4760


In [4]:
def aggressive_clean_description(text):
    if pd.isna(text):
        return ""

    # Handle raw Unicode and Windows CP-1252 artifacts
    try:
        text = text.encode("raw_unicode_escape").decode("utf-8", errors="ignore")
    except Exception:
        pass

    # Decode CP1252 if any remnants exist
    try:
        text = text.encode("cp1252", errors="ignore").decode("utf-8", errors="ignore")
    except Exception:
        pass

    # Remove control characters
    text = re.sub(r"[\x00-\x1F\x7F-\x9F]", "", text)

    # Clean HTML tags and entities
    text = html.unescape(text)
    text = BeautifulSoup(text, "html.parser").get_text()

    # Normalize smart quotes, hyphens, dashes
    replacements = {
        "“": "\"", "”": "\"", "‘": "'", "’": "'",
        "–": "-", "—": "-", "…": "...",
    }
    for bad, good in replacements.items():
        text = text.replace(bad, good)

    # Remove residual unicode escapes and non-ASCII
    text = re.sub(r"\\u[0-9a-fA-F]{4}", "", text)  # escaped unicode
    text = re.sub(r"[^\x00-\x7F]+", " ", text)      # non-ASCII chars

    # Remove excess whitespace
    text = re.sub(r"\s+", " ", text).strip()

    return text.lower()


# Applying function to column 
df["aggressive_cleaned_description"] = df["description"].apply(aggressive_clean_description)

# Show comparison on original example
df.loc[5, ["description", "aggressive_cleaned_description"]]


description                       Purple Punch is the sweet and sedating union o...
aggressive_cleaned_description    purple punch is the sweet and sedating union o...
Name: 5, dtype: object

In [5]:
def chunk_sentences(text, chunk_size=2):
    """Split text into groups of `chunk_size` sentences."""
    sentences = sent_tokenize(text)
    return [
        " ".join(sentences[i:i+chunk_size])
        for i in range(0, len(sentences), chunk_size)
    ]


tqdm.pandas()
df["chunks"] = df["aggressive_cleaned_description"].progress_apply(chunk_sentences)


100%|██████████| 4762/4762 [00:00<00:00, 36954.70it/s]


In [6]:
## In order to use OpenAI embeddings maximally, chunk the texts

docs = []
for strain_id, (name,chunks) in enumerate(zip(df['name'],df['chunks'])):
    for chunk_index, chunk in enumerate(chunks):
        docs.append({
            "strain_id":strain_id,
            "strain_name":name,
            "chunk_index": chunk_index,
            "chunk":chunk
        })

docs_df = pd.DataFrame(docs)

In [7]:
from openai import OpenAI
client = OpenAI(api_key=openai_api_key)

def get_embedding(text, model="text-embedding-3-small"):
    try:
        response = client.embeddings.create(
            input=[text],
            model=model
        )
        return response.data[0].embedding
    except Exception as e:
        print(f"Embedding error: {e}")
        return None


In [8]:
response = client.embeddings.create(
    input=["what helps with anxiety and sleep?"],
    model="text-embedding-3-small"
)
embedding = response.data[0].embedding
print(f"✅ Got embedding with {len(embedding)} dimensions")


✅ Got embedding with 1536 dimensions


In [9]:
tqdm.pandas()
docs_df["embedding"] = docs_df["chunk"].progress_apply(get_embedding)


100%|██████████| 9361/9361 [33:21<00:00,  4.68it/s]   


In [10]:
docs_df.to_parquet("data/docs_df_with_embeddings.parquet", index=False)
# Load if needed:
docs_df = pd.read_parquet("data/docs_df_with_embeddings.parquet")

# Compute mean vector per strain
strain_embeddings = (
    docs_df.groupby("strain_name")["embedding"]
    .apply(lambda vectors: np.mean(vectors.tolist(), axis=0))
    .reset_index()
)

# Convert into a dict for quick lookup
strain_embedding_dict = {
    row["strain_name"]: row["embedding"]
    for _, row in strain_embeddings.iterrows()
}


In [11]:
from sklearn.metrics.pairwise import cosine_similarity

def recommend_similar_strains(strain_name, top_n=3):
    if strain_name not in strain_embedding_dict:
        print(f"❌ '{strain_name}' not found.")
        return []

    query_vector = np.array(strain_embedding_dict[strain_name]).reshape(1, -1)

    results = []
    for name, emb in strain_embedding_dict.items():
        if name == strain_name:
            continue
        sim = cosine_similarity(query_vector, np.array(emb).reshape(1, -1))[0][0]
        results.append((name, sim))

    results.sort(key=lambda x: x[1], reverse=True)
    return results[:top_n]


In [12]:
recommend_similar_strains("Blue Dream", top_n=5)


[('Super Blue Dream', np.float64(0.8015651693296493)),
 ('Blue Dream CBD', np.float64(0.7836638187491647)),
 ('Blue Wonder', np.float64(0.7825964484481065)),
 ('Double Dream', np.float64(0.7819953158244742)),
 ('Blue Diesel', np.float64(0.7802391620739217))]