# Spuriosity Didn’t Kill the Classifier: Using Invariant Predictions to Harness Spurious Features

In [123]:
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

def rademacher(beta, size):
    """Generate Rademacher random variables with probability `beta` for +1 and `1-beta` for -1."""
    return np.where(np.random.rand(size) < beta, 1, -1)

def generate_dataset(beta_e, num_samples):
    """Generate dataset for given `beta_e` and number of samples."""
    # Generate Y ~ Rad(0.5)
    Y = rademacher(0.5, num_samples)
    
    # Generate X_S = Y * Rad(0.75)
    X_S = Y * rademacher(0.75, num_samples)
    
    # Generate X_U = Y * Rad(beta_e)
    X_U = Y * rademacher(beta_e, num_samples)
    
    # Combine X_S and X_U into X
    X = np.stack((X_S, X_U), axis=1)

    X_tensor = torch.tensor(X, dtype=torch.float32)
    Y_tensor = torch.tensor(Y, dtype=torch.float32).unsqueeze(1)
    Y_tensor = (Y_tensor + 1) / 2
    return X_tensor, Y_tensor

# Parameters
num_samples = 1000  # Number of samples per domain
train_betas = [0.95, 0.7]
val_beta = 0.6
test_beta = 0.1
batch_size = 32

# Generate datasets
train_domains = [generate_dataset(beta, num_samples) for beta in train_betas]
val_domain = generate_dataset(val_beta, num_samples)
test_domain = generate_dataset(test_beta, num_samples)

# Concatenate the training data from different domains
X_train = torch.cat([X for X, _ in train_domains], dim=0)
Y_train = torch.cat([Y for _, Y in train_domains], dim=0)

train_datasets = [TensorDataset(X, Y) for X, Y in train_domains]
val_dataset = TensorDataset(*val_domain)
test_dataset = TensorDataset(*test_domain)

train_loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=True) for dataset in train_datasets]
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Create a combined training dataset and dataloader
train_dataset = TensorDataset(X_train, Y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [124]:
# Get an iterator from the DataLoader
data_iter = iter(train_loader)

# Get the first batch (or sample) from the iterator
X_sample, Y_sample = next(data_iter)

# Print the sample
print("X_sample:\n", X_sample.shape)
print("Y_sample:\n", Y_sample.shape)

X_sample:
 torch.Size([32, 2])
Y_sample:
 torch.Size([32, 1])


In [125]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define the simple three-layer neural network model
class SimpleNN(nn.Module):
    def __init__(self, input_size=2, hidden_size=8, output_size=1):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # Use sigmoid to output a probability
        return x

# Initialize the model, loss function, and optimizer
model = SimpleNN(input_size=1, hidden_size=8, output_size=1)
criterion = nn.BCELoss()  # Binary cross-entropy loss for binary classification
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 50
best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    for X_batch, Y_batch in train_loader:
        optimizer.zero_grad()  # Clear gradients
        outputs = model(X_batch[:,0:1])  # Forward pass
        loss = criterion(outputs, Y_batch)  # Compute loss
        loss.backward()  # Backward pass
        optimizer.step()  # Update weights
    
    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for X_batch, Y_batch in val_loader:
            outputs = model(X_batch[:,0:1])
            loss = criterion(outputs, Y_batch)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}')
    
    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict()

# Load the best model for testing
model.load_state_dict(best_model_state)

# Testing
model.eval()

train_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for X_batch, Y_batch in train_loader:
        outputs = model(X_batch[:,0:1])
        loss = criterion(outputs, Y_batch)
        train_loss += loss.item()

        # Convert predictions to binary (0 or 1)
        predicted = (outputs > 0.5).float()
        correct += (predicted == Y_batch).sum().item()
        total += Y_batch.size(0)

train_loss /= len(train_loader)
accuracy = correct / total
print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {accuracy:.4f}')


