In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

def multi_task_feature_learning(X_tasks, Y_tasks, d_shared, lr=1e-3, num_epochs=1000, lambda_reg=0.1):
    """
    Implements Multi-Task Feature Learning using PyTorch.

    Parameters:
        X_tasks (list of torch.Tensor): Input data for each task. Each tensor has shape (n_samples, d_features).
        Y_tasks (list of torch.Tensor): Output data for each task. Each tensor has shape (n_samples, ).
        d_shared (int): Dimensionality of the shared feature space.
        lr (float): Learning rate for the optimizer.
        num_epochs (int): Number of training epochs.
        lambda_reg (float): Regularization coefficient for shared feature space.

    Returns:
        W_shared (torch.Tensor): Learned shared feature matrix.
        W_task (list of torch.Tensor): Learned task-specific weight matrices.
    """
    # Check number of tasks
    num_tasks = len(X_tasks)
    if num_tasks != len(Y_tasks):
        raise ValueError("Number of tasks in X_tasks and Y_tasks must match.")

    # Dimensionality of input features
    d_features = X_tasks[0].shape[1]

    # Initialize shared feature matrix (d_features x d_shared)
    W_shared = nn.Parameter(torch.randn(d_features, d_shared) * 0.01)

    # Initialize task-specific weight matrices (d_shared x 1 per task)
    W_task = [nn.Parameter(torch.randn(d_shared, 1) * 0.01) for _ in range(num_tasks)]

    # Optimizer
    optimizer = optim.Adam([W_shared] + W_task, lr=lr)

    # Loss function
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        total_loss = 0

        # Zero gradients
        optimizer.zero_grad()

        for task_idx in range(num_tasks):
            X = X_tasks[task_idx]
            Y = Y_tasks[task_idx].view(-1, 1)  # Ensure Y is a column vector

            # Forward pass: shared features and task-specific prediction
            Z = X @ W_shared  # Project input into shared space
            Y_pred = Z @ W_task[task_idx]  # Task-specific prediction

            # Compute task loss
            task_loss = criterion(Y_pred, Y)
            total_loss += task_loss

        # Add regularization term for shared feature matrix
        reg_loss = lambda_reg * torch.norm(W_shared, p='fro')
        total_loss += reg_loss

        # Backward pass and optimization
        total_loss.backward()
        optimizer.step()

        # Print loss every 100 epochs
        if (epoch + 1) % 100 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss.item():.4f}")

    return W_shared, W_task

# Example usage
if __name__ == "__main__":
    # Generate synthetic data for two tasks
    torch.manual_seed(42)
    
    # Task 1
    X_task1 = torch.randn(100, 10)
    W_true_task1 = torch.randn(10, 1)
    Y_task1 = X_task1 @ W_true_task1 + 0.1 * torch.randn(100, 1)

    # Task 2
    X_task2 = torch.randn(100, 10)
    W_true_task2 = torch.randn(10, 1)
    Y_task2 = X_task2 @ W_true_task2 + 0.1 * torch.randn(100, 1)

    # Combine tasks
    X_tasks = [X_task1, X_task2]
    Y_tasks = [Y_task1, Y_task2]

    # Train the model
    W_shared, W_task = multi_task_feature_learning(X_tasks, Y_tasks, d_shared=5, lr=1e-2, num_epochs=1000, lambda_reg=0.1)

    print("Learned shared feature matrix W_shared:")
    print(W_shared)
    print("Learned task-specific weight matrices W_task:")
    for idx, W in enumerate(W_task):
        print(f"Task {idx + 1} weights:")
        print(W)


Epoch [100/1000], Loss: 0.5044
Epoch [200/1000], Loss: 0.3795
Epoch [300/1000], Loss: 0.3694
Epoch [400/1000], Loss: 0.3578
Epoch [500/1000], Loss: 0.3454
Epoch [600/1000], Loss: 0.3326
Epoch [700/1000], Loss: 0.3196
Epoch [800/1000], Loss: 0.3069
Epoch [900/1000], Loss: 0.2946
Epoch [1000/1000], Loss: 0.2829
Learned shared feature matrix W_shared:
Parameter containing:
tensor([[ 0.0050, -0.0863,  0.0885, -0.0899, -0.0873],
        [-0.5076, -0.0011,  0.0149,  0.0125, -0.0240],
        [ 0.0505,  0.0372, -0.0366,  0.0371,  0.0370],
        [-0.2631, -0.2566,  0.2592, -0.2586, -0.2607],
        [ 0.2818, -0.1807,  0.1798, -0.1965, -0.1695],
        [-0.7366, -0.1411,  0.1514, -0.1267, -0.1626],
        [ 0.9508,  0.4917, -0.5236,  0.4766,  0.5322],
        [ 0.4060,  0.5401, -0.5516,  0.5512,  0.5630],
        [-0.7428,  0.2510, -0.2543,  0.2862,  0.2321],
        [-0.8807,  0.3548, -0.3592,  0.3948,  0.3330]], requires_grad=True)
Learned task-specific weight matrices W_task:
Task 1 wei