<a href="https://colab.research.google.com/github/TamarSdeChen/Self-Learner-DeepLearning-Course-Technion/blob/main/pre_trained_ResNet18_BYOL_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install lightly

In [None]:
import copy

import torch
import torchvision
from torch import nn

from torch.optim.lr_scheduler import MultiStepLR


from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.data import SimCLRCollateFunction
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.utils.scheduler import cosine_schedule

In [None]:
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor
import time

transform = SimCLRTransform(input_size=96, cj_prob=0 , normalize = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]})


dataset_train_unlabeled = STL10(root="data", split="unlabeled", download=True, transform=transform)


Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to data/stl10_binary.tar.gz


100%|██████████| 2640397119/2640397119 [02:31<00:00, 17389537.55it/s]


Extracting data/stl10_binary.tar.gz to data


In [None]:
import torch
trainloader_unlabeled = torch.utils.data.DataLoader(
    dataset_train_unlabeled,
    batch_size=512,
    shuffle=True,
    drop_last=True,
    num_workers=0,
)


In [None]:
class BYOL(nn.Module):
  def __init__(self, backbone):
    super().__init__()
    self.backbone = backbone # e.g., resnet
    self.projection_head = BYOLProjectionHead(512, 1024, 256)
    self.prediction_head = BYOLPredictionHead(256, 1024, 256)
    self.backbone_momentum = copy.deepcopy(self.backbone)
    self.projection_head_momentum = copy.deepcopy(self.projection_head)
    deactivate_requires_grad(self.backbone_momentum)
    deactivate_requires_grad(self.projection_head_momentum)

  def forward(self, x):
    y = self.backbone(x).flatten(start_dim=1)
    z = self.projection_head(y)
    p = self.prediction_head(z)
    return p
  def forward_momentum(self, x):
    y = self.backbone_momentum(x).flatten(start_dim=1)
    z = self.projection_head_momentum(y)
    z = z.detach()
    return z


In [None]:
from google.colab import drive
drive.mount('/content/drive')

device = "cuda" if torch.cuda.is_available() else "cpu"

# initialization with pre-trained ResNet-18
resnet = torchvision.models.resnet18(weights='IMAGENET1K_V1')
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BYOL(backbone)


#path_to_data = '/content/drive/MyDrive/DL_project/BYOL_50_epoc_2906.pth'
#state = torch.load(path_to_data)
#model.load_state_dict(state['net'])


model.to(device)

In [None]:
import os

learning_rate = 1e-3
criterion = NegativeCosineSimilarity()

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#scheduler = MultiStepLR(optimizer, milestones=[80, 90], gamma=0.1)
# epoch 0-80:   lr = 0.001
# epoch 80-350: lr = 0.0001
# epoch 350-400: lr = 0.00001

epochs = 100

print("Starting Training")
for epoch in range(epochs):
  epoch_time = time.time()
  total_loss = 0
  momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
  for batch in trainloader_unlabeled:
    x0, x1 = batch[0] #batch[0] contains 2 augmantation fot the same image
    update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
    update_momentum(model.projection_head, model.projection_head_momentum, m=momentum_val)

    x0 = x0.to(device)
    x1 = x1.to(device)
    p0 = model(x0)
    z0 = model.forward_momentum(x0)
    p1 = model(x1)
    z1 = model.forward_momentum(x1)
    loss = 0.5 * (criterion(p0, z1) + criterion(p1, z0))
    total_loss += loss.detach()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    #scheduler.step()
  # print for each epoch
  epoch_time = time.time() - epoch_time
  avg_loss = total_loss / len(trainloader_unlabeled)
  log = "Epoch: {} | Loss: {:.4f}".format(epoch, avg_loss)
  epoch_time = time.time() - epoch_time
  log += "Epoch Time: {:.2f} secs".format(epoch_time)
  print(log)
  if epoch % 10 == 0:
    path_to_data = '/content/drive/MyDrive/DL_project/pretrained_BYOL_{}_epoc_0207.pth'.format(epoch)

    # save model after training
    print('==> Saving model ...')
    state = {
    'net': model.state_dict(),
    }
    torch.save(state, path_to_data)




Starting Training
Epoch: 0 | Loss: -0.7253Epoch Time: 1688303529.94 secs
==> Saving model ...
Epoch: 1 | Loss: -0.8506Epoch Time: 1688303819.44 secs
Epoch: 2 | Loss: -0.8769Epoch Time: 1688304108.74 secs
Epoch: 3 | Loss: -0.8872Epoch Time: 1688304391.82 secs
Epoch: 4 | Loss: -0.8924Epoch Time: 1688304673.75 secs
Epoch: 5 | Loss: -0.8956Epoch Time: 1688304953.95 secs
Epoch: 6 | Loss: -0.8972Epoch Time: 1688305235.81 secs
Epoch: 7 | Loss: -0.8991Epoch Time: 1688305515.83 secs
Epoch: 8 | Loss: -0.8999Epoch Time: 1688305795.29 secs
Epoch: 9 | Loss: -0.9010Epoch Time: 1688306078.43 secs
Epoch: 10 | Loss: -0.9017Epoch Time: 1688306359.96 secs
==> Saving model ...
Epoch: 11 | Loss: -0.9024Epoch Time: 1688306644.10 secs
Epoch: 12 | Loss: -0.9035Epoch Time: 1688306930.60 secs
Epoch: 13 | Loss: -0.9037Epoch Time: 1688307215.62 secs
Epoch: 14 | Loss: -0.9049Epoch Time: 1688307501.66 secs
Epoch: 15 | Loss: -0.9043Epoch Time: 1688307789.69 secs
Epoch: 16 | Loss: -0.9053Epoch Time: 1688308076.58 sec

KeyboardInterrupt: ignored

In [None]:

path_to_data = '/content/drive/MyDrive/DL_project/pre_trained_BYOL_100_epoch_adam.pth'

# save model after training
print('==> Saving model ...')
state = {
'net': model.state_dict(),
}

torch.save(state, path_to_data)
