In [2]:
import torch
import torch.nn as nn
from torch.nn.utils import parametrizations
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingWarmRestarts
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    roc_auc_score,
    confusion_matrix,
)

In [51]:
class Embedding(nn.Module):
    def __init__(self, in_dim, n):
        super().__init__()
        self.embedding = torch.nn.Sequential(
            nn.Linear(in_dim, 4*n),
            nn.Hardswish(),
            nn.Linear(4*n, 2*n),
            nn.Hardswish(),
            nn.Linear(2*n, n)
        )
    def forward(self, x):
        return self.embedding(x)
class PreferenceEmbedding(nn.Module):
    def __init__(self, in_dim):
        """
        Args:
          n (int): dimension of input vectors x, x'
          J (Tensor[n,n]): constant matrix in the bilinear form
        """
        super().__init__()
        n = in_dim * 4 #embeddings dimension must be even
        J = torch.eye(n//2).kron(torch.tensor([[0, -1], [1, 0]]))
        # 1) Declare an unconstrained weight
        self.embedding = Embedding(in_dim, n)
        self.U = nn.Parameter(torch.randn(n, n))
        # 2) Apply an orthogonal re-parametrization
        parametrizations.orthogonal(self, 'U')
        # 3) Store J as a buffer
        self.register_buffer('J', J)

    def forward(self, x, x_prime):
        """Confirms whether x_prime is preferred to x."""
        # x, x_prime: [B, in_dim]
        x       = self.embedding(x)       # [B, 2*in_dim]
        x_prime = self.embedding(x_prime)   # [B, 2*in_dim]

        # U: [n, n] where n = 2*in_dim
        # apply U to each row of x and x_prime:
        Ux   = x       @ self.U.T          # [B, n]
        Uxp  = x_prime @ self.U.T          # [B, n]

        # compute batch of scalars w[i] = Uxp[i]^T J Ux[i]
        # J is [n, n]
        JUx = Ux @ self.J                     # [B, n]
        w   = (JUx * Uxp).sum(dim=1)          # [B]
        return torch.sigmoid(w)               # [B]




In [52]:
# 2) Quadrant‐comparison dataset
class QuadPairDataset(Dataset):
    def __init__(self, num_pairs):
        self.num_pairs = num_pairs
        # sample points once; you could re-sample each epoch if you like
        self.X = torch.randn(num_pairs, 2)
        self.Y = torch.randn(num_pairs, 2)
        # precompute labels
        self.labels = torch.zeros(num_pairs)
        for i in range(num_pairs):
            qx  = self.quadrant(self.X[i])
            qy  = self.quadrant(self.Y[i])
            # define cycle Q1>Q2>Q3>Q4>Q1
            # map: Q1=0, Q2=1, Q3=2, Q4=3
            #delta = (qy - qx) % 4
            # x > y if delta in {1,2}? Actually we want Qx > Qy if
            # moving from x to y you go forward less than 2 steps
            self.labels[i] = 1.0 if qx==3 and qy ==0 else qx < qy

    @staticmethod
    def quadrant(pt):
        x, y = pt
        if   x>=0 and y>=0: return 0   # Q1
        elif x < 0 <= y: return 1   # Q2
        elif x<0  and y<0 : return 2   # Q3
        else:                return 3   # Q4

    def __len__(self):
        return self.num_pairs

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx], self.labels[idx]

import torch
from torch.utils.data import Dataset

