Skip to content

A PyTorch implementation of VICReg by Bardes et al. (2021), trained on CIFAR-10 with a ResNet-18 backbone.

License

Notifications You must be signed in to change notification settings

augustwester/vicreg

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

36 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

VICReg on CIFAR-10

Siamese architecture

This repo is associated with the blog post "Self-Supervised Learning of Image Representations With VICReg" over at sigmoid prime. It contains an implementation of the VICReg training procedure applied on CIFAR-10 with a ResNet-18 backbone, making it feasible to run on a single GPU.

After 500 epochs of pre-training, the model achieves an accuracy of ~85.5% in linear evaluation (this will almost certainly improve if you increase the number of pre-training epochs).

Since CIFAR-10 is much smaller than ImageNet, a few simplifications have been made to the training process:

  1. We use an encoder dimension of 512 and projector dimension of 1024. This is in contrast with 2048 and 8192, respectively, in the original paper.
  2. Since the batch size is small (256), we use Adam, not LARS. The authors emphasize that VICReg doesn't require large batch sizes, so LARS shouldn't be too important anyway.
  3. Linear evaluation is performed using 50 finetuning epochs instead of 100.

The training procedure is simple. We first instantiate our model:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_dim, projector_dim = 512, 1024
model = VICReg(encoder_dim, projector_dim).to(device)

We then load CIFAR-10, setting the transform parameter to an instance of the Augmentation class, which produces two augmented versions for each image in a batch:

data = CIFAR10(root=".", train=True, download=True, transform=Augmentation())
dataloader = DataLoader(data, batch_size, shuffle=True, num_workers=2)

The inner loop of the training process is defined as follows:

for images, _ in dataloader:
    x1, x2 = [x.to(device) for x in images]
    z1, z2 = model(x1, x2)

    la, mu, nu = 25, 25, 1
    var1, var2 = variance(z1), variance(z2)
    inv = invariance(z1, z2)
    cov1, cov2 = covariance(z1), covariance(z2)
    loss = la*inv + mu*(var1 + var2) + nu*(cov1 + cov2)

    opt.zero_grad()
    loss.backward()
    opt.step()

If you would like to pre-train the model beyond 500 epochs, you can download a 500-epoch checkpoint here. All you need to do is move the file to the root of the directory, set num_epochs in train.py to a target value (e.g. 750 or 1000) and run python3 train.py.

About

A PyTorch implementation of VICReg by Bardes et al. (2021), trained on CIFAR-10 with a ResNet-18 backbone.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages