# Imports

In [1]:
import copy
import random
from functools import wraps

import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.nn.init as init

from torchvision import transforms as T

# Helper Functions

In [2]:
# helper functions

def default(val, def_val):
    return def_val if val is None else val

def flatten(t):
    return t.reshape(t.shape[0], -1)

def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

def get_module_device(module):
    return next(module.parameters()).device

def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

def MaybeSyncBatchnorm(is_distributed = None):
    is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
    return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d

# loss fn

def loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

# augmentation utils

class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# exponential moving average

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new
    
def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# BYOL Models

In [3]:
# MLP class for projector and predictor

def MLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size)
    )

def SimSiamMLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
    return nn.Sequential(
        nn.Linear(dim, hidden_size, bias=False),
        MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, hidden_size, bias=False),
        MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size, bias=False),
        MaybeSyncBatchnorm(sync_batchnorm)(projection_size, affine=False)
    )

# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets

class NetWrapper(nn.Module):
    def __init__(self, backbone, projection_size, projection_hidden_size, use_simsiam_mlp = False, sync_batchnorm = None):
        super().__init__()
        self.backbone = backbone
        self.projector = None
        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size
        self.use_simsiam_mlp = use_simsiam_mlp
        self.sync_batchnorm = sync_batchnorm

    @singleton('projector')
    def _get_projector(self, hidden):
        _, dim = hidden.shape
        create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP
        projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size, sync_batchnorm = self.sync_batchnorm)
        return projector.to(hidden)

    def get_representation(self, x):
        return self.backbone(x) #  (backbone + mlp) features: mlp can be trained instead of finetuning the backbone!

    def forward(self, x, return_projection = True):
        representation = self.get_representation(x)

        if not return_projection:
            return representation

        projector = self._get_projector(representation)
        projection = projector(representation)
        return projection, representation


# main class

class BYOLProj(nn.Module):
    """BYOL model: only projection head can be trained, parameters of the backbone are frozen.
    """

    def __init__(
        self,
        net,
        image_size,
        projection_size = 224,
        projection_hidden_size = 4096,
        augment_fn = None,
        augment_fn2 = None,
        moving_average_decay = 0.99,
        use_momentum = True,
        sync_batchnorm = None
    ):
        super().__init__()
        self.net = net

        # --custom set of augmentations

        DEFAULT_AUG = torch.nn.Sequential(
            # RandomApply(
            #     T.ColorJitter(0.8, 0.8, 0.8, 0.2),
            #     p = 0.3
            # ),
            # T.RandomGrayscale(p=0.2),
            T.RandomPerspective(distortion_scale=0.6, p=0.9),
            T.RandomHorizontalFlip(),
            # RandomApply(
            #     T.GaussianBlur((3, 3), (1.0, 2.0)),
            #     p = 0.2
            # ),
            T.RandomRotation(degrees=(0, 360)),
            T.RandomResizedCrop((image_size, image_size)),
            T.Normalize(
                mean=torch.tensor([0.485, 0.456, 0.406]),
                std=torch.tensor([0.229, 0.224, 0.225])),
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)

        self.online_encoder = NetWrapper(
            net,
            projection_size,
            projection_hidden_size,
            use_simsiam_mlp = not use_momentum,
            sync_batchnorm = sync_batchnorm
        )

        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device)) #! Trick to initialize weights!!!
     

    @singleton('target_encoder')
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        set_requires_grad(target_encoder, False)
        return target_encoder

    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_moving_average(self):
        assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
        assert self.target_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

    def forward(
        self,
        x,
        return_embedding = False,
        return_projection = True
    ):
        assert not (self.training and x.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'

        if return_embedding:
            return self.online_encoder(x, return_projection = return_projection)
        
        # -- create two views: anchor view, target view
        image_one, image_two = self.augment1(x), self.augment2(x)

        images = torch.cat((image_one, image_two), dim = 0)

        online_projections, _ = self.online_encoder(images)
        online_predictions = self.online_predictor(online_projections)

        online_pred_one, online_pred_two = online_predictions.chunk(2, dim = 0)

        with torch.no_grad():
            target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder

            target_projections, _ = target_encoder(images)
            target_projections = target_projections.detach()

            target_proj_one, target_proj_two = target_projections.chunk(2, dim = 0)

        loss_one = loss_fn(online_pred_one, target_proj_two.detach())
        loss_two = loss_fn(online_pred_two, target_proj_one.detach())

        loss = loss_one + loss_two
        return loss.mean()


# Models

## WideResNet50


In [4]:
########################### - Imports - #####################################
import os
import torch
import torch.nn as nn
import torch.nn.init as init
from torchvision import models

#############################################################################


########################### - Models - #####################################
class WideResnet50(nn.Module):
    def __init__(self, embedding_size, is_norm=True, bn_freeze=True):
        super(WideResnet50, self).__init__()

        self.embedding_size = embedding_size
        self.is_norm = is_norm
        self.model = models.wide_resnet50_2(pretrained=True)

        # Freezing the model weights in the backbone
        for param in self.model.parameters():
            param.requires_grad = False

        # Linear projection
        # self.model.fc = nn.Linear(2048, self.embedding_size)
        self.model.fc = nn.Sequential(
                            nn.Linear(2048, self.embedding_size),
                            # nn.BatchNorm1d(self.embedding_size),
                            # nn.ReLU(inplace=True),
                            # nn.Linear(self.embedding_size, self.embedding_size)
                        )
        self._initialize_weights()

        
        # Freezing the layernomralization parameters
        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)


    def forward(self, x):
        x = self.model(x)

        if self.is_norm:
            x = self.l2_norm(x)
        
        return x


    def _initialize_weights(self):
        # init.kaiming_normal_(self.model.fc.weight, mode='fan_out')
        # init.constant_(self.model.fc.bias, 0)
        for m in self.model.fc.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.kaiming_normal_(m.weight)
                m.weight.requires_grad = True
                if m.bias is not None:
                    m.bias.data.zero_()
                    m.bias.requires_grad = True


    def l2_norm(self,input):
        input_size = input.size()
        buffer = torch.pow(input, 2)

        normp = torch.sum(buffer, 1).add_(1e-12)
        norm = torch.sqrt(normp)

        _output = torch.div(input, norm.view(-1, 1).expand_as(input))

        output = _output.view(input_size)

        return output
    
