In [1]:
import os
import numpy as np
import polars as pl
import pandas as pd
import torch
from typing import List
from numpy.typing import NDArray

In [26]:
from sklearn.linear_model import Ridge, RidgeCV, LogisticRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.pipeline import make_pipeline
from sklearn.metrics import roc_auc_score, classification_report, pairwise_distances

In [3]:
from torch.utils.data import TensorDataset, Dataset, DataLoader, WeightedRandomSampler
import torch.nn as nn
import torch.optim as optim

In [4]:
seed = 42
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
train_years = [2012, 2013]
val_years = [2014]

train_meta_df = pd.DataFrame()
val_meta_df = pd.DataFrame()

for ty in train_years:
    for i in [1, 2]:
        train_meta_df = pd.concat([train_meta_df, pd.read_csv("./data/doc_features/{year}_{part}_master_metadata.csv".format(year=ty, part=i))])

for ty in val_years:
    for i in [1, 2]:
        val_meta_df = pd.concat([val_meta_df, pd.read_csv("./data/doc_features/{year}_{part}_master_metadata.csv".format(year=ty, part=i))])

train_meta_df = train_meta_df[(train_meta_df['SUESCORE'] <= -0.5) | (train_meta_df['SUESCORE'] >= 0.5)].reset_index(drop=True)
val_meta_df = val_meta_df[(val_meta_df['SUESCORE'] <= -0.5) | (val_meta_df['SUESCORE'] >= 0.5)].reset_index(drop=True)

In [6]:
test_meta_df = val_meta_df.sample(frac=0.5, random_state=seed)
test_meta_df = test_meta_df.reset_index(drop=True)

val_meta_df = val_meta_df.drop(test_meta_df.index)
val_meta_df = val_meta_df.reset_index(drop=True)

print(len(val_meta_df), len(test_meta_df))

1280 1280


In [7]:
def gather_embeddings(valid_transcriptids, model_name, cls_type, years=[2012,2013]):
    dirname = "./data/doc_features/{}_filtered/".format(model_name)

    if cls_type not in ["X_cls", "X_mean"]:
        raise Exception("data type in valid")
    
    X_list = []
    
    for ty in years:
        for part in [1, 2]:
            filename = "transcript_componenttext_{}_{}_cls_mean.npz".format(ty, part)

            path = os.path.join(dirname, filename)

            data = np.load(path)

            X = data[cls_type]
            tids = data["transcriptids"]
            tids_int = tids.astype(np.int64)

            mask_ids = np.isin(tids_int, valid_transcriptids)

            X_filt = X[mask_ids]
            # tids_filt = tids[mask_ids]

            X_list.append(X_filt)

    return np.vstack(X_list)

In [8]:
llama3_train_embs = gather_embeddings(train_meta_df.transcriptid, "llama3_3b_cls", "X_mean", years=train_years)
qwen_train_embs = gather_embeddings(train_meta_df.transcriptid, "qwen_4b_cls", "X_mean", years=train_years)
gemma_train_embs = gather_embeddings(train_meta_df.transcriptid, "gemma_2b_cls", "X_mean", years=train_years)

In [9]:
llama3_val_embs = gather_embeddings(val_meta_df.transcriptid, "llama3_3b_cls", "X_mean", years=val_years)
qwen_val_embs = gather_embeddings(val_meta_df.transcriptid, "qwen_4b_cls", "X_mean", years=val_years)
gemma_val_embs = gather_embeddings(val_meta_df.transcriptid, "gemma_2b_cls", "X_mean", years=val_years)

llama3_test_embs = gather_embeddings(test_meta_df.transcriptid, "llama3_3b_cls", "X_mean", years=val_years)
qwen_test_embs = gather_embeddings(test_meta_df.transcriptid, "qwen_4b_cls", "X_mean", years=val_years)
gemma_test_embs = gather_embeddings(test_meta_df.transcriptid, "gemma_2b_cls", "X_mean", years=val_years)