test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for X_batch, Y_batch in test_loader:
        outputs = model(X_batch[:,0:1])
        loss = criterion(outputs, Y_batch)
        test_loss += loss.item()

        # Convert predictions to binary (0 or 1)
        predicted = (outputs > 0.5).float()
        correct += (predicted == Y_batch).sum().item()
        total += Y_batch.size(0)

test_loss /= len(test_loader)
accuracy = correct / total
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.4f}')

Epoch [1/50], Validation Loss: 0.7090
Epoch [2/50], Validation Loss: 0.7038
Epoch [3/50], Validation Loss: 0.6988
Epoch [4/50], Validation Loss: 0.6939
Epoch [5/50], Validation Loss: 0.6892
Epoch [6/50], Validation Loss: 0.6847
Epoch [7/50], Validation Loss: 0.6802
Epoch [8/50], Validation Loss: 0.6755
Epoch [9/50], Validation Loss: 0.6710
Epoch [10/50], Validation Loss: 0.6667
Epoch [11/50], Validation Loss: 0.6624
Epoch [12/50], Validation Loss: 0.6580
Epoch [13/50], Validation Loss: 0.6538
Epoch [14/50], Validation Loss: 0.6493
Epoch [15/50], Validation Loss: 0.6452
Epoch [16/50], Validation Loss: 0.6409
Epoch [17/50], Validation Loss: 0.6366
Epoch [18/50], Validation Loss: 0.6325
Epoch [19/50], Validation Loss: 0.6285
Epoch [20/50], Validation Loss: 0.6245
Epoch [21/50], Validation Loss: 0.6207
Epoch [22/50], Validation Loss: 0.6171
Epoch [23/50], Validation Loss: 0.6136
Epoch [24/50], Validation Loss: 0.6104
Epoch [25/50], Validation Loss: 0.6071
Epoch [26/50], Validation Loss: 0.

In [126]:
# Initialize the model, loss function, and optimizer
modelU = SimpleNN(input_size=1, hidden_size=8, output_size=1)
criterionU = nn.BCELoss()  # Binary cross-entropy loss for binary classification
optimizerU = optim.Adam(modelU.parameters(), lr=1e-4)

# Training loop
num_epochs = 50
best_val_loss = float('inf')

for epoch in range(num_epochs):
    modelU.train()
    model.eval()
    for X_batch, _ in test_loader:
        optimizerU.zero_grad()  # Clear gradients
        Y_batch = model(X_batch[:,0:1])  # Forward pass
        Y_batch = (Y_batch > 0.5).float()
        outputs = modelU(X_batch[:,1:2])  # Forward pass
        lossU = criterionU(outputs, Y_batch)  # Compute loss
        lossU.backward()  # Backward pass
        optimizerU.step()  # Update weights
    
    # Validation
    modelU.eval()
    val_loss = 0.0
    with torch.no_grad():
        for X_batch, Y_batch in val_loader:
            Y_batch = model(X_batch[:,0:1])  # Forward pass
            Y_batch = (Y_batch > 0.5).float()
            outputs = modelU(X_batch[:,1:2])  # Forward pass
            loss = criterion(outputs, Y_batch)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}')
    
    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_stateU = modelU.state_dict()
    best_model_stateU = modelU.state_dict()

# Load the best model for testing
modelU.load_state_dict(best_model_stateU)

# Testing
model.eval()
modelU.eval()


train_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for X_batch, _ in train_loader:
        Y_batch = model(X_batch[:,0:1])  # Forward pass
        Y_batch = (Y_batch > 0.5).float()
        outputs = modelU(X_batch[:,1:2])  # Forward pass
        loss = criterion(outputs, Y_batch)
        train_loss += loss.item()

        # Convert predictions to binary (0 or 1)
        predicted = (outputs > 0.5).float()
        correct += (predicted == Y_batch).sum().item()
        total += Y_batch.size(0)

train_loss /= len(train_loader)
accuracy = correct / total
print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {accuracy:.4f}')


