In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

In [2]:
from src.models.byol.BYOL import BYOL
from src.scripts.etl_process.ETLProcessor import ETLProcessor
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
input_dim = 3
hidden_dim = 128
residual_hiddens = 64
num_residual_layers = 2  # 1 or 2
mlp_hidden_dim = 256  # 256, 512, 1024, 2048
mask_size_ratio = 0.35  # 0.3, 0.4, 0.5
tau = 0.99  # 0.95, 0.99, 0.999
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
byol = BYOL(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    num_residual_layers=num_residual_layers,
    residual_hiddens=residual_hiddens,
    mlp_hidden_dim=mlp_hidden_dim,
    mask_size_ratio=mask_size_ratio,
    tau=tau,
    device=device,
)

In [5]:
etl = ETLProcessor(
    kaggle_dataset="mahmudulhaqueshawon/cat-image",
    raw_dir="../data/raw_data",
    split_dir="../data/data_splits",
)
train_loader, _, _ = etl.process()

[INFO] Copying dataset 'mahmudulhaqueshawon/cat-image' to ../data/raw_data...
[INFO] Dataset ready at ../data/raw_data


In [None]:
from tqdm import tqdm
import torch
import torch.optim as optim


def fit_byol(
    model: BYOL,
    optimiser: optim.Optimizer,
    dataloader: torch.utils.data.DataLoader,
    epochs: int,
    device: str = "cuda",
    print_metrics: bool = True,
    model_name: str = "byol_model",
):
    model.to(device)
    best_loss = float("inf")

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0

        for x_batch, _ in tqdm(dataloader, desc=f"Epoch {epoch}"):
            x_batch = x_batch.to(device)

            q, z = model(x_batch)

            loss = model.byol_loss(q, z)

            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

            model.update_target_network()

            epoch_loss += loss.item()
            num_batches += 1

        avg_loss = epoch_loss / num_batches
        if not torch.isfinite(loss).all():
            print("NaN loss encountered, stopping training.")
            break

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), f"{model_name}.pt")

        if print_metrics:
            print(f"Epoch {epoch}: BYOL training loss = {avg_loss:.8f}")

In [8]:
byol = byol.to(device)
fit_byol(
    model=byol,
    optimiser=optim.Adam(byol.parameters(), lr=1e-4),
    dataloader=train_loader,
    epochs=20,
    device=device,
)

Epoch 0: 100%|██████████| 44/44 [00:15<00:00,  2.90it/s]


Epoch 0: BYOL training loss = 1.47985332


Epoch 1: 100%|██████████| 44/44 [00:13<00:00,  3.38it/s]


Epoch 1: BYOL training loss = 0.81724281


Epoch 2: 100%|██████████| 44/44 [00:13<00:00,  3.32it/s]


Epoch 2: BYOL training loss = 0.50021682


Epoch 3: 100%|██████████| 44/44 [00:12<00:00,  3.43it/s]


Epoch 3: BYOL training loss = 0.35095221


Epoch 4: 100%|██████████| 44/44 [00:13<00:00,  3.33it/s]


Epoch 4: BYOL training loss = 0.26830548


Epoch 5: 100%|██████████| 44/44 [00:13<00:00,  3.26it/s]


Epoch 5: BYOL training loss = 0.24031349


Epoch 6: 100%|██████████| 44/44 [00:12<00:00,  3.39it/s]


Epoch 6: BYOL training loss = 0.20490843


Epoch 7: 100%|██████████| 44/44 [00:13<00:00,  3.35it/s]


Epoch 7: BYOL training loss = 0.19691454


Epoch 8: 100%|██████████| 44/44 [00:13<00:00,  3.25it/s]


Epoch 8: BYOL training loss = 0.18380386


Epoch 9: 100%|██████████| 44/44 [00:14<00:00,  3.10it/s]


Epoch 9: BYOL training loss = 0.15496917


Epoch 10: 100%|██████████| 44/44 [00:14<00:00,  3.10it/s]


Epoch 10: BYOL training loss = 0.15259458


Epoch 11: 100%|██████████| 44/44 [00:13<00:00,  3.22it/s]


Epoch 11: BYOL training loss = 0.15282445


Epoch 12: 100%|██████████| 44/44 [00:12<00:00,  3.41it/s]


Epoch 12: BYOL training loss = 0.12741959


Epoch 13: 100%|██████████| 44/44 [00:12<00:00,  3.41it/s]


Epoch 13: BYOL training loss = 0.11881730


Epoch 14: 100%|██████████| 44/44 [00:13<00:00,  3.32it/s]


Epoch 14: BYOL training loss = 0.13963510


Epoch 15: 100%|██████████| 44/44 [00:14<00:00,  3.08it/s]


Epoch 15: BYOL training loss = 0.13001266


Epoch 16: 100%|██████████| 44/44 [00:13<00:00,  3.16it/s]


Epoch 16: BYOL training loss = 0.13474760


Epoch 17: 100%|██████████| 44/44 [00:13<00:00,  3.31it/s]


Epoch 17: BYOL training loss = 0.12845911


Epoch 18: 100%|██████████| 44/44 [00:13<00:00,  3.31it/s]


Epoch 18: BYOL training loss = 0.10844661


Epoch 19: 100%|██████████| 44/44 [00:13<00:00,  3.21it/s]

Epoch 19: BYOL training loss = nan



