In [4]:
import torch
from torch import nn
from tqdm import tqdm

device = 'cpu'

## V1

### Dataset

In [7]:
def generate_dataset(pe, n_samples):
    # Invariant features ~ N(0, I)
    X_inv = torch.randn(n_samples, 2)

    # Label Y = x1 + x2 + noise
    noise = torch.randn(n_samples) * (0.1 ** 0.5)
    Y = X_inv.sum(dim=1) + noise  # Shape: (n_samples,)

    # Spurious features ~ N([Y, Y], pe * I)
    mean_spurious = torch.stack([Y, Y], dim=1)
    spurious_noise = torch.randn(n_samples, 2) * (pe ** 0.5)
    X_env = mean_spurious + spurious_noise

    # Concatenate invariant and spurious features
    X = torch.cat([X_inv, X_env], dim=1)
    # Reshape labels to (n_samples, 1)
    Y = Y.unsqueeze(1)
    return X, Y

def combine_datasets(datasets):
    X = torch.concat([dataset[0] for dataset in datasets])
    Y = torch.concat([dataset[1] for dataset in datasets])
    return X, Y

In [8]:
d_1 = generate_dataset(0.5, 4)
d_2 = generate_dataset(1.0, 4)
d_3 = generate_dataset(9.9, 4)
comb = combine_datasets([d_1, d_2, d_3])
d_1

(tensor([[ 0.5002, -0.8092,  0.0520,  1.2263],
         [ 0.5710,  0.8687,  1.2774,  2.7452],
         [-1.2138,  0.6937,  0.2596, -0.8624],
         [ 1.6968, -0.0557,  1.0721,  3.5960]]),
 tensor([[-0.1811],
         [ 1.6353],
         [-0.2467],
         [ 1.9391]]))

### Models

In [4]:
from torch import nn
from torch.nn import functional as F

out_features = 1

In [5]:
class FeatureExtractor(nn.Module):

    def __init__(self) -> None:
        super(FeatureExtractor, self).__init__()
        self.lin = nn.Linear(4,out_features,bias=False)

    def forward(self, x):
        return self.lin(x)

In [6]:
class Classifier(nn.Module):

    def __init__(self, mean, std, epsilon) -> None:
        super(Classifier, self).__init__()
        self.w = mean + std*epsilon

    def forward(self, x):
        return torch.matmul(x, self.w.T)

In [7]:
class AutoEncoder(nn.Module):

    def __init__(self) -> None:
        super(AutoEncoder, self).__init__()
        self.m_u = torch.nn.parameter.Parameter(torch.rand(1,out_features))
        self.std = torch.nn.parameter.Parameter(torch.rand(1,out_features))
        torch.nn.init.uniform_(self.m_u, -1, 1)
        torch.nn.init.uniform_(self.std, 0, 1)

    def recon_loss(self, X, Y, classifier, f_e): #minimize this
        return F.mse_loss(Y, classifier(f_e(X)))

    def KL_loss(self):   #maximize this
        return torch.sum(1 + torch.log(torch.square(self.std)) - torch.square(self.m_u) - torch.square(self.std))/2

    def sample(self, epsilon=None):
        if epsilon is None:
            epsilon =  torch.randn_like(self.m_u).to(device)
        return Classifier(self.m_u, self.std, epsilon)

    def fit(self, X, Y, f_e, epochs):
        optim = torch.optim.Adam([self.m_u, self.std], betas=(0.5, 0.5))
        for _ in range(epochs):
            loss = self.recon_loss(X, Y, self.sample(), f_e) - self.KL_loss()
            optim.zero_grad()
            loss.backward()
            optim.step()

### Train

In [22]:
pe_train = [0.1, 0.3, 0.5, 0.7, 0.9]
pe_val = [0.4, 0.8]
pe_test = [10, 100]
n_samples = 500

train_env = [generate_dataset(pe, n_samples) for pe in pe_train]
train_env.append(combine_datasets(train_env))
val_env = [generate_dataset(pe, n_samples) for pe in pe_val]
X_val = torch.cat([x for x, _ in val_env], dim=0)
Y_val = torch.cat([y for _, y in val_env], dim=0)
val_env = [(X_val, Y_val)]
test_env = [generate_dataset(pe, n_samples) for pe in pe_test]
X_test = torch.cat([x for x, _ in test_env], dim=0)
Y_test = torch.cat([y for _, y in test_env], dim=0)
test_env = [(X_test, Y_test)]

In [35]:
f_e = FeatureExtractor().to(device)
q_u = [AutoEncoder().to(device) for _ in train_env]

