# BYOL

In this notebook we are going to implement [BYOL: Bootstrap Your Own Latent](https://arxiv.org/pdf/2006.07733.pdf) and compare the results of a classification task before and after pretraining the model with BYOL.

### Data Augmentations

In [1]:
import random
from typing import Callable, Tuple
import torch
import torchvision
from torch import nn, Tensor
from torchvision import transforms as T
from torch.nn import functional as F


class RandomApply(nn.Module):
    def __init__(self, fn: Callable, p: float):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        if random.random() > self.p:
            return x
        return self.fn(x)
    

def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:

    """
        1. resize images to 'image_size'
        2. RandomApply color jitter
        3. RandomApply grayscale
        4. RandomApply horizon flip
        5. RandomApply gaussian blur with kernel_size(3, 3), sigma=(1.5, 1.5)
        6. RandomApply ResizedCrop to 'image_size'
        7. Normalize
        choosing hyperparameters that are not mentioned is up to you
    """
    return nn.Sequential(
        # your code
        RandomApply(T.ColorJitter(0.8, 0.8, 0.8, 0.2), p = 0.3),
        T.RandomGrayscale(p=0.2),
        T.RandomHorizontalFlip(),
        RandomApply(T.GaussianBlur((3, 3), (1.0, 2.0)), p = 0.2),
        T.RandomResizedCrop((image_size)),
        T.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
            )
    )

# Model
We will use ResNet18 as our representation model.

In [2]:
def get_encoder_model():
    resnet = torchvision.models.resnet18()
    # remove last fully-connected layer
    # your code
    modules = list(resnet.children())[:-1]
    resnet = torch.nn.Sequential(*modules, nn.Flatten())
    return resnet

### Loss Function
We need to use NormalizedMSELoss as our loss function.
$$NormalizedMSELoss(v_1, v_2) = \Vert \bar{v_1} - \bar{v_2}\Vert_2^2 = 2 - 2.\frac{\langle v_1, v_2 \rangle}{\Vert v_1\Vert_2 \Vert v_2\Vert_2}$$

In [3]:
class NormalizedMSELoss(nn.Module):
    def __init__(self) -> None:
        super(NormalizedMSELoss,self).__init__()
        
    def forward(self, view1: Tensor, view2: Tensor) -> Tensor:
        view1 = F.normalize(view1, dim=-1, p=2)
        view2 = F.normalize(view2, dim=-1, p=2)
        return 2 - 2 * (view1 * view2).sum(dim=-1)

### MLP
Here you will implement a simple MLP class with one hidden layer with BatchNorm and ReLU activation, and a linear output layer. This class will be used for both the projections and the prediction networks.

In [4]:
class MLP(nn.Module):
    def __init__(self, input_dim: int, projection_dim: int = 256, hidden_dim: int = 4096) -> None:
        super(MLP,self).__init__()

        # your code
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, projection_dim),
            )
    
    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)
        

### Encoder + Projector Network
This is the network structure that is shared between online and target networks. It consists of our encoder model, followed by a projection MLP.

In [5]:
class EncoderProjecter(nn.Module):
    def __init__(self,
                 encoder: nn.Module,
                 hidden_dim: int = 4096,
                 projection_out_dim: int = 256
                 ) -> None:
        super(EncoderProjecter, self).__init__()

        # your code
        self.encoder = encoder
        self.projector = MLP(512) 

    def forward(self, x: Tensor) -> Tensor:
        # your code
        return self.projector(self.encoder(x))

## BYOL

In [6]:
import copy