In [10]:
scaler = StandardScaler().fit(gemma_train_embs)
gemma_train_aligned = scaler.transform(gemma_train_embs)
gemma_val_aligned = scaler.transform(gemma_val_embs)
gemma_test_aligned = scaler.transform(gemma_test_embs)

In [11]:
y_train = train_meta_df['label']
y_val = val_meta_df['label']
y_test = test_meta_df['label']

In [12]:
def fit_linear_transformation(X_train, target_train, X_test, target_test, X_val=None, use_skl=True, reg=10.0):
    # train T: llama -> Gemma
    scaler_X = StandardScaler().fit(X_train)
    scaler_Y = StandardScaler().fit(target_train)

    X_tr = scaler_X.transform(X_train)
    Y_tr = scaler_Y.transform(target_train)
    X_te = scaler_X.transform(X_test)
    Y_te = scaler_Y.transform(target_test)

    if X_val is not None:
        X_va = scaler_X.transform(X_val)
        
    alpha = reg
    d_X = X_tr.shape[1]

    if not use_skl:
        I = np.eye(d_X)
        # compute W_closed:
        W_closed = np.linalg.inv(X_tr.T @ X_tr + alpha * I) @ X_tr.T @ Y_tr

        # Map test embeddings and inverse-transform:
        Y_tr_from_transform_scaled = X_tr @ W_closed
        
        if X_val is not None:
            Y_val_from_transform_scaled = X_va @ W_closed

        Y_pred_scaled = X_te @ W_closed
        Y_pred = scaler_Y.inverse_transform(Y_pred_scaled)

        # Evaluate (e.g. MSE or cosine similarity)
        mse = np.mean((Y_pred - scaler_Y.inverse_transform(Y_te))**2)
        print(f"Closed-form ridge MSE: {mse:.4f}")
        print(np.diag(cosine_similarity(Y_pred_scaled, Y_te)).mean())

        if X_val is not None:
            return W_closed, Y_tr_from_transform_scaled, Y_val_from_transform_scaled, Y_pred_scaled
        else:
            return W_closed, Y_tr_from_transform_scaled, Y_pred_scaled
        
    else:
        # 3b) Using sklearn.Ridge:
        # Note: sklearn’s Ridge solves for each output dimension jointly when Y is 2D.
        model = Ridge(alpha=alpha, fit_intercept=False, solver="auto")# intercept is already handled by StandardScaler
        model.fit(X_tr, Y_tr)
        W_sklearn = model.coef_.T   

        Y_tr_from_transform_scaled = model.predict(X_tr)

        if X_val is not None:
            Y_val_from_transform_scaled = model.predict(X_va)
        
        Y_pred_scaled = model.predict(X_te)

        Y_pred = scaler_Y.inverse_transform(Y_pred_scaled)
        mse = np.mean((Y_pred - scaler_Y.inverse_transform(Y_te))**2)
        print(f"sklearn.Ridge MSE: {mse:.4f}")
        print(np.diag(cosine_similarity(Y_pred_scaled, Y_te)).mean())

        if X_val is not None:
            return model, Y_tr_from_transform_scaled, Y_val_from_transform_scaled, Y_pred_scaled
        else:
            return model, Y_tr_from_transform_scaled, Y_pred_scaled

In [13]:
# train T: llama -> Gemma
scaler_X = StandardScaler().fit(llama3_train_embs)
scaler_Y = StandardScaler().fit(gemma_train_embs)

X_tr = scaler_X.transform(llama3_train_embs)
Y_tr = scaler_Y.transform(gemma_train_embs)
X_te = scaler_X.transform(llama3_test_embs)
Y_te = scaler_Y.transform(gemma_test_embs)


alpha = 10.0
d_llama = X_tr.shape[1]
I = np.eye(d_llama)
# compute W_closed:
W_closed = np.linalg.inv(X_tr.T @ X_tr + alpha * I) @ X_tr.T @ Y_tr

# Map test embeddings and inverse-transform:
Y_pred_closed = scaler_Y.inverse_transform(X_te @ W_closed)

