# Reproduction code for MatchDG

### Paper: Domain Generalization using Causal Matching [Arxiv](https://arxiv.org/abs/2006.07500)

The following code reproduces results for Rotated MNIST and Fashion-MNIST datasets, corresponding to Tables 1, 2 and 3 in the paper.

For convenience, we provide the exact commands for Rotated MNIST dataset with training domains set to [15, 30, 45, 60, 75] and the test domains set to [0, 90]. 

To obtain results for the FashionMNIST dataset, change the dataset parameter `--dataset` from `rot_mnist` to `fashion_mnist`.

To obtain results for the different set of training domains in the paper, change the input to the parameter `--train_domains` with the list of training domains: `--train_domains [30, 45]` or `--train_domains [30, 45, 60]`

In [None]:
%cd ../../

## Prepare Data

From the directory `data/rot_mnist`, run

In [None]:
%%bash
cd data/rot_mnist
python data_gen.py resnet18

## Table 1
Now move back to the root directory.

* ERM: 

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.0

* ERM_RandomMatch:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.1

* ERM_PerfectMatch:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 1.0 --penalty_ws 0.1

* MatchDG:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --epochs 30 --batch_size 64 --pos_metric cos
python train.py --dataset rot_mnist --method_name matchdg_erm --penalty_ws 0.1 --match_case -1 --ctr_match_case 0.01 --ctr_match_flag 1 --ctr_match_interrupt 5

## Table 2

* ERM: 

In [None]:
%%bash
python test.py --dataset rot_mnist --method_name erm_match --match_case 0.01 --penalty_ws 0.0 --test_metric match_score 

* MatchDG (Default):

In [None]:
%%bash
python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 0.01 --match_flag 1 --pos_metric cos --test_metric match_score

* MatchDG (PerfMatch):

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name matchdg_ctr --match_case 1.0 --match_flag 1 --epochs 30 --batch_size 64 --pos_metric cos
python test.py --dataset rot_mnist --method_name matchdg_ctr --match_case 1.0 --match_flag 1 --pos_metric cos --test_metric match_score 

## Table 3

* Approx 25:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.25 --penalty_ws 0.1

* Approx 50:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.50 --penalty_ws 0.1

* Approx 75:

In [None]:
%%bash
python train.py --dataset rot_mnist --method_name erm_match --match_case 0.75 --penalty_ws 0.1