<a href="https://colab.research.google.com/github/TamarSdeChen/Self-Learner-DeepLearning-Course-Technion/blob/main/train_BYOL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install lightly



In [None]:
#import libraries
import copy
import time
import os

import torch
import torchvision
from torch import nn

from torch.optim.lr_scheduler import MultiStepLR
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor

from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.data import SimCLRCollateFunction
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.utils.scheduler import cosine_schedule

from google.colab import files
from google.colab import drive
#define the device
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
#download the data and create the dataset
transform = SimCLRTransform(input_size=96, cj_prob=0 , normalize = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]})
dataset_train_unlabeled = STL10(root="data", split="unlabeled", download=True, transform=transform)

Files already downloaded and verified


In [None]:
#trainloader from the unlabeled data
trainloader_unlabeled = torch.utils.data.DataLoader(
    dataset_train_unlabeled,
    batch_size=512,
    shuffle=True,
    drop_last=True,
    num_workers=0,
)

In [None]:
class BYOL(nn.Module):
  def __init__(self, backbone):
    super().__init__()
    self.backbone = backbone # e.g., resnet
    self.projection_head = BYOLProjectionHead(512, 1024, 256)
    self.prediction_head = BYOLPredictionHead(256, 1024, 256)
    self.backbone_momentum = copy.deepcopy(self.backbone)
    self.projection_head_momentum = copy.deepcopy(self.projection_head)
    deactivate_requires_grad(self.backbone_momentum)
    deactivate_requires_grad(self.projection_head_momentum)

  def forward(self, x):
    y = self.backbone(x).flatten(start_dim=1)
    z = self.projection_head(y)
    p = self.prediction_head(z)
    return p
  def forward_momentum(self, x):
    y = self.backbone_momentum(x).flatten(start_dim=1)
    z = self.projection_head_momentum(y)
    z = z.detach()
    return z

In [None]:
# insert True for pre-trained ResNet18 or false or un-pre-trained ResNet18
pre_trained = input()

# initialization with pre-trained ResNet-18
if pre_trained == "True":
  resnet = torchvision.models.resnet18(weights='IMAGENET1K_V1')
else:
  resnet = torchvision.models.resnet18()
#create the backbone for the BYOL model
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BYOL(backbone)
model.to(device)

True


BYOL(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 

In [None]:
#hyper-parameters
if pre_trained == "True":
  learning_rate = 1e-3
else:
  learning_rate = 1e-4
criterion = NegativeCosineSimilarity()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
if pre_trained == "True":
  epochs = 50
else:
  epochs = 150

print("Starting Training")
for epoch in range(epochs):
  epoch_time = time.time()
  total_loss = 0
  momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
  for batch in trainloader_unlabeled:
    x0, x1 = batch[0] #batch[0] contains 2 augmantation fot the same image
    update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
    update_momentum(model.projection_head, model.projection_head_momentum, m=momentum_val)

    x0 = x0.to(device)
    x1 = x1.to(device)
    p0 = model(x0)
    z0 = model.forward_momentum(x0)
    p1 = model(x1)
    z1 = model.forward_momentum(x1)
    loss = 0.5 * (criterion(p0, z1) + criterion(p1, z0))
    total_loss += loss.detach()

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

  # print for each epoch
  epoch_time = time.time() - epoch_time
  avg_loss = total_loss / len(trainloader_unlabeled)
  log = "Epoch: {} | Loss: {:.4f} | ".format(epoch, avg_loss)
  epoch_time = time.time() - epoch_time
  log += "Epoch Time: {:.2f} secs".format(epoch_time)
  print(log)

Starting Training
Epoch: 0 | Loss: -0.7323 | Epoch Time: 1688627825.86 secs


KeyboardInterrupt: ignored

In [None]:
# save model after training
checkpoint = 'BYOL_final_checkpoint.pth'

# save model after training
print('==> Saving model ...')
state = {
'net': model.state_dict(),
}
torch.save(state, checkpoint)