# Evaluate (e.g. MSE or cosine similarity)
mse_closed = np.mean((Y_pred_closed - scaler_Y.inverse_transform(Y_te))**2)
print(f"Closed-form ridge MSE: {mse_closed:.4f}")

# 3b) Using sklearn.Ridge:
# Note: sklearn’s Ridge solves for each output dimension jointly when Y is 2D.
model = Ridge(alpha=alpha, fit_intercept=False, solver="auto")# intercept is already handled by StandardScaler
model.fit(X_tr, Y_tr)
W_sklearn = model.coef_.T   # shape (d_llama, d_gemma)

Y_pred_skl = scaler_Y.inverse_transform(model.predict(X_te))
mse_skl = np.mean((Y_pred_skl - scaler_Y.inverse_transform(Y_te))**2)
print(f"sklearn.Ridge MSE:      {mse_skl:.4f}")

Closed-form ridge MSE: 0.1197
sklearn.Ridge MSE:      0.1197


In [14]:
# Y_te_scaled = scaler_Y.transform(gemma_test_embs)
Y_pred_scaled = X_te @ W_closed
# compute pairwise cosines for each i
cosines = np.diag(cosine_similarity(Y_pred_scaled, Y_te))
print("Avg. cosine sim:", cosines.mean())

Avg. cosine sim: 0.8950016481689836


In [15]:
Y_pred_scaled = model.predict(X_te)
cosines = np.diag(cosine_similarity(Y_pred_scaled, Y_te))
print("Avg. cosine sim:", cosines.mean())

Avg. cosine sim: 0.8950016


In [None]:
# trans_llama_gemma, llama_train_aligned, llama_val_aligned, llama_test_aligned = fit_linear_transformation(llama3_train_embs, gemma_train_embs, llama3_test_embs, gemma_test_embs, llama3_val_embs, use_skl=True, reg=200)
# trans_qwen_gemma, qwen_train_aligned, qwen_val_aligned, qwen_test_aligned = fit_linear_transformation(qwen_train_embs, gemma_train_embs, qwen_test_embs, gemma_test_embs, qwen_val_embs, use_skl=False, reg=150)

sklearn.Ridge MSE: 0.0954
0.91327083
Closed-form ridge MSE: 0.1153
0.8903838461753889


In [16]:
trans_llama_gemma, llama_train_aligned, llama_val_aligned, llama_test_aligned = fit_linear_transformation(llama3_train_embs, gemma_train_embs, llama3_test_embs, gemma_test_embs, llama3_val_embs, use_skl=True, reg=2)
trans_qwen_gemma, qwen_train_aligned, qwen_val_aligned, qwen_test_aligned = fit_linear_transformation(qwen_train_embs, gemma_train_embs, qwen_test_embs, gemma_test_embs, qwen_val_embs, use_skl=False, reg=1)

sklearn.Ridge MSE: 0.1474
0.8751319
Closed-form ridge MSE: 0.1668
0.8507743229369137


In [17]:
print(np.diag(cosine_similarity(gemma_test_aligned, llama_test_aligned)).mean())
print(np.diag(cosine_similarity(qwen_test_aligned, gemma_test_aligned)).mean())
print(np.diag(cosine_similarity(qwen_test_aligned, llama_test_aligned)).mean())

0.8751319
0.8507743229369137
0.8261375792973121


In [18]:
print(np.diag(cosine_similarity(gemma_val_aligned, llama_val_aligned)).mean())
print(np.diag(cosine_similarity(qwen_val_aligned, gemma_val_aligned)).mean())
print(np.diag(cosine_similarity(qwen_val_aligned, llama_val_aligned)).mean())

0.8727312
0.8465191930255955
0.8219528728637322


In [19]:
def cluster_std_analysis(gemma_emb, qwen_trans_emb, llama_trans_emb, n_clusters=50):
    """
    Quantifies uniqueness of transformed embeddings within Gemma-defined clusters
    
    Args:
        gemma_emb: Original Gemma embeddings (n_samples, dim)
        qwen_trans_emb: Qwen embeddings transformed to Gemma space (n_samples, dim)
        llama_trans_emb: Llama embeddings transformed to Gemma space (n_samples, dim)
        n_clusters: Number of clusters to create
        
    Returns:
        Dictionary containing uniqueness metrics and cluster statistics
    """
    # Cluster in original Gemma space
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(gemma_emb)
    cluster_centers = kmeans.cluster_centers_
    
    gemma_stds, qwen_stds, llama_stds = [], [], []
    uniqueness_scores_qwen, uniqueness_scores_llama = [], []
    
    for c in range(n_clusters):
        # Get indices for current cluster
        cluster_idx = np.where(cluster_labels == c)[0]
        if len(cluster_idx) < 5:  # Skip small clusters
            continue
            
        # Get cluster center
        center = cluster_centers[c]
        
        # Calculate distances to cluster center
        gemma_dists = pairwise_distances(gemma_emb[cluster_idx], [center])
        qwen_dists = pairwise_distances(qwen_trans_emb[cluster_idx], [center])
        llama_dists = pairwise_distances(llama_trans_emb[cluster_idx], [center])
        
        # Compute standard deviations
        gemma_std = np.std(gemma_dists)
        qwen_std = np.std(qwen_dists)
        llama_std = np.std(llama_dists)
        
        # Calculate uniqueness scores
        uniqueness_qwen = max(0, 1 - (qwen_std / gemma_std))
        uniqueness_llama = max(0, 1 - (llama_std / gemma_std))
        
        # Store results
        gemma_stds.append(gemma_std)
        qwen_stds.append(qwen_std)
        llama_stds.append(llama_std)
        uniqueness_scores_qwen.append(uniqueness_qwen)
        uniqueness_scores_llama.append(uniqueness_llama)
    
    # Compute cross-model divergence
    cross_divergence = 1 - np.mean([
        np.dot(qwen_trans_emb[i], llama_trans_emb[i]) / 
        (np.linalg.norm(qwen_trans_emb[i]) * np.linalg.norm(llama_trans_emb[i]))
        for i in range(len(gemma_emb))
    ])
    
    return {
        "mean_uniqueness_qwen": np.mean(uniqueness_scores_qwen),
        "mean_uniqueness_llama": np.mean(uniqueness_scores_llama),
        "std_ratio_qwen": np.mean(qwen_stds) / np.mean(gemma_stds),
        "std_ratio_llama": np.mean(llama_stds) / np.mean(gemma_stds),
        "cross_model_divergence": cross_divergence,
        "per_cluster": {
            "gemma_std": gemma_stds,
            "qwen_std": qwen_stds,
            "llama_std": llama_stds,
            "uniqueness_qwen": uniqueness_scores_qwen,
            "uniqueness_llama": uniqueness_scores_llama
        }
    }

In [20]:
results = cluster_std_analysis(gemma_test_aligned, llama_test_aligned, qwen_test_aligned)
    
print(f"Qwen uniqueness: {results['mean_uniqueness_qwen']:.3f}")
print(f"Llama uniqueness: {results['mean_uniqueness_llama']:.3f}")
print(f"Qwen std ratio: {results['std_ratio_qwen']:.3f}")
print(f"Llama std ratio: {results['std_ratio_llama']:.3f}")
print(f"Cross-model divergence: {results['cross_model_divergence']:.3f}")


Qwen uniqueness: 0.024
Llama uniqueness: 0.049
Qwen std ratio: 1.098
Llama std ratio: 1.027
Cross-model divergence: 0.174


