### Imports

In [1]:
import pickle
import torch
import torch.nn as nn
import numpy as np

import scripts.data_loader as data_loader
import src.training as training
import src.evaluation as evaluation

from dataclasses import dataclass, asdict
from src.types import * 
from src.models.model1 import OB_05Model

# Training

In [2]:
training_dataset, validation_dataset, testing_dataset = data_loader.split_images_dataset()
torch.save(training_dataset, "models/training_dataset.pth")
torch.save(validation_dataset, "models/validation_dataset.pth")
torch.save(testing_dataset, "models/testing_dataset.pth")

training_set_loader = data_loader.create_data_loader(training_dataset)
validation_set_loader = data_loader.create_data_loader(validation_dataset)
testing_set_loader = data_loader.create_data_loader(testing_dataset)

In [3]:
learning_rate = 0.001
model = OB_05Model()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

training_config = training.TrainingConfig(
    training_set_loader=training_set_loader,
    validation_set_loader=validation_set_loader,
    testing_set_loader=testing_set_loader,

    epochs=10,
    learning_rate=learning_rate,

    classes=data_loader.get_trainset().classes,
    model=model,
    criterion=criterion,
    optimizer=optimizer
)

In [4]:
training_logger = training.train_model(training_config)

with open("models/training_logger.pkl", "wb") as file:
    pickle.dump(training_logger, file)

torch.save(model.state_dict(), 'models/model.pth')

Epoch 1/10:
	Training precision: 0.4198
	Training recall: 0.1478
	Training accuracy: 0.4198
	Training f1-score: 0.2186

	Validation precision: 0.5879
	Validation recall: 0.1851
	Validation accuracy: 0.5879
	Validation f1-score: 0.2816


Epoch 2/10:
	Training precision: 0.6148
	Training recall: 0.1904
	Training accuracy: 0.6148
	Training f1-score: 0.2907

	Validation precision: 0.6254
	Validation recall: 0.1924
	Validation accuracy: 0.6254
	Validation f1-score: 0.2942


Epoch 3/10:
	Training precision: 0.6778
	Training recall: 0.2020
	Training accuracy: 0.6778
	Training f1-score: 0.3112

	Validation precision: 0.7061
	Validation recall: 0.2069
	Validation accuracy: 0.7061
	Validation f1-score: 0.3201


Epoch 4/10:
	Training precision: 0.6963
	Training recall: 0.2052
	Training accuracy: 0.6963
	Training f1-score: 0.3170

	Validation precision: 0.6888
	Validation recall: 0.2039
	Validation accuracy: 0.6888
	Validation f1-score: 0.3147


Epoch 5/10:
	Training precision: 0.7358
	Training re