class BYOL(nn.Module):
    def __init__(self,
                 model: nn.Module,
                 hidden_dim: int = 4096,
                 projection_out_dim: int = 256,
                 target_decay: float = 0.99         
                ) -> None:
        super(BYOL, self).__init__()
        
        # your code

        self.augment = default_augmentation((96,96))
        self.beta = target_decay

        self.online_network = EncoderProjecter(model, projection_out_dim, hidden_dim)  # encoder + projector
        self.online_predictor = MLP(projection_out_dim)

        self.target_network = copy.deepcopy(self.online_network)  # init with copy of parameters of online network
        # set target_network's weights to be untrainable
        for param in self.target_network.parameters():
            param.requires_grad = False
        self.target_network.eval()
                
        self.loss_function = NormalizedMSELoss()
        
        
    @torch.no_grad()    
    def soft_update_target_network(self) -> None:
        # your code
        for online_params, target_params in zip(self.online_network.parameters(), self.target_network.parameters()) :
            target_params.data = self.beta * target_params + (1 - self.beta) * online_params

            

    def forward(self, view) -> Tuple[Tensor]:
        # return online projection and target projection of view
        # your code
        online_projection = self.online_network(view)
        target_projection = self.target_network(view)

        return online_projection, target_projection
    
    
    def loss(self, view1, view2):
        # compute loss once for (online_prediction1, target_projection2) and once for (online_prediction2, target_projection1). 
        # then return the mean.
        # your code
        online_projection1, target_projection1 = self.forward(view1)
        online_projection2, target_projection2 = self.forward(view2)
        online_prediction1 = self.online_predictor(online_projection1)
        online_prediction2 = self.online_predictor(online_projection2)
        loss = self.loss_function(online_prediction1, target_projection2) + self.loss_function(online_prediction2, target_projection1)
        return torch.mean(loss)

# STL10 Datasets

We need 3 separate datasets from STL10 for this experiment:
1. `"train"` -- Contains only labeled training images. Used for supervised training.
2. `"train+unlabeled"` -- Contains training images, plus a large number of unlabelled images.  Used for self-supervised learning with BYOL.
3. `"test"` -- Labeled test images.  We use it both as a validation set, and for computing the final model accuracy.

In [7]:
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor


TRAIN_DATASET = STL10(root="data", split="train", download=True, transform=ToTensor())
TRAIN_UNLABELED_DATASET = STL10(root="data", split="train+unlabeled", download=True, transform=ToTensor())
TEST_DATASET = STL10(root="data", split="test", download=True, transform=ToTensor())

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to data/stl10_binary.tar.gz


100%|██████████| 2640397119/2640397119 [04:46<00:00, 9207428.30it/s] 


Extracting data/stl10_binary.tar.gz to data
Files already downloaded and verified
Files already downloaded and verified


Create dataloaders:

In [8]:
# your code
from torch.utils.data.dataloader import DataLoader
TRAIN_dataloader = DataLoader(dataset=TRAIN_DATASET, batch_size=256, shuffle=True)
TRAIN_UNLABELED_dataloader = DataLoader(dataset=TRAIN_UNLABELED_DATASET, batch_size=256, shuffle=True)
TEST_dataloader = DataLoader(dataset=TEST_DATASET, batch_size=256, shuffle=True)
classes = TRAIN_DATASET.classes

# Supervised Training without BYOL

First create a classifier model by simply adding a linear layer on top of the encoder model. Then train the model using the labeled training set. Performance should be pretty good already. 

In [9]:
import numpy as np
import torch.optim as optim
from tqdm import tqdm

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
encoder = get_encoder_model()
classifier = nn.Sequential(encoder, nn.Linear(512, len(classes))).to(device)

In [12]:
n_epochs = 40
criterion = nn.CrossEntropyLoss()
sup_optimizer = optim.SGD(classifier.parameters(), lr=0.001, momentum=0.9)

In [13]:
# training loop
def sup_train_model(model, train_dataloader, n_epochs, optimizer, criterion, device):
    
    for epoch in range(n_epochs):
        ####### Training Phase ########
        model.train()
        with tqdm(train_dataloader, unit="batch") as batches:
            running_loss = 0
            epoch_loss = 0
            for data, target in batches:
                batches.set_description(f"Epoch {epoch + 1}")

                # move to GPU
                data = data.to(device)
                target = target.to(device)

            
                # train 
                out = model(data)
                train_loss = criterion(out, target)

                # Backward and optimize
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

                batches.set_postfix(train_loss = train_loss.item())

                running_loss += train_loss.item() * data.size(0)

        epoch_loss = running_loss / len(TRAIN_DATASET)   
        print(f"loss for epoch {epoch+1} = {epoch_loss}")

    return model

In [14]:
sup_model = sup_train_model(classifier, TRAIN_dataloader, n_epochs, sup_optimizer, criterion, device)