test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for X_batch, _ in test_loader:
        Y_batch = model(X_batch[:,0:1])  # Forward pass
        Y_batch = (Y_batch > 0.5).float()
        outputs = modelU(X_batch[:,1:2])  # Forward pass
        loss = criterion(outputs, Y_batch)
        test_loss += loss.item()

        # Convert predictions to binary (0 or 1)
        predicted = (outputs > 0.5).float()
        correct += (predicted == Y_batch).sum().item()
        total += Y_batch.size(0)

test_loss /= len(test_loader)
accuracy = correct / total
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {accuracy:.4f}')

Epoch [1/50], Validation Loss: 0.7043
Epoch [2/50], Validation Loss: 0.7053
Epoch [3/50], Validation Loss: 0.7064
Epoch [4/50], Validation Loss: 0.7074
Epoch [5/50], Validation Loss: 0.7084
Epoch [6/50], Validation Loss: 0.7095
Epoch [7/50], Validation Loss: 0.7107
Epoch [8/50], Validation Loss: 0.7118
Epoch [9/50], Validation Loss: 0.7130
Epoch [10/50], Validation Loss: 0.7143
Epoch [11/50], Validation Loss: 0.7155
Epoch [12/50], Validation Loss: 0.7168
Epoch [13/50], Validation Loss: 0.7181
Epoch [14/50], Validation Loss: 0.7195
Epoch [15/50], Validation Loss: 0.7209
Epoch [16/50], Validation Loss: 0.7223
Epoch [17/50], Validation Loss: 0.7237
Epoch [18/50], Validation Loss: 0.7251
Epoch [19/50], Validation Loss: 0.7266
Epoch [20/50], Validation Loss: 0.7281
Epoch [21/50], Validation Loss: 0.7296
Epoch [22/50], Validation Loss: 0.7311
Epoch [23/50], Validation Loss: 0.7326
Epoch [24/50], Validation Loss: 0.7342
Epoch [25/50], Validation Loss: 0.7357
Epoch [26/50], Validation Loss: 0.

In [127]:
model.load_state_dict(best_model_state)
model.eval()
modelU.load_state_dict(best_model_stateU)
modelU.eval()

def f(which_loader, name):
    PY = 0
    n1 = 0
    n  = 0
    e0 = 0
    e1 = 0

    with torch.no_grad():
        for X_batch, Y_batch in which_loader:
            PY += Y_batch.sum().item()
            Y_batch = model(X_batch[:,0:1])  # Forward pass
            Y_batch = (Y_batch > 0.5).float()
            outputs = modelU(X_batch[:,1:2])  # Forward pass

            e0 += ((1-Y_batch)*(1-outputs)).sum().item()
            e1 += (Y_batch*outputs).sum().item()

            n1 += Y_batch.sum().item()
            n += Y_batch.size(0)
    e0 = e0 / (n-n1)
    e1 = e1 / n1
    PY = PY / n

    correct = 0
    total = 0
    OOD = 0
    with torch.no_grad():
        for X_batch, Y_batch in which_loader:
            YY_batch = model(X_batch[:,0:1])  # Forward pass
            Xlogit = torch.logit(YY_batch, eps=1e-6)
            YY_batch = (YY_batch > 0.5).float()
            outputs = modelU(X_batch[:,1:2])  # Forward pass

            outputs = (outputs + e0 - 1) / (e1 + e0 - 1)
            outputs = torch.clamp(outputs, min=0, max=1)
            Ulogit = torch.logit(outputs, eps=1e-6)

            predict = torch.sigmoid(Xlogit + Ulogit - np.log(PY / (1 - PY)))

            # Convert predictions to binary (0 or 1)
            predicted = (predict > 0.5).float()
            OOD += predicted.sum().item()
            correct += (predicted == Y_batch).sum().item()
            total += Y_batch.size(0)
    OOD = OOD / total
    accuracy = correct / total
    print(name + f' Accuracy: {accuracy:.4f}')

f(train_loader, 'train')
f(test_loader, 'test')

# PY, n1, n, e0, e1, torch.logit(outputs, eps=1e-6)

train Accuracy: 0.8325
test Accuracy: 0.8890