lamda = 5
with tqdm(range(1000)) as tepoch:
    for ep in tepoch:
        # Update posteriors distribution of classifiers for each training environment
        for q, D in zip(q_u, train_env):
            q.fit(D[0], D[1], f_e, 20)

        # Update feature extractor only
        loss = 0
        optim = torch.optim.Adam(f_e.parameters(), betas=(0.5, 0.5), lr=0.01)
        for _ in range(10):
            epsilon = torch.randn_like(q_u[-1].m_u).to(device)
            [ loss := loss + (1+lamda)*q_u[-1].recon_loss(D[0], D[1], q_u[-1].sample(epsilon), f_e)-lamda*q.recon_loss(D[0], D[1], q.sample(epsilon), f_e)  for q, D in zip(q_u[:-1], train_env[:-1]) ]
        loss = loss/10
        optim.zero_grad()
        loss.backward()
        optim.step()

        epsilon = torch.zeros_like(q_u[-1].m_u).to(device)
        test_acc = q_u[-1].recon_loss(test_env[0][0], test_env[0][1], q_u[-1].sample(epsilon), f_e)
        train_acc = q_u[-1].recon_loss(train_env[-1][0], train_env[-1][1], q_u[-1].sample(epsilon), f_e)
        tepoch.set_postfix_str(f'train:{train_acc}, test:{test_acc}')

        if ep%50==0:
            print(torch.matmul(q_u[-1].m_u, f_e.lin.weight))
            print(f_e.lin.weight)

  0%|          | 2/1000 [00:00<02:40,  6.22it/s, train:2.7753427028656006, test:3.357396125793457]

tensor([[-0.1106,  0.0435, -0.0517, -0.0903]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[-0.4330,  0.1702, -0.2025, -0.3537]], requires_grad=True)


  5%|▌         | 52/1000 [00:07<02:16,  6.96it/s, train:1.4322160482406616, test:1.6862764358520508]

tensor([[0.0146, 0.1525, 0.0614, 0.0144]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.0470, 0.4902, 0.1975, 0.0463]], requires_grad=True)


 10%|█         | 102/1000 [00:14<02:01,  7.39it/s, train:0.6221087574958801, test:1.4772703647613525]

tensor([[0.2831, 0.3205, 0.1191, 0.0747]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.3270, 0.3702, 0.1375, 0.0863]], requires_grad=True)


 15%|█▌        | 152/1000 [00:20<02:02,  6.90it/s, train:0.3509792387485504, test:1.5431060791015625] 

tensor([[0.4109, 0.4140, 0.1131, 0.1216]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.4270, 0.4302, 0.1175, 0.1263]], requires_grad=True)


 20%|██        | 202/1000 [00:28<01:50,  7.25it/s, train:0.2563991844654083, test:2.046274185180664]  

tensor([[0.4381, 0.4411, 0.1290, 0.1560]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.4670, 0.4702, 0.1375, 0.1663]], requires_grad=True)


 25%|██▌       | 252/1000 [00:34<01:40,  7.46it/s, train:0.1273856908082962, test:3.2789905071258545] 

tensor([[0.4605, 0.4816, 0.1612, 0.1692]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5070, 0.5302, 0.1775, 0.1863]], requires_grad=True)


 30%|███       | 301/1000 [00:41<01:48,  6.47it/s, train:0.3417434096336365, test:1.2421739101409912] 

tensor([[0.4472, 0.4502, 0.0896, 0.1160]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.4870, 0.4902, 0.0975, 0.1263]], requires_grad=True)


 35%|███▌      | 352/1000 [00:48<01:31,  7.12it/s, train:0.34626781940460205, test:1.2922898530960083]

tensor([[0.4446, 0.4477, 0.0929, 0.1203]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.4670, 0.4702, 0.0975, 0.1263]], requires_grad=True)


 40%|████      | 402/1000 [00:55<01:23,  7.16it/s, train:0.15628857910633087, test:2.424215316772461] 

tensor([[0.4845, 0.4875, 0.1265, 0.1529]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5270, 0.5302, 0.1375, 0.1663]], requires_grad=True)


 45%|████▌     | 452/1000 [01:03<01:16,  7.13it/s, train:0.2671554386615753, test:1.5219216346740723] 

tensor([[0.4417, 0.4627, 0.0885, 0.1146]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.4870, 0.5102, 0.0975, 0.1263]], requires_grad=True)


 50%|█████     | 502/1000 [01:10<01:11,  7.01it/s, train:0.17667031288146973, test:2.308046817779541] 

tensor([[0.4682, 0.4711, 0.1222, 0.1478]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5270, 0.5302, 0.1375, 0.1663]], requires_grad=True)


 55%|█████▌    | 552/1000 [01:18<01:01,  7.23it/s, train:0.18404708802700043, test:2.0895962715148926]

