# Import libraries

In [1]:
from src.preprocessing.dataLoader_CelebA import get_partitioned_dataloaders, create_subset_loader
from src.ml.resNet50 import SiameseResNet
import torch
from src.ml.hyperparam_study import run_optuna_study

# Import losses

In [2]:
from pytorch_metric_learning.losses import ContrastiveLoss
from pytorch_metric_learning.losses import MarginLoss
from pytorch_metric_learning.losses import AngularLoss
from pytorch_metric_learning.losses import ArcFaceLoss
from pytorch_metric_learning.losses import TupletMarginLoss
from pytorch_metric_learning.losses import MultiSimilarityLoss
from pytorch_metric_learning.losses import CosFaceLoss
from pytorch_metric_learning.losses import HistogramLoss

# 1. Load the data

In [3]:
IMAGE_DIR = "data/celeba/img_align_celeba"
LABEL_FILE = "data/celeba/identity_CelebA.txt"
PARTITION_FILE = "data/celeba/list_eval_partition.csv"
IMG_SIZE = 224
BATCH_SIZE = 32
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
train_loader, val_loader, test_loader = get_partitioned_dataloaders(image_dir= IMAGE_DIR,
                                                               label_file= LABEL_FILE,
                                                               partition_file= PARTITION_FILE,
                                                               batch_size=BATCH_SIZE,
                                                               img_size=IMG_SIZE)

# Create the model

In [5]:
model = SiameseResNet()

# Find best Hyperparameters

In [6]:
train_loader_study = (create_subset_loader(train_loader,10000))
val_loader_study = (create_subset_loader(train_loader,2000))
study = run_optuna_study(train_loader_study, val_loader_study, n_trials=10, study_name="siamese_constrastive_HP_study")
best_params = study.best_params

[I 2025-05-25 18:14:31,850] Using an existing study with name 'siamese_constrastive_HP_study' instead of creating a new one.


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 0 - Iteration 0 - Training Loss: 31.4555
Epoch 0 - Iteration 10 - Training Loss: 22.2276
Epoch 0 - Iteration 20 - Training Loss: 24.1295
Epoch 0 - Iteration 30 - Training Loss: 16.9374
Epoch 0 - Iteration 40 - Training Loss: 20.4434
Epoch 0 - Iteration 50 - Training Loss: 19.8936
Epoch 0 - Iteration 60 - Training Loss: 15.2382
Epoch 0 - Iteration 70 - Training Loss: 21.5090
Epoch 0 - Iteration 80 - Training Loss: 21.4501
Epoch 0 - Iteration 90 - Training Loss: 21.2901
Epoch 0 - Iteration 100 - Training Loss: 20.0849
Epoch 0 - Iteration 110 - Training Loss: 18.6803
Epoch 0 - Iteration 120 - Training Loss: 16.1173
Epoch 0 - Iteration 130 - Training Loss: 15.9739
Epoch 0 - Iteration 140 - Training Loss: 14.0331
Epoch 0 - Iteration 150 - Training Loss: 17.9047
Epoch 0 - Iteration 160 - Training Loss: 16.3593
Epoch 0 - Iteration 170 - Training Loss: 15.6632
Epoch 0 - Iteration 180 - Training Loss: 17.0294
Epoch 0 - Iteration 190 - Training Loss: 16.4768
Epoch 0 - Iteration 200 - Train

KeyboardInterrupt: 

# Train Model

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.00001)

In [7]:
contrastive_loss = ContrastiveLoss(neg_margin=1.0, pos_margin=0)
histogram_loss = HistogramLoss(n_bins=100)
multi_similarity_loss = MultiSimilarityLoss()
angular_loss = AngularLoss(alpha=180)
tuplet_margin_loss = TupletMarginLoss(margin=0.5)

In [12]:
results = model.train_model(
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=contrastive_loss,
        optimizer=optimizer,
        num_epochs=10,
        device=DEVICE,
        patience=5,
        experiment_name='SiameseResNet',
        tuning_mode=False
    )

Epoch 1/10 - Training:   0%|          | 1/5087 [00:04<6:11:47,  4.39s/it]

Epoch 0 - Iteration 0 - Training Loss: 1.44293690


Epoch 1/10 - Training:   0%|          | 6/5087 [00:28<7:00:29,  4.97s/it]

Epoch 0 - Iteration 5 - Training Loss: 0.00000000


Epoch 1/10 - Training:   0%|          | 11/5087 [00:50<6:23:49,  4.54s/it]

Epoch 0 - Iteration 10 - Training Loss: 0.02955031


Epoch 1/10 - Training:   0%|          | 15/5087 [01:30<8:30:14,  6.04s/it] 


KeyboardInterrupt: 

# Plot the results

In [None]:
! mlflow ui --port 5000