Epoch 1: 100%|██████████| 20/20 [00:13<00:00,  1.49batch/s, train_loss=2.27]


loss for epoch 1 = 2.319526119995117


Epoch 2: 100%|██████████| 20/20 [00:05<00:00,  3.45batch/s, train_loss=2.11]


loss for epoch 2 = 2.16297726020813


Epoch 3: 100%|██████████| 20/20 [00:05<00:00,  3.71batch/s, train_loss=2]


loss for epoch 3 = 2.0153730031967165


Epoch 4: 100%|██████████| 20/20 [00:05<00:00,  3.45batch/s, train_loss=1.76]


loss for epoch 4 = 1.8734141416549683


Epoch 5: 100%|██████████| 20/20 [00:05<00:00,  3.71batch/s, train_loss=1.68]


loss for epoch 5 = 1.7611118068695069


Epoch 6: 100%|██████████| 20/20 [00:05<00:00,  3.59batch/s, train_loss=1.67]


loss for epoch 6 = 1.6807172481536865


Epoch 7: 100%|██████████| 20/20 [00:05<00:00,  3.61batch/s, train_loss=1.58]


loss for epoch 7 = 1.6113888906478882


Epoch 8: 100%|██████████| 20/20 [00:05<00:00,  3.74batch/s, train_loss=1.6]


loss for epoch 8 = 1.5571433547973632


Epoch 9: 100%|██████████| 20/20 [00:05<00:00,  3.46batch/s, train_loss=1.34]


loss for epoch 9 = 1.5004141241073607


Epoch 10: 100%|██████████| 20/20 [00:05<00:00,  3.68batch/s, train_loss=1.47]


loss for epoch 10 = 1.4515578979492187


Epoch 11: 100%|██████████| 20/20 [00:05<00:00,  3.48batch/s, train_loss=1.46]


loss for epoch 11 = 1.4098481840133668


Epoch 12: 100%|██████████| 20/20 [00:05<00:00,  3.74batch/s, train_loss=1.26]


loss for epoch 12 = 1.3623339349746704


Epoch 13: 100%|██████████| 20/20 [00:05<00:00,  3.47batch/s, train_loss=1.4]


loss for epoch 13 = 1.3284835964202881


Epoch 14: 100%|██████████| 20/20 [00:05<00:00,  3.70batch/s, train_loss=1.22]


loss for epoch 14 = 1.2843469665527343


Epoch 15: 100%|██████████| 20/20 [00:05<00:00,  3.44batch/s, train_loss=1.27]


loss for epoch 15 = 1.2360808605194091


Epoch 16: 100%|██████████| 20/20 [00:05<00:00,  3.68batch/s, train_loss=1.17]


loss for epoch 16 = 1.187850731086731


Epoch 17: 100%|██████████| 20/20 [00:05<00:00,  3.53batch/s, train_loss=1.12]


loss for epoch 17 = 1.144682172012329


Epoch 18: 100%|██████████| 20/20 [00:05<00:00,  3.53batch/s, train_loss=1.11]


loss for epoch 18 = 1.104270755004883


Epoch 19: 100%|██████████| 20/20 [00:05<00:00,  3.63batch/s, train_loss=1.04]


loss for epoch 19 = 1.0647072813034058


Epoch 20: 100%|██████████| 20/20 [00:05<00:00,  3.38batch/s, train_loss=0.957]


loss for epoch 20 = 1.0100868553161622


Epoch 21: 100%|██████████| 20/20 [00:05<00:00,  3.62batch/s, train_loss=1.03]


loss for epoch 21 = 0.9616181591033935


Epoch 22: 100%|██████████| 20/20 [00:05<00:00,  3.40batch/s, train_loss=0.826]


loss for epoch 22 = 0.9083940101623535


Epoch 23: 100%|██████████| 20/20 [00:05<00:00,  3.70batch/s, train_loss=0.832]


loss for epoch 23 = 0.8554403490066528


Epoch 24: 100%|██████████| 20/20 [00:05<00:00,  3.42batch/s, train_loss=0.785]


loss for epoch 24 = 0.8065814290046692


Epoch 25: 100%|██████████| 20/20 [00:05<00:00,  3.69batch/s, train_loss=0.766]