#############################################################################


########################### - Testing - #####################################
# x = torch.randn(3, 3, 224, 224)
# model = WideResnet50(400, True, True)
# print('\n==============\n')
# print(model)
# print(model(x).shape)

## DINOv2

In [5]:
########################### - Imports - #####################################
import os
import torch
import torch.nn as nn
import torch.nn.init as init

from tqdm import tqdm

#############################################################################


########################### - Helper Functions - #####################################
def load_model(model_type='dinov2_vitg14'):
    """Loading the specified pretrained model from DINOv2 repository
        Args:
            model_type: name of the model to be loaded, default is 'dinov2_vitg14'

        Ref:
            https://github.com/facebookresearch/dinov2/blob/main/MODEL_CARD.md

        Returns:
            model: DINOv2 model
            embed_size: size of the embedding at the last layer
    """

    # Dowloading the specified dinov2 model
    model = torch.hub.load('facebookresearch/dinov2', model_type)
    # print('\nDINOv2 model summary:\n', model)

    # Assign the last_layer model type based on the model type
    if model_type == 'dinov2_vitg14':
        embed_size = 1536

    elif model_type == 'dinov2_vitl14':
        embed_size = 1024

    elif model_type == 'dinov2_vitb14':
        embed_size = 768
    
    elif model_type == 'dinov2_vits14':
        embed_size = 384

    return model, embed_size
#############################################################################


