In [1]:
import torch
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import Dataset, DataLoader, TensorDataset
import numpy as np

In [13]:
# data generation
N = 100 # data per domain
D = 100 # data dimension
batch_size = 128
ratio = 0.8 # ratio of the number of invariant features 
sigma_inv = 1
sigma_spu = 0.1
sigma_factor = [2, 100]


Z_1 = torch.normal(0, 1, (N, int(round(D*ratio)))).repeat(len(sigma_factor)+1,1)
theta_1 = torch.normal(0, sigma_inv, (int(round(D*ratio)), 1))
Y = ((Z_1 @ theta_1).squeeze() > 0).to(torch.float32) # Y is a binary vector with 0 and 1
Z_2 = [torch.normal(1/(D*(1-ratio))*(2*Y[:N]-1).unsqueeze(1).repeat(1,int(round(D*(1-ratio)))), sigma_spu).to(torch.float32)]
# Z_2 = [torch.normal((2*Y[:N]-1).unsqueeze(1).repeat(1,int(round(D*(1-ratio)))), sigma_spu).to(torch.float32)]
for d, factor in enumerate(sigma_factor):
    Z_2.append((Z_2[0]-1/(D*(1-ratio))*(2*Y[:N]-1).unsqueeze(1))*factor+2/D*(2*Y[:N]-1).unsqueeze(1))
Z_2 = torch.cat(Z_2)
print(Z_1.shape, Z_2.shape, Y.shape)
Z = torch.cat((Z_1,Z_2),1)
while True:
    matrix = torch.randn(D, D)  # Create a DxD matrix with random values between 0 and 1
    if torch.linalg.matrix_rank(matrix) == D:
        Q, R = torch.linalg.qr(matrix)
        break
X = Z @ Q

X_train, X_test = X[:N*(len(sigma_factor))], X[N*(len(sigma_factor)):]
Y_train, Y_test = Y[:N*(len(sigma_factor))], Y[N*(len(sigma_factor)):]
Z_train, Z_test = Z[:N*(len(sigma_factor))], Z[N*(len(sigma_factor)):]




class PairedDomainDataset(Dataset):
    def __init__(self, X, Y, Z, domain_labels):
        self.X = X
        self.Y = Y
        self.Z = Z
        self.domain_labels = domain_labels

    def __len__(self):
        return len(self.X) // 2  # Half the size since we're pairing

    def __getitem__(self, idx):
        # Get a pair of indices from the same domain
        idx1 = idx
        idx2 = idx + len(self.X) // 2
        return self.X[idx1], self.Y[idx1], self.X[idx2], self.Y[idx2]

# Generate domain labels (assuming the first half is one domain, the second half is the other)
domain_labels = np.concatenate([np.zeros(N * len(sigma_factor)), np.ones(N)])

# Create the paired dataset
paired_dataset = PairedDomainDataset(X_train, Y_train, Z_train, domain_labels)


paired_loader = DataLoader(paired_dataset, batch_size=batch_size, shuffle=True) 

loader = DataLoader(TensorDataset(X_train, Y_train), batch_size=batch_size, shuffle=True)


torch.Size([300, 80]) torch.Size([300, 20]) torch.Size([300])


In [14]:
# LogisticRegression
clf = LogisticRegression(random_state=0, fit_intercept=False).fit(X_train, Y_train)
clf.score(X_test, Y_test)

0.74

In [15]:
# Oracle (ERM using invariant feature only)
clf = LogisticRegression(random_state=0, fit_intercept=False).fit(Z_train[:,:int(round(D*ratio))], Y_train)
clf.score(Z_test[:,:int(round(D*ratio))], Y_test)

1.0

In [16]:
# A simple NN model.
import torch.nn as nn
import torch.optim as optim