tensor([[0.4972, 0.4999, 0.1334, 0.1578]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5870, 0.5902, 0.1575, 0.1863]], requires_grad=True)


 60%|██████    | 602/1000 [01:24<00:52,  7.55it/s, train:0.36993879079818726, test:1.1244726181030273]

tensor([[0.4264, 0.4291, 0.0652, 0.0894]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5070, 0.5102, 0.0775, 0.1063]], requires_grad=True)


 65%|██████▌   | 652/1000 [01:31<00:47,  7.33it/s, train:0.19690154492855072, test:1.7407044172286987]

tensor([[0.4933, 0.4964, 0.1100, 0.1183]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5270, 0.5302, 0.1175, 0.1263]], requires_grad=True)


 70%|███████   | 702/1000 [01:38<00:42,  7.03it/s, train:0.24483706057071686, test:1.6063544750213623]

tensor([[0.4801, 0.4830, 0.1207, 0.1284]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5470, 0.5502, 0.1375, 0.1463]], requires_grad=True)


 75%|███████▌  | 752/1000 [01:45<00:34,  7.23it/s, train:0.21476979553699493, test:1.5810916423797607]

tensor([[0.4784, 0.4987, 0.1028, 0.1105]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5470, 0.5702, 0.1175, 0.1263]], requires_grad=True)


 80%|████████  | 802/1000 [01:52<00:26,  7.60it/s, train:0.3476116359233856, test:0.953772783279419]  

tensor([[0.4785, 0.4812, 0.0823, 0.0897]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5670, 0.5702, 0.0975, 0.1063]], requires_grad=True)


 85%|████████▌ | 852/1000 [01:58<00:20,  7.05it/s, train:0.1763838529586792, test:1.5947726964950562] 

tensor([[0.5211, 0.5240, 0.1043, 0.1122]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5870, 0.5902, 0.1175, 0.1263]], requires_grad=True)


 90%|█████████ | 902/1000 [02:05<00:13,  7.23it/s, train:0.25676387548446655, test:1.0496997833251953]

tensor([[0.5283, 0.5314, 0.0942, 0.1027]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5470, 0.5502, 0.0975, 0.1063]], requires_grad=True)


 95%|█████████▌| 952/1000 [02:13<00:07,  6.76it/s, train:0.22807025909423828, test:1.3618232011795044]

tensor([[0.4993, 0.5024, 0.0924, 0.1007]], grad_fn=<MmBackward0>)
Parameter containing:
tensor([[0.5270, 0.5302, 0.0975, 0.1063]], requires_grad=True)


100%|██████████| 1000/1000 [02:20<00:00,  7.14it/s, train:0.1296904981136322, test:2.4738852977752686]


In [29]:
f_e.lin.weight, q_u[-1].m_u

(Parameter containing:
 tensor([[-0.5433, -0.4727, -0.0763, -0.1268]], requires_grad=True),
 Parameter containing:
 tensor([[-0.8624]], requires_grad=True))

## V2

In [16]:
E = 1
D = 0.1

### Models

In [65]:
class EBD(nn.Module):
    def __init__(self, d_env):
        super(EBD, self).__init__()
        self.embedings = torch.nn.Embedding(d_env, 1).to(device)
        self.re_init()

    def re_init(self):
        self.embedings.weight.data.fill_(1.)

    def re_init_with_noise(self, d_env):
        rd = torch.normal(
            torch.Tensor([E] * d_env),
            torch.Tensor([D] * d_env))
        self.embedings.weight.data = rd.view(-1, 1).to(device)

    def forward(self, e):
        return self.embedings(e.long())