In [21]:
class EmbeddingConcatDataset(Dataset):
    def __init__(self, aligned_embs: List[NDArray[np.float32]], labels):
        
        self.concat_embs = np.concat(aligned_embs, axis=1)
        self.num_models = len(aligned_embs)
        self.emb_dim = aligned_embs[0].shape[1]
        self.concat_emb_dim = self.num_models * self.emb_dim
        self.labels = labels

        assert self.concat_embs.shape[0] == len(self.labels)

    def __len__(self):
        return len(self.concat_embs)

    def __getitem__(self, index):
        x = self.concat_embs[index]
        y = self.labels[index] 

        x_tensor = torch.from_numpy(x)
        y_tensor = torch.tensor(y)

        return x_tensor, y_tensor



In [22]:
# train_aligned_embs= [gemma_train_aligned, llama_train_aligned, qwen_train_aligned]
# val_aligned_embs= [gemma_val_aligned, llama_val_aligned, qwen_val_aligned]
# test_aligned_embs= [gemma_test_aligned, llama_test_aligned, qwen_test_aligned]

# train_ds = EmbeddingConcatDataset(train_aligned_embs, y_train)
# val_ds = EmbeddingConcatDataset(val_aligned_embs, y_val)
# test_ds = EmbeddingConcatDataset(test_aligned_embs, y_test)

# batch_size = 32
# train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  drop_last=True)
# val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)
# test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

def to_tensor_dataset(X, y):
    Xt = torch.from_numpy(X).float().to(device)
    yt = torch.from_numpy(y).float().unsqueeze(1).to(device)
    return TensorDataset(Xt, yt)

train_aligned_embs = np.array([gemma_train_aligned, llama_train_aligned, qwen_train_aligned]).mean(axis=0)
val_aligned_embs = np.array([gemma_val_aligned, llama_val_aligned, qwen_val_aligned]).mean(axis=0)
test_aligned_embs = np.array([gemma_test_aligned, llama_test_aligned, qwen_test_aligned]).mean(axis=0)

train_ds = to_tensor_dataset(train_aligned_embs, y_train.to_numpy())
val_ds = to_tensor_dataset(val_aligned_embs, y_val.to_numpy())
test_ds = to_tensor_dataset(test_aligned_embs, y_test.to_numpy())

batch_size = 32

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)
test_dl  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)



In [28]:
y_train.to_numpy()

array([1., 1., 1., ..., 1., 1., 0.], shape=(4920,))

In [None]:
# X_mean, l2 0.005
param_grid = {"logisticregression__C": [0.0001, 0.006, 0.1]}

pipeline = make_pipeline(
    # StandardScaler(),
    LogisticRegression(
        penalty="l2",
        solver="saga",    
        # solver="liblinear",    
        # class_weight="balanced",
        max_iter=3000,
        random_state=42
    )
)

search = GridSearchCV(
    pipeline,
    param_grid,
    cv=5,
    scoring="roc_auc",
    n_jobs=-1,
    verbose=1
)
search.fit(train_aligned_embs, y_train.to_numpy())

print("Best C (inverse reg. strength):", search.best_params_["logisticregression__C"])
print("CV ROC AUC:", search.best_score_)


best_clf = search.best_estimator_
y_pred_probs = best_clf.predict_proba(test_aligned_embs)[:, 1]
y_pred       = best_clf.predict(test_aligned_embs)

print(classification_report(y_test.to_numpy(), y_pred))
print("Test ROC AUC:", roc_auc_score(y_test.to_numpy(), y_pred_probs))

Fitting 5 folds for each of 3 candidates, totalling 15 fits
Best C (inverse reg. strength): 0.006
CV ROC AUC: 0.6585952175685856
              precision    recall  f1-score   support

         0.0       0.19      0.08      0.11       270
         1.0       0.79      0.91      0.85      1010

    accuracy                           0.74      1280
   macro avg       0.49      0.50      0.48      1280
weighted avg       0.66      0.74      0.69      1280

Test ROC AUC: 0.494987165383205


: 

In [23]:
class ShallowMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=512, dropout=0.5):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.net(x)

model = ShallowMLP(input_dim=gemma_train_embs.shape[1]).to(device)

# 4) Optimizer and loss
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-7, weight_decay=1e-5)
criterion = nn.BCELoss()
# criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# 5) Training & Validation Loop
n_epochs = 300
best_val_loss = float('inf')

