# Import Library

In [21]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt

# Kernel 
## RBF Kernel

In [22]:
# def rbf_kernel(x1, x2, length_scale=1.0, variance=1.0):
#     """
#     Compute the RBF kernel between two sets of points.
    
#     Args:
#         x1: First set of points (N x D).
#         x2: Second set of points (M x D).
#         length_scale: Length scale parameter for the RBF kernel.
    
#     Returns:
#         Kernel matrix (N x M).
#     """
#     dists = torch.cdist(x1, x2, p=2)
#     kernel_matrix = variance * torch.exp(-0.5 * (dists / length_scale) ** 2)
#     return kernel_matrix

class RBFKernel(nn.Module):
    def __init__(self, length_scale=1.0, variance=1.0):
        super().__init__()
        self.length_scale = nn.Parameter(torch.tensor(length_scale))
        self.variance = nn.Parameter(torch.tensor(variance))

    def forward(self, x1, x2):
        dists = torch.cdist(x1, x2)
        return self.variance * torch.exp(-0.5 * (dists / self.length_scale) ** 2)

## Coregionalize Kernel

In [23]:
# class CoregionalizeKernel:
#     def __init__(self, base_kernel, num_tasks, rank_R, scale=0.1, min_value=1e-3, max_value=5e-2, seed=42):
#         self.base_kernel = base_kernel
#         self.num_tasks = num_tasks
#         self.rank_R = rank_R

#         if seed is not None:
#             torch.manual_seed(seed)
#             np.random.seed(seed)
#             torch.random.manual_seed(seed)

#         self.W = nn.Parameter(scale * torch.randn(num_tasks, rank_R))
#         v = (max_value - min_value) * torch.rand(num_tasks) + min_value 
#         self.v = nn.Parameter(v)

#     def compute(self, X1, task1, X2, task2): 
#         """
#         X1: (N1, D) input points
#         task1: (N1,) task indices for X1
#         X2: (N2, D) input points
#         task2: (N2,) task indices for X2
#         """
#         K_input = self.base_kernel(X1, X2)
#         B = self.W @ self.W.T + torch.diag(self.v)

#         # Task correlation part 
#         B_tasks = B[task1, :][:, task2] # (n1 x n2)

#         # Final kernel
#         K = K_input * B_tasks
#         return K

class CoregionalizeKernel(nn.Module):
    def __init__(self, base_kernel, num_tasks, rank_R):
        super().__init__()
        self.base_kernel = base_kernel
        self.W = nn.Parameter(0.1 * torch.randn(num_tasks, rank_R))
        self.v = nn.Parameter(1e-2 * torch.ones(num_tasks))

    # def forward(self, X1, task1, X2, task2):
    #     K_input = self.base_kernel(X1, X2)
    #     B = self.W @ self.W.T + torch.diag(self.v)
    #     B_tasks = B[task1, :][:, task2]
    #     return K_input * B_tasks

    def forward(self, X1, task1, X2, task2):
        task1 = task1.to(torch.long)
        task2 = task2.to(torch.long)
        
        K_input = self.base_kernel(X1, X2)
        B = self.W @ self.W.T + torch.diag(self.v)
        B_tasks = B[task1, :][:, task2]
        return K_input * B_tasks


# Multi-output GP Classification 

In [24]:
# class MultiOutputGPClassifier(nn.Module):
#     def __init__(self, kernel, noise=1e-6, max_iter=20, device=None):
#         """
#         coregionalization_kernel: an instance with a `.compute(X1, task1, X2, task2)` method returning a torch.Tensor
#         noise: float, jitter for numerical stability
#         max_iter: int, Laplace iterations
#         device: 'cuda' or 'cpu'
#         """
#         self.kernel = kernel
#         self.noise = noise
#         self.max_iter = max_iter
#         self.is_trained = False
#         self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

#     def fit(self, X_train, task_train, y_train):
#         """
#         X_train: (n x d) torch.Tensor
#         task_train: (n,) torch.LongTensor
#         y_train: (n,) torch.FloatTensor (binary: 0 or 1)
#         """
#         X_train = X_train.to(self.device)
#         task_train = task_train.to(self.device)
#         y_train = y_train.to(self.device)

#         n = X_train.shape[0]

#         # Compute kernel matrix
#         K = self.kernel.compute(X_train, task_train, X_train, task_train)  # (n x n)
#         K += self.noise * torch.eye(n, device=self.device)