loss for epoch 25 = 0.7693659346580506


Epoch 26: 100%|██████████| 20/20 [00:05<00:00,  3.42batch/s, train_loss=0.708]


loss for epoch 26 = 0.7037437643051148


Epoch 27: 100%|██████████| 20/20 [00:05<00:00,  3.64batch/s, train_loss=0.659]


loss for epoch 27 = 0.6492395677566528


Epoch 28: 100%|██████████| 20/20 [00:05<00:00,  3.49batch/s, train_loss=0.588]


loss for epoch 28 = 0.5940336296081543


Epoch 29: 100%|██████████| 20/20 [00:05<00:00,  3.58batch/s, train_loss=0.578]


loss for epoch 29 = 0.542984739112854


Epoch 30: 100%|██████████| 20/20 [00:05<00:00,  3.61batch/s, train_loss=0.556]


loss for epoch 30 = 0.492848752784729


Epoch 31: 100%|██████████| 20/20 [00:05<00:00,  3.41batch/s, train_loss=0.481]


loss for epoch 31 = 0.447937020778656


Epoch 32: 100%|██████████| 20/20 [00:05<00:00,  3.50batch/s, train_loss=0.41]


loss for epoch 32 = 0.3914671799659729


Epoch 33: 100%|██████████| 20/20 [00:06<00:00,  3.32batch/s, train_loss=0.347]


loss for epoch 33 = 0.3432642992496491


Epoch 34: 100%|██████████| 20/20 [00:05<00:00,  3.34batch/s, train_loss=0.31]


loss for epoch 34 = 0.3021245875835419


Epoch 35: 100%|██████████| 20/20 [00:06<00:00,  3.32batch/s, train_loss=0.27]


loss for epoch 35 = 0.26231801719665526


Epoch 36: 100%|██████████| 20/20 [00:05<00:00,  3.59batch/s, train_loss=0.246]


loss for epoch 36 = 0.23142368609905242


Epoch 37: 100%|██████████| 20/20 [00:06<00:00,  3.27batch/s, train_loss=0.214]


loss for epoch 37 = 0.20061892545223237


Epoch 38: 100%|██████████| 20/20 [00:05<00:00,  3.54batch/s, train_loss=0.189]


loss for epoch 38 = 0.17092267065048217


Epoch 39: 100%|██████████| 20/20 [00:06<00:00,  3.31batch/s, train_loss=0.189]


loss for epoch 39 = 0.15120679137706758


Epoch 40: 100%|██████████| 20/20 [00:05<00:00,  3.59batch/s, train_loss=0.118]

loss for epoch 40 = 0.1275482924938202





In [15]:
def multi_acc(y_pred, y_test):
    _, y_pred_tags = torch.max(y_pred, dim = 1)    
    correct_pred = (y_pred_tags == y_test).float()
    acc = correct_pred.sum() / len(correct_pred)
    acc = torch.round(acc * 100)
    return acc

In [26]:
accuracy = 0
for x,y in TEST_dataloader:
    sup_model.eval()
    x = x.to(device)
    y = y.to(device)
    accuracy += multi_acc(sup_model(x), y) / len(TEST_dataloader)
print(f"accuracy without BYOL = {accuracy}")

accuracy without BYOL = 48.125


### Self-Supervised Training with BYOL

Now perform the self-supervised training. This is the most computationally intensive part of the script.

In [17]:
n_epochs = 20
byol = BYOL(get_encoder_model())
ssl_optimizer = torch.optim.Adam(byol.parameters(), lr=3e-4)

In [18]:
def SSL_train_model(model, train_unlabeled_dataloader, n_epochs, optimizer, criterion, device):
    model.to(device)
    
    for epoch in range(n_epochs):
        ####### Training Phase ########
        model.train()
        with tqdm(train_unlabeled_dataloader, unit="batch") as batches:
            running_loss = 0
            for data in batches:
                batches.set_description(f"Epoch {epoch + 1}")

                # move to GPU
                data = data[0].to(device)

 
                with torch.no_grad():
                    data1, data2 = model.augment(data), model.augment(data)

                train_loss = model.loss(data1, data2)
                
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

                model.soft_update_target_network()


                batches.set_postfix(train_loss = train_loss.item())

                running_loss += train_loss.item() * data.size(0)
      

        epoch_loss = running_loss / len(TRAIN_DATASET)   
        print(f"loss for epoch {epoch+1} = {epoch_loss}")

    return model

