## Measure time

In [1]:
import time
start_time = time.time()

## Reproducible results

In [2]:
from determinism import Determinism
determinism = Determinism(seed=42).sow()  # Keep this before any torch import
print("Training results should now be reproducible.")

Training results should now be reproducible.


## Setup dataset and hyperparameters
To investigate the impact of imbalanced training data on fine-tuning performance, we simulate class imbalance by restricting the training set to 20% of the images for each cat breed while keeping all dog images intact. We do not create a validation set as we are not tuning hyperparameters.

In [3]:
%reload_ext autoreload
%autoreload 2
from datasets import DatasetParams
from training import TrainParams, AdamParams
from augmentation import AugmentationParams

# Map breed to family (ex. 0 (Abbyssinian) to 0 (Cat))
breed_family_idx = [0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1,
                    1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1]

dataset_params = DatasetParams(
    splitting_seed=determinism.seed,
    shuffler_seed=determinism.seed,
    batch_size=32,
    # Reduce cat data by 80 %
    class_fractions = tuple(0.2 if x == 0 else 1.0 for x in breed_family_idx),
    validation_set_fraction=0,  # no validation set
)

baseline_params = TrainParams(
    seed=determinism.seed,
    architecture="resnet50",
    n_epochs=3,
    optimizer=AdamParams(
        learning_rate=1e-3,
        weight_decay=0,
    ),
    freeze_layers=True,
    unfreezing_epochs=None,
    augmentation=AugmentationParams(
        enabled=False,
        transform=None,
    ),
    validation_freq=0, # no validation set
)

## Experiments
Now we will conduct experiments comparing three configurations under the imbalanced training setup: (1) standard cross-entropy loss, (2) weighted cross-entropy loss, and (3) over-sampling of minority classes. Other parameters will remain fixed.

### 1. Standard cross-entropy loss
We check how much the test performance drops without any compensation.

In [4]:
from evaluation import evaluate_final_test_accuracy
evaluate_final_test_accuracy(dataset_params, baseline_params, determinism, trials=3)

Test size: 3669
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]



Epoch [1/3], Loss: 2.1248, Train Acc: 62.83%
Epoch [2/3], Loss: 0.7762, Train Acc: 90.76%
Epoch [3/3], Loss: 0.4636, Train Acc: 94.06%
Total elapsed: 405.52s, average per update step: 1.57s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 84.328 %
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch [1/3], Loss: 2.1152, Train Acc: 63.38%
Epoch [2/3], Loss: 0.7780, Train Acc: 90.25%
Epoch [3/3], Loss: 0.4587, Train Acc: 94.46%
Total elapsed: 402.34s, average per update step: 1.54s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 83.701 %
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch [1/3], Loss: 2.1169, Train Acc: 63.71%
Epoch [2/3], Loss: 0.7739, Train Acc: 90.76%
Epoch [3/3], Loss: 0.4612, Train Acc: 93.95%
Total elapsed: 389.29s, average per update step: 1.49s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 83.620 %
Test Accuracy Mean: 83.88 %
Test Accuracy Standard Error: 0.22 percentage points


### 2. Weighted cross-entropy loss
We add weights to the cross entropy loss function to compensate. Since we reduced the number of cats to 1/5, a natural weight to try is 5.

In [5]:
cat_vs_dog_ratio = 5
weights = tuple(
    cat_vs_dog_ratio if fam_idx == 0 else 1.0
    for fam_idx in breed_family_idx
)
baseline_params.loss_weights = weights
dataset_params.oversampling_weights = None
evaluate_final_test_accuracy(dataset_params, baseline_params, determinism, trials=3)

Test size: 3669
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]



Epoch [1/3], Loss: 2.3536, Train Acc: 63.27%
Epoch [2/3], Loss: 0.9247, Train Acc: 91.31%
Epoch [3/3], Loss: 0.5488, Train Acc: 94.10%
Total elapsed: 388.45s, average per update step: 1.48s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 85.582 %
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch [1/3], Loss: 2.3485, Train Acc: 63.34%
Epoch [2/3], Loss: 0.9287, Train Acc: 90.95%
Epoch [3/3], Loss: 0.5429, Train Acc: 94.21%
Total elapsed: 389.48s, average per update step: 1.50s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 85.010 %
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch [1/3], Loss: 2.3517, Train Acc: 63.67%
Epoch [2/3], Loss: 0.9264, Train Acc: 91.68%
Epoch [3/3], Loss: 0.5409, Train Acc: 93.95%
Total elapsed: 392.06s, average per update step: 1.50s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 84.737 %
Test Accuracy Mean: 85.11 %
Test Accuracy Standard Error: 0.25 percentage points


### 3. Over-sampling of minority classes
Next we try oversampling.

In [6]:
baseline_params.loss_weights = None
dataset_params.oversampling_weights = weights
evaluate_final_test_accuracy(dataset_params, baseline_params, determinism, trials=3)

Test size: 3669
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]



Epoch [1/3], Loss: 2.1325, Train Acc: 64.70%
Epoch [2/3], Loss: 0.7638, Train Acc: 93.11%
Epoch [3/3], Loss: 0.4216, Train Acc: 94.90%
Total elapsed: 391.27s, average per update step: 1.50s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 87.108 %
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch [1/3], Loss: 2.1389, Train Acc: 63.49%
Epoch [2/3], Loss: 0.7745, Train Acc: 92.96%
Epoch [3/3], Loss: 0.4432, Train Acc: 95.01%
Total elapsed: 390.45s, average per update step: 1.49s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 87.653 %
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch [1/3], Loss: 2.1272, Train Acc: 66.09%
Epoch [2/3], Loss: 0.7501, Train Acc: 94.57%
Epoch [3/3], Loss: 0.4159, Train Acc: 95.53%
Total elapsed: 391.36s, average per update step: 1.49s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 87.163 %
Test Accuracy Mean: 87.31 %
Test Accuracy Standard Error: 0.17 percentage points


### 4. Both
Finally we try both at the same time, using the same weight for both.

In [7]:
baseline_params.loss_weights = weights
dataset_params.oversampling_weights = weights
evaluate_final_test_accuracy(dataset_params, baseline_params, determinism, trials=3)

Test size: 3669
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]



Epoch [1/3], Loss: 2.0150, Train Acc: 50.40%
Epoch [2/3], Loss: 0.7323, Train Acc: 91.02%
Epoch [3/3], Loss: 0.3800, Train Acc: 94.39%
Total elapsed: 391.93s, average per update step: 1.50s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 87.299 %
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch [1/3], Loss: 2.0266, Train Acc: 50.00%
Epoch [2/3], Loss: 0.7332, Train Acc: 91.09%
Epoch [3/3], Loss: 0.3969, Train Acc: 94.28%
Total elapsed: 390.64s, average per update step: 1.50s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 87.953 %
Trying to load trainer from disk...
Trainer not found. Retraining...


Update step:   0%|          | 0/258 [00:00<?, ?it/s]

Epoch [1/3], Loss: 2.0277, Train Acc: 51.43%
Epoch [2/3], Loss: 0.7260, Train Acc: 92.56%
Epoch [3/3], Loss: 0.3820, Train Acc: 94.57%
Total elapsed: 391.94s, average per update step: 1.50s


Evaluating:   0%|          | 0/115 [00:00<?, ?it/s]

Test Accuracy: 87.463 %
Test Accuracy Mean: 87.57 %
Test Accuracy Standard Error: 0.20 percentage points


## Time elapsed

In [8]:
elapsed = time.time() - start_time
print(f"Total elapsed time: {elapsed:.2f} seconds")

Total elapsed time: 6863.60 seconds
