Skip to content

ICLR2021-rep-gen/Rationality-Generalization

Repository files navigation

Rationality implies generalization

This is code accompanying the ICLR 2021 submission "For self-supervised learning, Rationality implies Generalization, provably"

We prove non-vacuous generalization bounds for SSS learning algorithms i.e. algorithms that learn by
(i) performing pre-training with a self-supervised task (i.e., without labels) to obtain a complex representation of the data points, and then
(ii) fitting a simple (e.g., linear) classifier on the representation and the labels.

In this repository, we assume that the self-supervised pre-trained model is available, and the representations are available. This repository shows how to compute the RRM bound for any given representation.

Training the simple classifier

For each SSS-algorithm, calculating the RRM bound requires running two experiments

  1. the clean experiment where we train the simple classifier on the data and labels $(x, y)$
  2. the $\eta$-noisy experiment where we train the simple classifier on $(x, \tilde{y})$ where $\tilde{y}$ are the $\eta$ noised labels.

Clean run

%run fitlabels.py --train_noise_prob 0.0 --dataname CIFAR10 --feature_path ./data --log_predictions --batch_size 512 --epochs 100 --eval_type linear --from_features ----weight_decay 1e-06 --optimname adam --lr_sched_type const --lr 0.0002 --beta1 0.8 --beta2 0.999	

Noisy run

%run fitlabels.py --train_noise_prob 0.05  --dataname CIFAR10 --feature_path ./data --log_predictions --batch_size 512 --epochs 100 --eval_type linear --from_features ----weight_decay 1e-06 --optimname adam --lr_sched_type const --lr 0.0002 --beta1 0.8 --beta2 0.999	

We compute these simple classifiers for a variety of self-supervised training methods. Given the training and test accuracies of both these runs, we compute the empirical RRM bound.

Computing Theorem II bound

We provide a theoretical bound for the Memorization gap in Theorem II. This bound can be computed empirically as follows.

  1. Compute K noisy runs using the code above
  2. Create the $N \times K$ matrix of the classifier predictions, where $N$ is the number of samples and $K$ is the number of trials called pred_matrix
  3. Create the $N \times K$ matrices of the clean labels and noisy labels respectivly called y_matrix and y_tilde_matrix respectively.

