In [None]:
import torch
import model 
import helper
import os

## Out-of-distribution Generalization with ERM
The most straightforward way to handle OOD generalization is empirical risk minimization.
In short, you can just merge the data from multiple sources (a.k.a. domains, environments and subpopulations) and train a model with them.
Previous researches like [DomainBed](https://github.com/facebookresearch/DomainBed) have found that such simple strategy can beat several sophisticatedly designed methods in practical settings.
Thus we introduce ERM as our very first baseline to solve the OOD generalization problem.

We first set the experimental environment, including random seed, gpu_id and several other arguments.

In [None]:
helper.fix_seed(0)
args = helper.Args()
device = torch.device("cuda:%d"%args.gpu_id if torch.cuda.is_available() else "cpu")

Then we load the ERM model (we instantiate it with Resnet18 backbone with substituted classifier).

In [None]:
my_model = model.ERM(args.num_classes, args)
my_model.to(device)

We build the training dataloader with sample NICO++ data (just ignore the domain labels).

In [None]:
train_dataloader = helper.get_ERM_dataloader(args, 'train')

We train the Resnet18 using backpropagation and print the training loss every 20 iterations.
We also test the model on seperate test set and report the test accuracy every 100 iterations.

In [None]:
for step in range(args.num_steps):   
    x, y = next(train_dataloader)
    mini_batches = [x.to(device), y.to(device)]
    step_vals = my_model.update(mini_batches)
    if (step+1) % 20 == 0:
        log_str = "Step %d " % (step+1)
        for k, v in step_vals.items():
            log_str = log_str + "%s: %.4f, " % (k, v)
        print(log_str)

    if (step+1) % 100 == 0:        
        test_dataloader = helper.get_ERM_dataloader(args, 'test')        
        accuracy = helper.test(my_model, test_dataloader, device)
        print("ite: %d, test accuracy: %.4f" % (step+1, accuracy))

The training loss generally decreases over time, and the test accuracy is better than random guess (20%), seems good :)

## Domain Generalization with Mixup

Besides the naive training strategy with pooled data, we show another strand of method called Mixup which interpolates minibatches from different domains

https://arxiv.org/pdf/2001.00677.pdf

https://arxiv.org/pdf/1912.01805.pdf

In this method, through the lens of the simple yet effective mixup training, the authors try to implement the mixup across different domain images and labels to achieve the domain robustness.

We first intialize the Mixup model. Please refer the the model.py file for the more details.

In [None]:
my_model = model.Mixup(args.num_classes, args)
my_model.to(device)

Then we build the training dataloader. Note that different from the ERM dataloader, here we also sample the domain label to perform the mixup across different domains.

In [None]:
train_dataloader = helper.get_DG_dataloader(args, 'train')

Start the training process. We can find that the training loss is higher than that of the ERM training. It is reasonable due to the cross-domain mixup.

In [None]:
for step in range(args.num_steps):
    mini_batches = next(train_dataloader)
    step_vals = my_model.update(mini_batches)
    if (step+1) % 20 == 0:
        log_str = "Step %d " % (step+1)
        for k, v in step_vals.items():
            log_str = log_str + "%s: %.4f, " % (k, v)
        print(log_str)

    if (step+1) % 100 == 0:
        test_dataloader = helper.get_ERM_dataloader(args, 'test')
        accuracy = helper.test(my_model, test_dataloader, device)
        print("ite: %d, test accuracy: %.4f" % (step+1, accuracy))

The final validation accuracy looks like ok :)

In practice, you can either use the domain labels as extra information or just ignore them, depends on your actual applications.
We show both two strands of methods here to let you know there are generally two paradigms on handling OOD problem.
Feel free to play with NICO++ and happy researching!