#         # Initialize latent function
#         f = torch.zeros(n, device=self.device)

#         for iter in range(self.max_iter):
#             pi = torch.sigmoid(f)
#             W_diag = pi * (1 - pi)
#             sqrt_W = torch.sqrt(W_diag + 1e-9)

#             B = torch.eye(n, device=self.device) + sqrt_W[:, None] * K * sqrt_W[None, :]
#             B += 1e-5 * torch.eye(B.shape[0], device=self.device)
#             L = torch.linalg.cholesky(B)

#             b = W_diag * f + (y_train - pi)
#             # Solve: a = b - sqrt_W * solve(L, solve(L.T, sqrt_W * (K @ b)))
#             temp = torch.cholesky_solve((sqrt_W[:, None] * (K @ b).unsqueeze(1)), L)
#             a = b - sqrt_W * temp.squeeze(1)
#             f = K @ a

#             if iter % 5 == 0:
#                 ll = torch.sum(y_train * torch.log(pi + 1e-6) + (1 - y_train) * torch.log(1 - pi + 1e-6))
#                 print(f"Iteration {iter}: Bernoulli Log-Likelihood = {ll.item():.4f}")

#         self.f_hat = f
#         self.X_train = X_train
#         self.task_train = task_train
#         self.W_diag = W_diag
#         self.L = L
#         self.y_train = y_train
#         self.is_trained = True

#     def compute_nll(self, X, task_ids, y):
#         X = X.to(self.device)
#         task_ids = task_ids.to(self.device)
#         y = y.to(self.device)

#         n = X.shape[0]
#         K = self.kernel(X, task_ids, X, task_ids)
#         K += self.noise * torch.eye(n, device=self.device)

#         f = torch.zeros(n, device=self.device)  # Initial latent function

#         for _ in range(self.max_iter):
#             pi = torch.sigmoid(f)
#             W_diag = pi * (1 - pi)
#             sqrt_W = torch.sqrt(W_diag + 1e-9)

#             B = torch.eye(n, device=self.device) + sqrt_W[:, None] * K * sqrt_W[None, :]
#             L = torch.linalg.cholesky(B)

#             b = W_diag * f + (y - pi)
#             temp = torch.cholesky_solve((sqrt_W[:, None] * (K @ b).unsqueeze(1)), L)
#             a = b - sqrt_W * temp.squeeze(1)
#             f = K @ a

#         # Save variables for prediction
#         self.f_hat = f.detach()
#         self.X_train = X.detach()
#         self.task_train = task_ids.detach()
#         self.W_diag = W_diag.detach()
#         self.L = L.detach()
#         self.y_train = y.detach()

#         # Negative Log Likelihood
#         pi = torch.sigmoid(f)
#         nll = -torch.sum(y * torch.log(pi + 1e-6) + (1 - y) * torch.log(1 - pi + 1e-6))

#         return nll

#     def predict(self, X_test, task_test):
#         if not self.is_trained:
#             raise RuntimeError("Train the model first!")

#         X_test = X_test.to(self.device)
#         task_test = task_test.to(self.device)

#         K_s = self.kernel.compute(self.X_train, self.task_train, X_test, task_test)  # (n_train x n_test)

#         f_mean = K_s.T @ (self.y_train - torch.sigmoid(self.f_hat))  # (n_test,)

#         sqrt_W = torch.sqrt(self.W_diag + 1e-9)
#         v = torch.cholesky_solve((sqrt_W[:, None] * K_s), self.L)
#         K_test = self.kernel.compute(X_test, task_test, X_test, task_test)
#         f_var = torch.diagonal(K_test) - torch.sum(v**2, dim=0)
#         f_var = torch.clamp(f_var, min=1e-6)

#         gamma = 1.0 / torch.sqrt(1.0 + (np.pi * f_var) / 8.0)
#         probs = torch.sigmoid(gamma * f_mean)

#         return probs.cpu()  # move back to CPU if needed