class LR(nn.Module):
    def __init__(self, D_in):
        super(LR, self).__init__()
        self.linear_1 = nn.Linear(D_in, D_in)
        self.linear_2 = nn.Linear(D_in, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        return self.classifier(self.featurizer(x))

    def featurizer(self, x):
        x = self.linear_1(x)
        return x
    
    def classifier(self, x):
        x = self.linear_2(x)
        x = self.sigmoid(x)
        return x.squeeze()


In [8]:
# ERM using our NN model
epoch = 100

repeats = 20
sum_in = []
sum_out = []
for repeat in range(repeats):
    model = LR(D)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    model.train()
    for i in range(epoch):
        for j, (X, Y) in enumerate(loader):
            y = model(X)
            loss = criterion(y, Y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    model.eval()
    y_test = model(X_test)
    y_train = model(X_train)
    sum_in.append(float(((y_train>0.5)==Y_train).sum()/len(y_train)))
    sum_out.append(float(((y_test>0.5)==Y_test).sum()/len(y_test)))
    print(float(((y_train>0.5)==Y_train).sum()/len(y_train)), float(((y_test>0.5)==Y_test).sum()/len(y_test)))

print("---")
print("mean: " + str(torch.tensor(sum_out).mean().item()) + "std: " + str(torch.tensor(sum_out).std().item()))

1.0 0.6200000047683716
1.0 0.699999988079071
1.0 0.699999988079071
1.0 0.699999988079071
1.0 0.6499999761581421
1.0 0.6700000166893005
1.0 0.6800000071525574
1.0 0.6700000166893005
1.0 0.7099999785423279
1.0 0.7200000286102295
1.0 0.6800000071525574
1.0 0.7200000286102295
1.0 0.6700000166893005
1.0 0.6700000166893005
1.0 0.6600000262260437
1.0 0.6800000071525574
1.0 0.6399999856948853
1.0 0.6700000166893005
1.0 0.6499999761581421
1.0 0.7300000190734863
---
mean: 0.6794999837875366std: 0.02874113619327545


In [17]:
# CFM
epoch = 100

repeats = 20
sum_in = 0
sum_out = 0
for repeat in range(repeats):
    model = LR(D)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    lmda = 1000
    model.train()
    for i in range(epoch):
        for j, (X_1, Y_1, X_2, Y_2) in enumerate(paired_loader):
            z_1 = model.featurizer(X_1)
            z_2 = model.featurizer(X_2)
            y_1 = model.classifier(z_1)
            y_2 = model.classifier(z_2)
            loss_1 = criterion(torch.cat((y_1,y_2)), torch.cat((Y_1,Y_2)))
            # print(z_1.shape,z_2.shape)
            loss_2 = (torch.norm(z_1-z_2, p=2) / len(z_1)) ** 2
            # print(loss_1,loss_2)
            loss = loss_1 + lmda * loss_2
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    model.eval()
    y_test = model(X_test)
    y_train = model(X_train)
    sum_in +=  float(((y_train>0.5)==Y_train).sum()/len(y_train))
    sum_out += float(((y_test>0.5)==Y_test).sum()/len(y_test))
    print(float(((y_train>0.5)==Y_train).sum()/len(y_train)), float(((y_test>0.5)==Y_test).sum()/len(y_test)))
    
print("---")
print(sum_in/repeats, sum_out/repeats)

1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
---
1.0 1.0


In [None]:
# Few shot CF Pair
epoch = 100

repeats = 20


shots = np.arange(1,100)
for shot in shots:
    sum_in = []
    sum_out = []

    for repeat in range(repeats):
        model = LR(D)
        criterion = nn.BCELoss()
        optimizer = optim.Adam(model.parameters(), lr=0.01)

        lmda = 2000
        model.train()
        for i in range(epoch):
            for j, (X, Y) in enumerate(loader):
                X= torch.cat((X_train[0:shot], X, X_train[N:N+shot]))
                Y = torch.cat((Y_train[0:shot], Y, Y_train[N:N+shot]))
                z = model.featurizer(X)
                y = model.classifier(z)
                loss_1 = criterion(y, Y)
                loss_2 = (torch.norm(z[:shot]-z[-shot:], p=2) / shot)
                loss = loss_1 + lmda * loss_2
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        model.eval()
        y_test = model(X_test)
        y_train = model(X_train)
        sum_in.append(float(((y_train>0.5)==Y_train).sum()/len(y_train)))
        sum_out.append(float(((y_test>0.5)==Y_test).sum()/len(y_test)))
        # print(float(((y_train>0.5)==Y_train).sum()/len(y_train)), float(((y_test>0.5)==Y_test).sum()/len(y_test)))

    print("num_of_pair: " + str(shot) + "; mean: " + str(torch.tensor(sum_out).mean().item()) + "; std: " + str(torch.tensor(sum_out).std().item()))

num_of_pair: 1; mean: 0.6220000386238098; std: 0.038879022002220154
num_of_pair: 2; mean: 0.6720000505447388; std: 0.054733797907829285
num_of_pair: 3; mean: 0.6540001034736633; std: 0.05103146657347679
num_of_pair: 4; mean: 0.6609999537467957; std: 0.04216508939862251
num_of_pair: 5; mean: 0.6664999723434448; std: 0.04858795925974846
num_of_pair: 6; mean: 0.6639999747276306; std: 0.05761944130063057
num_of_pair: 7; mean: 0.6940000057220459; std: 0.06777672469615936
num_of_pair: 8; mean: 0.6995000243186951; std: 0.05679835006594658
num_of_pair: 9; mean: 0.6864999532699585; std: 0.055276818573474884
num_of_pair: 10; mean: 0.7024999260902405; std: 0.04666509851813316
num_of_pair: 11; mean: 0.7464999556541443; std: 0.040036171674728394
num_of_pair: 12; mean: 0.7430000305175781; std: 0.06489668786525726
num_of_pair: 13; mean: 0.7489999532699585; std: 0.060253847390413284
num_of_pair: 14; mean: 0.75; std: 0.06341176480054855
num_of_pair: 15; mean: 0.8075000047683716; std: 0.0685085430741310

In [10]:
# IRM

repeats = 20
sum_in = 0
sum_out = 0


import torch.autograd as autograd

scale = torch.tensor(1.).requires_grad_()


def irm_penalty(loss_0, loss_1):
    grad_0 = autograd.grad(loss_0.mean(), [scale], create_graph=True)[0]
    grad_1 = autograd.grad(loss_1.mean(), [scale], create_graph=True)[0]
    result = torch.sum(grad_0 * grad_1)
    del grad_0, grad_1
    return result

for repeat in range(repeats):
    model = LR(D)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    penalty_weight = 1

    model.train()
    for i in range(epoch):
        for j, (X_1, Y_1, X_2, Y_2) in enumerate(paired_loader):
            z_1 = model.featurizer(X_1)
            z_2 = model.featurizer(X_2)
            y_1 = model.classifier(z_1)
            y_2 = model.classifier(z_2)
            loss_1 = criterion(y_1*scale, Y_1)
            loss_2 = criterion(y_2*scale, Y_2)
            loss_3 = irm_penalty(loss_1, loss_2)
            # print(loss_1,loss_2)
            loss = loss_1 + lmda * loss_2
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    model.eval()
    y_test = model(X_test)
    y_train = model(X_train)
    sum_in +=  float(((y_train>0.5)==Y_train).sum()/len(y_train))
    sum_out += float(((y_test>0.5)==Y_test).sum()/len(y_test))
    print(float(((y_train>0.5)==Y_train).sum()/len(y_train)), float(((y_test>0.5)==Y_test).sum()/len(y_test)))

print("---")
print(sum_in/repeats, sum_out/repeats)

1.0 0.7300000190734863
1.0 0.7400000095367432
1.0 0.7400000095367432
1.0 0.7699999809265137
1.0 0.7400000095367432
1.0 0.7200000286102295
1.0 0.699999988079071
1.0 0.7099999785423279
1.0 0.7099999785423279
1.0 0.7200000286102295
1.0 0.7400000095367432
1.0 0.7599999904632568
1.0 0.7400000095367432
1.0 0.7400000095367432
1.0 0.7400000095367432
1.0 0.6600000262260437
1.0 0.7400000095367432
1.0 0.7200000286102295
1.0 0.7300000190734863
1.0 0.7699999809265137
---
1.0 0.7310000061988831
