In [17]:
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

class RBFKernel(nn.Module):
    def __init__(self, input_dim, lengthscale_init=1.0):
        super().__init__()
        self.input_dim = input_dim
        # Use one lengthscale per dimension (ARD - Automatic Relevance Determination)
        self.log_lengthscale = nn.Parameter(torch.log(torch.ones(input_dim) * lengthscale_init))

    def forward(self, X1, X2):
        """
        Compute RBF (Gaussian) kernel matrix between X1 and X2 using ARD lengthscales.
        Args:
            X1: Tensor of shape (N1, D)
            X2: Tensor of shape (N2, D)
        Returns:
            Kernel matrix of shape (N1, N2)
        """
        # Ensure input is 2D
        if X1.ndimension() == 1:
            X1 = X1.unsqueeze(0)
        if X2.ndimension() == 1:
            X2 = X2.unsqueeze(0)

        # Scale by lengthscale (ARD: each dimension can have a different scale)
        X1_scaled = X1 / self.log_lengthscale.exp()
        X2_scaled = X2 / self.log_lengthscale.exp()

        # Compute squared Euclidean distance
        sqdist = torch.cdist(X1_scaled, X2_scaled, p=2).pow(2)

        # Compute RBF
        return torch.exp(-0.5 * sqdist)


In [18]:
class LMCKernel(nn.Module):
    def __init__(self, input_dim, num_outputs, num_latents, base_kernel_cls=RBFKernel):
        """
        input_dim: Dimension of input x
        num_outputs: Number of labels (T)
        num_latents: Number of latent processes (Q)
        base_kernel_cls: Class of base kernel, e.g., RBFKernel
        """
        super().__init__()
        self.num_outputs = num_outputs
        self.num_latents = num_latents

        # B ∈ ℝ^{T × Q}
        self.B = nn.Parameter(torch.randn(num_outputs, num_latents))

        # Each latent process has its own kernel
        self.kernels = nn.ModuleList([
            base_kernel_cls(input_dim) for _ in range(num_latents)
        ])

    def forward(self, X1, X2):
        """
        Returns: Tensor of shape (num_outputs, N, M)
        """
        N, M = X1.size(0), X2.size(0)

        # Compute all latent kernels
        latent_kernels = []  # Will be list of (N, M)
        for q in range(self.num_latents):
            k_q = self.kernels[q](X1, X2)  # (N, M)
            latent_kernels.append(k_q.unsqueeze(0))  # (1, N, M)

        K_latents = torch.cat(latent_kernels, dim=0)  # (Q, N, M)

        # K_output[t] = ∑_q B[t,q]^2 * K_q
        B = self.B  # (T, Q)
        K_output = []
        for t in range(self.num_outputs):
            weights = B[t]  # (Q,)
            weighted_k = (weights.view(-1, 1, 1) * K_latents).sum(dim=0)  # (N, M)
            K_output.append(weighted_k.unsqueeze(0))

        return torch.cat(K_output, dim=0)  # (T, N, M)

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal, Normal

