In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

def generate_data():
    # Step 1: Generate synthetic data
    np.random.seed(42)

    # Parameters for synthetic dataset
    n_samples_list = [100, 100, 100, 100, 100]
    n_samples = sum(n_samples_list)
    atom_types = [0, 1, 2]  # Example atom types
    systematic_errors = [
        {0: 10  , 1: -0.3, 2: 2 },
        {0: 2   , 1:  3.1, 2: 0 },
        {0: -1  , 1: -0.5, 2: -1},
        {0: 0   , 1: -1.2, 2: 1 },
        {0: -1.3, 1:  0.2, 2: 0 },
    ]  # Referencing errors

    # Generate features (e.g., structural properties) and true chemical shifts
    X = np.random.rand(n_samples, 3) * 10  # Random features in the range [0, 10]
    dataset_id = np.repeat(np.arange(len(n_samples_list)), np.array(n_samples_list))
    true_cs = np.sin(3*X[:, 0] - 5*X[:, 1]) - 4*X[:, 1]**2 + X[:, 2] # + 0.1 * np.random.randn(n_samples)  # True shifts

    # Assign atom types and introduce systematic errors
    atom_type_labels_list = [np.random.choice(atom_types, size=n) for n in n_samples_list]
    atom_type_labels = np.concatenate(atom_type_labels_list)

    systematic_error = []
    for batch, atom_types in enumerate(atom_type_labels_list):
        systematic_error.append(np.array([systematic_errors[batch][atom] for atom in atom_types]))
    systematic_error = np.concatenate(systematic_error)
    adjusted_cs = true_cs + systematic_error

    # Create a dataframe for clarity
    data = pd.DataFrame({
        'Feature1': X[:, 0],
        'Feature2': X[:, 1],
        'Feature3': X[:, 2],
        'DatasetID': dataset_id,
        'TrueCS': true_cs,
        'AdjustedCS': adjusted_cs,
        'AtomType': atom_type_labels
    })

    # Step 2: Train-test split
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)

    # Separate features and targets
    X_train, y_train, ID_train = train_data[['Feature1', 'Feature2', 'Feature3']], train_data['AdjustedCS'], train_data['DatasetID']
    X_test , y_test , ID_test  = test_data [['Feature1', 'Feature2', 'Feature3']], test_data ['AdjustedCS'], test_data ['DatasetID']
    atom_types_train = train_data['AtomType'].values
    atom_types_test  = test_data ['AtomType'].values

    # Convert data to PyTorch tensors
    X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32).view(-1, 1)
    X_test_tensor = torch.tensor(X_test.values, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32).view(-1, 1)
    ID_train_tensor = torch.tensor(ID_train.values, dtype=torch.int8)
    ID_test_tensor = torch.tensor(ID_test.values, dtype=torch.int8)
    atom_types_train_tensor = torch.tensor(atom_types_train)
    atom_types_test_tensor  = torch.tensor(atom_types_test)

    return X_train_tensor, y_train_tensor, ID_train_tensor, atom_types_train_tensor, X_test_tensor, y_test_tensor, ID_test_tensor, atom_types_test_tensor

def calculate_referencing_offset(y_true, y_pred, ID_tensor, atom_types):
    """
    Calculate systematic referencing offsets for each atom type in the batch.
    """
    offsets = {}
    unique_atom_types = torch.unique(atom_types)
    unique_dataset_ID = torch.unique(ID_tensor)
    for atom in unique_atom_types:
        for dataset_id in unique_dataset_ID:
            mask = (atom_types == atom) * (ID_tensor == dataset_id)
            if mask.sum() > 0:
                offsets[atom.item(), dataset_id.item()] = torch.mean(y_true[mask] - y_pred[mask]).item()
    return offsets

def apply_offsets(y_pred, ID_tensor, atom_types, offsets, alpha=1.):
    """
    Apply systematic referencing offsets to the target values.
    """
    corrected_y = y_pred.clone()
    for (atom, dataset_id), offset in offsets.items():
        mask = (atom_types == atom) * (ID_tensor == dataset_id)
        corrected_y[mask] += alpha * offset
    return corrected_y

def train(model, criterion, optimizer, X_train_tensor, y_train_tensor, ID_train_tensor,
          atom_types_train_tensor, epochs, use_rereferencing, alpha_breakeven_epoch=0):
    model.train()
    for epoch in range(1, epochs+1):
        optimizer.zero_grad()
    
        # Forward pass
        y_pred_train = model(X_train_tensor, atom_types_train_tensor)
    
        if use_rereferencing:
            # Calculate referencing offsets
            offsets = calculate_referencing_offset(
                y_train_tensor, 
                y_pred_train,
                ID_train_tensor,
                atom_types_train_tensor,
            )
            alpha = np.tanh(epoch - alpha_breakeven_epoch)/2 + 0.5
            # Apply offsets to targets
            y_pred_train = apply_offsets(y_pred_train, ID_train_tensor, atom_types_train_tensor, offsets, alpha).view(-1, 1)
    
        # Recompute loss after offset adjustment
        loss = criterion(y_pred_train, y_train_tensor)
        loss.backward()
        optimizer.step()
    
        # Print progress
        if epoch % 1000 == 0:
            print(f"Epoch {epoch}/{epochs}, Loss: {loss.item()}")

