<a href="https://colab.research.google.com/github/TamarSdeChen/Self-Learner-DeepLearning-Course-Technion/blob/main/BYOL_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install lightly



In [None]:
!nvidia-smi

Thu Jun 29 11:20:55 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    24W / 300W |      0MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import copy

import torch
import torchvision
from torch import nn

from torch.optim.lr_scheduler import MultiStepLR


from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.data import SimCLRCollateFunction
from lightly.transforms.simclr_transform import SimCLRTransform
from lightly.utils.scheduler import cosine_schedule

In [None]:
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor
import time

transform = SimCLRTransform(input_size=96, cj_strength=0.5)

#collate_fn = SimCLRCollateFunction(input_size=96)

dataset_train_unlabeled = STL10(root="data", split="unlabeled", download=True, transform=transform)


Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to data/stl10_binary.tar.gz


100%|██████████| 2640397119/2640397119 [00:48<00:00, 54470093.30it/s]


Extracting data/stl10_binary.tar.gz to data


In [None]:
import torch
trainloader_unlabeled = torch.utils.data.DataLoader(
    dataset_train_unlabeled,
    batch_size=512,
    shuffle=True,
    drop_last=True,
    num_workers=0,
)


In [None]:
class BYOL(nn.Module):
  def __init__(self, backbone):
    super().__init__()
    self.backbone = backbone # e.g., resnet
    self.projection_head = BYOLProjectionHead(512, 1024, 256)
    self.prediction_head = BYOLPredictionHead(256, 1024, 256)
    self.backbone_momentum = copy.deepcopy(self.backbone)
    self.projection_head_momentum = copy.deepcopy(self.projection_head)
    deactivate_requires_grad(self.backbone_momentum)
    deactivate_requires_grad(self.projection_head_momentum)

  def forward(self, x):
    y = self.backbone(x).flatten(start_dim=1)
    z = self.projection_head(y)
    p = self.prediction_head(z)
    return p
  def forward_momentum(self, x):
    y = self.backbone_momentum(x).flatten(start_dim=1)
    z = self.projection_head_momentum(y)
    z = z.detach()
    return z


In [None]:
# # initialization example
# # weights='IMAGENET1K_V1' - we not using it
# resnet = torchvision.models.resnet18()
# backbone = nn.Sequential(*list(resnet.children())[:-1])
# model = BYOL(backbone)

# device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)

device = "cuda" if torch.cuda.is_available() else "cpu"

# initialization example
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = BYOL(backbone)

from google.colab import drive
drive.mount('/content/drive')
path_to_data = '/content/drive/MyDrive/DL_project/BYOL_50_epoc_2906.pth'
state = torch.load(path_to_data)
model.load_state_dict(state['net'])
model.to(device)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


BYOL(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 

In [None]:
import os


criterion = NegativeCosineSimilarity()
#optimizer = torch.optim.SGD(model.parameters(), lr=0.06)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = MultiStepLR(optimizer, milestones=[50, 80], gamma=0.1)
# epoch 0-80:   lr = 0.001
# epoch 80-350: lr = 0.0001
# epoch 350-400: lr = 0.00001

epochs = 100

print("Starting Training")
for epoch in range(epochs):
  epoch_time = time.time()
  total_loss = 0
  momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
  for batch in trainloader_unlabeled:
    x0, x1 = batch[0] #batch[0] contains 2 augmantation fot the same image
    update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
    update_momentum(model.projection_head, model.projection_head_momentum, m=momentum_val)

    x0 = x0.to(device)
    x1 = x1.to(device)
    p0 = model(x0)
    z0 = model.forward_momentum(x0)
    p1 = model(x1)
    z1 = model.forward_momentum(x1)
    loss = 0.5 * (criterion(p0, z1) + criterion(p1, z0))
    total_loss += loss.detach()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    #scheduler.step()
  # print for each epoch
  epoch_time = time.time() - epoch_time
  avg_loss = total_loss / len(trainloader_unlabeled)
  log = "Epoch: {} | Loss: {:.4f}".format(epoch, avg_loss)
  epoch_time = time.time() - epoch_time
  log += "Epoch Time: {:.2f} secs".format(epoch_time)
  print(log)
  if epoch % 10 == 0:
    path_to_data = '/content/drive/MyDrive/DL_project/BYOL_{}_epoc_2906.pth'.format(epoch)

    # save model after training
    print('==> Saving model ...')
    state = {
    'net': model.state_dict(),
    }
    if not os.path.isdir('checkpoints'):
      os.mkdir('checkpoints')
    torch.save(state, path_to_data)




Starting Training
Epoch: 0 | Loss: -0.8953Epoch Time: 1688106609.98 secs
==> Saving model ...
Epoch: 1 | Loss: -0.8957Epoch Time: 1688107163.13 secs
Epoch: 2 | Loss: -0.8952Epoch Time: 1688107680.01 secs
Epoch: 3 | Loss: -0.8957Epoch Time: 1688108193.93 secs
Epoch: 4 | Loss: -0.8959Epoch Time: 1688108708.24 secs
Epoch: 5 | Loss: -0.8953Epoch Time: 1688109222.77 secs
Epoch: 6 | Loss: -0.8958Epoch Time: 1688109734.84 secs
Epoch: 7 | Loss: -0.8964Epoch Time: 1688110248.00 secs
Epoch: 8 | Loss: -0.8958Epoch Time: 1688110760.70 secs
Epoch: 9 | Loss: -0.8969Epoch Time: 1688111271.33 secs
Epoch: 10 | Loss: -0.8964Epoch Time: 1688111781.88 secs
==> Saving model ...
Epoch: 11 | Loss: -0.8976Epoch Time: 1688112294.83 secs
Epoch: 12 | Loss: -0.8962Epoch Time: 1688112805.69 secs
Epoch: 13 | Loss: -0.8974Epoch Time: 1688113316.55 secs
Epoch: 14 | Loss: -0.8974Epoch Time: 1688113826.46 secs
Epoch: 15 | Loss: -0.8975Epoch Time: 1688114337.01 secs
Epoch: 16 | Loss: -0.8973Epoch Time: 1688114847.56 sec

KeyboardInterrupt: ignored

In [None]:

path_to_data = '/content/drive/MyDrive/DL_project/after_training_BYOL_160_epoch_adam.pth'

# save model after training
print('==> Saving model ...')
state = {
'net': model.state_dict(),
}
if not os.path.isdir('checkpoints'):
  os.mkdir('checkpoints')
torch.save(state, path_to_data)


==> Saving model ...
