In [1]:
import sys
import torch
import numpy as np
sys.path.append("..")
from src import mdata, mmodel, method
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Define Dataloader

Here we generate data from SCM III in the paper, where label-flipping features exist.

In [2]:
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)

M = 12
n = 1000
d = 18
nb_classes = 2

# simulate SEM data. The last dataloader is target

dataloaders = mdata.simu.SCM_3(M=M, n=n, d=d)

### Define model

In [3]:
# simple linear model
model=mmodel.Linear(d, nb_classes).to(device)

### CIP

In [4]:
alg = method.CIP(device=device, model=model,
                  lamCIP=1., discrepType='mean', nb_classes=2, 
                  loss_type='CrossEntropyLoss', optimizer='Adam', lr=1e-2)
result = alg.fit(
    dataloaders=dataloaders[:-1],  # CIP uses only source domains
    grouper=None,                  # set grouper only for WILD dataset (e.g. camelyon)
    tarId=None,                    # index of target domain. Default to -1
    epochs=50,            
    verbose_every=10
)

Epoch 9: 100%|██████████| 10/10 [00:00<00:00, 10.54it/s, epoch_loss=0.42732008]
Epoch 19: 100%|██████████| 10/10 [00:00<00:00, 10.77it/s, epoch_loss=0.42399898]
Epoch 29: 100%|██████████| 10/10 [00:00<00:00, 12.30it/s, epoch_loss=0.42386653]
Epoch 39: 100%|██████████| 10/10 [00:01<00:00,  9.23it/s, epoch_loss=0.42381443]
Epoch 49: 100%|██████████| 10/10 [00:01<00:00,  9.43it/s, epoch_loss=0.42378902]


In [5]:
print("CIP results")
for i in range(M):
    ypreds, acc, correct = alg.predict_dataloader(dataloaders[i])
    print(f"{'Source' if i<M-1 else 'Target'} domain {i+1 if i<M-1 else '':>2} accuracy: {acc*100:.2f}%")

CIP results
Source domain  1 accuracy: 84.00%
Source domain  2 accuracy: 83.10%
Source domain  3 accuracy: 85.30%
Source domain  4 accuracy: 83.70%
Source domain  5 accuracy: 82.00%
Source domain  6 accuracy: 84.30%
Source domain  7 accuracy: 83.40%
Source domain  8 accuracy: 84.20%
Source domain  9 accuracy: 81.00%
Source domain 10 accuracy: 83.60%
Source domain 11 accuracy: 83.50%
Target domain    accuracy: 81.90%


CIP achieves reasonable accuracy.

### DIP

In [6]:
model=mmodel.Linear(d, nb_classes).to(device)
alg = method.DIP(device=device, model=model,
                 lamDIP=1., discrepType='mean', nb_classes=2, 
                 loss_type='CrossEntropyLoss', optimizer='Adam', lr=1e-2)
result = alg.fit(
    dataloaders=[dataloaders[0], dataloaders[-1]],  # DIP uses one source domain and one target domain
                                                    # Can provide more for DIP-Pool, e.g. [dataloaders[0], dataloaders[1], dataloaders[-1]]
    grouper=None,                  
    tarId=None,
    epochs=50,            
    verbose_every=10
)

Epoch 9: 100%|██████████| 10/10 [00:00<00:00, 131.91it/s, epoch_loss=0.14145027]
Epoch 19: 100%|██████████| 10/10 [00:00<00:00, 157.21it/s, epoch_loss=0.08715069]
Epoch 29: 100%|██████████| 10/10 [00:00<00:00, 165.18it/s, epoch_loss=0.06564525]
Epoch 39: 100%|██████████| 10/10 [00:00<00:00, 270.69it/s, epoch_loss=0.05497827]
Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 276.86it/s, epoch_loss=0.04896626]


In [7]:
print("DIP results")
for i in range(M):
    ypreds, acc, correct = alg.predict_dataloader(dataloaders[i])
    print(f"{'Source' if i<M-1 else 'Target'} domain {i+1 if i<M-1 else '':>2} accuracy: {acc*100:.2f}%")

DIP results
Source domain  1 accuracy: 100.00%
Source domain  2 accuracy: 100.00%
Source domain  3 accuracy: 100.00%
Source domain  4 accuracy: 99.00%
Source domain  5 accuracy: 99.70%
Source domain  6 accuracy: 90.90%
Source domain  7 accuracy: 50.60%
Source domain  8 accuracy: 36.10%
Source domain  9 accuracy: 23.10%
Source domain 10 accuracy: 10.10%
Source domain 11 accuracy: 36.90%
Target domain    accuracy: 9.80%


The accuracy of DIP is even less than random guess, as it incorrectly picks the label flipping feature.

### JointDIP

In [8]:
model=mmodel.Linear(d, nb_classes).to(device)
alg = method.CIP_JointCIPDIP(device=device, model=model, 
                             modelA=None,               # model for joint matching features. Default to copy of model
                             pretrained_modelA=False,   # if False, train CIP on modelA first. Otherwise, skip the CIP trianing step.
                             lamCIP_A=1.,               # lambda for CIP
                             lamDIP=10.,                # lambda for jointDIP
                             discrepType='MMD', 
                             nb_classes=2, loss_type='CrossEntropyLoss', 
                             optimizer='Adam', lr=1e-2)
result = alg.fit(
    dataloaders=dataloaders,    # all dataloaders are used in JointDIP
    grouper=None, 
    srcIds=[0],                 # indices of source domains used for final joint matching
    tarId=-1, 
    epochs=50, 
    verbose_every=10
)

Epoch 9: 100%|██████████| 10/10 [00:01<00:00,  6.47it/s, epoch_loss=0.54299492]
Epoch 19: 100%|██████████| 10/10 [00:01<00:00,  5.40it/s, epoch_loss=0.50208367]
Epoch 29: 100%|██████████| 10/10 [00:01<00:00,  5.99it/s, epoch_loss=0.48868067]
Epoch 39: 100%|██████████| 10/10 [00:01<00:00,  6.06it/s, epoch_loss=0.48273060]
Epoch 49: 100%|██████████| 10/10 [00:01<00:00,  5.52it/s, epoch_loss=0.47849036]
Epoch 9: 100%|██████████| 10/10 [00:00<00:00, 172.47it/s, epoch_loss=0.69390416]
Epoch 19: 100%|██████████| 10/10 [00:00<00:00, 185.60it/s, epoch_loss=0.64479875]
Epoch 29: 100%|██████████| 10/10 [00:00<00:00, 132.15it/s, epoch_loss=0.62938417]
Epoch 39: 100%|██████████| 10/10 [00:00<00:00, 120.98it/s, epoch_loss=0.62331256]
Epoch 49: 100%|██████████| 10/10 [00:00<00:00, 111.72it/s, epoch_loss=0.62047727]


In [9]:
print("JointDIP results")
for i in range(M):
    ypreds, acc, correct = alg.predict_dataloader(dataloaders[i])
    print(f"{'Source' if i<M-1 else 'Target'} domain {i+1 if i<M-1 else '':>2} accuracy: {acc*100:.2f}%")

JointDIP results
Source domain  1 accuracy: 85.50%
Source domain  2 accuracy: 81.10%
Source domain  3 accuracy: 86.60%
Source domain  4 accuracy: 84.30%
Source domain  5 accuracy: 82.90%
Source domain  6 accuracy: 80.70%
Source domain  7 accuracy: 73.60%
Source domain  8 accuracy: 81.40%
Source domain  9 accuracy: 82.80%
Source domain 10 accuracy: 82.90%
Source domain 11 accuracy: 80.90%
Target domain    accuracy: 84.10%


JointDIP solves the problem of DIP by jointly matching DIP with CIP features. It has an accuracy better than CIP.