In [19]:
SSL_model = SSL_train_model(byol, TRAIN_UNLABELED_dataloader, n_epochs, ssl_optimizer, criterion, device)

Epoch 1: 100%|██████████| 411/411 [04:09<00:00,  1.65batch/s, train_loss=0.646]


loss for epoch 1 = 21.386822641944885


Epoch 2: 100%|██████████| 411/411 [04:08<00:00,  1.66batch/s, train_loss=0.223]


loss for epoch 2 = 12.804969591498375


Epoch 3: 100%|██████████| 411/411 [04:09<00:00,  1.65batch/s, train_loss=0.495]


loss for epoch 3 = 11.014533401966094


Epoch 4: 100%|██████████| 411/411 [04:08<00:00,  1.66batch/s, train_loss=0.515]


loss for epoch 4 = 9.469083618927002


Epoch 5: 100%|██████████| 411/411 [04:08<00:00,  1.66batch/s, train_loss=0.774]


loss for epoch 5 = 8.640660267353057


Epoch 6: 100%|██████████| 411/411 [04:06<00:00,  1.67batch/s, train_loss=0.183]


loss for epoch 6 = 7.605020485448837


Epoch 7: 100%|██████████| 411/411 [04:07<00:00,  1.66batch/s, train_loss=0.211]


loss for epoch 7 = 7.453278600096702


Epoch 8: 100%|██████████| 411/411 [04:07<00:00,  1.66batch/s, train_loss=0.334]


loss for epoch 8 = 7.097322121667862


Epoch 9: 100%|██████████| 411/411 [04:08<00:00,  1.65batch/s, train_loss=0.396]


loss for epoch 9 = 6.69550075135231


Epoch 10: 100%|██████████| 411/411 [04:08<00:00,  1.65batch/s, train_loss=0.402]


loss for epoch 10 = 6.54262501039505


Epoch 11: 100%|██████████| 411/411 [04:08<00:00,  1.65batch/s, train_loss=0.304]


loss for epoch 11 = 6.101790420007705


Epoch 12: 100%|██████████| 411/411 [04:08<00:00,  1.66batch/s, train_loss=0.298]


loss for epoch 12 = 6.305572372961044


Epoch 13: 100%|██████████| 411/411 [04:09<00:00,  1.65batch/s, train_loss=0.172]


loss for epoch 13 = 5.387404637765885


Epoch 14: 100%|██████████| 411/411 [04:09<00:00,  1.65batch/s, train_loss=0.234]


loss for epoch 14 = 5.887911245965958


Epoch 15: 100%|██████████| 411/411 [04:08<00:00,  1.65batch/s, train_loss=0.563]


loss for epoch 15 = 5.93919371061325


Epoch 16: 100%|██████████| 411/411 [04:09<00:00,  1.65batch/s, train_loss=0.985]


loss for epoch 16 = 6.002861955356598


Epoch 17: 100%|██████████| 411/411 [04:07<00:00,  1.66batch/s, train_loss=0.333]


loss for epoch 17 = 6.156608871936798


Epoch 18: 100%|██████████| 411/411 [04:11<00:00,  1.63batch/s, train_loss=0.837]


loss for epoch 18 = 5.653103424930572


Epoch 19: 100%|██████████| 411/411 [04:08<00:00,  1.65batch/s, train_loss=0.28]


loss for epoch 19 = 5.777278143787384


Epoch 20: 100%|██████████| 411/411 [04:08<00:00,  1.65batch/s, train_loss=0.539]

loss for epoch 20 = 5.904601030540467





### Supervised Training Again

Extract the encoder network's state dictionary from BYOL, and load it into our ResNet18 model before starting training.  Then run supervised training, and watch the accuracy improve from last time!

In [20]:
modules = list(SSL_model.online_network.children())[:-1]

In [21]:
b_encoder = torch.nn.Sequential(*modules)
b_classifier = nn.Sequential(b_encoder, nn.Linear(512, len(classes))).to(device)

