# Contrastive learning

This notebook runs Sim-CLR on our images (square or circle).

In [1]:
import os

from cnn_framework.utils.data_loader_generators.data_loader_generator import DataLoaderGenerator
from cnn_framework.utils.data_managers.default_data_manager import DefaultDataManager
from cnn_framework.utils.lr_schedulers.linear_warmup_cosine_annealing_lr import (
    LinearWarmupCosineAnnealingLR,
)
from cnn_framework.utils.metrics.positive_pair_matching_metric import PositivePairMatchingMetric
from cnn_framework.utils.optimizers.lars import create_optimizer_lars
from cnn_framework.utils.model_managers.contrastive_model_manager import ContrastiveModelManager
from cnn_framework.utils.losses.info_nce_loss import InfoNceLoss
from cnn_framework.utils.create_dummy_data_set import generate_data_set

from cnn_framework.dummy_sim_clr.data_set import SimCLRDataSet
from cnn_framework.dummy_sim_clr.model import ResNetSimCLR
from cnn_framework.dummy_sim_clr.model_params import SimCLRModelParams


In [2]:
params = SimCLRModelParams()
params.update()

# Create data set if needed
if not os.path.exists(params.data_dir):
    generate_data_set(params.data_dir)
    print(f"\nData set created in {params.data_dir}")

print(f"\nModel will be saved in {params.models_folder}")
print(f"Predictions will be saved in {params.output_dir}")
print(f"Tensorboard logs will be saved in {params.tensorboard_folder_path}")

Model time id: 20230928-144116-local
epochs 50 | batch 32 | lr 0.0375 | weight decay 1e-06 | dropout 0.0

Model will be saved in C:\Users\thoma\models/local/sim_clr/20230928-144116-local
Predictions will be saved in C:\Users\thoma\predictions/local/sim_clr/20230928-144116-local
Tensorboard logs will be saved in C:\Users\thoma\tensorboard/local/20230928-144116-local_sim_clr


In [3]:
loader_generator = DataLoaderGenerator(params, SimCLRDataSet, DefaultDataManager)
train_dl, val_dl, test_dl = loader_generator.generate_data_loader()

### Data source ###
train data is loaded from C:\Users\thoma\data\dummy - 80% elements
val data is loaded from C:\Users\thoma\data\dummy - 10% elements
test data is loaded from C:\Users\thoma\data\dummy - 10% elements
###################
train has 160 images.
val has 20 images.
test has 20 images.
###################


In [4]:
model = ResNetSimCLR(
        nb_input_channels=len(params.c_indexes) * len(params.z_indexes),
    )
manager = ContrastiveModelManager(model, params, PositivePairMatchingMetric)

optimizer = create_optimizer_lars(
    model,
    lr=params.learning_rate,
    momentum=0.9,
    weight_decay=params.weight_decay,
    bn_bias_separately=True,
    epsilon=1e-5,
)
loss_function = InfoNceLoss(manager.device, params.temperature)  # define the loss function
lr_scheduler = LinearWarmupCosineAnnealingLR(
    optimizer, warmup_epochs=params.nb_warmup_epochs, max_epochs=params.num_epochs
)

Current commit hash: 3531c7365ab2d6c158b506b2444c84fade7107e6


In [5]:
manager.fit(train_dl, val_dl, optimizer, loss_function, lr_scheduler=lr_scheduler)

Training in progress: 100.0% | Local step 5 | Epoch 50
Best model saved at epoch 45.

Training successfully finished in 0:01:26.267799.


In [6]:
manager.predict(test_dl)

Model evaluation in progress: 100.0% | Batch #0
Average PositivePairMatchingMetric: 1.0
