- Efficient Learning of CNNs using Patch Based Features
- Table of Contents
- Summary of our Empirical Study
- Getting started
- Reproducing the experiments from the paper
- Table 1: Comparison between our algorithm and baselines
- Table 2: The effect of the constraint on the linear function
- Table 3: Examining the effect of pooling and bottleneck
- Table 4: Test accuracy across depth
- Figure 5: Mean distance between patches and centroids
- Appendix C.4: Whitening
- Appendix C.5: The effects of the different parameters in our model
Check out the wandb report summarizing the results of our empirical study.
Clicking on a run will open its page in wandb, which might be useful: 1) view the exact command line and packages requirements enabling re-producing the run. 2) enables downloading the checkpoint to play around with the model by yourself. 3) more visualizations (accuracy, losses, images, etc).
An illustration of the patch-based image embedding. For more information please refer to the paper.
An illustration of the semi-supervised algorithm, containing an unsupervised stage to obtain the patches dictionary which is then used by the embedding in the supervised stage. For more information please refer to the paper.
Create the conda environment named patch-based-learning and activate it:
conda env create --file environment.yml
conda activate patch-based-learning
Notes:
- It's recommended to install faiss as well,
to enable running k-means faster (can take a couple of minutes with faiss, instead of almost 1 hour with sklearn).
Runand the code will recognize the environment has faiss and use it instead of sklearn.conda install -c pytorch faiss-cpu=1.7.2
- It's not by default in the environment.yml file because faiss is not always available in conda (for example for mac os with m1 chip, as of June 2022).
- If the creation of the environment takes too long, consider using mamba instead of conda.
We use wandb for visualizing the experiments.
For setting everything up, this section needs to be performed (once, like building the environment).
Login online to wandb. Then, in the environment containing wandb, run wandb login
.
You can run the commands on GPU by passing --device cuda:0
(or any other device number).
You can also use multiple GPUs by passing --multi_gpu X
where X
can be -1
which will take all available GPUs,
or a list of GPUs like --multi_gpu 0 1 4 5 6 8
.
In order to log each run to wandb to the corresponding project, you'll need to add --wandb_project_name WANDB-PROJECT-NAME
(or edit the default value in schemas/environment.py
). In order to give a name to a run, you can use the argument --wandb_run_name RUN-NAME
(although the names can be changed later using the wandb website).
Below you can find command-lines to reproduce each experiment from the paper.
Notes:
- The accuracy reported in the table is the maximal accuracy achieved during the training phase (measured after each epoch).
- The numbers you'll get can be slightly different to those reported in the paper, because the experiments in the paper were launched 5 times and the we report the mean (+-std) of the (maximal) accuracy.
-
Vanilla 1 hidden-layer CNN
python main.py --train_locally_linear_network True --replace_embedding_with_regular_conv_relu True --use_conv False --wandb_run_name "Table 1 - Vanilla 1 hidden-layer CNN"
-
Phi_hard with random patches
python main.py --train_locally_linear_network True --use_conv False --random_gaussian_patches True --wandb_run_name "Table 1 - Phi_hard with random patches"
-
Phi_full with random patches
python main.py --train_locally_linear_network True --use_conv True --random_gaussian_patches True --wandb_run_name "Table 1 - Phi_full with random patches"
-
Phi_hard with clustered patches
python main.py --train_locally_linear_network True --use_conv False --random_gaussian_patches False --wandb_run_name "Table 1 - Phi_hard with clustered patches"
-
Phi_full with clustered patches
python main.py --train_locally_linear_network True --use_conv True --random_gaussian_patches False --wandb_run_name "Table 1 - Phi_full with clustered patches"
-
Simple
python main.py --train_locally_linear_network True --use_conv False --use_avg_pool False --use_batch_norm False --use_bottle_neck False --k 64 --full_embedding True --batch_size 32 --n_clusters 256 --n_patches 65536 --learning_rate 0.00001 --wandb_run_name "Table 2 - Simple"
-
Constrained
python main.py --train_locally_linear_network True --use_conv True --use_avg_pool False --use_batch_norm False --use_bottle_neck False --k 64 --full_embedding False --batch_size 32 --n_clusters 256 --n_patches 65536 --learning_rate 0.001 --wandb_run_name "Table 2 - Constrained"
-
Original
python main.py --train_locally_linear_network True --use_avg_pool False --use_batch_norm False --use_bottle_neck False --wandb_run_name "Table 3 - Original"
-
AvgPool
python main.py --train_locally_linear_network True --use_avg_pool True --use_batch_norm True --use_bottle_neck False --wandb_run_name "Table 3 - AvgPool"
-
Bottleneck
python main.py --train_locally_linear_network True --use_avg_pool False --use_batch_norm False --use_bottle_neck True --wandb_run_name "Table 3 - Bottleneck"
-
Both
python main.py --train_locally_linear_network True --use_avg_pool True --use_batch_norm True --use_bottle_neck True --wandb_run_name "Table 3 - Both"
-
Phi_full depths 1, 2, 3, 4
python main.py --train_locally_linear_network True --depth 4 --kernel_size 5 3 3 3 --use_avg_pool True False False False --pool_size 2 --pool_stride 2 --use_batch_norm True False False False --k 256 128 128 128 --wandb_run_name "Table 4 - Phi_full"
-
Phi_hard depths 1, 2, 3, 4
python main.py --train_locally_linear_network True --depth 4 --kernel_size 5 3 3 3 --use_conv False --use_avg_pool True False False False --pool_size 2 --pool_stride 2 --use_batch_norm True False False False --k 256 128 128 128 --wandb_run_name "Table 4 - Phi_hard"
-
CNN
-
Depth 1
python main.py --model_name VGGc1024d1A --kernel_size 5 --padding 0 --use_batch_norm True --final_mlp_n_hidden_layers 0 --use_relu_after_bottleneck True --wandb_run_name "Table 4 - CNN Depth 1"
-
Depth 2
python main.py --model_name VGGc1024d2A --kernel_size 5 3 --padding 0 --use_batch_norm True False --final_mlp_n_hidden_layers 0 --use_relu_after_bottleneck True --wandb_run_name "Table 4 - CNN Depth 2"
-
Depth 3
python main.py --model_name VGGc1024d3A --kernel_size 5 3 3 --padding 0 --use_batch_norm True False False --final_mlp_n_hidden_layers 0 --use_relu_after_bottleneck True --wandb_run_name "Table 4 - CNN Depth 3"
-
Depth 4
python main.py --model_name VGGc1024d4A --kernel_size 5 3 3 3 --padding 0 --use_batch_norm True False False False --final_mlp_n_hidden_layers 0 --use_relu_after_bottleneck True --wandb_run_name "Table 4 - CNN Depth 4"
-
-
CNN (layerwise) depths 1, 2, 3, 4
python main.py --train_locally_linear_network True --train_locally_linear_network True --depth 4 --kernel_size 5 3 3 3 --replace_embedding_with_regular_conv_relu True --use_conv False --use_avg_pool True False False False --pool_size 2 --pool_stride 2 --use_batch_norm True False False False --wandb_run_name "Table 4 - CNN layerwise"
-
CIFAR-10:
python intrinsic_dimension_playground.py --dataset_name CIFAR10 --wandb_run_name "CIFAR10 patches intrinsic-dimension"
-
ImageNet (the validation dataset needs to be downloaded beforehand):
python intrinsic_dimension_playground.py --dataset_name ImageNet --wandb_run_name "ImageNet patches intrinsic-dimension"
The patches in use by our algorithm (the patches dictionary) are logged to wandb each run. Generally, in the figures we take the patches from the very end of the training phase (you can also view the patches at the beginning of the training).
-
Figure 6: CNN kernels (left), our whitened patches (right)
-
CNN kernels (Vanilla 1 hidden-layer CNN from Table 1):
The patches are taken from the wandb plot best_patches.python main.py --train_locally_linear_network True --replace_embedding_with_regular_conv_relu True --use_conv False --wandb_run_name "Vanilla 1 hidden-layer CNN"
-
Our whitened patches (similar to Phi_full with clustered patches from Table 1, but with ZCA-whitening and not the default PCA-whitening) (the patches are taken from the wandb plot best_patches_whitened):
python main.py --train_locally_linear_network True --use_conv True --random_gaussian_patches False --wandb_run_name "Phi_full with clustered patches using ZCA whitening" -zca_whitening True
-
-
Figure 7: Patches before and after ZCA-whitening Use the same run as in Figure 6 (Phi_full with clustered patches using ZCA whitening) and observe best_patches and best_patches_whitened.
-
Figure 8: Accuracy per k (number of neighbors), where the patches-dictionary size is 1,024
foreach K ( 1 2 4 8 16 32 64 128 256 512 1024 ) python main.py --train_locally_linear_network True --wandb_run_name "Phi_full with clustered patches k=${K}" --k ${K}
-
Figure 9: Accuracy per dictionary-size
- Dictionary size 1,024 (256 neighbors, 262,144 sampled patches for clustering)
python main.py --train_locally_linear_network True --wandb_run_name "Phi_full with clustered patches M=262144 N=1024 k=256" --n_patches 262144 --n_clusters 1024 --k 256
- Dictionary size 2,048 (512 neighbors, 524,288 sampled patches for clustering)
python main.py --train_locally_linear_network True --wandb_run_name "Phi_full with clustered patches M=524288 N=2048 k=512" --n_patches 524288 --n_clusters 2048 --k 512
- Dictionary size 4,096 (1,024 neighbors, 1,048,576 sampled patches for clustering)
python main.py --train_locally_linear_network True --wandb_run_name "Phi_full with clustered patches M=1048576 N=4096 k=1024" --n_patches 1048576 --n_clusters 4096 --k 1024
- Dictionary size 8,192 (2,048 neighbors, 2,097,152 sampled patches for clustering)
python main.py --train_locally_linear_network True --wandb_run_name "Phi_full with clustered patches M=2097152 N=8192 k=2048" --n_patches 2097152 --n_clusters 8192 --k 2048
- Dictionary size 16,384 (4,096 neighbors, 4,194,304 sampled patches for clustering)
python main.py --train_locally_linear_network True --wandb_run_name "Phi_full with clustered patches M=4194304 N=16384 k=4096 lr=0.001" --n_patches 4194304 --n_clusters 16384 --k 4096 --learning_rate 0.001