In [1]:
import torch

from src.utils import initialize_model, save_model
from src.train import train_model
from src.dataset import load_data, prepare_data_loaders, prepare_test_loader
from src.visualize import plot_loss, plot_accuracy, plot_f1_precision_recall, plot_auc
import src.config as config

ModuleNotFoundError: No module named 'src'

In [None]:
def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    train_val_data, test_data, all_labels = load_data()

    train_loader, valid_loader = prepare_data_loaders(
        train_val_data, all_labels)
    test_loader = prepare_test_loader(test_data, all_labels)

    model, criterion, optimizer, scheduler = initialize_model(
        device, num_classes=len(all_labels))
    trained_model, best_trained_model, metrics_history = train_model(
        model, criterion, optimizer, scheduler, train_loader, valid_loader, device, num_epochs=config.NUM_EPOCHS, model_save=True, save_path=config.MODEL_SAVE_PATH)

    # Save the model with the best validation accuracy
    save_model(best_trained_model, 'model_weights/final_best_model.pth')
    save_model(trained_model, 'model_weights/complete_trained_model.pth')

    # Plot the metrics
    save_path = config.GRAPHS_SAVE_PATH

    plot_loss(metrics_history, save_path)
    plot_accuracy(metrics_history, save_path)
    plot_f1_precision_recall(metrics_history, save_path)
    plot_auc(metrics_history, save_path)


if __name__ == '__main__':
    main()
