This repo implements a domain adaptation neural networks. The paper Unsupervised Domain Adaptation by Backpropagation by Yaroslav Ganin and Victor Lempitsky was a great source of inspiration to design the neural networks.
The goal of this repo is to implement a network able to classify MNIST samples by using only SVHN labelled samples and MNIST unlabelled samples during training.
virtualenv -p python3 env
source env/bin/activate
pip install -r requirements.txt
python utils/download_data.py
All the configuration variables are in the file utils/config.py.
To have a baseline of the performance without domain adaptation, I tested a simple Convolutional Neural Network. The architecture is described in this picture: cnn.
The network with domain adaptation was designed such that the architecture for label prediction was the same as the previous CNN for two main reasons:
- compare similar networks performance
- load pre trained weigths to make training easier
The architecture is described in those pictures: cnn_grl_model and cnn_grl_fe.
To display training options
python train.py -h
python evaluate.py
Network | Source (accuracy) | Target (accuracy) |
---|---|---|
CNN | SVHN (0.908) | MNIST (0.601) |
CNN-GRL | SVHN (0.883) | MNIST (0.711) |
CNN | MNIST (0.986) | SVHN (0.230) |
CNN-GRL | MNIST (0.982) | SVHN (0.238) |
It seems than the Gradient Reversal Layer leads to a significative improvement of the network to classify MNIST when it has been trained on SVHN. However, the opposite task isn't more effective with the GRL.
The next two plots are a visualization of the features built by different networks. It is a t-sne with 2 components of the output of the Dense layer with 512 units in each network. I only used 3000 samples to make t-sne computation faster.
The goal of the network with gradient reversal layer is to build features independent of the input distribution samples (MNIST or SVHN for example). Thus, the features built by the CNN-GRL network should be more mixed than the one built by the classic CNN.
Even if it isn't completely obvious, the features built by the CNN-GRL architecture seem to be more mixed than the ones built by the CNN architecture.
The same calculation on the whole datasets would probably lead to a more obvious result.