class MultiLabelSVGP(nn.Module):
    def __init__(self, input_dim, num_labels, num_inducing, num_latents=2):
        super().__init__()
        self.input_dim = input_dim
        self.num_labels = num_labels
        self.num_latents = num_latents
        self.num_inducing = num_inducing

        # Inducing inputs Z ∈ ℝ^{Q×M×D}
        self.Z = nn.Parameter(torch.randn(num_latents, num_inducing, input_dim))

        # Variational parameters: mean and L (lower-triangular) for each latent process
        self.q_mu = nn.Parameter(torch.zeros(num_latents, num_inducing))
        self.q_log_diag = nn.Parameter(torch.zeros(num_latents, num_inducing))  # log diag for stability

        # Kernel (shared across latents), coregionalization learned via B
        self.kernel = LMCKernel(input_dim, num_labels, num_latents)

    def compute_kernel_matrices(self, X):
        """
        Returns:
            Kxz: (Q, N, M)
            Kzz: (Q, M, M)
            Kxx_diag: (Q, N)
        """
        Kxz = []
        Kzz = []
        Kxx_diag = []

        for q in range(self.num_latents):
            k_q = self.kernel.kernels[q]
            Zq = self.Z[q]  # (M, D)

            Kxz_q = k_q(X, Zq)        # (N, M)
            Kzz_q = k_q(Zq, Zq) + 1e-6 * torch.eye(self.num_inducing)  # (M, M)
            Kxx_q_diag = torch.diagonal(k_q(X, X))  # (N,)

            Kxz.append(Kxz_q)
            Kzz.append(Kzz_q)
            Kxx_diag.append(Kxx_q_diag)

        return torch.stack(Kxz), torch.stack(Kzz), torch.stack(Kxx_diag)

    def forward(self, X, Y, full_n=None):
        """
        Compute negative ELBO for multi-label classification
        Inputs:
            X: (N, D)
            Y: (N, T)
        Returns:
            Negative ELBO loss
        """
        N = X.shape[0]
        Q = self.num_latents
        T = self.num_labels

        # Compute kernel matrices
        Kxz, Kzz, Kxx_diag = self.compute_kernel_matrices(X)  # Q×N×M, Q×M×M, Q×N

        # KL Divergence: KL[q(u) || p(u)] = sum_q KL[N(μ_q, S_q) || N(0, K_zz)]
        kl = 0
        for q in range(Q):
            mu_q = self.q_mu[q]  # (M,)
            S_q_diag = self.q_log_diag[q].exp().pow(2)  # diagonal covariance (M,)
            Kzz_q = Kzz[q]  # (M, M)

            L_q = torch.diag(S_q_diag)
            q_u = MultivariateNormal(mu_q, scale_tril=L_q)
            p_u = MultivariateNormal(torch.zeros_like(mu_q), covariance_matrix=Kzz_q)
            kl += torch.distributions.kl.kl_divergence(q_u, p_u)

        # Predictive mean for each latent
        f_q_means = []  # list of shape (N,) for each q
        for q in range(Q):
            Kxz_q = Kxz[q]  # (N, M)
            Kzz_q = Kzz[q]  # (M, M)
            Kzz_inv_q = torch.linalg.inv(Kzz_q)
            mu_q = self.q_mu[q]  # (M,)
            f_q = Kxz_q @ Kzz_inv_q @ mu_q  # (N,)
            f_q_means.append(f_q.unsqueeze(1))  # (N, 1)

        F_q = torch.cat(f_q_means, dim=1)  # (N, Q)

        # Project to T labels: f_t(x) = B_{tq} f_q(x)
        B = self.kernel.B  # (T, Q)
        F_pred = F_q @ B.T  # (N, T)

        # Binary classification loss: sigmoid cross-entropy
        likelihood = torch.nn.BCEWithLogitsLoss()
        recon_loss = likelihood(F_pred, Y)

        scale = N / (full_n if full_n is not None else N)
        return recon_loss * scale + kl
    
    @torch.no_grad()
    def predict(self, X):
        """
        Predict the sigmoid probabilities for each label.

        Inputs:
            X: (N, D) input data

        Returns:
            probs: (N, T) probability of each label being 1
        """
        self.eval()
        N = X.shape[0]
        Q = self.num_latents
        T = self.num_labels

        # Kernel computations
        Kxz, Kzz, _ = self.compute_kernel_matrices(X)  # Q×N×M, Q×M×M, Q×N

        f_q_means = []
        for q in range(Q):
            Kxz_q = Kxz[q]  # (N, M)
            Kzz_q = Kzz[q]  # (M, M)
            Kzz_inv_q = torch.linalg.inv(Kzz_q)
            mu_q = self.q_mu[q]  # (M,)
            f_q = Kxz_q @ Kzz_inv_q @ mu_q  # (N,)
            f_q_means.append(f_q.unsqueeze(1))  # (N, 1)

        F_q = torch.cat(f_q_means, dim=1)  # (N, Q)
        B = self.kernel.B  # (T, Q)
        F_pred = F_q @ B.T  # (N, T)

        probs = torch.sigmoid(F_pred)  # convert logits to probabilities
        return probs

In [20]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split

# Seed for reproducibility
torch.manual_seed(42)

# Toy multi-label dataset
class ToyMultiLabelDataset(Dataset):
    def __init__(self, n_samples=200, input_dim=2, num_labels=3):
        super().__init__()
        self.X = torch.randn(n_samples, input_dim)  # input_dim = 2
        self.Y = torch.zeros(n_samples, num_labels)  # num_labels = 3

        # Define labels with simple nonlinear logic
        self.Y[:, 0] = (self.X[:, 0] > 0).float()                   # Label 1: X1 > 0
        self.Y[:, 1] = (self.X[:, 1] > 0.5).float()                 # Label 2: X2 > 0.5
        self.Y[:, 2] = ((self.X[:, 0]**2 + self.X[:, 1]**2) > 1).float()  # Label 3: distance from origin > 1

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

# Create dataloader
dataset = ToyMultiLabelDataset(n_samples=300)

X_all = dataset.X
Y_all = dataset.Y

X_train, X_test, Y_train, Y_test = train_test_split(X_all, Y_all, test_size=0.2, random_state=42)

