In [None]:
from DomainBed.domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader
from DomainBed.domainbed.datasets import get_dataset_class
from networks import CMNIST_MLP, Classifier, ClassifierV2, MMDClassifier
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
from utils import AverageMeter, compute_acc
from dataloader_factory import get_kfold_cross_validation_splits

n_repeats = 5

In [None]:
hparams = {
    'data_augmentation': False,
    'cmnist_env_ps': [0.1, 0.1], #args.cmnist_env_ps,
    'cmnist_blue_means': [0.0, 0.0], #args.cmnist_blue_means,
    'cmnist_blue_stds': [0.0, 0.0], #args.cmnist_blue_stds,
}
dataset_class = get_dataset_class('ColoredMNIST_IRM_Blue')
dataset = dataset_class('/home/ma-user/work/kaican/dataset', [], hparams)

In [None]:
hparams = {
    'data_augmentation': False,
    'cmnist_env_ps': [0.5, 0.5], #args.cmnist_env_ps,
    'cmnist_blue_means': [0.0, 1.0], #args.cmnist_blue_means,
    'cmnist_blue_stds': [0.1, 0.1], #args.cmnist_blue_stds,
}
dataset_class = get_dataset_class('ColoredMNIST_IRM_Blue')
dataset = dataset_class('/home/ma-user/work/kaican/dataset', [], hparams)

In [None]:
def compute_NI(model, envs):
    ''' Computes Non-I.I.D. Index (https://arxiv.org/pdf/1906.02899.pdf)
    Args:
        model: an ERM model predicting y1.
        dataset: a dataset.
    Returns:
        The Non-I.I.D. Index.
    '''
    model.eval()
    
    embs = [[], []]
    for i, env in enumerate(envs):
        data = [env[i] for i in range(len(env))]
        x = torch.stack([p[0] for p in data]).cuda()
        y = torch.stack([p[1] for p in data]).cuda()
        embs[i].append(model.backbone(x).data.cpu().numpy())
    embs_tr = np.concatenate(embs[0])
    embs_te = np.concatenate(embs[1])
    embs = np.concatenate([embs_tr, embs_te])
    
    mean_tr = embs_tr.mean(0)
    mean_te = embs_te.mean(0)
    std_emb = embs.std(0)
    
    normalized = (mean_tr - mean_te) / (std_emb + 1e-8)
    return np.linalg.norm(normalized, ord=2)

In [None]:
# NI
batch_size = 64
num_workers = 8
n_steps = 1000
lr = 0.01

splits = get_kfold_cross_validation_splits(dataset, n_repeats, seed=0)

results = []
for k, (tr_envs, vl_envs) in enumerate(splits):
    model = Classifier(CMNIST_MLP(), 2).cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=0.0005)
    # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_steps)

    tr_env = tr_envs[0]
    vl_env = vl_envs[0]


    tr_loader = InfiniteDataLoader(tr_env, None, batch_size, num_workers)
    tr_iter = iter(tr_loader)

    vl_loader = FastDataLoader(vl_env, 128, 16)

    for i in range(1, n_steps + 1):
        x, y = next(tr_iter)
        x = x.cuda()
        y = y.cuda()

        logits = model(x)
        loss = F.cross_entropy(logits, y)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if not i % 200:
            model.eval()

            acc = AverageMeter()
            with torch.no_grad():
                for x, y in vl_loader:
                    x = x.cuda()
                    y = y.cuda()
                    logits = model(x)
                    acc.update(compute_acc(logits, y), x.size(0))

            print(f'step {i}:')
            print(f'loss {loss.item():.4f}\tacc: {acc.avg:6.3f}')

            model.train()

    r = compute_NI(model, vl_envs)
    print(r)
    results.append(r)

arr = np.array(results)
print(arr.mean(), arr.std())

In [None]:
import ot
import ot.plot


def compute_EMD(xs, xt):
    ''' Earth mover's distance, a.k.a. Wasserstein metric.
    Args:
        xs: points from source domain.
        xt: points from target domain.
    Returns:
        Earth mover's distance.
    '''
    assert xs.shape[0] == xt.shape[0]
    n = xs.shape[0]
    
    M = ot.dist(xs, xt)
    M /= M.max()
    
    a, b = np.ones((n,)) / n, np.ones((n,)) / n
    G0 = ot.emd(a, b, M, numItermax=1000000)
    
    return np.sum(G0 * M)

