In [None]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import root_mean_squared_error

In [None]:
class MultiHeadRegression(nn.Module):
    def __init__(self, input_size, num_heads):
        super(MultiHeadRegression, self).__init__()
        self.num_heads = num_heads
        self.heads = nn.ModuleList([nn.Linear(input_size, 1) for _ in range(num_heads)])

    def forward(self, x):
        outputs = [head(x) for head in self.heads]
        return torch.cat(outputs, dim=1)
    def predict(self, x):
        return self.forward(x)

    def loss(self, y_true, y_pred):
        return F.mse_loss(y_pred, y_true)

    def evaluate(self, y_true, y_pred):
        return root_mean_squared_error(y_true, y_pred)

    def fit(self, X, y, epochs=10, lr=0.001):
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        for epoch in range(epochs):
            y_pred = self.forward(X)
            loss = self.loss(y, y_pred)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

        return self
    