In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [2]:
from train import *
from utils import *
from data_load import *

In [4]:
from torchvision import models


In [3]:
!nvidia-smi

Thu Dec 12 09:25:17 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 561.03                 Driver Version: 561.03         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3050 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   49C    P8              9W /   35W |       0MiB /   4096MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [5]:
if torch.cuda.is_available():
    print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA is not available. Using CPU.")
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device used : {device}")

CUDA is available. Using GPU: NVIDIA GeForce RTX 3050 Ti Laptop GPU
Device used : cuda


## Load Datasets

In [6]:
# Définir le chemin du dataset
train_dir = '../SF-MASK-dataset/train'
test_dir = '../SF-MASK-dataset/test'

# Compter les images dans chaque sous-dossier
train_compliant = len(os.listdir(os.path.join(train_dir, 'compliant')))
train_non_compliant = len(os.listdir(os.path.join(train_dir, 'non-compliant')))
test_compliant = len(os.listdir(os.path.join(test_dir, 'compliant')))
test_non_compliant = len(os.listdir(os.path.join(test_dir, 'non-compliant')))

print(f"Training dataset: Compliant: {train_compliant}, Non-compliant: {train_non_compliant}")
print(f"Test dataset: Compliant: {test_compliant}, Non-compliant: {test_non_compliant}")

Training dataset: Compliant: 21384, Non-compliant: 15772
Test dataset: Compliant: 3622, Non-compliant: 5048


In [9]:
output_dir_train = "../SF-MASK-dataset-padded/train"
output_dir_test = "../SF-MASK-dataset-padded/test"

In [10]:
AUGMENT = True
train_dataset, train_subset, val_subset , test_subset, train_loader, val_loader, test_loader = get_dataloaders(
    train_dir=output_dir_train, 
    test_dir=output_dir_test, 
    augment=AUGMENT
)

print(f"We take {DATASET_PROPORTION*100}% of the entire dataset (size {len(train_dataset)})\n")
print(f"Train Subset: {len(train_subset)}, Val subset: {len(val_subset)}, Test subset: {len(test_subset)}")
print(f"Train Loader: {len(train_loader)}, Val Loader: {len(val_loader)}, Test Loader: {len(test_loader)}")

We take 10.0% of the entire dataset (size 37156)

Train Subset: 2972, Val subset: 743, Test subset: 867
Train Loader: 93, Val Loader: 24, Test Loader: 28


## Training

In [11]:
# Define hyperparameters to search
param_grid = {
    "learning_rate": [0.01, 0.001, 0.0001],
    "batch_size": [16, 32, 64]
}

In [12]:
# Perform K-fold cross-validation with hyperparameter tuning
results = kfold_cross_validation_with_hyperparams(
    model_class=models.mobilenet_v2, 
    dataset=train_dataset, 
    criterion=nn.CrossEntropyLoss(), 
    optimizer_class=torch.optim.Adam, 
    param_grid=param_grid, 
    k_folds=5, 
    epochs=30, 
    augment=True,  # Toggle to compare with and without augmentation
    training_time_name="mask_detection"
)

# Save results to JSON for later analysis
save_metrics_to_json(results, "kfold_hyperparam_results.json")



Testing hyperparameters: {'batch_size': 16, 'learning_rate': 0.01}
Fold 1/5 for hyperparameters {'batch_size': 16, 'learning_rate': 0.01}




size of train_loader : 1858, size of val_loader : 465
Training MobileNetV2 for 30 epochs...


KeyboardInterrupt: 