In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score, mean_squared_error, mean_absolute_error, r2_score
from torch.utils.tensorboard import SummaryWriter

# Sinusoidal Positional Encoding
def get_sinusoidal_encoding(seq_len, embed_dim):
    pe = torch.zeros(seq_len, embed_dim)
    position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

# Improved Multi-task Transformer Encoder with explicit NaN handling
class MultiTaskNumericEncoder(nn.Module):
    def __init__(self, seq_len, embed_dim=128, num_heads=4, num_layers=2, dropout=0.1):
        super().__init__()
        self.seq_len = seq_len

        self.value_embedding = nn.Linear(1, embed_dim)
        self.missing_embedding = nn.Parameter(torch.randn(1, 1, embed_dim))

        pos_enc = get_sinusoidal_encoding(seq_len, embed_dim)
        self.register_buffer("pos_enc", pos_enc.unsqueeze(0))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.classification_head = nn.Linear(embed_dim, 1)
        self.regression_head = nn.Linear(embed_dim, 1)

    def forward(self, x_filled, nan_mask):
        value_embed = self.value_embedding(x_filled.unsqueeze(-1))
        missing_embed = self.missing_embedding.expand_as(value_embed)

        x_embed = nan_mask.unsqueeze(-1) * value_embed + (1 - nan_mask).unsqueeze(-1) * missing_embed
        x_embed += self.pos_enc

        encoder_output = self.transformer_encoder(x_embed)

        class_logits = self.classification_head(encoder_output).squeeze(-1)
        classification_output = torch.sigmoid(class_logits)
        regression_output = self.regression_head(encoder_output).squeeze(-1)

        return classification_output, regression_output