########################### - Models - #####################################
class DINOv2(nn.Module):
    def __init__(self, model_type, embedding_size,  is_norm=True, ln_freeze=True):
        super(DINOv2, self).__init__()

        self.embedding_size = embedding_size # For embedding layer
        self.is_norm = is_norm

        # Load the pretrained backbone model
        backbone, last_layer_embed_size = load_model(model_type)
        self.model = backbone

        # Freezing the model weights in the backbone
        for param in self.model.parameters():
            param.requires_grad = False

        # Linear projection
        # self.model.head = nn.Linear(last_layer_embed_size, self.embedding_size)
        self.model.head = nn.Sequential(
                            nn.Linear(last_layer_embed_size, self.embedding_size),
                            # nn.BatchNorm1d(self.embedding_size),
                            # nn.ReLU(inplace=True),
                            # nn.Linear(self.embedding_size, self.embedding_size)
                        )
        self._initialize_weights()

        # Freezing the layernomralization parameters
        if ln_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.LayerNorm):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)


    def forward(self, x):
        x = self.model(x)

        if self.is_norm:
            x = self.l2_norm(x)
        
        return x


    def _initialize_weights(self):
        # init.kaiming_normal_(self.model.head.weight, mode='fan_out')
        # init.constant_(self.model.head.bias, 0)
        for m in self.model.head.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.kaiming_normal_(m.weight)
                m.weight.requires_grad = True
                if m.bias is not None:
                    m.bias.data.zero_()
                    m.bias.requires_grad = True


    def l2_norm(self,input):
        input_size = input.size()
        buffer = torch.pow(input, 2)

        normp = torch.sum(buffer, 1).add_(1e-12)
        norm = torch.sqrt(normp)

        _output = torch.div(input, norm.view(-1, 1).expand_as(input))

        output = _output.view(input_size)

        return output

#############################################################################


########################### - Testing - #####################################
# x = torch.randn(1, 3, 224, 224)
# model = DINOv2('dinov2_vitb14', 400, True, True)
# print('\n==============\n')
# print(model)
# print(model(x).shape)

# Traning Process

