In [1]:
import torch
from torchvision import transforms as T
from lightly.transforms import SimCLRTransform, DINOTransform, MAETransform, MoCoV2Transform, utils
from datasets import create_dataset
from models import SimCLRModel
import pytorch_lightning as pl
import os
import copy
import gc
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


# Seed

In [2]:
SEED = 42

def seed_everything(seed: int=42):
    pl.seed_everything(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    ####### Normaly you would also need to seed those generators but `pytorch_lightning` does it in one func
    # random.seed(seed)
    # np.random.seed(seed)
    # torch.manual_seed(seed)
    ######
    torch.cuda.manual_seed(seed) # Don't know if pytorch lightning does this
    torch.cuda.manual_seed_all(seed) # Don't know if pytorch lightning does this
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(SEED)

Seed set to 42


# Transformacje

In [3]:
test_transform = T.v2.Compose(
    [
        T.Resize((224, 224)),
        T.v2.ToImage(),
        T.v2.ToDtype(torch.float32, scale=True),
        T.Normalize(
            mean=utils.IMAGENET_NORMALIZE["mean"],
            std=utils.IMAGENET_NORMALIZE["std"],
        )
    ]
)
scratch_transform = T.v2.Compose(
    [
        T.RandomResizedCrop((224, 224)),
        T.RandomHorizontalFlip(),
        T.v2.ToImage(),
        T.v2.ToDtype(torch.float32, scale=True),
        T.Normalize(
            mean=utils.IMAGENET_NORMALIZE["mean"],
            std=utils.IMAGENET_NORMALIZE["std"],
        )
    ]
)

simclr_transform = SimCLRTransform(
    input_size=(224, 224),
    vf_prob=0.5,
    rr_prob=0.5,
)



# Zbiorki

In [4]:
train_full_cifar10_simclr, train_ssl_cifar10_simclr, train_cifar10_simclr, test_cifar10_simclr = create_dataset("CIFAR10", 0.9, simclr_transform, scratch_transform, test_transform, "data", False)

Length of entire train dataset:  50000
Length of SSL train dataset:  45000
Length of classification train dataset:  5000
Length of test dataset:  10000


# Hiperparametry

In [5]:
### PARAMETERS ###
BATCH_SIZE = 128
NUM_EPOCHS = 20
LEARNING_RATE = [0.1, 0.001]
NUM_WORKERS = 3
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
#GARBAGE COLLECTOR FAJNA SPRAWA - BEZ NIEGO VRAMu BRAKUJE
if device == "gpu":
    torch.cuda.empty_cache()
    gc.collect()
    torch.set_float32_matmul_precision('high')

Using device: cuda


# Dataloadery i trening

In [6]:
dl_train_cifar10_simclr = torch.utils.data.DataLoader(
    train_ssl_cifar10_simclr,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=True
)
dl_val_cifar10_simclr = torch.utils.data.DataLoader(
    test_cifar10_simclr,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)

In [7]:
simclr_model = SimCLRModel(lr=LEARNING_RATE[0], max_epochs=NUM_EPOCHS)



In [9]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="checkpoints/cifar10_simclr",
    every_n_epochs=2
)

trainer = pl.Trainer(
    accelerator='cuda', 
    devices=1, 
    max_epochs=NUM_EPOCHS, 
    log_every_n_steps=10,
    # callbacks=[checkpoint_callback],
)

trainer.fit(simclr_model, dl_train_cifar10_simclr, dl_val_cifar10_simclr)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                 | Params | Mode
----------------------------------------------------------------
0 | backbone        | Sequential           | 11.2 M | eval
1 | projection_head | SimCLRProjectionHead | 328 K  | eval
2 | criterion       | NTXentLoss           | 0      | eval
----------------------------------------------------------------
11.5 M    Trainable params
0         Non-trainable params
11.5 M    Total params
46.022    Total estimated model params size (MB)
0         Modules in train mode
77        Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

ValueError: too many values to unpack (expected 2)