In [None]:
import torch
from torchvision import transforms as T
from lightly.transforms import SimCLRTransform, utils
from datasets import create_dataset, create_stratified_bootstrap_dataloader
from models import BYOLModel
import pytorch_lightning as pl
import os
import copy
import gc
import matplotlib.pyplot as plt
from itertools import product

  from .autonotebook import tqdm as notebook_tqdm


# Seed

In [None]:
SEED = 42

def seed_everything(seed: int=42):
    pl.seed_everything(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    ####### Normaly you would also need to seed those generators but `pytorch_lightning` does it in one func
    # random.seed(seed)
    # np.random.seed(seed)
    # torch.manual_seed(seed)
    ######
    torch.cuda.manual_seed(seed) # Don't know if pytorch lightning does this
    torch.cuda.manual_seed_all(seed) # Don't know if pytorch lightning does this
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(SEED)

Seed set to 42


# Transformacje

In [None]:
test_transform = T.v2.Compose(
    [
        T.Resize((224, 224)),
        T.v2.ToImage(),
        T.v2.ToDtype(torch.float32, scale=True),
        T.Normalize(
            mean=utils.IMAGENET_NORMALIZE["mean"],
            std=utils.IMAGENET_NORMALIZE["std"],
        )
    ]
)
scratch_transform = T.v2.Compose(
    [
        T.RandomResizedCrop((224, 224)),
        T.RandomHorizontalFlip(),
        T.v2.ToImage(),
        T.v2.ToDtype(torch.float32, scale=True),
        T.Normalize(
            mean=utils.IMAGENET_NORMALIZE["mean"],
            std=utils.IMAGENET_NORMALIZE["std"],
        )
    ]
)

byol_transform = SimCLRTransform(
    input_size=(224, 224),
    vf_prob=0.5,
    rr_prob=0.5,
)



# Zbiorki

In [None]:
train_full_cifar10_byol, train_ssl_cifar10_byol, train_cifar10_byol, test_cifar10_byol, labels_cifar10 = create_dataset("CIFAR10", 0.9, byol_transform, scratch_transform, test_transform, "data", False)

Length of entire train dataset:  50000
Length of SSL train dataset:  45000
Length of classification train dataset:  5000
Length of test dataset:  10000


In [None]:
train_full_cifar100_byol, train_ssl_cifar100_byol, train_cifar100_byol, test_cifar100_byol, labels_cifar100 = create_dataset("CIFAR100", 0.9, byol_transform, scratch_transform, test_transform, "data", False)

Length of entire train dataset:  50000
Length of SSL train dataset:  45000
Length of classification train dataset:  5000
Length of test dataset:  10000


# Hiperparametry

In [None]:
### PARAMETERS ###
BATCH_SIZE = 128
NUM_EPOCHS = 20
LEARNING_RATE = [0.001, 0.01]
BACKBONE_TYPE = ['pretrained', 'random']
TAU = [0.98, 0.996]
NUM_WORKERS = 3
device = "gpu" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
#GARBAGE COLLECTOR FAJNA SPRAWA - BEZ NIEGO VRAMu BRAKUJE
if device == "gpu":
    torch.cuda.empty_cache()
    gc.collect()
    torch.set_float32_matmul_precision('high')

Using device: gpu


# CIFAR10

In [None]:
dl_train_cifar10_byol = create_stratified_bootstrap_dataloader(
    train_cifar10_byol, 
    labels=labels_cifar10,
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    seed=SEED,
    shuffle=True, 
    drop_last=True,
    pin_memory=True,
)

dl_valid_cifar10_byol = torch.utils.data.DataLoader(
    test_cifar10_byol, 
    batch_size=BATCH_SIZE, 
    num_workers=NUM_WORKERS, 
    shuffle=False, 
)

In [None]:
from pytorch_lightning.callbacks import Callback
import os

class SaveAtEpochsCallback(Callback):
    def __init__(self, save_epochs, dirpath="checkpoints"):
        super().__init__()
        self.save_epochs = set(save_epochs)
        self.dirpath = dirpath
        os.makedirs(self.dirpath, exist_ok=True)

    def on_train_epoch_end(self, trainer, pl_module):
        current_epoch = trainer.current_epoch  # epoka 0-based
        if current_epoch + 1 in self.save_epochs:
            filename = f"model_epoch_{current_epoch}.ckpt"
            path = os.path.join(self.dirpath, filename)
            trainer.save_checkpoint(path)
            print(f"Zapisano model po epoce {current_epoch}: {path}")

In [None]:
for lr, backbone_type, tau in product(LEARNING_RATE, BACKBONE_TYPE, TAU):
    byol_model = BYOLModel(
        backbone_type=backbone_type + '_resnet18',
        lr=lr,
        max_epochs=NUM_EPOCHS,
        tau=tau,
    )

    dirpath = f"checkpoints/byol_cifar10_{backbone_type}_{lr}_{tau}"

    checkpoint_callback = SaveAtEpochsCallback(
        save_epochs=[10, 15, 20],
        dirpath=dirpath
    )

    trainer = pl.Trainer(
        accelerator='gpu',
        devices=1,
        max_epochs=NUM_EPOCHS,
        log_every_n_steps=1,
        callbacks=[checkpoint_callback],
        default_root_dir=dirpath,
    )

    trainer.fit(byol_model, dl_train_cifar10_byol, dl_valid_cifar10_byol)
