In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

# Import GT4SD’s RegressionTransformer


In [4]:
import torch
print(torch.float8_e4m3fn)

AttributeError: module 'torch' has no attribute 'float8_e4m3fn'

In [3]:
from pydantic.v1 import BaseSettings

In [1]:
from gt4sd.algorithms.conditional_generation.regression_transformer import RegressionTransformer

ImportError: cannot import name 'cached_download' from 'huggingface_hub' (/Users/sb/miniforge3/envs/gt4sd/lib/python3.9/site-packages/huggingface_hub/__init__.py)

In [None]:

# (If you have a molecule-specific configuration, GT4SD provides classes like RegressionTransformerMolecules.
# Here we define our own minimal config for tabular data.)

###########################################
# 1. Data Preparation: Convert Tabular Data to Text
###########################################
# Assume your DataFrame `df` has several feature columns and one numeric target column called 'target'.
# For each row, we create a text sequence:
#    "<target>{target_value}|{feat1},{feat2},...,{featN}"

def row_to_text(row):
    target = row['target']
    features = row.drop('target').astype(str).tolist()
    return f"<target>{target}|{','.join(features)}"

# For demonstration, create a dummy DataFrame (replace with your actual data)
data = {
    'feat1': np.random.rand(200),
    'feat2': np.random.rand(200) * 10,
    'feat3': np.random.randint(0, 100, 200),
    'target': np.random.rand(200) * 50  # continuous target values
}
df = pd.DataFrame(data)
df['text'] = df.apply(row_to_text, axis=1)

# Split the DataFrame into training and testing sets.
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
train_texts = train_df['text'].tolist()
test_texts = test_df['text'].tolist()

###########################################
# 2. Define a Minimal Configuration Class for Tabular Data
###########################################
# GT4SD’s RegressionTransformer expects a configuration object.
# For tabular data we define a simple configuration class.
class RegressionTransformerTabular:
    def __init__(self, algorithm_version, search, temperature, tolerance, sampling_wrapper):
        self.algorithm_version = algorithm_version  # e.g. "tabular"
        self.search = search                        # e.g. "sample" for sampling-based search
        self.temperature = temperature              # controls randomness of sampling
        self.tolerance = tolerance                  # tolerance for property matching
        self.sampling_wrapper = sampling_wrapper    # dictionary with property goal and mask fraction

# For demonstration, we won’t perform full training here; rather, we assume that the RegressionTransformer
# (which in GT4SD is typically pretrained and/or fine-tuned via provided scripts) can be instantiated
# with a configuration and then used to sample candidate outputs.

# Let’s choose a desired target value for conditional generation.
# (In practice you would train/fine-tune the RT on your training data.)
desired_target = 42.0  # Example target value

# Create a configuration instance for tabular data.
config = RegressionTransformerTabular(
    algorithm_version="tabular",
    search="sample",
    temperature=2,
    tolerance=5,
    sampling_wrapper={
        'property_goal': {'<target>': desired_target},
        'fraction_to_mask': 0.2  # fraction of tokens to mask during generation
    }
)

###########################################
# 3. Inference with GT4SD’s RegressionTransformer
###########################################
# For demonstration we pick one seed sample from the test set.
# In practice, you would train or fine-tune the model on your training set.
seed_text = test_texts[0]
# Also extract the true target value from the seed text.
def extract_target(text):
    # Expected format: "<target>{number}|..."
    if text.startswith("<target>"):
        try:
            remainder = text[len("<target>"):]
            target_str = remainder.split("|")[0]
            return float(target_str)
        except:
            return None
    return None

true_target_seed = extract_target(seed_text)
print(f"Seed text: {seed_text}")
print(f"True target for seed: {true_target_seed}")

# Instantiate the RegressionTransformer with the given configuration and seed target.
rt_model = RegressionTransformer(configuration=config, target=seed_text)

# Generate candidate outputs. For example, sample 8 candidates.
candidates = list(rt_model.sample(8))
print("Generated candidate outputs:")
for cand in candidates:
    print(cand)

# Extract the predicted target from each generated candidate.
predicted_candidates = [extract_target(cand) for cand in candidates]
print("Extracted candidate target values:", predicted_candidates)

###########################################
# 4. Evaluation over the Test Set
###########################################
# Now, for each test sample, we instantiate a new RegressionTransformer (using the same configuration)
# and generate a candidate output. Then we extract the predicted target value and compare it to the true value.
y_true = []
y_pred = []
for txt in test_texts:
    # For each sample, create a new instance with the sample text as target.
    instance = RegressionTransformer(configuration=config, target=txt)
    generated_text = list(instance.sample(1))[0]
    pred_val = extract_target(generated_text)
    true_val = extract_target(txt)
    if pred_val is not None and true_val is not None:
        y_true.append(true_val)
        y_pred.append(pred_val)

# Compute evaluation metrics.
mse = mean_squared_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)
print(f"Test MSE: {mse:.4f}")
print(f"Test R2 Score: {r2:.4f}")

# Plot Actual vs Predicted Target Values.
plt.figure(figsize=(8, 6))
plt.scatter(y_true, y_pred, alpha=0.7)
plt.xlabel("Actual Target")
plt.ylabel("Predicted Target")
plt.title("Actual vs Predicted Target Values")
min_val, max_val = min(min(y_true), min(y_pred)), max(max(y_true), max(y_pred))
plt.plot([min_val, max_val], [min_val, max_val], 'r--')
plt.show()
