# BYOL from Scratch

## This notebook implements BYOL end‑to‑end

1. Build augmentations

2. Build the BYOL model

3. Train on STL‑10 (unlabeled)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from torchvision.datasets import STL10
import matplotlib.pyplot as plt
from PIL import Image
import os
import random
import numpy as np

In [2]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    os.environ["PYTHONHASHSEED"] = str(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(40)

# 1. BYOL Augmentations

In [3]:
class BYOLAugmentations:
    def __init__(self, image_size=96):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=9),
            transforms.ToTensor(),
        ])


    def __call__(self, x):
        v1 = self.transform(x)
        v2 = self.transform(x)
        return v1, v2

byol_transform = BYOLAugmentations()

In [4]:
class STL10BYOL(STL10):
    def __init__(self, *args, simclr_transform=None, **kwargs):
        super().__init__(*args, transform=None, **kwargs)
        self.simclr_transform = simclr_transform

    def __getitem__(self, index):
        img, _ = super().__getitem__(index)  # raw PIL image
        v1, v2 = self.simclr_transform(img)
        return v1, v2

# 2. Build the BYOL model

## 2.1 Encoder (ResNet18)

In [5]:
import torchvision.models as models

def get_encoder():
    resnet = models.resnet18(weights=None)
    resnet.fc = torch.nn.Identity()
    return resnet


## 2.2 Projection head

In [6]:
class MLP(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim=512, out_dim=256):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_dim, hidden_dim),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.net(x)
predictor = MLP(256, 512, 256)


## 2.3 BYOL model

In [7]:
import copy
import torch.nn.functional as F

class BYOL(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.online_encoder = get_encoder()
        self.target_encoder = copy.deepcopy(self.online_encoder)

        self.online_proj = MLP(512)
        self.target_proj = MLP(512)

        self.predictor = MLP(256, 512, 256)

        for p in self.target_encoder.parameters():
            p.requires_grad = False
        for p in self.target_proj.parameters():
            p.requires_grad = False

    @torch.no_grad()
    def update_target(self, m=0.996):
        for o, t in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            t.data = t.data * m + o.data * (1 - m)

        for o, t in zip(self.online_proj.parameters(), self.target_proj.parameters()):
            t.data = t.data * m + o.data * (1 - m)

    def forward(self, v1, v2):
        o1 = self.predictor(self.online_proj(self.online_encoder(v1)))
        o2 = self.predictor(self.online_proj(self.online_encoder(v2)))

        with torch.no_grad():
            t1 = self.target_proj(self.target_encoder(v1))
            t2 = self.target_proj(self.target_encoder(v2))

        o1 = F.normalize(o1, dim=1)
        o2 = F.normalize(o2, dim=1)
        t1 = F.normalize(t1, dim=1)
        t2 = F.normalize(t2, dim=1)

        loss = 2 - 2 * (
            (o1 * t2).sum(dim=1).mean() +
            (o2 * t1).sum(dim=1).mean()
        ) / 2

        return loss


# 3. Training loop 

In [8]:
dataset = STL10BYOL(
    root="./data",
    split="unlabeled",
    download=True,
    simclr_transform=byol_transform
)

loader = DataLoader(
    dataset,
    batch_size=512,
    shuffle=True,
    num_workers=4,
    drop_last=True,
    pin_memory=True
)


100%|██████████| 2.64G/2.64G [02:20<00:00, 18.8MB/s] 


In [9]:
model = BYOL().cuda()
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=3e-4
)

scaler = torch.cuda.amp.GradScaler()


  scaler = torch.cuda.amp.GradScaler()


In [10]:
for epoch in range(50):
    model.train()
    total_loss = 0

    for v1, v2 in loader:
        v1, v2 = v1.cuda(), v2.cuda()

        with torch.cuda.amp.autocast():
            loss = model(v1, v2)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        model.update_target()

        total_loss += loss.item()

    print(f"Epoch {epoch}: BYOL Loss = {total_loss/len(loader):.4f}")


  with torch.cuda.amp.autocast():


