In [None]:
from utils import evaluate, train


def train_model(model, criterion, optimizer, num_epochs):
    scheduler = torch.optim.OneCycleLR(
        optimizer,
        max_lr=1e-1,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.3,
        div_factor=25,
        final_div_factor=1000,
        anneal_strategy='cos'
    )

    best_eval_f1 = 0
    patience = 5
    patience_counter = 0
    for epoch in range(num_epochs):

        train_loss, train_acc, train_prec, train_rec, train_f1 = train(model, criterion, optimizer,device,train_loader)
        eval_acc, eval_prec, eval_rec, eval_f1 = evaluate(model,device,test_loader)

        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, Prec: {train_prec:.4f}, Rec: {train_rec:.4f},F1: {train_f1:.4f}")
        wandb.log({
                "Training/Loss": train_loss,
                "Training/Accuracy": train_acc,
                "Training/Precision": train_prec,
                "Training/Recall": train_rec,
                "Training/F1_Score": train_f1,
                "Learning_Rate": optimizer.param_groups[0]['lr']
            })

        print(f"Eval  - Acc: {eval_acc:.4f}, Prec: {eval_prec:.4f}, Rec: {eval_rec:.4f}, F1: {eval_f1:.4f}")
        wandb.log({
        "Evaluation/Accuracy": eval_acc,
        "Evaluation/Precision": eval_prec,
        "Evaluation/Recall": eval_rec,
        "Evaluation/F1_Score": eval_f1
        })

        if eval_f1 > best_eval_f1:
            best_eval_f1 = eval_f1
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping на епосі {epoch+1}")
                break
        
        # Крок scheduler
        scheduler.step()

    return model

model = train_model(model, criterion, optimizer, EPOCHS_NUM)