for epoch in range(1, n_epochs+1):
    # -- Training
    model.train()
    train_loss = 0.0
    for xb, yb in train_dl:
        preds = model(xb)
        loss = criterion(preds, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_dl.dataset)

    # -- Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_dl:
            preds = model(xb)
            loss = criterion(preds, yb)
            val_loss += loss.item() * xb.size(0)
    val_loss /= len(val_dl.dataset)

    print(f"Epoch {epoch:2d}  Train Loss: {train_loss:.4f}  Val Loss: {val_loss:.4f}")

    # Optional: save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")

# 6) Load best model and test evaluation
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

y_probs = []
y_true  = []
with torch.no_grad():
    for xb, yb in test_dl:
        probs = model(xb)
        y_probs.extend(probs.cpu().numpy().flatten().tolist())
        y_true .extend(yb.cpu().numpy().flatten().tolist())

y_pred = (np.array(y_probs) >= 0.5).astype(int)

print("\nTest Classification Report:")
print(classification_report(y_true, y_pred))
print("Test ROC AUC:", roc_auc_score(y_true, y_probs))

Epoch  1  Train Loss: 0.7511  Val Loss: 0.7380
Epoch  2  Train Loss: 0.7388  Val Loss: 0.7248
Epoch  3  Train Loss: 0.7324  Val Loss: 0.7151
Epoch  4  Train Loss: 0.7216  Val Loss: 0.7156
Epoch  5  Train Loss: 0.7119  Val Loss: 0.7026
Epoch  6  Train Loss: 0.7044  Val Loss: 0.6991
Epoch  7  Train Loss: 0.7061  Val Loss: 0.6925
Epoch  8  Train Loss: 0.7018  Val Loss: 0.6888
Epoch  9  Train Loss: 0.6935  Val Loss: 0.6897
Epoch 10  Train Loss: 0.6894  Val Loss: 0.6773
Epoch 11  Train Loss: 0.6830  Val Loss: 0.6800
Epoch 12  Train Loss: 0.6775  Val Loss: 0.6757
Epoch 13  Train Loss: 0.6775  Val Loss: 0.6727
Epoch 14  Train Loss: 0.6682  Val Loss: 0.6654
Epoch 15  Train Loss: 0.6642  Val Loss: 0.6637
Epoch 16  Train Loss: 0.6608  Val Loss: 0.6630
Epoch 17  Train Loss: 0.6591  Val Loss: 0.6530
Epoch 18  Train Loss: 0.6612  Val Loss: 0.6497
Epoch 19  Train Loss: 0.6540  Val Loss: 0.6498
Epoch 20  Train Loss: 0.6496  Val Loss: 0.6493
Epoch 21  Train Loss: 0.6479  Val Loss: 0.6458
Epoch 22  Tra

In [42]:
class EmbeddingMeanDataset(Dataset):
    def __init__(self, aligned_embs: List[NDArray[np.float32]], labels):
        
        self.mean_embs = np.mean(aligned_embs, axis=0)
        self.num_models = len(aligned_embs)
        self.emb_dim = aligned_embs[0].shape[1]
        self.labels = labels

        assert self.mean_embs.shape[0] == len(self.labels)

    def __len__(self):
        return len(self.mean_embs)

    def __getitem__(self, index):
        x = self.mean_embs[index]
        y = self.labels[index] 

        x_tensor = torch.from_numpy(x)
        y_tensor = torch.tensor(y)

        return x_tensor, y_tensor


In [43]:
class_counts = torch.bincount(torch.tensor(y_train.astype(int)))
# weight for each class = 1 / count
class_weights = 1.0 / class_counts.float()
# assign a sample-level weight based on its label
sample_weights = class_weights[y_train]

sampler = WeightedRandomSampler(
    weights=sample_weights, 
    num_samples=len(sample_weights), 
    replacement=True
)