class ReshuffleQuadDataset(Dataset):
    def __init__(self, num_pairs, pool_size=10000, seed=None):
        """
        num_pairs:   nominal __len__ of the dataset (how many pairs you draw per epoch)
        pool_size:   how many base points to keep around
        seed:        optional RNG seed for reproducibility
        """
        super().__init__()
        self.num_pairs = num_pairs
        self.pool_size = pool_size
        if seed is not None:
            torch.manual_seed(seed)

        # sample a fixed pool of points once
        # shape: [pool_size, 2]
        self.points = torch.randn(pool_size, 2)

    @staticmethod
    def quadrant(pt):
        x, y = pt
        if   x >= 0 and y >= 0: return 0   # Q1
        elif x <  0 and y >= 0: return 1   # Q2
        elif x <  0 and y <  0: return 2   # Q3
        else:                   return 3   # Q4

    @staticmethod
    def label_from_quadrants(qx, qy):
        # your cycle Q1>Q2>Q3>Q4>Q1
        # return 1.0 if x > y else 0.0
        if qx == 3 and qy == 0:
            return 1.0
        return 1.0 if qx < qy else 0.0

    def __len__(self):
        return self.num_pairs

    def __getitem__(self, idx):
        # ignore idx — we just draw a random pair each call
        i = torch.randint(0, self.pool_size, (1,)).item()
        j = torch.randint(0, self.pool_size, (1,)).item()
        # if you prefer no repeats, you can:
        #  while j == i:
        #      j = torch.randint(0, self.pool_size, (1,)).item()

        x       = self.points[i]
        x_prime = self.points[j]

        qx = self.quadrant(x)
        qy = self.quadrant(x_prime)
        label = self.label_from_quadrants(qx, qy)

        return x, x_prime, label

class FuncDataset(Dataset):
    def __init__(self, num_pairs, func, minimize = True, in_dim = 2, seed=None):
        self.num_pairs = num_pairs
        # sample points once; you could re-sample each epoch if you like
        if seed is not None:
            torch.manual_seed(seed)
        self.X = torch.randn(num_pairs, in_dim)
        self.Y = torch.randn(num_pairs, in_dim)
        # precompute labels
        self.labels = torch.zeros(num_pairs)
        self.func = func
        for i in range(num_pairs):
            qx  = self.func(self.X[i])
            qy  = self.func(self.Y[i])
            # define cycle Q1>Q2>Q3>Q4>Q1
            # map: Q1=0, Q2=1, Q3=2, Q4=3
            #delta = (qy - qx) % 4
            # x > y if delta in {1,2}? Actually we want Qx > Qy if
            # moving from x to y you go forward less than 2 steps
            self.labels[i] = qx < qy if minimize else qx > qy
    def __len__(self):
        return self.num_pairs

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx], self.labels[idx]


In [53]:
def evaluate(model, test_loader, device='cpu'):
    model.eval()
    all_preds = []
    all_scores = []
    all_targets = []

    with torch.no_grad():
        for x, x_prime, y_true in test_loader:
            x, x_prime, y_true = x.to(device), x_prime.to(device), y_true.to(device)
            y_prob = model(x, x_prime)       # [B] float probabilities
            y_pred = (y_prob > 0.5).long()   # [B] binary predictions

            all_scores.append(y_prob.cpu())
            all_preds.append(y_pred.cpu())
            all_targets.append(y_true.cpu().long())

    # flatten lists
    scores  = torch.cat(all_scores).numpy()
    preds   = torch.cat(all_preds).numpy()
    targets = torch.cat(all_targets).numpy()

    # basic metrics
    acc = accuracy_score(targets, preds)
    prec, rec, f1, _ = precision_recall_fscore_support(
        targets, preds, average='binary'
    )
    cm = confusion_matrix(targets, preds)

    # optional: ROC-AUC
    try:
        auc = roc_auc_score(targets, scores)
    except ValueError:
        auc = float('nan')  # e.g. if one class missing

    print(f"Accuracy : {acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall   : {rec:.4f}")
    print(f"F1-score : {f1:.4f}")
    print(f"ROC-AUC  : {auc:.4f}")
    print("Confusion Matrix:")
    print(cm)

    return {
        'accuracy': acc,
        'precision': prec,
        'recall': rec,
        'f1': f1,
        'roc_auc': auc,
        'confusion_matrix': cm,
    }

In [54]:
# 6) Model, optimizer, loss
model = PreferenceEmbedding(in_dim=2)

