In [1]:
import os, re, json, copy
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

np.random.seed(7)
torch.manual_seed(7)

print("Torch version:", torch.__version__)

Torch version: 2.9.1+cpu


In [2]:
DATA_PATH = "../data/nexkey_synthetic_dataset_v1"

queries = pd.read_csv(f"{DATA_PATH}/queries.csv")
properties = pd.read_csv(f"{DATA_PATH}/properties.csv")
interactions = pd.read_csv(f"{DATA_PATH}/interactions.csv")

print("queries:", queries.shape)
print("properties:", properties.shape)
print("interactions:", interactions.shape)

queries: (30000, 16)
properties: (15000, 27)
interactions: (480000, 4)


In [3]:
rng = np.random.RandomState(7)

all_qids = queries["query_id"].unique()
rng.shuffle(all_qids)

n = len(all_qids)
train_qids = set(all_qids[:int(0.80*n)])
val_qids   = set(all_qids[int(0.80*n):int(0.90*n)])
test_qids  = set(all_qids[int(0.90*n):])

train_int = interactions[interactions["query_id"].isin(train_qids)].copy()
val_int   = interactions[interactions["query_id"].isin(val_qids)].copy()
test_int  = interactions[interactions["query_id"].isin(test_qids)].copy()

print("train_int:", train_int.shape)
print("val_int  :", val_int.shape)
print("test_int :", test_int.shape)

train_int: (384000, 4)
val_int  : (48000, 4)
test_int : (48000, 4)


In [4]:
def property_to_text(row):
    return (
        f"{row['deal_type']} {row['property_type']} in {row['city']} {row['state']}. "
        f"{int(row['beds'])} bed {row['baths']} bath, {int(row['sqft'])} sqft. "
        f"Purchase {int(row['purchase_price'])}, ARV {int(row['arv'])}, "
        f"Entry {int(row['entry_fee'])}, Payment {row['estimated_monthly_payment']}. "
        f"Condition {row['condition']}, Occupancy {row['occupancy']}."
    )

properties["deal_text"] = properties.apply(property_to_text, axis=1)
properties[["property_id", "deal_text"]].head()

Unnamed: 0,property_id,deal_text
0,1,Subto Single Family in Raleigh NC. 4 bed 1.0 b...
1,2,Hybrid Single Family in Sacramento CA. 5 bed 1...
2,3,"Cash Condo in Charleston SC. 4 bed 2.5 bath, 3..."
3,4,Subto Manufactured in Greenville AL. 4 bed 2.0...
4,5,Seller Finance Single Family in Fairview AL. 3...


In [5]:
CKPT_PATH = "../models/checkpoints"

with open(f"{CKPT_PATH}/dual_vocab_v1.json", "r") as f:
    dual_vocab = json.load(f)

cross_vocab = copy.deepcopy(dual_vocab)
if "<SEP>" not in cross_vocab:
    cross_vocab["<SEP>"] = len(cross_vocab)

PAD_ID = cross_vocab["<PAD>"]
UNK_ID = cross_vocab["<UNK>"]
SEP_ID = cross_vocab["<SEP>"]

print("dual_vocab size :", len(dual_vocab))
print("cross_vocab size:", len(cross_vocab))
print("PAD/UNK/SEP:", PAD_ID, UNK_ID, SEP_ID)

dual_vocab size : 18145
cross_vocab size: 18146
PAD/UNK/SEP: 0 1 18145


In [6]:
query_text_map = queries.set_index("query_id")["query_text"].to_dict()
deal_text_map  = properties.set_index("property_id")["deal_text"].to_dict()

print("query_text_map:", len(query_text_map))
print("deal_text_map :", len(deal_text_map))

query_text_map: 30000
deal_text_map : 15000


In [7]:
MAX_TRAIN = 250_000
MAX_VAL   = 50_000

train_df = train_int.copy()
val_df   = val_int.copy()

print("Before sampling:")
print("Train size:", len(train_df))
print("Val size  :", len(val_df))
print("\nTrain relevance counts:\n", train_df["relevance"].value_counts())

if len(train_df) > MAX_TRAIN:
    train_df = train_df.sample(MAX_TRAIN, random_state=7)

if len(val_df) > MAX_VAL:
    val_df = val_df.sample(MAX_VAL, random_state=7)

print("\nAfter sampling:")
print("Train size:", len(train_df))
print("Val size  :", len(val_df))

Before sampling:
Train size: 384000
Val size  : 48000

Train relevance counts:
 relevance
0    192000
2     72000
1     72000
3     48000
Name: count, dtype: int64

After sampling:
Train size: 250000
Val size  : 48000


In [8]:
def tokenize(text: str):
    return re.findall(r"[a-z0-9]+", str(text).lower())