In [151]:
train_aligned_embs= [gemma_train_aligned, llama_train_aligned, qwen_train_aligned]
val_aligned_embs= [gemma_val_aligned, llama_val_aligned, qwen_val_aligned]
test_aligned_embs= [gemma_test_aligned, llama_test_aligned, qwen_test_aligned]

train_ds = EmbeddingMeanDataset(train_aligned_embs, y_train)
val_ds = EmbeddingMeanDataset(val_aligned_embs, y_val)
test_ds = EmbeddingMeanDataset(test_aligned_embs, y_test)

batch_size = 32
# train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

In [149]:
input_dim = gemma_train_embs.shape[1]
c = 1

model = nn.Linear(input_dim, c)
criterion = nn.BCEWithLogitsLoss()

optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)

In [152]:
class ShallowMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, dropout=0.5):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.net(x)

model = ShallowMLP(input_dim=gemma_train_embs.shape[1]).to(device)

# 4) Optimizer and loss
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6, weight_decay=1e-5)
criterion = nn.BCELoss()
# criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# 5) Training & Validation Loop
n_epochs = 300
best_val_loss = float('inf')

for epoch in range(1, n_epochs+1):
    # -- Training
    model.train()
    train_loss = 0.0
    for xb, yb in train_loader:
        xb = xb.to(device).float()
        yb = yb.to(device).unsqueeze(-1).float()
        preds = model(xb)
        loss = criterion(preds, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)

    # -- Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device).float()
            yb = yb.to(device).unsqueeze(-1).float()
            preds = model(xb)
            loss = criterion(preds, yb)
            val_loss += loss.item() * xb.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch:2d}  Train Loss: {train_loss:.4f}  Val Loss: {val_loss:.4f}")

    # Optional: save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")

# 6) Load best model and test evaluation
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

y_probs = []
y_true  = []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device).float()
        yb = yb.to(device).unsqueeze(-1).float()
        probs = model(xb)
        y_probs.extend(probs.cpu().numpy().flatten().tolist())
        y_true .extend(yb.cpu().numpy().flatten().tolist())

y_pred = (np.array(y_probs) >= 0.5).astype(int)

print("\nTest Classification Report:")
print(classification_report(y_true, y_pred))
print("Test ROC AUC:", roc_auc_score(y_true, y_probs))

Epoch  1  Train Loss: 0.7830  Val Loss: 0.7461
Epoch  2  Train Loss: 0.7717  Val Loss: 0.7332
Epoch  3  Train Loss: 0.7652  Val Loss: 0.7378
Epoch  4  Train Loss: 0.7590  Val Loss: 0.7267
Epoch  5  Train Loss: 0.7524  Val Loss: 0.7299
Epoch  6  Train Loss: 0.7448  Val Loss: 0.7258
Epoch  7  Train Loss: 0.7395  Val Loss: 0.7098
Epoch  8  Train Loss: 0.7312  Val Loss: 0.7074
Epoch  9  Train Loss: 0.7263  Val Loss: 0.7025
Epoch 10  Train Loss: 0.7210  Val Loss: 0.7017
Epoch 11  Train Loss: 0.7124  Val Loss: 0.6916
Epoch 12  Train Loss: 0.7047  Val Loss: 0.6893
Epoch 13  Train Loss: 0.7056  Val Loss: 0.6903
Epoch 14  Train Loss: 0.7028  Val Loss: 0.6841
Epoch 15  Train Loss: 0.6988  Val Loss: 0.6705
Epoch 16  Train Loss: 0.6913  Val Loss: 0.6742
Epoch 17  Train Loss: 0.6835  Val Loss: 0.6792
Epoch 18  Train Loss: 0.6779  Val Loss: 0.6693
Epoch 19  Train Loss: 0.6769  Val Loss: 0.6637
Epoch 20  Train Loss: 0.6670  Val Loss: 0.6595
Epoch 21  Train Loss: 0.6661  Val Loss: 0.6581
Epoch 22  Tra

In [161]:
class ShallowMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim=256, dropout=0.5):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.net(x)