def evaluate(model, X_test_tensor, y_test_tensor, ID_test_tensor, atom_types_test_tensor):

    # Step 6: Evaluate on Test Set
    model.eval()
    with torch.no_grad():
        y_pred_test  = model(X_test_tensor, atom_types_test_tensor)

    # Correct test targets using calculated offsets
    test_offsets  = calculate_referencing_offset(
        y_test_tensor,
        y_pred_test,
        ID_test_tensor,
        atom_types_test_tensor)
    y_pred_test_corrected = apply_offsets(y_pred_test, ID_test_tensor, atom_types_test_tensor, test_offsets)

    # Compute errors
    original_test_error  = mean_squared_error(y_test_tensor, y_pred_test)
    corrected_test_error = mean_squared_error(y_test_tensor, y_pred_test_corrected)

    # Step 7: Visualization
    plt.figure(figsize=(10, 6))
    plt.scatter(y_test_tensor, y_pred_test, alpha=0.7, label='Original Targets')
    plt.scatter(y_test_tensor, y_pred_test_corrected, alpha=0.7, label='Corrected Targets')
    plt.plot([min(y_test_tensor), max(y_test_tensor)], [min(y_test_tensor), max(y_test_tensor)], color='red', linestyle='--', label='Ideal')
    plt.xlabel('True Chemical Shifts')
    plt.ylabel('Predicted Chemical Shifts')
    plt.title('MLP Model Performance Before and After Target Correction')
    plt.legend()
    plt.show()

    print(original_test_error, corrected_test_error)

# Step 2: Define MLP Model
class ResidualMLPModel(nn.Module):

    # Step 1: Custom Activation Functions
    class SinActivation(nn.Module):
        def forward(self, x):
            return torch.sin(x)

    class Pow2Activation(nn.Module):
        def forward(self, x):
            return x ** 2

    def __init__(self, input_dim, hidden_dim, output_dim, num_types, embedding_dim=8):
        super(ResidualMLPModel, self).__init__()
        self.embedder = torch.nn.Embedding(num_types, embedding_dim=embedding_dim)
        self.input_layer = nn.Linear(input_dim + embedding_dim, hidden_dim, bias=False)
        self.sin_activation = ResidualMLPModel.SinActivation()
        self.hidden_layer_1 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.pow2_activation = ResidualMLPModel.Pow2Activation()
        self.hidden_layer_2 = nn.Linear(input_dim, hidden_dim, bias=False)
        self.output_layer = nn.Linear(hidden_dim, output_dim, bias=False)
    
    def forward(self, x, atom_types):

        embedding = self.embedder(atom_types)

        # First layer with activation
        o1 = self.input_layer(torch.cat([x, embedding], dim=-1))
        o2 = self.sin_activation(o1) # torch.nn.functional.relu(x)
        
        # Second layer with residual connection and activation
        o3 = self.hidden_layer_1(x)
        o4 = self.pow2_activation(o3) + o2 # torch.nn.functional.relu(x)

        # Third layer with residual connection and activation
        o5 = self.hidden_layer_2(x) + o4
        
        # Final output layer
        return self.output_layer(o5)

In [13]:
X_train_tensor, y_train_tensor, ID_train_tensor, atom_types_train_tensor, X_test_tensor, y_test_tensor, ID_test_tensor, atom_types_test_tensor = generate_data()

In [14]:
# Step 4: Initialize Model, Loss, and Optimizer
mlp_model = ResidualMLPModel(input_dim=X_train_tensor.shape[1], hidden_dim=32, output_dim=1, num_types=3, embedding_dim=8)
criterion = nn.MSELoss()
optimizer = optim.Adam(mlp_model.parameters(), lr=0.001)

In [None]:
train(mlp_model, criterion, optimizer, X_train_tensor, y_train_tensor, ID_train_tensor,
      atom_types_train_tensor, epochs=20000, use_rereferencing=True, alpha_breakeven_epoch=10000)

In [None]:
evaluate(mlp_model, X_test_tensor, y_test_tensor, ID_test_tensor, atom_types_test_tensor)

In [None]:
# Step 4: Initialize Model, Loss, and Optimizer
base_mlp_model = ResidualMLPModel(input_dim=X_train_tensor.shape[1], hidden_dim=128, output_dim=1, num_types=3, embedding_dim=64)
criterion = nn.MSELoss()
optimizer = optim.Adam(base_mlp_model.parameters(), lr=0.001)

In [None]:
train(base_mlp_model, criterion, optimizer, X_train_tensor, y_train_tensor, ID_train_tensor, atom_types_train_tensor, epochs=10000, use_rereferencing=False)

In [None]:
evaluate(base_mlp_model, X_test_tensor, y_test_tensor, ID_test_tensor, atom_types_test_tensor)