# ‘Streaming’ ANN using embeddings from SentenceTransformers

In [None]:
import pandas as pd
import numpy as np
import tensorflow as tf
import gc
import pickle
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, Lambda, concatenate, GlobalAveragePooling1D
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive/')

# Load pickled sentence embeddings
def load_embeddings(file_path):
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    return {k: np.array(v, dtype=np.float16) for k, v in data.items()}

# Define data generator
def data_generator(df, embedding_map):
    for _, row in df.iterrows():
        emb1 = embedding_map.get(row['sentence1_clean'])
        emb2 = embedding_map.get(row['sentence2_clean'])
        if emb1 is not None and emb2 is not None:
            yield (emb1.astype(np.float16), emb2.astype(np.float16)), np.float16(row['score'])

# Build ANN model
def build_ann_model(input_shape):
    input1 = Input(shape=input_shape)
    input2 = Input(shape=input_shape)

    pooled1 = GlobalAveragePooling1D()(input1)
    pooled2 = GlobalAveragePooling1D()(input2)

    abs_diff = Lambda(lambda x: tf.abs(x[0] - x[1]))([pooled1, pooled2])
    mult = Lambda(lambda x: x[0] * x[1])([pooled1, pooled2])

    merged = concatenate([pooled1, pooled2, abs_diff, mult])

    dense = Dense(128, activation='relu')(merged)
    drop = Dropout(0.3)(dense)
    output = Dense(1, activation='linear')(drop)

    model = Model(inputs=[input1, input2], outputs=output)
    model.compile(optimizer='adam', loss='mse', metrics=['mae'])
    return model

# Load resources
embedding_path = "drive/MyDrive/sentence_to_embedding.pkl"
sentence_to_embedding = load_embeddings(embedding_path)

csv_path = "drive/MyDrive/rs2_augmented.csv"
df_full = pd.read_csv(csv_path)[['sentence1_clean', 'sentence2_clean', 'score']]

# Remove unknown embeddings
df_full = df_full[
    df_full['sentence1_clean'].isin(sentence_to_embedding) &
    df_full['sentence2_clean'].isin(sentence_to_embedding)
].reset_index(drop=True)

# Train-test split
df_train, df_val = train_test_split(df_full, test_size=0.2, random_state=42)

# Infer input shape
sample_embedding = next(iter(sentence_to_embedding.values()))
embedding_dim = sample_embedding.shape[1] if sample_embedding.ndim == 2 else sample_embedding.shape[0]
max_len = 30
input_shape = (max_len, embedding_dim)

# Create tf.data datasets
batch_size = 16

train_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(df_train, sentence_to_embedding),
    output_signature=(
        (tf.TensorSpec(shape=input_shape, dtype=tf.float16),
         tf.TensorSpec(shape=input_shape, dtype=tf.float16)),
        tf.TensorSpec(shape=(), dtype=tf.float16)
    )
).batch(batch_size).prefetch(tf.data.AUTOTUNE)

val_dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(df_val, sentence_to_embedding),
    output_signature=(
        (tf.TensorSpec(shape=input_shape, dtype=tf.float16),
         tf.TensorSpec(shape=input_shape, dtype=tf.float16)),
        tf.TensorSpec(shape=(), dtype=tf.float16)
    )
).batch(batch_size).prefetch(tf.data.AUTOTUNE)

# Free memory
del df_full
gc.collect()

# Build and train model
model = build_ann_model(input_shape)
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=10
)

# Predictions (optional, be careful with RAM here)
#y_pred = model.predict(val_dataset).flatten()

In [None]:
pearson_corr, _ = pearsonr(y_val, y_pred)
spearman_corr, _ = spearmanr(y_val, y_pred)

print(f"📈 Pearson Correlation:  {pearson_corr:.4f}")
print(f"📊 Spearman Correlation: {spearman_corr:.4f}")

# Plotting
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['mae'], label='Train MAE')
plt.plot(history.history['val_mae'], label='Val MAE')
plt.title('MAE over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Mean Absolute Error')
plt.legend()

plt.tight_layout()
plt.show()