def encode_pair_cross(query_text: str, deal_text: str, max_len: int = 96):
    q_ids = [cross_vocab.get(w, UNK_ID) for w in tokenize(query_text)]
    d_ids = [cross_vocab.get(w, UNK_ID) for w in tokenize(deal_text)]

    q_max = int(max_len * 0.45)
    d_max = max_len - q_max - 1  # reserve 1 for <SEP>

    q_ids = q_ids[:q_max]
    d_ids = d_ids[:d_max]

    ids = q_ids + [SEP_ID] + d_ids
    if len(ids) < max_len:
        ids += [PAD_ID] * (max_len - len(ids))

    ids = np.array(ids, dtype=np.int64)

    # Safety: keep IDs always valid for embedding lookup
    ids = np.clip(ids, 0, len(cross_vocab) - 1)
    return ids

In [9]:
class CrossEncoderDataset(Dataset):
    def __init__(self, df, query_text_map, deal_text_map, max_len=96):
        self.df = df.reset_index(drop=True)
        self.qmap = query_text_map
        self.dmap = deal_text_map
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        qid = int(row["query_id"])
        pid = int(row["property_id"])
        y   = int(row["relevance"])

        q_text = self.qmap[qid]
        d_text = self.dmap[pid]

        x = encode_pair_cross(q_text, d_text, max_len=self.max_len)
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

print("Dataset class ready âœ…")

Dataset class ready âœ…


In [10]:
class CrossEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim=128, hidden=128, pad_id=0):
        super().__init__()
        self.pad_id = pad_id
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_id)
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 4)  # classes 0-3
        )

    def forward(self, token_ids):
        x = self.emb(token_ids)  # (B, L, D)
        mask = (token_ids != self.pad_id).float().unsqueeze(-1)
        pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        return self.mlp(pooled)

model = CrossEncoder(vocab_size=len(cross_vocab), emb_dim=128, hidden=128, pad_id=PAD_ID)
print("CrossEncoder created âœ… | vocab_size:", len(cross_vocab))

CrossEncoder created âœ… | vocab_size: 18146


In [11]:
train_loader = DataLoader(
    CrossEncoderDataset(train_df, query_text_map, deal_text_map, max_len=96),
    batch_size=256, shuffle=True
)

val_loader = DataLoader(
    CrossEncoderDataset(val_df, query_text_map, deal_text_map, max_len=96),
    batch_size=256, shuffle=False
)

Xb, yb = next(iter(train_loader))
print("Batch X:", Xb.shape, "Batch y:", yb.shape)
print("Token min/max:", Xb.min().item(), Xb.max().item(), " | vocab_size:", len(cross_vocab))

Batch X: torch.Size([256, 96]) Batch y: torch.Size([256])
Token min/max: 0 18145  | vocab_size: 18146


In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def run_epoch(model, loader, train=True):
    model.train() if train else model.eval()

    total_loss = 0.0
    correct = 0
    total = 0

    for X, y in loader:
        if train:
            logits = model(X)
            loss = criterion(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            with torch.no_grad():
                logits = model(X)
                loss = criterion(logits, y)

        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    return total_loss / len(loader), correct / total

os.makedirs(CKPT_PATH, exist_ok=True)

best_val_loss = float("inf")
patience = 2
bad_epochs = 0

EPOCHS = 10

for epoch in range(EPOCHS):
    tr_loss, tr_acc = run_epoch(model, train_loader, train=True)
    va_loss, va_acc = run_epoch(model, val_loader, train=False)

    print(f"Epoch {epoch+1}: train loss={tr_loss:.4f} acc={tr_acc:.4f} | val loss={va_loss:.4f} acc={va_acc:.4f}")

    if va_loss < best_val_loss - 1e-4:
        best_val_loss = va_loss
        bad_epochs = 0
        torch.save(model.state_dict(), f"{CKPT_PATH}/cross_encoder_best.pt")
        print("  âœ… saved cross_encoder_best.pt")
    else:
        bad_epochs += 1
        if bad_epochs >= patience:
            print("  ðŸ›‘ early stopping")
            break

print("Training finished âœ…")

Epoch 1: train loss=1.2302 acc=0.5002 | val loss=1.2198 acc=0.5001
  âœ… saved cross_encoder_best.pt
Epoch 2: train loss=1.1975 acc=0.5033 | val loss=1.1860 acc=0.5050
  âœ… saved cross_encoder_best.pt
Epoch 3: train loss=1.1632 acc=0.5115 | val loss=1.1780 acc=0.5074
  âœ… saved cross_encoder_best.pt
Epoch 4: train loss=1.1449 acc=0.5171 | val loss=1.1698 acc=0.5097
  âœ… saved cross_encoder_best.pt
Epoch 5: train loss=1.1329 acc=0.5216 | val loss=1.1673 acc=0.5113
  âœ… saved cross_encoder_best.pt
Epoch 6: train loss=1.1235 acc=0.5254 | val loss=1.1674 acc=0.5125
Epoch 7: train loss=1.1150 acc=0.5287 | val loss=1.1674 acc=0.5124
  ðŸ›‘ early stopping
Training finished âœ…


In [13]:
with open(f"{CKPT_PATH}/cross_vocab_v1.json", "w") as f:
    json.dump(cross_vocab, f)

print("Saved cross_vocab_v1.json âœ…")

Saved cross_vocab_v1.json âœ…