In [22]:
n_epochs = 40
b_criterion = nn.CrossEntropyLoss()
b_sup_optimizer = optim.SGD(b_classifier.parameters(), lr=0.001, momentum=0.9)

In [23]:
# training loop
def b_sup_train_model(model, train_dataloader, n_epochs, optimizer, criterion, device):
    
    for epoch in range(n_epochs):
        ####### Training Phase ########
        model.train()
        with tqdm(train_dataloader, unit="batch") as batches:
            running_loss = 0
            epoch_loss = 0
            for data, target in batches:
                batches.set_description(f"Epoch {epoch + 1}")

                # move to GPU
                data = data.to(device)
                target = target.to(device)

            
                # train 
                out = model(data)
                train_loss = criterion(out, target)

                # Backward and optimize
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

                batches.set_postfix(train_loss = train_loss.item())

                running_loss += train_loss.item() * data.size(0)

        epoch_loss = running_loss / len(TRAIN_DATASET)   
        print(f"loss for epoch {epoch+1} = {epoch_loss}")

    return model

In [24]:
b_sup_model = b_sup_train_model(b_classifier, TRAIN_dataloader, n_epochs, b_sup_optimizer, b_criterion, device)

Epoch 1: 100%|██████████| 20/20 [00:05<00:00,  3.47batch/s, train_loss=2.02]


loss for epoch 1 = 2.248141623687744


Epoch 2: 100%|██████████| 20/20 [00:06<00:00,  3.26batch/s, train_loss=1.76]


loss for epoch 2 = 1.8876248138427734


Epoch 3: 100%|██████████| 20/20 [00:05<00:00,  3.52batch/s, train_loss=1.59]


loss for epoch 3 = 1.6879217853546142


Epoch 4: 100%|██████████| 20/20 [00:06<00:00,  3.30batch/s, train_loss=1.51]


loss for epoch 4 = 1.5709947483062745


Epoch 5: 100%|██████████| 20/20 [00:05<00:00,  3.49batch/s, train_loss=1.46]


loss for epoch 5 = 1.488596364402771


Epoch 6: 100%|██████████| 20/20 [00:05<00:00,  3.36batch/s, train_loss=1.38]


loss for epoch 6 = 1.4263654218673707


Epoch 7: 100%|██████████| 20/20 [00:05<00:00,  3.43batch/s, train_loss=1.38]


loss for epoch 7 = 1.3773021192550658


Epoch 8: 100%|██████████| 20/20 [00:05<00:00,  3.42batch/s, train_loss=1.25]


loss for epoch 8 = 1.3371961051940917


Epoch 9: 100%|██████████| 20/20 [00:05<00:00,  3.39batch/s, train_loss=1.21]


loss for epoch 9 = 1.3049393882751466


Epoch 10: 100%|██████████| 20/20 [00:05<00:00,  3.50batch/s, train_loss=1.25]


loss for epoch 10 = 1.2724446174621582


Epoch 11: 100%|██████████| 20/20 [00:06<00:00,  3.33batch/s, train_loss=1.19]


loss for epoch 11 = 1.2503278528213502


Epoch 12: 100%|██████████| 20/20 [00:05<00:00,  3.50batch/s, train_loss=1.34]


loss for epoch 12 = 1.2298801166534423


Epoch 13: 100%|██████████| 20/20 [00:06<00:00,  3.25batch/s, train_loss=1.23]


loss for epoch 13 = 1.2089129018783569


Epoch 14: 100%|██████████| 20/20 [00:05<00:00,  3.50batch/s, train_loss=1.15]


loss for epoch 14 = 1.1911247747421265


Epoch 15: 100%|██████████| 20/20 [00:06<00:00,  3.30batch/s, train_loss=1.25]


loss for epoch 15 = 1.176486005783081


Epoch 16: 100%|██████████| 20/20 [00:05<00:00,  3.50batch/s, train_loss=1.13]


loss for epoch 16 = 1.1640207962036133


Epoch 17: 100%|██████████| 20/20 [00:06<00:00,  3.28batch/s, train_loss=1.05]


loss for epoch 17 = 1.1478514373779296


Epoch 18: 100%|██████████| 20/20 [00:05<00:00,  3.49batch/s, train_loss=1.14]