In [55]:
batch_size = 64
lr         = 1e-1
epochs     = 500
# 5) DataLoader
dataset = FuncDataset(num_pairs=50_000, func = lambda x: 100 * torch.abs(torch.sqrt(x[1] - 0.01*x[0])) + 0.01 * torch.abs(x[0] + 10))
loader  = DataLoader(dataset, batch_size=len(dataset), shuffle=True)
total_steps = epochs * len(loader)
warmup_steps = int(0.1 * total_steps)

#optimizer = torch.optim.LBFGS(model.parameters(), lr=lr, max_iter=50)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
def lr_lambda(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    else:
        # cosine decay after warmup
        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        return 0.5 * (1.0 + torch.cos(torch.tensor(torch.pi * progress)))
scheduler = LambdaLR(optimizer, lr_lambda)
loss_fn = nn.BCELoss()
def closure():
    optimizer.zero_grad()
    preds = model(X, Xp)
    loss  = loss_fn(preds, Y)
    loss.backward()
    return loss


In [56]:
patience = 10
threshold = 1e-10
best_loss = torch.inf
bad_steps = 10
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = LambdaLR(optimizer, lr_lambda)
for epoch in range(1, epochs):
    # pull the entire dataset in one batch
    X, Xp, Y = next(iter(loader))
    #X, Xp, Y = X.to(device), Xp.to(device), Y.to(device)

    optimizer.zero_grad()
    preds = model(X, Xp)
    loss  = loss_fn(preds, Y)
    loss.backward()
    optimizer.step()
    scheduler.step()

    # recompute for logging
    with torch.no_grad():
        loss_val = loss_fn(model(X, Xp), Y).item()
        if loss_val < best_loss - threshold:
            best_loss = loss_val
            bad_steps = 0
        else:
            bad_steps += 1
    if bad_steps >= patience:
        pass
    if epoch % 10 == 0:
        print(f"Epoch {epoch:03d} — loss: {loss_val:.6f}")

Epoch 010 — loss: 0.682885
Epoch 020 — loss: 0.636989
Epoch 030 — loss: 0.613953
Epoch 040 — loss: 0.592428
Epoch 050 — loss: 0.649253
Epoch 060 — loss: 0.580834
Epoch 070 — loss: 0.556883
Epoch 080 — loss: 0.544193
Epoch 090 — loss: 0.542075
Epoch 100 — loss: 0.540675
Epoch 110 — loss: 0.534487
Epoch 120 — loss: 0.532277
Epoch 130 — loss: 0.551060
Epoch 140 — loss: 0.540675
Epoch 150 — loss: 0.534384
Epoch 160 — loss: 0.530531
Epoch 170 — loss: 0.529010
Epoch 180 — loss: 0.528269
Epoch 190 — loss: 0.527244
Epoch 200 — loss: 0.526681
Epoch 210 — loss: 0.526341
Epoch 220 — loss: 0.552621
Epoch 230 — loss: 1.139943
Epoch 240 — loss: 0.636734
Epoch 250 — loss: 0.576384
Epoch 260 — loss: 0.546818
Epoch 270 — loss: 0.536866
Epoch 280 — loss: 0.532781
Epoch 290 — loss: 0.529748
Epoch 300 — loss: 0.528576
Epoch 310 — loss: 0.528424
Epoch 320 — loss: 0.528308
Epoch 330 — loss: 0.529826
Epoch 340 — loss: 0.527622
Epoch 350 — loss: 0.527218
Epoch 360 — loss: 0.526680
Epoch 370 — loss: 0.526496
E

In [50]:
res = evaluate(model, loader)

Accuracy : 0.6343
Precision: 0.2581
Recall   : 0.9995
F1-score : 0.4103
ROC-AUC  : 0.9994
Confusion Matrix:
[[25353 18283]
 [    3  6361]]


In [67]:
torch.save(model.state_dict(), 'model.pt')

In [68]:
loss_val

0.2851998209953308