In [25]:
class MultiOutputGPClassifier(nn.Module):
    def __init__(self, kernel, noise=1e-6, max_iter=20, device=None):
        super().__init__()
        self.kernel = kernel
        self.noise = noise
        self.max_iter = max_iter
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def compute_nll(self, X, task_ids, y):
        X = X.to(self.device)
        task_ids = task_ids.to(self.device)
        y = y.to(self.device)

        n = X.shape[0]
        K = self.kernel(X, task_ids, X, task_ids)
        K += self.noise * torch.eye(n, device=self.device)

        f = torch.zeros(n, device=self.device)  # Initial latent function

        for _ in range(self.max_iter):
            pi = torch.sigmoid(f)
            W_diag = pi * (1 - pi)
            sqrt_W = torch.sqrt(W_diag + 1e-9)

            B = torch.eye(n, device=self.device) + sqrt_W[:, None] * K * sqrt_W[None, :]
            L = torch.linalg.cholesky(B)

            b = W_diag * f + (y - pi)
            temp = torch.cholesky_solve((sqrt_W[:, None] * (K @ b).unsqueeze(1)), L)
            a = b - sqrt_W * temp.squeeze(1)
            f = K @ a

        # Save variables for prediction
        self.f_hat = f.detach()
        self.X_train = X.detach()
        self.task_train = task_ids.detach()
        self.W_diag = W_diag.detach()
        self.L = L.detach()
        self.y_train = y.detach()

        # Negative Log Likelihood
        pi = torch.sigmoid(f)
        nll = -torch.sum(y * torch.log(pi + 1e-6) + (1 - y) * torch.log(1 - pi + 1e-6))

        return nll

    def predict(self, X_test, task_ids_test):
        self.eval()
        X_test = X_test.to(self.device)
        task_ids_test = task_ids_test.to(self.device)

        # Compute cross-covariance between test and train
        K_star = self.kernel(X_test, task_ids_test, self.X_train, self.task_train)

        # Recompute pi and W_diag to get 'a' again
        pi = torch.sigmoid(self.f_hat)
        W_diag = self.W_diag
        b = W_diag * self.f_hat + (self.y_train - pi)

        # Solve for 'a'
        sqrt_W = torch.sqrt(W_diag + 1e-9)
        temp = torch.cholesky_solve((sqrt_W[:, None] * (self.kernel(self.X_train, self.task_train, self.X_train, self.task_train) @ b).unsqueeze(1)), self.L)
        a = b - sqrt_W * temp.squeeze(1)

        # Predictive mean
        f_mean = K_star @ a

        # Predict probability via sigmoid
        prob = torch.sigmoid(f_mean)

        return prob


# Dataset Generation 

In [26]:
def dataset_generation(n_samples=100, n_tasks=3, seed=None):
    """
    Generate a synthetic dataset for multi-output GP classification.
    
    Args:
        n_samples: Number of samples to generate.
        n_tasks: Number of tasks (outputs).
        seed: Random seed for reproducibility.
    
    Returns:
        X: Input features (n_samples x 2).
        y: Target labels (n_samples,).
        task_train: Task indices for training data (n_samples,).
    """
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)

    X = torch.randn(n_samples, 2)
    # y = torch.randint(0, 2, (n_samples,))
    # y = (X[:, 0] * X[:, 1] > 0).float()
    task_train = torch.randint(0, n_tasks, (n_samples,))
    
    y = torch.where(task_train == 0, (X[:, 0] > 0), 
    torch.where(task_train == 1, (X[:, 1] > 0), 
                (X[:, 0] + X[:, 1] > 0))).float()

    return X, y, task_train

# Main function 

In [27]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score

num_tasks = 2

model = MultiOutputGPClassifier(
    kernel=CoregionalizeKernel(
        base_kernel=RBFKernel(length_scale=1.0, variance=1.0),
        num_tasks=num_tasks,
        rank_R=2
        # scale=0.1,
        # min_value=1e-3,
        # max_value=5e-2,
        # seed=42
    ),
    max_iter=20,
    noise=1e-5
)


# 1. Generate Data
X, y, task_train = dataset_generation(n_samples=300, n_tasks=num_tasks, seed=42)

# Convert to torch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)
task_train = torch.tensor(task_train, dtype=torch.long)

# 2. Train/Test Split
X_train, X_test, y_train, y_test, task_train_split, task_test_split = train_test_split(
    X, y, task_train, test_size=0.2, random_state=42, stratify=task_train
)

# 3. Train Model
# model.fit(X_train, task_train_split, y_train)

# 4. Predict
# probs = model.predict(X_test, task_test_split)
# preds = (probs > 0.5).float()

# # 5. Evaluate
# acc = accuracy_score(y_test.detach().numpy(), preds.detach().numpy())
# auc = roc_auc_score(y_test.detach().numpy(), probs.detach().numpy())