In [6]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.2.1-py3-none-any.whl.metadata (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.8/56.8 kB[0m [31m783.5 kB/s[0m eta [36m0:00:00[0m
Downloading lightning-2.2.1-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning
Successfully installed lightning-2.2.1


In [7]:
import os
import sys
import argparse
import numpy as np
import multiprocessing
from pathlib import Path
from PIL import Image, ImageFilter

import logging

import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
import lightning as l
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

In [8]:
# seeding
_GLOBAL_SEED = 0
np.random.seed(_GLOBAL_SEED)
torch.manual_seed(_GLOBAL_SEED)
torch.backends.cudnn.benchmark = True


# logging
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()

# constants

BATCH_SIZE = 32
EPOCHS     = 5 #50
LR         = 1e-4
lr_decay_step = 10
lr_decay_gamma = 0.5
NUM_GPUS   = 2
IMAGE_SIZE = 224 # 256
IMAGE_CROP_SIZE = 224
IMAGE_RESIZE = 256
IMAGE_EXTS = ['.jpg', '.png', '.jpeg']
NUM_WORKERS = multiprocessing.cpu_count()

In [9]:
# pytorch lightning module

class SelfSupervisedLearner(l.LightningModule):
    def __init__(self, net, **kwargs):
        super().__init__()
        # self.save_hyperparameters()
        self.learner = BYOLProj(net, **kwargs)

    def forward(self, images):
        return self.learner(images)

    def training_step(self, images, _):
        loss_list = []
        for image_batch in images:
            multi_view_loss = self.forward(image_batch)
            loss_list.append(multi_view_loss)

        loss = sum(loss_list) / len(loss_list)
        self.log("train_loss", loss, prog_bar=True, on_epoch=True, on_step=True, sync_dist=True)
        return {'loss': loss}

    def validation_step(self, images, _):
        val_loss = self.forward(images)
        self.log("val_loss", val_loss, prog_bar=True, on_epoch=True, on_step=True, sync_dist=True)
        self.log('metric_to_track', val_loss, sync_dist=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LR, weight_decay=1.5e-6)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma = lr_decay_gamma)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=lr_decay_step, eta_min=0)
        lr_scheduler_config = {
            # REQUIRED: The scheduler instance
            "scheduler": scheduler,
            # The unit of the scheduler's step size, could also be 'step'.
            # 'epoch' updates the scheduler on epoch end whereas 'step'
            # updates it after a optimizer update.
            "interval": "epoch",
            # How many epochs/steps should pass between calls to
            # `scheduler.step()`. 1 corresponds to updating the learning
            # rate after every epoch/step.
            "frequency": 1,
            # Metric to to monitor for schedulers like `ReduceLROnPlateau`
            "monitor": "metric_to_track",
            # If set to `True`, will enforce that the value specified 'monitor'
            # is available when the scheduler is updated, thus stopping
            # training if not found. If set to `False`, it will only produce a warning
            "strict": True,
            # If using the `LearningRateMonitor` callback to monitor the
            # learning rate progress, this keyword can be used to specify
            # a custom logged name
            "name": None,
        }
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler_config
        }

    def on_before_zero_grad(self, _):
        if self.learner.use_momentum:
            self.learner.update_moving_average()


# -- dataset transformations
def make_transforms(
    rand_size=224,
    focal_size=224,
    rand_crop_scale=(0.3, 1.0),
    focal_crop_scale=(0.05, 0.3),
    color_jitter=1.0,
    rand_views=2,
    focal_views=10,
):

    def get_color_distortion(s=1.0):
        # s is the strength of color distortion.
        color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
        rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
        rnd_gray = transforms.RandomGrayscale(p=0.2)
        color_distort = transforms.Compose([
            rnd_color_jitter,
            rnd_gray])
        return color_distort

    rand_transform = transforms.Compose([
        transforms.RandomResizedCrop(rand_size, scale=rand_crop_scale),
        transforms.RandomHorizontalFlip(),
        # get_color_distortion(s=color_jitter),
        GaussianBlur(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.485, 0.456, 0.406),
            (0.229, 0.224, 0.225))
    ])
    focal_transform = transforms.Compose([
        transforms.RandomResizedCrop(focal_size, scale=focal_crop_scale),
        transforms.RandomHorizontalFlip(),
        # get_color_distortion(s=color_jitter),
        GaussianBlur(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.485, 0.456, 0.406),
            (0.229, 0.224, 0.225))
    ])

    transform = MultiViewTransform(
        rand_transform=rand_transform,
        focal_transform=focal_transform,
        rand_views=rand_views,
        focal_views=focal_views
    )
    return transform


class MultiViewTransform(object):

    def __init__(
        self,
        rand_transform=None,
        focal_transform=None,
        rand_views=1,
        focal_views=1,
    ):
        self.rand_views = rand_views
        self.focal_views = focal_views
        self.rand_transform = rand_transform
        self.focal_transform = focal_transform

    def __call__(self, img):
        img_views = []

        # -- generate random views
        if self.rand_views > 0:
            img_views += [self.rand_transform(img) for i in range(self.rand_views)]

        # -- generate focal views
        if self.focal_views > 0:
            img_views += [self.focal_transform(img) for i in range(self.focal_views)]

        return img_views


class GaussianBlur(object):
    def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
        self.prob = p
        self.radius_min = radius_min
        self.radius_max = radius_max

    def __call__(self, img):
        if torch.bernoulli(torch.tensor(self.prob)) == 0:
            return img

        radius = self.radius_min + torch.rand(1) * (self.radius_max - self.radius_min)
        return img.filter(ImageFilter.GaussianBlur(radius=radius))


# images dataset
class Identity(): # used for skipping transforms
    def __call__(self, im):
        return im
    
class ScaleIntensities():
    def __init__(self, in_range, out_range):
        """ Scales intensities. For example [-1, 1] -> [0, 255]."""
        self.in_range = in_range
        self.out_range = out_range

    def __oldcall__(self, tensor):
        tensor.mul_(255)
        return tensor

    def __call__(self, tensor):
        tensor = (
            tensor - self.in_range[0]
        ) / (
            self.in_range[1] - self.in_range[0]
        ) * (
            self.out_range[1] - self.out_range[0]
        ) + self.out_range[0]
        return tensor

def expand_greyscale(t):
    return t.expand(3, -1, -1)

class ImagesDataset(Dataset):
    def __init__(self, folder, image_size, transform=None, is_train=True):
        super().__init__()
        self.folder = folder
        self.paths = []

        for path in Path(f'{folder}').glob('**/*'):
            _, ext = os.path.splitext(path)
            if ext.lower() in IMAGE_EXTS:
                self.paths.append(path)

        print(f'{len(self.paths)} images found')

        # self.transform = transforms.Compose([
        #     transforms.RandomResizedCrop(IMAGE_CROP_SIZE) if is_train else Identity(), 
        #     transforms.RandomRotation(degrees=(0, 360)),
        #     transforms.Resize(IMAGE_RESIZE) if not is_train else Identity(),
        #     transforms.CenterCrop(image_size) if not is_train else Identity(),
        #     transforms.ToTensor(),
        #     # ScaleIntensities([0, 1], [0, 255]),
        #     transforms.Lambda(expand_greyscale) if is_train else Identity()
        # ])
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize(image_size),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
                transforms.Lambda(expand_greyscale)
            ])
        else:
            self.transform = transform # focal + random data augmentations

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        img = img.convert('RGB')
        return self.transform(img)