model = ShallowMLP(input_dim=gemma_train_embs.shape[1]).to(device)

# 4) Optimizer and loss
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6, weight_decay=1e-5)
criterion = nn.BCELoss()
# criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# 5) Training & Validation Loop
n_epochs = 300
best_val_loss = float('inf')

for epoch in range(1, n_epochs+1):
    # -- Training
    model.train()
    train_loss = 0.0
    for xb, yb in train_loader:
        xb = xb.to(device).float()
        yb = yb.to(device).unsqueeze(-1).float()
        preds = model(xb)
        loss = criterion(preds, yb)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)

    # -- Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device).float()
            yb = yb.to(device).unsqueeze(-1).float()
            preds = model(xb)
            loss = criterion(preds, yb)
            val_loss += loss.item() * xb.size(0)
    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch:2d}  Train Loss: {train_loss:.4f}  Val Loss: {val_loss:.4f}")

    # Optional: save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_model.pth")

# 6) Load best model and test evaluation
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

y_probs = []
y_true  = []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device).float()
        yb = yb.to(device).unsqueeze(-1).float()
        probs = model(xb)
        y_probs.extend(probs.cpu().numpy().flatten().tolist())
        y_true .extend(yb.cpu().numpy().flatten().tolist())

y_pred = (np.array(y_probs) >= 0.5).astype(int)

print("\nTest Classification Report:")
print(classification_report(y_true, y_pred))
print("Test ROC AUC:", roc_auc_score(y_true, y_probs))

Epoch  1  Train Loss: 0.6456  Val Loss: 0.6377
Epoch  2  Train Loss: 0.6335  Val Loss: 0.6286
Epoch  3  Train Loss: 0.6276  Val Loss: 0.6217
Epoch  4  Train Loss: 0.6132  Val Loss: 0.6171
Epoch  5  Train Loss: 0.6146  Val Loss: 0.6142
Epoch  6  Train Loss: 0.6108  Val Loss: 0.6070
Epoch  7  Train Loss: 0.6119  Val Loss: 0.5987
Epoch  8  Train Loss: 0.6029  Val Loss: 0.6028
Epoch  9  Train Loss: 0.6007  Val Loss: 0.5980
Epoch 10  Train Loss: 0.5957  Val Loss: 0.5931
Epoch 11  Train Loss: 0.5960  Val Loss: 0.5893
Epoch 12  Train Loss: 0.5909  Val Loss: 0.5835
Epoch 13  Train Loss: 0.5873  Val Loss: 0.5849
Epoch 14  Train Loss: 0.5844  Val Loss: 0.5809
Epoch 15  Train Loss: 0.5810  Val Loss: 0.5735
Epoch 16  Train Loss: 0.5787  Val Loss: 0.5769
Epoch 17  Train Loss: 0.5746  Val Loss: 0.5772
Epoch 18  Train Loss: 0.5786  Val Loss: 0.5733
Epoch 19  Train Loss: 0.5684  Val Loss: 0.5702
Epoch 20  Train Loss: 0.5616  Val Loss: 0.5739
Epoch 21  Train Loss: 0.5674  Val Loss: 0.5709
Epoch 22  Tra

In [159]:
train_aligned_embs= [gemma_train_aligned, gemma_train_aligned, gemma_train_aligned]
val_aligned_embs= [gemma_val_aligned, gemma_val_aligned, gemma_val_aligned]
test_aligned_embs= [gemma_test_aligned, gemma_test_aligned, gemma_test_aligned]

train_ds = EmbeddingMeanDataset(train_aligned_embs, y_train)
val_ds = EmbeddingMeanDataset(val_aligned_embs, y_val)
test_ds = EmbeddingMeanDataset(test_aligned_embs, y_test)

batch_size = 32
# train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

In [160]:
input_dim = gemma_train_embs.shape[1]
c = 1

model = nn.Linear(input_dim, c)
criterion = nn.BCEWithLogitsLoss()

optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)