loss for epoch 18 = 1.134900359916687


Epoch 19: 100%|██████████| 20/20 [00:06<00:00,  3.27batch/s, train_loss=1.07]


loss for epoch 19 = 1.1247735124588012


Epoch 20: 100%|██████████| 20/20 [00:05<00:00,  3.54batch/s, train_loss=1.12]


loss for epoch 20 = 1.1137183206558228


Epoch 21: 100%|██████████| 20/20 [00:06<00:00,  3.28batch/s, train_loss=1.1]


loss for epoch 21 = 1.105375039291382


Epoch 22: 100%|██████████| 20/20 [00:05<00:00,  3.50batch/s, train_loss=1.09]


loss for epoch 22 = 1.0904491165161132


Epoch 23: 100%|██████████| 20/20 [00:06<00:00,  3.22batch/s, train_loss=1.06]


loss for epoch 23 = 1.0832419496536254


Epoch 24: 100%|██████████| 20/20 [00:05<00:00,  3.47batch/s, train_loss=1.03]


loss for epoch 24 = 1.0708614572525024


Epoch 25: 100%|██████████| 20/20 [00:06<00:00,  3.01batch/s, train_loss=0.999]


loss for epoch 25 = 1.0614239099502563


Epoch 26: 100%|██████████| 20/20 [00:05<00:00,  3.52batch/s, train_loss=1.09]


loss for epoch 26 = 1.0562662368774414


Epoch 27: 100%|██████████| 20/20 [00:06<00:00,  3.27batch/s, train_loss=1.04]


loss for epoch 27 = 1.0478030960083007


Epoch 28: 100%|██████████| 20/20 [00:05<00:00,  3.50batch/s, train_loss=1.08]


loss for epoch 28 = 1.040783090019226


Epoch 29: 100%|██████████| 20/20 [00:06<00:00,  3.25batch/s, train_loss=1.06]


loss for epoch 29 = 1.034704252243042


Epoch 30: 100%|██████████| 20/20 [00:05<00:00,  3.43batch/s, train_loss=1.05]


loss for epoch 30 = 1.0224911287307739


Epoch 31: 100%|██████████| 20/20 [00:06<00:00,  3.32batch/s, train_loss=1.08]


loss for epoch 31 = 1.0174109817504884


Epoch 32: 100%|██████████| 20/20 [00:05<00:00,  3.45batch/s, train_loss=1.06]


loss for epoch 32 = 1.0140555765151977


Epoch 33: 100%|██████████| 20/20 [00:05<00:00,  3.38batch/s, train_loss=1.09]


loss for epoch 33 = 1.0065037279129028


Epoch 34: 100%|██████████| 20/20 [00:05<00:00,  3.38batch/s, train_loss=0.851]


loss for epoch 34 = 1.0000760180473327


Epoch 35: 100%|██████████| 20/20 [00:05<00:00,  3.47batch/s, train_loss=1.01]


loss for epoch 35 = 0.9942375047683716


Epoch 36: 100%|██████████| 20/20 [00:06<00:00,  3.32batch/s, train_loss=1.08]


loss for epoch 36 = 0.9857145727157592


Epoch 37: 100%|██████████| 20/20 [00:05<00:00,  3.49batch/s, train_loss=0.849]


loss for epoch 37 = 0.9792817165374755


Epoch 38: 100%|██████████| 20/20 [00:06<00:00,  3.25batch/s, train_loss=1.09]


loss for epoch 38 = 0.9776198558807373


Epoch 39: 100%|██████████| 20/20 [00:05<00:00,  3.51batch/s, train_loss=0.856]


loss for epoch 39 = 0.9675305804252624


Epoch 40: 100%|██████████| 20/20 [00:06<00:00,  3.29batch/s, train_loss=0.842]

loss for epoch 40 = 0.9592639419555664





In [27]:
accuracy = 0
for x,y in TEST_dataloader:
    b_sup_model.eval()
    x = x.to(device)
    y = y.to(device)
    accuracy += multi_acc(b_sup_model(x), y) / len(TEST_dataloader)
print(f"accuracy with BYOL = {accuracy}")

accuracy with BYOL = 60.71875


# Conclusion:
## The accuracy increased by 13 percent.