In [10]:
def get_model(model_name):
    if model_name == 'dinov2':
        return DINOv2(model_type='dinov2_vitb14', embedding_size=256)
#     elif model_name == 'resnet50':
#         return Resnet50(embedding_size=256)
    elif model_name == 'wideresnet50':
        return WideResnet50(embedding_size=256)
    else:
        raise ValueError('Unknown model type!!!')

## Training Loop

In [11]:
class Args(argparse.ArgumentParser):
    def __init__(self, model_type, image_folder, use_momentum):
        super().__init__()
        self.model_type = model_type
        self.image_folder = image_folder
        self.use_momentum = use_momentum
        

folder_path = '/kaggle/input/product-image-amazone/Categories'
args = Args(model_type = 'wideresnet50',
           image_folder = folder_path,
           use_momentum = False)
args

Args(prog='ipykernel_launcher.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)

In [12]:
transform = make_transforms()

# Dataset: [train, val] -> train [0, 1, 2] and val [4, 5, 6]
train_dataset = ImagesDataset(os.path.join(args.image_folder, 'train_set'), IMAGE_SIZE, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
val_dataset = ImagesDataset(os.path.join(args.image_folder, 'valid_set'), IMAGE_SIZE, is_train=False)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

base_model = get_model(args.model_type)


model = SelfSupervisedLearner(
    base_model,
    image_size = IMAGE_SIZE,
    projection_size = 256,
    projection_hidden_size = 4096,
    moving_average_decay = 0.99,
    use_momentum = args.use_momentum
)
logger.info('#### - Model - ###')
print(model)
logger.info('#### - Arguements - ###')
print(args)

# -- callbacks
early_stop_fn = EarlyStopping(monitor='val_loss', mode='min', patience=5)
trainer = l.Trainer(
                    callbacks=[early_stop_fn],
                     max_epochs=EPOCHS,
                     accelerator="auto",
                     accumulate_grad_batches = 1,
                     sync_batchnorm = True,
                     log_every_n_steps=1,
                     devices=-1)

trainer.fit(model, train_loader, val_loader)

7020 images found
1755 images found


Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to /root/.cache/torch/hub/checkpoints/wide_resnet50_2-95faca4d.pth
100%|██████████| 132M/132M [00:01<00:00, 100MB/s]


SelfSupervisedLearner(
  (learner): BYOLProj(
    (net): WideResnet50(
      (model): ResNet(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, 

INFO: GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO: Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
INFO: Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
INFO: ----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

2024-03-31 11:01:37.927372: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-31 11:01:37.927554: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already b

Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=5` reached.


In [13]:
!zip -r file.zip /kaggle/working

  adding: kaggle/working/ (stored 0%)
  adding: kaggle/working/__notebook__.ipynb (deflated 85%)
  adding: kaggle/working/lightning_logs/ (stored 0%)
  adding: kaggle/working/lightning_logs/version_0/ (stored 0%)
  adding: kaggle/working/lightning_logs/version_0/checkpoints/ (stored 0%)
  adding: kaggle/working/lightning_logs/version_0/checkpoints/epoch=4-step=550.ckpt (deflated 26%)
  adding: kaggle/working/lightning_logs/version_0/hparams.yaml (stored 0%)
  adding: kaggle/working/lightning_logs/version_0/events.out.tfevents.1711882911.9bd9fa44b6e7.57.0 (deflated 69%)


In [14]:
!ls /kaggle/working

__notebook__.ipynb  file.zip  lightning_logs


In [15]:
from IPython.display import FileLink
FileLink(r'file.zip')