In [18]:
# Install necessary packages
!pip install kaggle fuzzywuzzy python-levenshtein

import os
import glob
import pandas as pd
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from fuzzywuzzy import process
import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt

# Kaggle API authentication setup
os.makedirs('/root/.kaggle', exist_ok=True)
if os.path.exists('kaggle.json'):
    !cp kaggle.json /root/.kaggle/
    !chmod 600 /root/.kaggle/kaggle.json
    print("Kaggle authentication configured.")
else:
    print("WARNING: 'kaggle.json' not found! Upload your Kaggle API token first.")

# Download and unzip Leafly dataset from Kaggle
if not os.path.exists('leafly-cannabis-strains-metadata.csv'):
    !kaggle datasets download gthrosa/leafly-cannabis-strains-metadata --unzip -q
    print("Dataset downloaded and unzipped.")
else:
    print("Dataset CSV found, skipping download.")

# Locate CSV file dynamically
csv_files = glob.glob('*.csv')
dataset_file = [f for f in csv_files if 'leafly' in f.lower()]
if not dataset_file:
    raise FileNotFoundError("Leafly dataset CSV not found after unzip.")
dataset_file = dataset_file[0]

# Load dataset
df_full = pd.read_csv(dataset_file)
df_full['strain_name'] = df_full['name'].str.strip()

# Define chemical profile columns (update as needed)
chemical_cols = ['thc', 'cbd', 'myrcene', 'limonene', 'pinene', 'linalool', 'caryophyllene', 'humulene']
for col in chemical_cols:
    if col not in df_full.columns:
        df_full[col] = 0.0

df = df_full[['strain_name']].copy()
X = df_full[chemical_cols].fillna(0)

# Align input features to your model expected input size
expected_input_features = 12  # Change to your model's input layer feature count
current_features = X.shape[1]

if current_features < expected_input_features:
    n_missing = expected_input_features - current_features
    for i in range(n_missing):
        X[f'pad_{i}'] = 0.0
    print(f"Padded X with {n_missing} zero columns to match model input features.")

elif current_features > expected_input_features:
    X = X.iloc[:, :expected_input_features]
    print(f"Trimmed X to first {expected_input_features} features to match model input.")

X_tensor = torch.tensor(X.values, dtype=torch.float32)

effect_mapping = {
    'reduces_stress': 'Stress',
    'analgesic': 'PainRelief',
    'low_psychoactivity': 'LowPsycho',
    'anti_inflammatory': 'AntiInflammatory',
    'antioxidant': 'Antioxidant',
    'sedative': 'Sedative',
    'mood_uplift': 'MoodUplift',
    'appetite_stim': 'Appetite',
    'neuroprotective': 'Neuroprotective',
}

print("Example valid strain names from dataset:")
print(df['strain_name'].sample(10).tolist())

# Widgets initialization
strain_text = widgets.Text(
    description='Parents (comma sep):',
    placeholder='Type 1 or 2 strain names',
    layout=widgets.Layout(width='70%')
)
display(strain_text)

def get_parents_from_text(text):
    names = [n.strip().lower() for n in text.split(',') if n.strip()]
    matched = []
    choices = [x.lower() for x in df['strain_name']]
    for name in names:
        best_match, score = process.extractOne(name, choices)
        if score > 70:
            matched.append(df['strain_name'].iloc[choices.index(best_match)])
    return matched

weight_sliders_box = widgets.VBox()
display(weight_sliders_box)

def update_weight_sliders_for_parents(parent_names):
    sliders = []
    for strain in parent_names:
        slider = widgets.FloatSlider(
          value=1.0, min=0.0, max=2.0, step=0.1,
          description=strain, continuous_update=False,
          orientation='horizontal', readout_format='.1f',
          layout=widgets.Layout(width='70%'))
        sliders.append(slider)
    weight_sliders_box.children = sliders

chem_sliders = {}
chem_box = widgets.VBox()
display(chem_box)

def update_chem_sliders():
    sliders = []
    for col in X.columns:
        slider = widgets.FloatSlider(
          value=1.0, min=0.0, max=2.0, step=0.05,
          description=col, continuous_update=False,
          orientation='horizontal', readout_format='.2f',
          layout=widgets.Layout(width='70%'))
        sliders.append(slider)
        chem_sliders[col] = slider
    chem_box.children = sliders

update_chem_sliders()

effect_select = widgets.SelectMultiple(
    options=list(effect_mapping.keys()), description='Target Effects:')
display(effect_select)

generate_button = widgets.Button(description="Generate Candidate Strain")
output_area = widgets.Output()
display(generate_button, output_area)

# Load your pretrained model and effect predictors here
# Example:
# ae = YourAutoencoderModel()
# trained_models = {'reduces_stress': model1, ...}