train_loader = DataLoader(list(zip(X_train, Y_train)), batch_size=32, shuffle=True)
test_loader = DataLoader(list(zip(X_test, Y_test)), batch_size=32, shuffle=False)


In [21]:
input_dim = 2
num_outputs = 3
num_latents = 2
num_inducing = 30

svgp_model = MultiLabelSVGP(input_dim, num_outputs, num_latents, num_inducing)
optimizer = torch.optim.Adam(svgp_model.parameters(), lr=0.1)

In [22]:
def train_svgp_multilabel(model, dataloader, optimizer, full_n, epochs=100):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        all_preds = []
        all_targets = []

        for X, Y in dataloader:
            optimizer.zero_grad()
            loss = model(X, Y, full_n)  
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            preds = model.predict(X)  # [B, L]
            all_preds.append(preds.cpu())
            all_targets.append(Y.cpu())

        Y_pred = torch.cat(all_preds).numpy()
        Y_true = torch.cat(all_targets).numpy()

        Y_pred_binary = (Y_pred >= 0.5).astype(int)

        # accuracy = accuracy_score(Y_true, Y_pred_binary)
        accuracy = np.mean((Y_true == Y_pred_binary).mean(axis=0))
        precision = precision_score(Y_true, Y_pred_binary, average="macro", zero_division=0)
        recall = recall_score(Y_true, Y_pred_binary, average="macro", zero_division=0)
        f1 = f1_score(Y_true, Y_pred_binary, average="macro", zero_division=0)

        avg_loss = total_loss / len(dataloader)

        if epoch % 10 == 0 or epoch == epochs - 1:
            print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")



In [23]:
@torch.no_grad()
def evaluate_svgp_multilabel(model, dataloader):
    model.eval()
    all_preds = []
    all_targets = []

    for X, Y in dataloader:
        preds = model.predict(X)  # [B, L]
        all_preds.append(preds.cpu())
        all_targets.append(Y.cpu())

    Y_pred = torch.cat(all_preds).numpy()
    Y_true = torch.cat(all_targets).numpy()

    # Convert probabilities to binary using threshold = 0.5
    Y_pred_binary = (Y_pred >= 0.5).astype(int)

    accuracy_score = np.mean((Y_true == Y_pred_binary).mean(axis=0))

    metrics = {
        # "accuracy": accuracy_score(Y_true, Y_pred_binary),
        "accuracy": accuracy_score,
        "precision": precision_score(Y_true, Y_pred_binary, average="macro", zero_division=0),
        "recall": recall_score(Y_true, Y_pred_binary, average="macro", zero_division=0),
        "f1_score": f1_score(Y_true, Y_pred_binary, average="macro", zero_division=0),
    }

    return metrics

In [24]:
train_svgp_multilabel(svgp_model, train_loader, optimizer, full_n=len(X_train), epochs=100)

evaluate_svgp_multilabel(svgp_model, test_loader)

Epoch 1, Loss: 3.4629, Accuracy: 0.4444, Precision: 0.4213, Recall: 0.4561, F1 Score: 0.4333
Epoch 11, Loss: 0.0857, Accuracy: 0.6875, Precision: 0.9015, Recall: 0.4995, F1 Score: 0.5334
Epoch 21, Loss: 0.0734, Accuracy: 0.8347, Precision: 0.9074, Recall: 0.7501, F1 Score: 0.8125
Epoch 31, Loss: 0.0615, Accuracy: 0.8917, Precision: 0.9445, Recall: 0.8381, F1 Score: 0.8865
Epoch 41, Loss: 0.0556, Accuracy: 0.9222, Precision: 0.9460, Recall: 0.9057, F1 Score: 0.9253
Epoch 51, Loss: 0.0503, Accuracy: 0.9222, Precision: 0.9481, Recall: 0.8983, F1 Score: 0.9213
Epoch 61, Loss: 0.0448, Accuracy: 0.9569, Precision: 0.9823, Recall: 0.9312, F1 Score: 0.9558
Epoch 71, Loss: 0.0420, Accuracy: 0.9597, Precision: 0.9762, Recall: 0.9389, F1 Score: 0.9571
Epoch 81, Loss: 0.0425, Accuracy: 0.9514, Precision: 0.9810, Recall: 0.9164, F1 Score: 0.9472
Epoch 91, Loss: 0.0388, Accuracy: 0.9722, Precision: 0.9851, Recall: 0.9575, F1 Score: 0.9711
Epoch 100, Loss: 0.0422, Accuracy: 0.9597, Precision: 0.9709,

{'accuracy': 0.9333333333333335,
 'precision': 0.9160419790104948,
 'recall': 0.9885057471264368,
 'f1_score': 0.9478553406223718}