This repo is associated with the blog post "Self-Supervised Learning of Image Representations With VICReg" over at sigmoid prime. It contains an implementation of the VICReg training procedure applied on CIFAR-10 with a ResNet-18 backbone, making it feasible to run on a single GPU.
After 500 epochs of pre-training, the model achieves an accuracy of ~85.5% in linear evaluation (this will almost certainly improve if you increase the number of pre-training epochs).
Since CIFAR-10 is much smaller than ImageNet, a few simplifications have been made to the training process:
- We use an encoder dimension of 512 and projector dimension of 1024. This is in contrast with 2048 and 8192, respectively, in the original paper.
- Since the batch size is small (256), we use Adam, not LARS. The authors emphasize that VICReg doesn't require large batch sizes, so LARS shouldn't be too important anyway.
- Linear evaluation is performed using 50 finetuning epochs instead of 100.
The training procedure is simple. We first instantiate our model:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_dim, projector_dim = 512, 1024
model = VICReg(encoder_dim, projector_dim).to(device)
We then load CIFAR-10, setting the transform
parameter to an instance of the Augmentation
class, which produces two augmented versions for each image in a batch:
data = CIFAR10(root=".", train=True, download=True, transform=Augmentation())
dataloader = DataLoader(data, batch_size, shuffle=True, num_workers=2)
The inner loop of the training process is defined as follows:
for images, _ in dataloader:
x1, x2 = [x.to(device) for x in images]
z1, z2 = model(x1, x2)
la, mu, nu = 25, 25, 1
var1, var2 = variance(z1), variance(z2)
inv = invariance(z1, z2)
cov1, cov2 = covariance(z1), covariance(z2)
loss = la*inv + mu*(var1 + var2) + nu*(cov1 + cov2)
opt.zero_grad()
loss.backward()
opt.step()
If you would like to pre-train the model beyond 500 epochs, you can download a 500-epoch checkpoint here. All you need to do is move the file to the root of the directory, set num_epochs
in train.py
to a target value (e.g. 750 or 1000) and run python3 train.py
.