Complexity measure $C^{dc}`$

from complexity_functions import complexity, complexity_average
num_classes = 10
noise_matrix = (y_tilde_matrix - y_matrix) % num_classes
diff_matrix = (pred_matrix - y_matrix) % num_classes 

Cdc = complexity_average(diff_matrix, noise_matrix)
Cdc = np.maximum(mi_acc_j_only, 0)

bound_Cdc = (np.mean(np.sqrt(0.5*Cdc)))/ 0.05

print(f'Bound based on Cdc is {bound_Cdc}')    

Complexity measure $C^pc$

Cpc = complexity(diff_matrix, noise_matrix)
Cpc = np.maximum(mi_acc_j_only, 0)

bound_Cpc = (np.sqrt(0.5*np.mean(Cpc)))/ 0.05

print(f'Bound based on Cpc is {bound_Cpc}') 

RRM bound for CIFAR-10

We now list the various quantities of interest for CIFAR-10

Method Backbone Data Augmentation Generalization Gap Robustness Memorization Rationality Theorem II bound RRM bound Test Performance
18 amdim amdim_encoder False 6.682000 2.076349 5.688700 0.000000 70.516720 7.765049 87.380000
19 amdim resnet101 False 12.458000 1.220833 14.264408 0.000000 100.000000 15.485241 62.430000
16 amdim resnet18 False 4.338000 0.422667 4.581044 0.000000 33.470433 5.003710 62.280000
21 amdim resnet50_bn False 14.731333 1.809750 16.625074 0.000000 100.000000 18.434824 66.283333
20 amdim wide_resnet50_2 False 13.070667 1.698750 15.327215 0.000000 100.000000 17.025965 63.803333
13 mocov2 resnet101 False 2.821333 0.329500 3.032190 0.000000 22.779988 3.361690 69.080000
7 mocov2 resnet18 False 1.425333 0.150250 1.243309 0.031775 14.144346 1.425333 67.596667
12 mocov2 resnet50 False 2.718667 0.296083 2.964104 0.000000 24.181311 3.260187 70.086667
14 mocov2 wide_resnet50_2 False 3.106667 0.384917 2.791697 0.000000 22.386794 3.176614 70.843333
8 simclr resnet18 False 1.434000 0.283048 0.791300 0.359652 13.349844 1.434000 82.496667
10 simclr resnet50 False 1.974000 0.215833 0.784471 0.973696 15.745243 1.974000 92.003333
11 simclr resnet50 False 2.240000 0.520000 1.711757 0.008243 19.532210 2.240000 84.943333
17 amdim amdim_encoder True 4.430000 0.682200 0.356427 3.391373 10.323196 4.430000 87.326667
5 amdim resnet101 True -0.908600 0.642133 3.698682 0.000000 25.993151 4.340815 63.563333
6 amdim resnet18 True 0.331400 0.229575 1.148386 0.000000 8.660545 1.377961 62.843333
15 amdim resnet50_bn True 3.693067 0.837233 4.222282 0.000000 31.119562 5.059515 66.440000
9 amdim wide_resnet50_2 True 1.600533 0.685423 2.462525 0.000000 19.200017 3.147948 64.383333
2 mocov2 resnet101 True -6.013333 0.152892 0.706704 0.000000 6.377163 0.859596 68.576667
0 mocov2 resnet18 True -7.350733 0.068200 0.214771 0.000000 3.469925 0.282971 67.190000
3 mocov2 resnet50 True -5.381000 0.189875 0.836944 0.000000 6.986381 1.026819 69.683333
1 mocov2 wide_resnet50_2 True -6.371867 0.180308 1.026729 0.000000 7.632505 1.207037 70.993333
4 simclr resnet50 True -2.886267 0.304940 0.545692 0.000000 6.634170 0.850632 91.956667

RRM bound for ImageNet

Method Backbone Data Augmentation Generalization Gap Robustness Memorization Rationality Theorem II bound RRM bound Test Performance
20 CMC ResNet-50 False 14.730569 2.298659 12.304347 0.127563 NaN 14.730569 54.596667
8 InfoMin ResNet-50 False 10.207255 2.343046 8.963331 0.000000 NaN 11.306377 70.312667
18 InsDis ResNet-50 False 12.022083 1.395160 8.524625 2.102298 NaN 12.022083 56.673333
17 PiRL ResNet-50 False 11.433350 1.493768 8.260058 1.679524 NaN 11.433350 59.105333
19 amdim ResNet-50 False 13.624736 0.902634 9.715600 3.006502 NaN 13.624736 67.693000
21 bigbigan ResNet-50 False 29.595812 3.132483 25.189973 1.273357 NaN 29.595812 50.238667
12 moco ResNet-50 False 10.718562 1.822505 7.860507 1.035550 NaN 10.718562 68.390667
15 simclr ResNet50_1x False 11.071524 1.218472 7.727698 2.125353 NaN 11.071524 68.725333
16 simclrv2 ResNet-50 False 11.164953 0.639183 7.674531 2.851239 NaN 11.164953 74.987333
10 simclrv2 r101_1x_sk0 False 10.528165 1.113542 6.992656 2.421967 NaN 10.528165 73.044000
7 simclrv2 r101_1x_sk1 False 8.234167 0.709457 4.663610 2.861099 NaN 8.234167 76.067333
14 simclrv2 r101_2x_sk0 False 11.024481 0.736880 7.512353 2.775247 NaN 11.024481 76.720000
9 simclrv2 r152_1x_sk0 False 10.316767 1.120541 6.932093 2.264134 NaN 10.316767 74.171333
13 simclrv2 r152_2x_sk0 False 10.924102 0.753688 7.445563 2.724851 NaN 10.924102 77.247333
11 simclrv2 r50_1x_sk0 False 10.621827 0.993703 7.314382 2.313741 NaN 10.621827 70.693333
4 InfoMin ResNet-50 True 4.882868 0.807126 1.012511 3.063231 NaN 4.882868 72.286400
6 InsDis ResNet-50 True 6.845196 0.254156 1.128115 5.462925 NaN 6.845196 58.300800
5 PiRL ResNet-50 True 6.225963 0.291859 0.987236 4.946868 NaN 6.225963 60.559600
3 moco ResNet-50 True 1.316271 0.565968 0.927215 0.000000 NaN 1.493183 70.153600
1 simclrv2 r101_2x_sk0 True 0.633055 0.103505 0.804871 0.000000 47.90 0.908376 77.243600
2 simclrv2 r152_2x_sk0 True 1.004486 0.130411 0.772675 0.101400 NaN 1.004486 77.649600
0 simclrv2 r50_1x_sk0 True -2.336561 0.261075 0.675008 0.000000 46.93 0.936083 70.962400

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages