In [1]:
import numpy as np
import pandas as pd
import torch
from torch.autograd import grad
def generate_environment_df(n_samples, pe, env_id):
#ГЕНЕРАЦИЯ ДАННЫХ
    I2 = np.eye(2)

    # Invariant features ~ N(0, I)
    X_inv = np.random.multivariate_normal(mean=[0, 0], cov=I2, size=n_samples)

    # Label Y = x1 + x2 + noise
    Y = X_inv.sum(axis=1) + np.random.normal(loc=0.0, scale=np.sqrt(0.1), size=n_samples)

    # Spurious features ~ N([Y, Y], pe * I)
    X_env = np.stack([Y, Y], axis=1) + np.random.multivariate_normal(mean=[0, 0], cov=pe * I2, size=n_samples)

    # Combine into DataFrame
    df = pd.DataFrame({
        "x_inv_0": X_inv[:, 0],
        "x_inv_1": X_inv[:, 1],
        "x_env_0": X_env[:, 0],
        "x_env_1": X_env[:, 1],
        "y": Y,
        "env_id": env_id,
        "pe": pe
    })

    return df

def generate_irm_dataset(n_samples_per_env=10000):
    pe_train = [0.1, 0.3, 0.5, 0.7, 0.9]
    pe_val = [0.4, 0.8]
    pe_test = [10, 100]

    # Build full DataFrames for each split
    train_df = pd.concat(
        [generate_environment_df(n_samples_per_env, pe, f"train_{i}") for i, pe in enumerate(pe_train)],
        ignore_index=True
    )
    val_df = pd.concat(
        [generate_environment_df(n_samples_per_env, pe, f"val_{i}") for i, pe in enumerate(pe_val)],
        ignore_index=True
    )
    test_df = pd.concat(
        [generate_environment_df(n_samples_per_env, pe, f"test_{i}") for i, pe in enumerate(pe_test)],
        ignore_index=True
    )

    return train_df, val_df, test_df

In [2]:
train_data, val_data, test_data = generate_irm_dataset()

In [None]:
import pandas as pd
import torch
from torch.autograd import grad

def detect_device():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    return device

def to_torch(df, device):
    X = torch.tensor(df[["x_inv_0", "x_inv_1", "x_env_0", "x_env_1"]].values, dtype=torch.float32, device=device)
    y = torch.tensor(df["y"].values, dtype=torch.float32, device=device).unsqueeze(1)
    env_ids = df["env_id"].unique()
    environments = []
    for env in env_ids:
        mask = df["env_id"] == env
        environments.append((X[mask], y[mask]))
    return environments

def initialize_model(dim_x, device):
    phi = torch.nn.Parameter(torch.eye(dim_x,1, device=device))
    dummy_w = torch.ones(1, 1, device=device, requires_grad=True)
    optimizer = torch.optim.Adam([phi], lr=1e-3)
    return phi, dummy_w, optimizer

def compute_penalty(error, dummy_w):
    grad_w = grad(error, dummy_w, create_graph=True)[0]
    penalty = grad_w.pow(2).mean()
    return penalty

def train_model(train_envs,val_envs, phi, dummy_w, optimizer, reg=1e-3, iterations=500, verbose=True):
    mse = torch.nn.MSELoss()
    
    for iteration in range(iterations):
        total_error = 0
        total_penalty = 0
        
        for x_e, y_e in train_envs:
            error_e = mse((x_e @ phi) @ dummy_w, y_e)
            penalty_e = compute_penalty(error_e, dummy_w)
            total_error += error_e
            total_penalty += penalty_e
        
        optimizer.zero_grad()
        loss = reg * total_error + (1 - reg) * total_penalty
        loss.backward()
        optimizer.step()
        if verbose and iteration % 100 == 0:
            total_val = 0
            with torch.no_grad():
                for x_e, y_e in val_envs:
                    error_val = mse((x_e @ phi) @ dummy_w, y_e)
                    total_val += error_val
                print(f"Iteration {iteration}, Training Loss: {loss.item():.6f}, Validation MSE: {total_val:.6f}")

def tune_regularization(train_envs, val_envs, dim_x, device, reg_values, iterations=100, verbose=True):
    best_err = float('inf')
    best_reg = None
    best_phi = None
    mse = torch.nn.MSELoss()
    for reg in reg_values:
        phi, dummy_w, optimizer = initialize_model(dim_x, device)
        train_model(train_envs,val_envs,  phi, dummy_w, optimizer, reg=reg, iterations=iterations, verbose=True)
        
        with torch.no_grad():
           total_val = 0
           for x_e, y_e in val_envs:
                    error_val = mse((x_e @ phi) @ dummy_w, y_e)
                    total_val += error_val
        
        if verbose:
            print(f"IRM (reg={reg:.3e}) has {error_val:.6f} validation error.")
        
        if total_val < best_err:
            best_err = total_val
            best_reg = reg
            best_phi = phi.clone()

    print(f"\nBest reg={best_reg:.3e} with validation error={total_val:.6f}")
    return best_phi, best_reg

def evaluate_model(phi, dummy_w, test_envs, val_envs):
    total_test = 0
    total_val = 0
    with torch.no_grad():
        mse = torch.nn.MSELoss()
        for t, v in zip(test_envs,val_envs):
            error_test = mse((t[0] @ phi) @ dummy_w, t[1])
            error_val = mse((v[0] @ phi) @ dummy_w, v[1])
            total_test += error_test
            total_val += error_val
        ratio = total_test / total_val
    print(f"\nValidation MSE: {total_val:.6f}")
    print(f"Test MSE: {total_test:.6f}")
    print(f"Validation/Test Ratio: {ratio:.6f}")

def main():
    device = detect_device()
    
    # Load datasets
    train_envs = to_torch(train_data, device)
    val_envs = to_torch(val_data, device)
    test_envs= to_torch(test_data, device)
    dim_x = train_envs[0][0].shape[1]
    reg_values = [1e-3]
    
    best_phi, best_reg = tune_regularization(train_envs,val_envs, dim_x, device, reg_values, iterations=3000)
    
    dummy_w = torch.ones(1, 1, device=device)
    evaluate_model(best_phi, dummy_w, test_envs, val_envs)
    print(best_phi)
if __name__ == "__main__":
    main()


Using device: cuda
torch.Size([10000, 4]) torch.Size([10000, 4]) torch.Size([10000, 4])
Iteration 0, Training Loss: 0.006103, Validation MSE: 2.185445
Iteration 500, Training Loss: 0.003209, Validation MSE: 0.985494
Iteration 1000, Training Loss: 0.001733, Validation MSE: 0.495343
Iteration 1500, Training Loss: 0.000981, Validation MSE: 0.263894
Iteration 2000, Training Loss: 0.000743, Validation MSE: 0.197901
Iteration 2500, Training Loss: 0.000688, Validation MSE: 0.185176
IRM (reg=1.000e-03) has 0.091965 validation error.

Best reg=1.000e-03 with validation error=0.183896
torch.Size([10000, 4])
torch.Size([10000, 4])

Validation MSE: 0.183896
Test MSE: 0.403704
Validation/Test Ratio: 2.195287
tensor([[0.9626],
        [0.9259],
        [0.0429],
        [0.0142]], device='cuda:0', grad_fn=<CloneBackward0>)
