This is code accompanying the ICLR 2021 submission "For self-supervised learning, Rationality implies Generalization, provably"
We prove non-vacuous generalization bounds for SSS learning algorithms i.e. algorithms that learn by
(i) performing pre-training with a self-supervised task (i.e., without labels) to obtain a complex representation of the data points, and then
(ii) fitting a simple (e.g., linear) classifier on the representation and the labels.
In this repository, we assume that the self-supervised pre-trained model is available, and the representations are available. This repository shows how to compute the RRM bound for any given representation.
For each SSS-algorithm, calculating the RRM bound requires running two experiments
- the clean experiment where we train the simple classifier on the data and labels
$(x, y)$ - the
$\eta$ -noisy experiment where we train the simple classifier on$(x, \tilde{y})$ where$\tilde{y}$ are the$\eta$ noised labels.
%run fitlabels.py --train_noise_prob 0.0 --dataname CIFAR10 --feature_path ./data --log_predictions --batch_size 512 --epochs 100 --eval_type linear --from_features ----weight_decay 1e-06 --optimname adam --lr_sched_type const --lr 0.0002 --beta1 0.8 --beta2 0.999
%run fitlabels.py --train_noise_prob 0.05 --dataname CIFAR10 --feature_path ./data --log_predictions --batch_size 512 --epochs 100 --eval_type linear --from_features ----weight_decay 1e-06 --optimname adam --lr_sched_type const --lr 0.0002 --beta1 0.8 --beta2 0.999
We compute these simple classifiers for a variety of self-supervised training methods. Given the training and test accuracies of both these runs, we compute the empirical RRM bound.
We provide a theoretical bound for the Memorization gap in Theorem II. This bound can be computed empirically as follows.
- Compute K noisy runs using the code above
- Create the
$N \times K$ matrix of the classifier predictions, where$N$ is the number of samples and$K$ is the number of trials calledpred_matrix
- Create the
$N \times K$ matrices of the clean labels and noisy labels respectivly calledy_matrix
andy_tilde_matrix
respectively.
from complexity_functions import complexity, complexity_average
num_classes = 10
noise_matrix = (y_tilde_matrix - y_matrix) % num_classes
diff_matrix = (pred_matrix - y_matrix) % num_classes
Cdc = complexity_average(diff_matrix, noise_matrix)
Cdc = np.maximum(mi_acc_j_only, 0)
bound_Cdc = (np.mean(np.sqrt(0.5*Cdc)))/ 0.05
print(f'Bound based on Cdc is {bound_Cdc}')
Cpc = complexity(diff_matrix, noise_matrix)
Cpc = np.maximum(mi_acc_j_only, 0)
bound_Cpc = (np.sqrt(0.5*np.mean(Cpc)))/ 0.05
print(f'Bound based on Cpc is {bound_Cpc}')
We now list the various quantities of interest for CIFAR-10
Method | Backbone | Data Augmentation | Generalization Gap | Robustness | Memorization | Rationality | Theorem II bound | RRM bound | Test Performance | |
---|---|---|---|---|---|---|---|---|---|---|
18 | amdim | amdim_encoder | False | 6.682000 | 2.076349 | 5.688700 | 0.000000 | 70.516720 | 7.765049 | 87.380000 |
19 | amdim | resnet101 | False | 12.458000 | 1.220833 | 14.264408 | 0.000000 | 100.000000 | 15.485241 | 62.430000 |
16 | amdim | resnet18 | False | 4.338000 | 0.422667 | 4.581044 | 0.000000 | 33.470433 | 5.003710 | 62.280000 |
21 | amdim | resnet50_bn | False | 14.731333 | 1.809750 | 16.625074 | 0.000000 | 100.000000 | 18.434824 | 66.283333 |
20 | amdim | wide_resnet50_2 | False | 13.070667 | 1.698750 | 15.327215 | 0.000000 | 100.000000 | 17.025965 | 63.803333 |
13 | mocov2 | resnet101 | False | 2.821333 | 0.329500 | 3.032190 | 0.000000 | 22.779988 | 3.361690 | 69.080000 |
7 | mocov2 | resnet18 | False | 1.425333 | 0.150250 | 1.243309 | 0.031775 | 14.144346 | 1.425333 | 67.596667 |
12 | mocov2 | resnet50 | False | 2.718667 | 0.296083 | 2.964104 | 0.000000 | 24.181311 | 3.260187 | 70.086667 |
14 | mocov2 | wide_resnet50_2 | False | 3.106667 | 0.384917 | 2.791697 | 0.000000 | 22.386794 | 3.176614 | 70.843333 |
8 | simclr | resnet18 | False | 1.434000 | 0.283048 | 0.791300 | 0.359652 | 13.349844 | 1.434000 | 82.496667 |
10 | simclr | resnet50 | False | 1.974000 | 0.215833 | 0.784471 | 0.973696 | 15.745243 | 1.974000 | 92.003333 |
11 | simclr | resnet50 | False | 2.240000 | 0.520000 | 1.711757 | 0.008243 | 19.532210 | 2.240000 | 84.943333 |
17 | amdim | amdim_encoder | True | 4.430000 | 0.682200 | 0.356427 | 3.391373 | 10.323196 | 4.430000 | 87.326667 |
5 | amdim | resnet101 | True | -0.908600 | 0.642133 | 3.698682 | 0.000000 | 25.993151 | 4.340815 | 63.563333 |
6 | amdim | resnet18 | True | 0.331400 | 0.229575 | 1.148386 | 0.000000 | 8.660545 | 1.377961 | 62.843333 |
15 | amdim | resnet50_bn | True | 3.693067 | 0.837233 | 4.222282 | 0.000000 | 31.119562 | 5.059515 | 66.440000 |
9 | amdim | wide_resnet50_2 | True | 1.600533 | 0.685423 | 2.462525 | 0.000000 | 19.200017 | 3.147948 | 64.383333 |
2 | mocov2 | resnet101 | True | -6.013333 | 0.152892 | 0.706704 | 0.000000 | 6.377163 | 0.859596 | 68.576667 |
0 | mocov2 | resnet18 | True | -7.350733 | 0.068200 | 0.214771 | 0.000000 | 3.469925 | 0.282971 | 67.190000 |
3 | mocov2 | resnet50 | True | -5.381000 | 0.189875 | 0.836944 | 0.000000 | 6.986381 | 1.026819 | 69.683333 |
1 | mocov2 | wide_resnet50_2 | True | -6.371867 | 0.180308 | 1.026729 | 0.000000 | 7.632505 | 1.207037 | 70.993333 |
4 | simclr | resnet50 | True | -2.886267 | 0.304940 | 0.545692 | 0.000000 | 6.634170 | 0.850632 | 91.956667 |
Method | Backbone | Data Augmentation | Generalization Gap | Robustness | Memorization | Rationality | Theorem II bound | RRM bound | Test Performance | |
---|---|---|---|---|---|---|---|---|---|---|
20 | CMC | ResNet-50 | False | 14.730569 | 2.298659 | 12.304347 | 0.127563 | NaN | 14.730569 | 54.596667 |
8 | InfoMin | ResNet-50 | False | 10.207255 | 2.343046 | 8.963331 | 0.000000 | NaN | 11.306377 | 70.312667 |
18 | InsDis | ResNet-50 | False | 12.022083 | 1.395160 | 8.524625 | 2.102298 | NaN | 12.022083 | 56.673333 |
17 | PiRL | ResNet-50 | False | 11.433350 | 1.493768 | 8.260058 | 1.679524 | NaN | 11.433350 | 59.105333 |
19 | amdim | ResNet-50 | False | 13.624736 | 0.902634 | 9.715600 | 3.006502 | NaN | 13.624736 | 67.693000 |
21 | bigbigan | ResNet-50 | False | 29.595812 | 3.132483 | 25.189973 | 1.273357 | NaN | 29.595812 | 50.238667 |
12 | moco | ResNet-50 | False | 10.718562 | 1.822505 | 7.860507 | 1.035550 | NaN | 10.718562 | 68.390667 |
15 | simclr | ResNet50_1x | False | 11.071524 | 1.218472 | 7.727698 | 2.125353 | NaN | 11.071524 | 68.725333 |
16 | simclrv2 | ResNet-50 | False | 11.164953 | 0.639183 | 7.674531 | 2.851239 | NaN | 11.164953 | 74.987333 |
10 | simclrv2 | r101_1x_sk0 | False | 10.528165 | 1.113542 | 6.992656 | 2.421967 | NaN | 10.528165 | 73.044000 |
7 | simclrv2 | r101_1x_sk1 | False | 8.234167 | 0.709457 | 4.663610 | 2.861099 | NaN | 8.234167 | 76.067333 |
14 | simclrv2 | r101_2x_sk0 | False | 11.024481 | 0.736880 | 7.512353 | 2.775247 | NaN | 11.024481 | 76.720000 |
9 | simclrv2 | r152_1x_sk0 | False | 10.316767 | 1.120541 | 6.932093 | 2.264134 | NaN | 10.316767 | 74.171333 |
13 | simclrv2 | r152_2x_sk0 | False | 10.924102 | 0.753688 | 7.445563 | 2.724851 | NaN | 10.924102 | 77.247333 |
11 | simclrv2 | r50_1x_sk0 | False | 10.621827 | 0.993703 | 7.314382 | 2.313741 | NaN | 10.621827 | 70.693333 |
4 | InfoMin | ResNet-50 | True | 4.882868 | 0.807126 | 1.012511 | 3.063231 | NaN | 4.882868 | 72.286400 |
6 | InsDis | ResNet-50 | True | 6.845196 | 0.254156 | 1.128115 | 5.462925 | NaN | 6.845196 | 58.300800 |
5 | PiRL | ResNet-50 | True | 6.225963 | 0.291859 | 0.987236 | 4.946868 | NaN | 6.225963 | 60.559600 |
3 | moco | ResNet-50 | True | 1.316271 | 0.565968 | 0.927215 | 0.000000 | NaN | 1.493183 | 70.153600 |
1 | simclrv2 | r101_2x_sk0 | True | 0.633055 | 0.103505 | 0.804871 | 0.000000 | 47.90 | 0.908376 | 77.243600 |
2 | simclrv2 | r152_2x_sk0 | True | 1.004486 | 0.130411 | 0.772675 | 0.101400 | NaN | 1.004486 | 77.649600 |
0 | simclrv2 | r50_1x_sk0 | True | -2.336561 | 0.261075 | 0.675008 | 0.000000 | 46.93 | 0.936083 | 70.962400 |