# Base Architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import mean_absolute_error
import matplotlib.pyplot as plt
import dateutil

# ------------------ MODEL COMPONENTS ------------------

class TaskSpecificAttention(nn.Module):
    def __init__(self, input_dim):
        super(TaskSpecificAttention, self).__init__()
        self.fc = nn.Linear(input_dim, input_dim)
        self.residual_fc = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        attention_weights = F.softmax(self.fc(x), dim=-1)
        context_vector = torch.tanh(x * attention_weights)
        return x + self.residual_fc(context_vector)


class SharedGlobalTemporalAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(SharedGlobalTemporalAttention, self).__init__()
        self.fc = nn.Linear(hidden_dim, hidden_dim)
        self.final_fc = nn.Linear(hidden_dim, 1)
        self.residual_fc = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x_list):
        combined_hidden = torch.stack([x.mean(dim=1) for x in x_list], dim=1).mean(dim=1)
        tanh_hidden = torch.tanh(self.fc(combined_hidden))
        attention_scores = self.final_fc(tanh_hidden).squeeze(-1)
        attention_weights = F.softmax(attention_scores, dim=-1).unsqueeze(-1)
        context_vector = combined_hidden * attention_weights
        repeated_context = self.residual_fc(context_vector).unsqueeze(1)
        return [torch.cat((x, repeated_context.repeat(1, x.size(1), 1), x * repeated_context), dim=-1) for x in x_list]


class FATHOMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, window_size=32):
        super(FATHOMModel, self).__init__()
        self.task_attention = TaskSpecificAttention(input_dim)
        self.lstm1 = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.lstm2 = nn.LSTM(hidden_dim * 3, hidden_dim, batch_first=True)
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, shared_context):
        x = self.task_attention(x)
        x, _ = self.lstm1(x)
        enriched_context = torch.cat((x, shared_context, x * shared_context), dim=-1)
        x, _ = self.lstm2(enriched_context)
        x = torch.cat((x[:, -1, :], shared_context[:, -1, :]), dim=-1)
        return self.fc2(F.relu(self.fc1(x)))


class MultiTaskFATHOM(nn.Module):
    def __init__(self, num_tasks, input_dim, hidden_dim, output_dim, window_size=32):
        super(MultiTaskFATHOM, self).__init__()
        self.shared_global_attention = SharedGlobalTemporalAttention(hidden_dim)
        self.tasks = nn.ModuleList([
            FATHOMModel(input_dim, hidden_dim, output_dim, window_size) for _ in range(num_tasks)
        ])

    def forward(self, inputs):
        first_stage_outputs = []
        for task_model, x in zip(self.tasks, inputs):
            x, _ = task_model.lstm1(task_model.task_attention(x))
            first_stage_outputs.append(x)

        shared_contexts = self.shared_global_attention(first_stage_outputs)
        return [task_model(x, shared_context) for task_model, x, shared_context in zip(self.tasks, inputs, shared_contexts)]


# ------------------ DATA LOADER ------------------

def df_to_X_y(df, features, target, window_size=32, horizon=1):
    if target not in features:
        features = [target] + features

    data = df[features].to_numpy()
    target_data = df[target].to_numpy()

    X, y = [], []
    for i in range(len(data) - window_size - horizon + 1):
        X.append(data[i:i + window_size])
        y.append(target_data[i + window_size: i + window_size + horizon])

    return np.array(X), np.array(y)


def load_and_preprocess_site_data(site_path, features, target, window_size=32, horizon=1, min_date=None, max_date=None, batch_size=16, device='cpu'):
    df = pd.read_csv(site_path)

    if 'date' in df.columns:
        df['date'] = pd.to_datetime(df['date'])
        if min_date:
            min_date = dateutil.parser.parse(min_date) if isinstance(min_date, str) else min_date
            df = df[df['date'] >= min_date]
        if max_date:
            max_date = dateutil.parser.parse(max_date) if isinstance(max_date, str) else max_date
            df = df[df['date'] <= max_date]
        df.drop(columns=['date'], inplace=True)

    if target not in features:
        features = [target] + features

    all_columns = features
    if not all(col in df.columns for col in all_columns):
        missing = [col for col in all_columns if col not in df.columns]
        raise ValueError(f"Missing columns in dataset: {missing}")

    train_size = int(0.8 * len(df))
    train_df = df.iloc[:train_size]
    test_df = df.iloc[train_size:]

    val_size = int(0.2 * len(train_df))
    train_df, val_df = train_df.iloc[:-val_size], train_df.iloc[-val_size:]

    print(f"Train size: {len(train_df)} | Validation size: {len(val_df)} | Test size: {len(test_df)}")

    train_mean, train_std = train_df[all_columns].mean(), train_df[all_columns].std()
    train_df[all_columns] = (train_df[all_columns] - train_mean) / (train_std + 1e-8)
    val_df[all_columns] = (val_df[all_columns] - train_mean) / (train_std + 1e-8)
    test_df[all_columns] = (test_df[all_columns] - train_mean) / (train_std + 1e-8)

    X_train, y_train = df_to_X_y(train_df, features, target, window_size, horizon)
    X_val, y_val = df_to_X_y(val_df, features, target, window_size, horizon)
    X_test, y_test = df_to_X_y(test_df, features, target, window_size, horizon)

    train_data = TensorDataset(torch.tensor(X_train, dtype=torch.float32).to(device), torch.tensor(y_train, dtype=torch.float32).to(device))
    val_data = TensorDataset(torch.tensor(X_val, dtype=torch.float32).to(device), torch.tensor(y_val, dtype=torch.float32).to(device))
    test_data = TensorDataset(torch.tensor(X_test, dtype=torch.float32).to(device), torch.tensor(y_test, dtype=torch.float32).to(device))

    train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size, drop_last=True)
    val_loader = DataLoader(val_data, shuffle=False, batch_size=batch_size, drop_last=True)
    test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size, drop_last=True)

    return train_loader, val_loader, test_loader

# ------------------ TRAINING & EVALUATION ------------------

def train_fathom_model(site_loaders, model, optimizer, criterion, scheduler, num_epochs, device):
    for epoch in range(num_epochs):
        model.train()
        train_losses, val_losses = [], []
        for task_id, (train_loader, val_loader, _) in enumerate(site_loaders):
            # Training phase
            for X, y in train_loader:
                X, y = X.to(device), y.to(device)
                optimizer.zero_grad()
                preds = model([X])[task_id]
                loss = criterion(preds, y.view(y.size(0), -1))
                loss.backward()
                optimizer.step()
                train_losses.append(loss.item())

            # Validation phase
            model.eval()
            with torch.no_grad():
                val_task_losses = []
                for X_val, y_val in val_loader:
                    X_val, y_val = X_val.to(device), y_val.to(device)
                    preds_val = model([X_val])[task_id]
                    val_loss = criterion(preds_val, y_val.view(y_val.size(0), -1))
                    val_task_losses.append(val_loss.item())
                val_losses.append(np.mean(val_task_losses))
            model.train()
        scheduler.step()

        print(f"Epoch {epoch + 1}/{num_epochs} | Train Loss: {np.mean(train_losses):.4f} | Validation Loss: {np.mean(val_losses):.4f}")

    print("Training complete.")


def evaluate_fathom_model(model, site_loaders, device='cpu'):
    model.eval()
    mae_scores = []
    with torch.no_grad():
        for task_id, (_, _, test_loader) in enumerate(site_loaders):
            preds, targets = [], []
            for X, y in test_loader:
                X, y = X.to(device), y.to(device)
                preds.append(model([X])[task_id].cpu().numpy())
                targets.append(y.cpu().numpy())
            mae_scores.append(mean_absolute_error(np.concatenate(targets), np.concatenate(preds)))
    print("Evaluation complete.")
    return mae_scores


# Usage

In [None]:
if __name__ == "__main__":
    num_tasks = 3
    batch_size, window_size, input_dim, hidden_dim, output_dim = 32, 32, 15, 64, 16

    model = MultiTaskFATHOM(num_tasks, input_dim, hidden_dim, output_dim, window_size).to("cpu")
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
    criterion = nn.MSELoss()

    # Mock data (replace with DataLoader-based site_loaders)
    site_paths = ["site_1.csv", "site_2.csv", "site_3.csv"]
    features = ["feature1", "feature2", "feature3"]
    target = "target"
    site_loaders = [load_and_preprocess_site_data(site_path, features, target, window_size, horizon=1, batch_size=batch_size) for site_path in site_paths]

    train_fathom_model(site_loaders, model, optimizer, criterion, scheduler, num_epochs=5, device="cpu")
    mae_scores = evaluate_fathom_model(model, site_loaders, device="cpu")
    print(f"MAE per task: {mae_scores}")