# print(f"Test Accuracy: {acc:.4f}")
# print(f"Test ROC-AUC: {auc:.4f}")

  X = torch.tensor(X, dtype=torch.float32)
  y = torch.tensor(y, dtype=torch.float32)
  task_train = torch.tensor(task_train, dtype=torch.long)


In [None]:
def train_model(model, X, task_ids, y, lr=0.01, epochs=100):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        optimizer.zero_grad()
        loss = model.compute_nll(X, task_ids, y)
        loss.backward()
        optimizer.step()

        if epoch % 10 == 0:
            print(f"Epoch {epoch} - Loss: {loss.item():.4f}")

    return model

trained_model = train_model(model, X_train, task_train_split, y_train, lr=0.01, epochs=100)

Epoch 0 - Loss: 149.9748
Epoch 10 - Loss: 99.6918
Epoch 20 - Loss: 77.5921
Epoch 30 - Loss: 65.0688
Epoch 40 - Loss: 57.0103
Epoch 50 - Loss: 51.3285
Epoch 60 - Loss: 47.0646
Epoch 70 - Loss: 43.7065
Epoch 80 - Loss: 40.9669
Epoch 90 - Loss: 38.6716
Epoch 100 - Loss: 36.7086
Epoch 110 - Loss: 35.0025
Epoch 120 - Loss: 33.5000
Epoch 130 - Loss: 32.1624
Epoch 140 - Loss: 30.9603
Epoch 150 - Loss: 29.8713
Epoch 160 - Loss: 28.8779
Epoch 170 - Loss: 27.9660
Epoch 180 - Loss: 27.1243
Epoch 190 - Loss: 26.3435


In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

def evaluate_model(model, X_test, task_ids_test, y_test, threshold=0.5):
    model.eval()
    
    # Get predicted probabilities
    with torch.no_grad():
        probs = model.predict(X_test, task_ids_test)
    
    # Convert probabilities to class labels
    y_pred = (probs >= threshold).long()

    # Convert to NumPy
    y_true_np = y_test.cpu().numpy()
    y_pred_np = y_pred.cpu().numpy()

    # Metrics
    acc = accuracy_score(y_true_np, y_pred_np)
    precision, recall, f1, _ = precision_recall_fscore_support(y_true_np, y_pred_np, average='binary')
    cm = confusion_matrix(y_true_np, y_pred_np)

    print("📊 Evaluation Metrics:")
    print(f"  🔹 Accuracy:  {acc:.4f}")
    print(f"  🔹 Precision: {precision:.4f}")
    print(f"  🔹 Recall:    {recall:.4f}")
    print(f"  🔹 F1-score:  {f1:.4f}")
    print("\n🧮 Confusion Matrix:")
    print(cm)

    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "confusion_matrix": cm
    }
    
# Evaluate the trained model
evaluation = evaluate_model(trained_model, X_test, task_test_split, y_test, threshold=0.5)

📊 Evaluation Metrics:
  🔹 Accuracy:  0.9833
  🔹 Precision: 1.0000
  🔹 Recall:    0.9600
  🔹 F1-score:  0.9796

🧮 Confusion Matrix:
[[35  0]
 [ 1 24]]
tensor([0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0,
        1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0,
        1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1])


# Visualization

In [30]:
def plot_multi_output_gpc(model, X_train, task_train, y_train, X_test, task_test, y_test):
    """
    Visualize the predictions of the multi-output GP classifier.
    
    Args:
        model: Trained multi-output GP classifier.
        X_train: Training input features.
        task_train: Task indices for training data.
        y_train: Training target labels.
        X_test: Test input features.
        task_test: Task indices for test data.
        y_test: Test target labels.
    """
    with torch.no_grad():
        probs = model.predict(X_test, task_test)
    
    plt.figure(figsize=(12, 6))
    for i in range(model.kernel.num_tasks):
        plt.subplot(1, model.kernel.num_tasks, i + 1)
        plt.scatter(X_train[task_train == i][:, 0], X_train[task_train == i][:, 1], c=y_train[task_train == i], cmap='coolwarm', label='Train')
        plt.scatter(X_test[task_test == i][:, 0], X_test[task_test == i][:, 1], c=probs[task_test == i], cmap='coolwarm', marker='x', label='Test Predicted')
        plt.title(f'Task {i + 1}')
        plt.colorbar()
    
    plt.tight_layout()
    plt.legend()
    plt.show()
# plot_multi_output_gpc(model, X_train, task_train_split, y_train, X_test, task_test_split, y_test)