In [4]:
# Training function with metrics logging and checkpointing
def train_model(model, train_loader, val_loader, epochs, lr=1e-4, alpha=0.5, device='cuda'):
    criterion_class = nn.BCELoss()
    criterion_reg = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
    model.to(device)

    writer = SummaryWriter()

    best_metrics = {}
    steps_per_epoch = len(train_loader)
    eval_interval = max(1, steps_per_epoch // 10)

    for epoch in range(epochs):
        model.train()
        for step, batch in enumerate(train_loader):
            x_filled = batch['x_filled'].to(device)
            nan_mask = batch['nan_mask'].to(device)
            training_mask = batch['training_mask'].to(device)
            classification_output, regression_output = model(x_filled, nan_mask)

            class_loss = criterion_class(classification_output, nan_mask)
            reg_loss = criterion_reg(regression_output*training_mask, x_filled*training_mask)

            total_loss = alpha * class_loss + (1 - alpha) * reg_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            if step % eval_interval == 0 or step == steps_per_epoch - 1:
                model.eval()
                for phase, loader in [('Train', train_loader), ('Validation', val_loader)]:
                    all_nan_masks, all_train_masks = [], []
                    all_targets, all_preds_class, all_preds_reg = [], [], []

                    with torch.no_grad():
                        for batch in loader:
                            x_filled = batch['x_filled'].to(device)
                            nan_mask = batch['nan_mask'].to(device)
                            training_mask = batch['training_mask'].to(device)
                            val_class_out, val_reg_out = model(x_filled, nan_mask)

                            all_nan_masks.extend(nan_mask.cpu().numpy().flatten())
                            all_train_masks.extend(training_mask.cpu().numpy().flatten())
                            all_targets.extend(x_filled.cpu().numpy().flatten())
                            all_preds_class.extend(val_class_out.cpu().numpy().flatten())
                            all_preds_reg.extend(val_reg_out.cpu().numpy().flatten())

                    # Classification metrics: Evaluate original missingness prediction (nan_mask)
                    class_mask_indices = np.array(all_nan_masks) >= 0  # Evaluate classification everywhere
                    bin_preds = (np.array(all_preds_class)[class_mask_indices] > 0.5).astype(int)
                    bin_targets = np.array(all_nan_masks)[class_mask_indices]

                    # Regression metrics: Evaluate only on hidden-for-training positions (training_mask)
                    reg_mask_indices = np.array(all_train_masks) == 1
                    reg_preds = np.array(all_preds_reg)[reg_mask_indices]
                    reg_targets = np.array(all_targets)[reg_mask_indices]

                    metrics = {}

                    # Classification metrics (ensure valid classes present) Sometimes no missing classes
                    if len(np.unique(bin_targets)) > 1:
                        metrics.update({
                            'F1': f1_score(bin_targets, bin_preds),
                            'Precision': precision_score(bin_targets, bin_preds),
                            'Recall': recall_score(bin_targets, bin_preds),
                            'ROC-AUC': roc_auc_score(bin_targets, np.array(all_preds_class)[class_mask_indices])
                        })

                    # Regression metrics (veryy slim possibiliy empty)
                    if reg_targets.size > 0:
                        metrics.update({
                            'MSE': mean_squared_error(reg_targets, reg_preds),
                            'MAE': mean_absolute_error(reg_targets, reg_preds),
                            'R2': r2_score(reg_targets, reg_preds)
                        })

                    # Log metrics safely
                    for metric_name, metric_value in metrics.items():
                        writer.add_scalar(f'{phase}/{metric_name}', metric_value, epoch * steps_per_epoch + step)
                        if phase == 'Validation':
                            if metric_name not in best_metrics or \
                                (metric_name in ['MSE', 'MAE'] and metric_value < best_metrics[metric_name]) or \
                                (metric_name not in ['MSE', 'MAE'] and metric_value > best_metrics[metric_name]):
                                best_metrics[metric_name] = metric_value
                                torch.save(model.state_dict(), f'checkpoint_best_{metric_name}.pt')

                model.train()

        scheduler.step()

    writer.close()

In [3]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset

samples = np.load('samples43Regions.npy')

class MaskedNumericDataset(Dataset):
    def __init__(self, data_tensor, nan_mask_tensor, training_mask_tensor):
        self.data = data_tensor
        self.nan_mask = nan_mask_tensor
        self.training_mask = training_mask_tensor

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return {
            'x_filled': self.data[idx],
            'nan_mask': self.nan_mask[idx],
            'training_mask': self.training_mask[idx]
        }


# Prepare data function ensuring at least 1 hidden position
def prepare_data(data, hidden_fraction=0.3):
    nan_mask = np.isnan(data).astype(np.float32)
    hidden_mask = np.zeros_like(data, dtype=np.float32)

    for i in range(len(data)):
        observed_indices = np.where(~np.isnan(data[i]))[0]

        # Ensure at least one hidden position if possible
        num_hidden = max(1, int(len(observed_indices) * hidden_fraction)) if len(observed_indices) > 0 else 0

        if num_hidden > 0:
            hidden_indices = np.random.choice(observed_indices, size=num_hidden, replace=False)
            hidden_mask[i, hidden_indices] = 1

    data_filled = np.nan_to_num(data, nan=0.0)
    data_filled[hidden_mask == 1] = 0.0

    data_tensor = torch.tensor(data_filled, dtype=torch.float32)
    nan_mask_tensor = torch.tensor(nan_mask, dtype=torch.float32)
    training_mask_tensor = torch.tensor(hidden_mask, dtype=torch.float32)

    return data_tensor, nan_mask_tensor, training_mask_tensor


train_data, test_data = train_test_split(samples, test_size=0.2, random_state=42)

train_dataset = MaskedNumericDataset(*prepare_data(train_data))
test_dataset = MaskedNumericDataset(*prepare_data(test_data))

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64)

input_dim = samples.shape[1]  # which is 43
model = MultiTaskNumericEncoder(input_dim, embed_dim=256, num_heads=4, num_layers=6, dropout=0.1)

train_model(model, train_loader, test_loader, epochs=6, lr=1e-4, alpha=0.5)


RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.