In [1]:
import os, re, sys, torch
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from classification_early_stopping import ModelTrainer
from constants import *

In [2]:
def extract_hyperparams_from_folder(folder_name):
    pattern = r"bs=(\d+)_lr=([\de.-]+)_wd=([\de.-]+)_dr=([\d.]+)"
    match = re.search(pattern, folder_name)
    
    if match:
        batch_size = int(match.group(1))
        lr = float(match.group(2))
        weight_decay = float(match.group(3))
        dropout = float(match.group(4))
        return {
            "batch_size": batch_size,
            "lr": lr,
            "weight_decay": weight_decay,
            "dropout": dropout
        }
    else:
        return None

In [3]:
model_folders = [f for f in os.listdir(MODEL_DIR) if os.path.isdir(os.path.join(MODEL_DIR, f))]

In [4]:
model_folders

['bs=8_lr=1e-04_wd=1e-04_dr=0.3',
 'bs=8_lr=7e-04_wd=9e-04_dr=0.2',
 'bs=16_lr=1e-04_wd=3e-04_dr=0.4',
 'bs=32_lr=7e-04_wd=1e-04_dr=0.3',
 'bs=8_lr=9e-04_wd=5e-04_dr=0.4',
 'bs=16_lr=1e-04_wd=1e-04_dr=0.3',
 'bs=8_lr=5e-04_wd=5e-04_dr=0.3',
 'bs=32_lr=3e-04_wd=9e-04_dr=0.3',
 'bs=8_lr=1e-04_wd=7e-04_dr=0.3',
 'bs=8_lr=5e-04_wd=1e-04_dr=0.2',
 'bs=8_lr=3e-04_wd=3e-04_dr=0.2',
 'bs=8_lr=5e-04_wd=5e-04_dr=0.2',
 'bs=8_lr=3e-04_wd=7e-04_dr=0.2',
 'bs=32_lr=3e-04_wd=5e-04_dr=0.3',
 'bs=8_lr=1e-04_wd=9e-04_dr=0.2',
 'bs=16_lr=1e-04_wd=1e-04_dr=0.4',
 'bs=32_lr=5e-04_wd=9e-04_dr=0.3',
 'bs=8_lr=1e-04_wd=7e-04_dr=0.2',
 'bs=32_lr=9e-04_wd=5e-04_dr=0.3',
 'bs=32_lr=7e-04_wd=3e-04_dr=0.3',
 'bs=8_lr=3e-04_wd=3e-04_dr=0.3',
 'bs=32_lr=9e-04_wd=1e-04_dr=0.3',
 'bs=8_lr=7e-04_wd=3e-04_dr=0.2',
 'bs=8_lr=7e-04_wd=9e-04_dr=0.3',
 'bs=8_lr=1e-04_wd=5e-04_dr=0.4',
 'bs=8_lr=9e-04_wd=7e-04_dr=0.4',
 'bs=16_lr=7e-04_wd=3e-04_dr=0.3',
 'bs=8_lr=7e-04_wd=7e-04_dr=0.2',
 'bs=32_lr=5e-04_wd=3e-04_dr=0.3',
 '

In [5]:
# Dictionary to store overfitting epoch for each model
overfitting_epochs = {}

for idx, folder in enumerate(model_folders):
    hyperparams = extract_hyperparams_from_folder(folder)
    
    if hyperparams:
        model_dir = os.path.join(MODEL_DIR, folder)
        chk_path = os.path.join(model_dir, 'last.pth')

        if not os.path.exists(chk_path) or hyperparams['dropout'] != 0.3:
            continue

        # Load the ModelTrainer for the given model
        trainer = ModelTrainer(D1_DATA_DIR, DEVICE, model_dir, chk_path, hyperparams['batch_size'], NUM_EPOCHS, 
                               hyperparams['lr'], MOMENTUM, hyperparams['weight_decay'], MODEL, PREDICTION_ONLY, 
                               CLASS_WEIGHTS, hyperparams['dropout'], None, PATIENCE)

        val_loss = trainer.training_stats['val_loss']

        best_val_loss = float('inf')
        early_stopping_counter = 0
        overfitting_epoch = None

        # Loop through each epoch and check for overfitting
        for epoch in range(len(val_loss)):
            epoch_loss = val_loss[epoch]
            
            if epoch_loss > best_val_loss + DELTA:
                early_stopping_counter += 1
            else:
                early_stopping_counter = 0
                if epoch_loss < best_val_loss:
                    best_val_loss = epoch_loss

            if early_stopping_counter >= PATIENCE:
                overfitting_epoch = epoch + 1 - PATIENCE  # Record the epoch where overfitting started
                break

        if overfitting_epoch:
            overfitting_epochs[folder] = overfitting_epoch
        else:
            overfitting_epochs[folder] = 'No overfitting detected'

Loading checkpoint from /home/priyansh/Downloads/code/weights/d1_cell_balanced_v2/params/bs=8_lr=1e-04_wd=1e-04_dr=0.3/last.pth
Loading checkpoint from /home/priyansh/Downloads/code/weights/d1_cell_balanced_v2/params/bs=32_lr=7e-04_wd=1e-04_dr=0.3/last.pth
Loading checkpoint from /home/priyansh/Downloads/code/weights/d1_cell_balanced_v2/params/bs=16_lr=1e-04_wd=1e-04_dr=0.3/last.pth
Loading checkpoint from /home/priyansh/Downloads/code/weights/d1_cell_balanced_v2/params/bs=8_lr=5e-04_wd=5e-04_dr=0.3/last.pth
Loading checkpoint from /home/priyansh/Downloads/code/weights/d1_cell_balanced_v2/params/bs=32_lr=3e-04_wd=9e-04_dr=0.3/last.pth
Loading checkpoint from /home/priyansh/Downloads/code/weights/d1_cell_balanced_v2/params/bs=8_lr=1e-04_wd=7e-04_dr=0.3/last.pth
Loading checkpoint from /home/priyansh/Downloads/code/weights/d1_cell_balanced_v2/params/bs=32_lr=3e-04_wd=5e-04_dr=0.3/last.pth
Loading checkpoint from /home/priyansh/Downloads/code/weights/d1_cell_balanced_v2/params/bs=32_lr=5e

In [6]:
# Sort the overfitting epochs in descending order, treating "No overfitting detected" as None for sorting
sorted_overfitting_epochs = sorted(
    overfitting_epochs.items(),
    key=lambda x: (x[1] != 'No overfitting detected', x[1] if isinstance(x[1], int) else float('inf')),
    reverse=True
)

In [7]:
# Print the overfitting epoch for each model in the sorted order
for folder, epoch in sorted_overfitting_epochs:
    print(f"Model: {folder}, Overfitting Epoch: {epoch}")

Model: bs=8_lr=9e-04_wd=9e-04_dr=0.3, Overfitting Epoch: 11
Model: bs=8_lr=3e-04_wd=5e-04_dr=0.3, Overfitting Epoch: 10
Model: bs=16_lr=3e-04_wd=5e-04_dr=0.3, Overfitting Epoch: 10
Model: bs=16_lr=9e-04_wd=5e-04_dr=0.3, Overfitting Epoch: 10
Model: bs=32_lr=7e-04_wd=1e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=8_lr=5e-04_wd=5e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=8_lr=1e-04_wd=7e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=16_lr=7e-04_wd=3e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=8_lr=5e-04_wd=3e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=8_lr=3e-04_wd=7e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=32_lr=1e-04_wd=1e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=32_lr=1e-04_wd=5e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=16_lr=3e-04_wd=7e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=32_lr=9e-04_wd=3e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=16_lr=5e-04_wd=7e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=32_lr=1e-04_wd=7e-04_dr=0.3, Overfitting Epoch: 9
Model: bs=8_lr=1e-04_wd=9e-04_dr=0.3, Over