Epoch 0: BYOL Loss = 0.6509
Epoch 1: BYOL Loss = 0.4477
Epoch 2: BYOL Loss = 0.3755
Epoch 3: BYOL Loss = 0.3621
Epoch 4: BYOL Loss = 0.3554
Epoch 5: BYOL Loss = 0.3504
Epoch 6: BYOL Loss = 0.3458
Epoch 7: BYOL Loss = 0.3390
Epoch 8: BYOL Loss = 0.3325
Epoch 9: BYOL Loss = 0.3283
Epoch 10: BYOL Loss = 0.3240
Epoch 11: BYOL Loss = 0.3185
Epoch 12: BYOL Loss = 0.3139
Epoch 13: BYOL Loss = 0.3065
Epoch 14: BYOL Loss = 0.3014
Epoch 15: BYOL Loss = 0.2951
Epoch 16: BYOL Loss = 0.2912
Epoch 17: BYOL Loss = 0.2860
Epoch 18: BYOL Loss = 0.2829
Epoch 19: BYOL Loss = 0.2785
Epoch 20: BYOL Loss = 0.2782
Epoch 21: BYOL Loss = 0.2768
Epoch 22: BYOL Loss = 0.2783
Epoch 23: BYOL Loss = 0.2758
Epoch 24: BYOL Loss = 0.2742
Epoch 25: BYOL Loss = 0.2734
Epoch 26: BYOL Loss = 0.2739
Epoch 27: BYOL Loss = 0.2729
Epoch 28: BYOL Loss = 0.2731
Epoch 29: BYOL Loss = 0.2725
Epoch 30: BYOL Loss = 0.2712
Epoch 31: BYOL Loss = 0.2725
Epoch 32: BYOL Loss = 0.2706
Epoch 33: BYOL Loss = 0.2724
Epoch 34: BYOL Loss = 0.

In [13]:
torch.save(
    model.online_encoder.state_dict(),
    "encoder_byol.pth"
)


# 5. Validation
5.1 Linear probing
5.2 KNN testing

In [14]:
encoder = get_encoder().cuda()
encoder.load_state_dict(torch.load("encoder_byol.pth"))
encoder = encoder.cuda()

for param in encoder.parameters():
    param.requires_grad = False

encoder.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### 5.1 Linear probing

In [15]:
linear_head = torch.nn.Linear(512, 10).cuda()

In [16]:
from torchvision.datasets import STL10
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(96),
    transforms.ToTensor()
])

train_set = STL10(
    root="./data",
    split="train",
    download=True,
    transform=transform
)

test_set = STL10(
    root="./data",
    split="test",
    download=True,
    transform=transform
)

train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4)


In [17]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(linear_head.parameters(), lr=1e-3)


for epoch in range(20):
    linear_head.train()
    total_loss = 0

    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()

        with torch.no_grad():
            features = encoder(x)

        logits = linear_head(features)
        loss = criterion(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch}: Loss = {total_loss/len(train_loader):.4f}")


Epoch 0: Loss = 1.7736
Epoch 1: Loss = 1.2320
Epoch 2: Loss = 1.0625
Epoch 3: Loss = 0.9862
Epoch 4: Loss = 0.9476
Epoch 5: Loss = 0.9239
Epoch 6: Loss = 0.8964
Epoch 7: Loss = 0.8806
Epoch 8: Loss = 0.8657
Epoch 9: Loss = 0.8570
Epoch 10: Loss = 0.8492
Epoch 11: Loss = 0.8392
Epoch 12: Loss = 0.8309
Epoch 13: Loss = 0.8223
Epoch 14: Loss = 0.8169
Epoch 15: Loss = 0.8132
Epoch 16: Loss = 0.8042
Epoch 17: Loss = 0.7997
Epoch 18: Loss = 0.7987
Epoch 19: Loss = 0.7909


In [18]:
linear_head.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in test_loader:
        x, y = x.cuda(), y.cuda()
        features = encoder(x)
        logits = linear_head(features)
        preds = logits.argmax(dim=1)

        correct += (preds == y).sum().item()
        total += y.size(0)

acc = correct / total * 100
print(f"Linear Probe Accuracy: {acc:.2f}%")


Linear Probe Accuracy: 69.58%


## 5.2 KNN testing

In [19]:
import numpy as np

def extract_features(loader):
    feats = []
    labels = []

    with torch.no_grad():
        for x, y in loader:
            x = x.cuda()
            f = encoder(x)
            feats.append(f.cpu())
            labels.append(y)

    return torch.cat(feats), torch.cat(labels)

train_feats, train_labels = extract_features(train_loader)
test_feats, test_labels = extract_features(test_loader)


In [20]:
train_feats = torch.nn.functional.normalize(train_feats, dim=1)
test_feats = torch.nn.functional.normalize(test_feats, dim=1)


In [21]:
def knn_accuracy(train_feats, train_labels, test_feats, test_labels, k=20):
    correct = 0

    for i in range(test_feats.size(0)):
        sim = torch.matmul(train_feats, test_feats[i])
        topk = sim.topk(k).indices
        pred = train_labels[topk].mode()[0]

        correct += (pred == test_labels[i]).item()

    return correct / test_feats.size(0) * 100


In [22]:
acc_knn = knn_accuracy(
    train_feats, train_labels,
    test_feats, test_labels,
    k=20
)

print(f"k-NN Accuracy (k=20): {acc_knn:.2f}%")


k-NN Accuracy (k=20): 67.58%