In [None]:
# EMD
results = []

for i in range(n_repeats):
    r = compute_EMD(dataset[0][i::n_repeats][0].view(6000, -1).data.numpy(),
                    dataset[1][i::n_repeats][0].view(6000, -1).data.numpy())
    results.append(r)
arr = np.array(results)
print(arr.mean(), arr.std())

In [None]:
def compute_MMD(model, envs):
    ''' Maximum mean discrepancy (https://papers.nips.cc/paper/2006/file/e9fb2eda3d9c55a0d89c98d6c54b5b3e-Paper.pdf)
    Args:
        model: an ERM model predicting either an example is from training or test environments.
        envs: training and test environments.
    Returns:
        Maximum mean discrepancy.
    '''
    model.eval()
    
    pred1 = AverageMeter()
    pred2 = AverageMeter()
    
    for i, batch in enumerate(envs, 1):
        n_envs = len(batch)
        for j, env in enumerate(batch):
            if j+1 < n_envs:
                x = env['images'][::n_envs-1].cuda()
                y1 = env['labels'][::n_envs-1].squeeze(1).cuda()
                pred = pred1
            else:
                x = env['images'].cuda()
                y1 = env['labels'].squeeze(1).cuda()
                pred = pred2
            
            logits = model(x)
            pred_ = torch.sigmoid(logits)
            pred.update(pred_.mean(), x.size(0))
    
    return abs(pred1.avg - pred2.avg)

In [None]:
# MMD
batch_size = 32
num_workers = 8
n_steps = 1000
lr = 0.01

splits = get_kfold_cross_validation_splits(dataset, n_repeats, seed=0)

results = []
for k, (envs, vl_envs) in enumerate(splits):
    model = MMDClassifier(CMNIST_MLP()).cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=0.0005)
    # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_steps)

    tr_env, te_env = envs
    vl_tr_env, vl_te_env = vl_envs
    
    tr_loader = InfiniteDataLoader(tr_env, None, batch_size, num_workers)
    tr_iter = iter(tr_loader)
    
    te_loader = InfiniteDataLoader(te_env, None, batch_size, num_workers)
    te_iter = iter(te_loader)

    vl_tr_loader = FastDataLoader(vl_tr_env, 64, 8)
    vl_te_loader = FastDataLoader(vl_te_env, 64, 8)

    for i in range(1, n_steps + 1):
        tr_x, _ = next(tr_iter)
        te_x, _ = next(te_iter)
        
        x = torch.cat([tr_x, te_x]).cuda()
        e = torch.cat([torch.zeros(tr_x.size(0)), torch.ones(tr_x.size(0))]).cuda()
        e.unsqueeze_(1)
        
        logits = model(x)
        pred = torch.sigmoid(logits)
        loss = F.binary_cross_entropy_with_logits(pred, e)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i == 1 or not i % 200:
            model.eval()

            tr_meter = AverageMeter()
            te_meter = AverageMeter()
            loss = AverageMeter()
            
            with torch.no_grad():
                for (tr_x, _), (te_x, _) in zip(vl_tr_loader, vl_te_loader):
                    tr_x = tr_x.cuda()
                    te_x = te_x.cuda()
                    
                    tr_e = torch.zeros(tr_x.size(0)).unsqueeze_(1).cuda()
                    te_e = torch.ones(te_x.size(0)).unsqueeze_(1).cuda()
                    
                    tr_logits = model(tr_x)
                    tr_pred = torch.sigmoid(tr_logits)
                    loss.update(F.binary_cross_entropy_with_logits(tr_pred, tr_e), tr_e.size(0))
                    
                    te_logits = model(te_x)
                    te_pred = torch.sigmoid(te_logits)
                    loss.update(F.binary_cross_entropy_with_logits(te_pred, te_e), te_e.size(0))
                    
                    tr_meter.update(tr_pred.data.cpu().numpy().mean(), tr_x.size(0))
                    te_meter.update(te_pred.data.cpu().numpy().mean(), te_x.size(0))

            print(f'step {i}:')
            print(f'loss {loss.avg:.4f}')

            model.train()
            
            if i == n_steps:
                results.append(abs(tr_meter.avg - te_meter.avg))

#     r = compute_MMD(model, vl_envs)
#     print(r)
#     results.append(r)

arr = np.array(results)
print(arr.mean(), arr.std())