Given the MNIST 64x64 handwritten recognition dataset, we define the following two sets as "source" and "target":
The goal is to maximize the accuracy of a classifier on the "target" set. However, labels are only available for the "source" set. We formulate the task as a domain adaptation problem in adversarial terms. Two networks, the classifier and the domain discriminator, compete to optimize opposite objectives.
In particular, the discriminator predicts whether a sample image belongs to the "source" or the "target" domains, accessing features from the input images through the classifier network only. At the same time, the classifier has two simultaneous objectives: a) recognizing digits from the "source" domain in a supervised fashion, and b) fooling the discriminator by maximizing its classification error.
The idea behind adversarial domain adaptation is that the classifier will eventually learn to hide features that are useful to discriminate between domains. By doing so, it becomes robust against domain differences and improves its classification accuracy in the "target" set.
Classification accuracy:
Source Set | Target Set | Performance Gap | |
---|---|---|---|
Without ADDA | 0.99 | 0.38 | 0.61 |
With ADDA | 0.99 | 0.78 | 0.21 |
Using adversarial domain adaptation, the performance gap between the "source" and "target" datasets decreases by 3-folds.
- Simply execute
python train.py
- Anaconda Python 2.5
- Lasagne 0.2.dev1
- Theano 0.9
- GPU for fast training