def generate_strain(b):
    with output_area:
        output_area.clear_output()
        parents = get_parents_from_text(strain_text.value)
        if len(parents) == 0:
            print("Enter at least one valid parent strain name.")
            return
        if len(parents) > 2:
            print("Please enter at most two parent strain names separated by commas.")
            return
        if not effect_select.value:
            print("Select at least one target effect.")
            return

        parent_indices = [df[df['strain_name'] == s].index[0] for s in parents]

        if len(parents) == 1:
            single_idx = parent_indices[0]
            single_chem = X.iloc[single_idx].values
            chem_weights_np = np.array([chem_sliders[c].value for c in X.columns])
            weighted_chem = single_chem * chem_weights_np

            similarities = cosine_similarity(weighted_chem.reshape(1, -1), X)[0]
            similarities[single_idx] = -1
            best_partner_idx = np.argmax(similarities)
            parent_indices.append(best_partner_idx)

        parent_names_real = [df.loc[i, 'strain_name'] for i in parent_indices]
        update_weight_sliders_for_parents(parent_names_real)

        weights = np.array([s.value for s in weight_sliders_box.children])
        weights = torch.tensor(weights, dtype=torch.float32).unsqueeze(1)

        parent_vectors = X_tensor[parent_indices]
        latent_vectors = ae.encoder(parent_vectors)
        combined_latent = (latent_vectors * weights).sum(dim=0) / weights.sum()
        combined_latent = combined_latent.unsqueeze(0)

        decoded = ae.decoder(combined_latent).detach().numpy()
        generated_df_custom = pd.DataFrame(decoded, columns=X.columns)

        for col, slider in chem_sliders.items():
            generated_df_custom[col] *= slider.value

        sim = cosine_similarity(generated_df_custom, X)
        top_parents = []
        for sim_row in sim:
            top_idxs = np.argsort(sim_row)[-2:][::-1]
            parents_similar = df['strain_name'].iloc[top_idxs].tolist()
            top_parents.append(parents_similar)
        generated_df_custom['parent_strains'] = top_parents

        for lbl, model in trained_models.items():
            pred_col = lbl + '_pred'
            generated_df_custom[pred_col] = model.predict(generated_df_custom[X.columns])

        def generate_effect_name(row):
            effects = []
            for col, label in effect_mapping.items():
                pred_col = col + '_pred'
                if pred_col in row and row[pred_col] == 1:
                    effects.append(label)
            return "_".join(effects) + "_candidate" if effects else "Neutral_candidate"
        generated_df_custom['strain_name'] = generated_df_custom.apply(generate_effect_name, axis=1)

        def generate_full_name(row):
            parent_ids = "_".join([p.split("_")[1] if "_" in p else p for p in row['parent_strains']])
            return f"{row['strain_name']}_{parent_ids}"
        generated_df_custom['full_strain_name'] = generated_df_custom.apply(generate_full_name, axis=1)

        try:
            import plotly.express as px
            sample_plot = generated_df_custom.iloc[0][X.columns]
            parent_plot = df.loc[parent_indices, X.columns].mean()
            plot_df = pd.DataFrame([sample_plot, parent_plot], index=['Candidate', 'Parents'])
            fig = px.line_polar(plot_df.T, r=plot_df.T['Candidate'], theta=plot_df.T.index,
                                line_close=True, title='Candidate vs Parent Chemical Profile')
            fig.show()
        except:
            plt.figure(figsize=(6, 6))
            plt.plot(sample_plot.index, sample_plot.values, label='Candidate', marker='o')
            plt.plot(parent_plot.index, parent_plot.values, label='Parents', marker='x')
            plt.xticks(rotation=90)
            plt.title('Candidate vs Parent Chemical Profile')
            plt.legend()
            plt.show()

        cols_to_display = ['full_strain_name', 'parent_strains'] + \
            [c + '_pred' for c in effect_mapping.keys() if c + '_pred' in generated_df_custom.columns]
        display(generated_df_custom[cols_to_display])

generate_button.on_click(generate_strain)



Kaggle authentication configured.
Dataset URL: https://www.kaggle.com/datasets/gthrosa/leafly-cannabis-strains-metadata
License(s): CC0-1.0
Dataset downloaded and unzipped.
Padded X with 4 zero columns to match model input features.
Example valid strain names from dataset:
['Golden Pineapple', 'Lost Coast Hash Plant', 'Gammaberry', 'Pineapple Muffin', 'Medibud', 'Grape Stomper OG', 'Zkittlez OG', 'Purple Sunset', 'Notorious THC', 'Kosher Dawg']


Text(value='', description='Parents (comma sep):', layout=Layout(width='70%'), placeholder='Type 1 or 2 strain…

VBox()

VBox()

SelectMultiple(description='Target Effects:', options=('reduces_stress', 'analgesic', 'low_psychoactivity', 'a…

Button(description='Generate Candidate Strain', style=ButtonStyle())

Output()