class NeuralTest(nn.Module):
    def __init__(self, dim = 5):
        super(NeuralTest, self).__init__()
        self.fc1 = nn.Linear(in_features=4, out_features=12)
        # self.fc1 = nn.Linear(in_features=2 + dim, out_features=4 + dim*2)
        self.fc2 = nn.Linear(in_features=12, out_features=2 + dim//2)
        self.drop = nn.Dropout(p = 0.2)
        self.fc3 = nn.Linear(in_features=2 + dim//2, out_features=1)
        self.selu = nn.SELU()

    def forward(self, z):
        x = self.selu(self.fc1(z))
        x = self.selu(self.fc2(x))
        x = self.drop(x)
        return self.fc3(x)

class LinearTest(nn.Module):
    def __init__(self, dim):
        super(LinearTest, self).__init__()
        self.fc1 = nn.Linear(in_features=4, out_features=1)
    def forward(self, z):
        return self.fc1(z)

### Train

In [82]:
pe_train = [0.1, 0.3, 0.5, 0.7, 0.9]
pe_val = [0.4, 0.8]
pe_test = [10, 100]
n_samples = 1000

train_env = [generate_dataset(pe, n_samples) for pe in pe_train]
# train_env.append(combine_datasets(train_env))

val_env = [generate_dataset(pe, n_samples) for pe in pe_val]
X_val = torch.cat([x for x, _ in val_env], dim=0)
Y_val = torch.cat([y for _, y in val_env], dim=0)
val_env = [(X_val, Y_val)]

test_env = [generate_dataset(pe, n_samples) for pe in pe_test]
X_test = torch.cat([x for x, _ in test_env], dim=0)
Y_test = torch.cat([y for _, y in test_env], dim=0)
test_env = [(X_test, Y_test)]

In [144]:
d_env = 5
lamda = 100
annealing_epochs = 1000
ebd = EBD(d_env)
model = LinearTest(d_env)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3)

In [145]:
len(train_env)

5

In [146]:
for epoch in tqdm(range(5000)):
    n_sample = 3
    loss = 0
    penalty = 0
    for j in range(n_sample):
        ebd.re_init_with_noise(d_env)
        loss_list = []
        for i_env, env in enumerate(train_env):
            # X_train_e = torch.from_numpy(env[0].astype('float32')).to(device)
            # Y_train_e = torch.from_numpy(env[1].astype('float32')).to(device)
            X_train_e = env[0]
            Y_train_e = env[1]
            y_pred_e = model(X_train_e)
            G_train_e = torch.ones_like(Y_train_e) * i_env
            y_pred_e_w = y_pred_e * ebd(G_train_e).view(-1,1)
            loss_e = criterion(y_pred_e_w, Y_train_e)
        loss_list.append(loss_e)
        loss_t = torch.stack(loss_list)
        train_penalty0 = ((loss_t - loss_t.mean())**2).mean()
        loss += loss_t.mean() / n_sample
        penalty += train_penalty0 / n_sample
    loss += penalty * (lamda if epoch > annealing_epochs else 0)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 500 == 0:
        val_mse = 0
        for i_env, env in enumerate(val_env):
            X_val = env[0]
            Y_val = env[1]
            y_pred = model(X_val)
            val_mse += (y_pred - Y_val).pow(2).mean()
        test_mse = 0
        for i_env, env in enumerate(test_env):
            X_test = env[0]
            Y_test = env[1]
            y_pred = model(X_test)
            test_mse += (y_pred - Y_test).pow(2).mean()

        print(f'train_loss: {loss:.4f}, penalty: {penalty}, val_mse:{val_mse:.4f}, test_mse: {test_mse:.4f}')
print(model.fc1.weight)

  0%|          | 24/5000 [00:00<00:21, 228.33it/s]

train_loss: 2.9508, penalty: 0.0, val_mse:2.7837, test_mse: 9.8602


 10%|█         | 517/5000 [00:02<00:17, 256.79it/s]

train_loss: 0.3639, penalty: 0.0, val_mse:0.2708, test_mse: 14.0961


 21%|██        | 1048/5000 [00:04<00:16, 241.49it/s]

train_loss: 0.2167, penalty: 0.0, val_mse:0.1549, test_mse: 10.6605


 30%|███       | 1525/5000 [00:06<00:13, 250.45it/s]

train_loss: 0.1492, penalty: 0.0, val_mse:0.1092, test_mse: 6.7428


 40%|████      | 2017/5000 [00:08<00:13, 226.97it/s]

train_loss: 0.1553, penalty: 0.0, val_mse:0.0870, test_mse: 4.3798


 51%|█████     | 2528/5000 [00:10<00:09, 254.49it/s]

train_loss: 0.1276, penalty: 0.0, val_mse:0.0773, test_mse: 2.8674


 61%|██████    | 3037/5000 [00:12<00:07, 246.10it/s]

train_loss: 0.0887, penalty: 0.0, val_mse:0.0725, test_mse: 2.0450


 71%|███████   | 3550/5000 [00:14<00:05, 255.78it/s]

train_loss: 0.0824, penalty: 0.0, val_mse:0.0721, test_mse: 1.4774


 81%|████████  | 4026/5000 [00:16<00:04, 209.11it/s]

train_loss: 0.1269, penalty: 0.0, val_mse:0.0722, test_mse: 1.2258


 91%|█████████ | 4533/5000 [00:19<00:01, 248.13it/s]

train_loss: 0.1068, penalty: 0.0, val_mse:0.0726, test_mse: 1.0898


100%|██████████| 5000/5000 [00:20<00:00, 238.66it/s]

Parameter containing:
tensor([[0.7930, 0.7981, 0.0860, 0.0900]], requires_grad=True)





NameError: name 'generate_dataset' is not defined