In [5]:
import os
import json
from transformers import AutoModel

def find_best_checkpoint(base_dir):
    best_metric = float('inf')  # Adjust if the best metric is max e.g., accuracy
    best_checkpoint = None

    # Loop through all checkpoint directories
    for dirname in os.listdir(base_dir):
        if dirname.startswith("checkpoint-"):
            checkpoint_dir = os.path.join(base_dir, dirname)
            state_file = os.path.join(checkpoint_dir, 'trainer_state.json')

            if os.path.exists(state_file):
                with open(state_file, 'r') as file:
                    state_data = json.load(file)
                    # Assuming lower metric is better; adjust if necessary
                    if state_data['best_metric'] > best_metric:
                        best_metric = state_data['best_metric']
                        best_checkpoint = checkpoint_dir

    return best_checkpoint, best_metric

def load_model_from_checkpoint(checkpoint_path):
    if checkpoint_path:
        # Load the model
        model = AutoModel.from_pretrained(checkpoint_path)
        return model
    else:
        print("No checkpoint found.")
        return None

# Specify the base directory containing all checkpoints
base_dir = '/home/yusuf/python/SeaTurtle/model_results/ViT3_31test'
best_checkpoint, best_metric = find_best_checkpoint(base_dir)
print(f"Best checkpoint directory: {best_checkpoint}")
print(f"Best model metric (eval loss): {best_metric}")

best_model = load_model_from_checkpoint(best_checkpoint)


Best checkpoint directory: None
Best model metric (eval loss): inf
No checkpoint found.
