### Imports

In [None]:
import pickle
import torch
import torch.nn as nn
import numpy as np

import scripts.data_loader as data_loader
import src.training as training
import src.evaluation as evaluation

from dataclasses import dataclass, asdict
from src.types import * 
from src.models.main_model import OB_05Model
from src.models.main_model_v1 import OB_05Model_Variant1
from src.models.main_model_v2 import OB_05Model_Variant2 

# Training

In [None]:
training_dataset, validation_dataset, testing_dataset = data_loader.split_images_dataset()
torch.save(training_dataset, "models/training_dataset.pth")
torch.save(validation_dataset, "models/validation_dataset.pth")
torch.save(testing_dataset, "models/testing_dataset.pth")

training_set_loader = data_loader.create_data_loader(training_dataset)
validation_set_loader = data_loader.create_data_loader(validation_dataset)
testing_set_loader = data_loader.create_data_loader(testing_dataset)

In [None]:
model = OB_05Model()
initial_learning_rate = 0.0001
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=initial_learning_rate, weight_decay=5e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5)

training_config = training.TrainingConfig(
    output_dir=r"models/",
    
    training_set_loader=training_set_loader,
    validation_set_loader=validation_set_loader,
    testing_set_loader=testing_set_loader,

    epochs=100,

    classes=data_loader.get_trainset().classes,
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler
)

In [None]:
training_logger = training.train_model(training_config)

with open("models/training_logger.pkl", "wb") as file:
    pickle.dump(training_logger, file)

torch.save(model.state_dict(), 'models/model.pth')

# For the variant #1

In [None]:
training_dataset, validation_dataset, testing_dataset = data_loader.split_images_dataset()
torch.save(training_dataset, "model_v1/training_dataset.pth")
torch.save(validation_dataset, "model_v1/validation_dataset.pth")
torch.save(testing_dataset, "model_v1/testing_dataset.pth")

training_set_loader = data_loader.create_data_loader(training_dataset)
validation_set_loader = data_loader.create_data_loader(validation_dataset)
testing_set_loader = data_loader.create_data_loader(testing_dataset)

model = OB_05Model_Variant1()

initial_learning_rate = 0.0001
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=initial_learning_rate, weight_decay=5e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5)

training_config = training.TrainingConfig(
    output_dir=r"model_v1/",

    training_set_loader=training_set_loader,
    validation_set_loader=validation_set_loader,
    testing_set_loader=testing_set_loader,

    epochs=100,

    classes=data_loader.get_trainset().classes,
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler
)


# Initialize weights
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)


model.apply(weights_init)

training_logger = training.train_model(training_config)

with open("model_v1/training_logger.pkl", "wb") as file:
    pickle.dump(training_logger, file)

torch.save(model.state_dict(), 'model_v1/model.pth')

In [None]:
training_dataset, validation_dataset, testing_dataset = data_loader.split_images_dataset()
torch.save(training_dataset, "model_v2/training_dataset.pth")
torch.save(validation_dataset, "model_v2/validation_dataset.pth")
torch.save(testing_dataset, "model_v2/testing_dataset.pth")

training_set_loader = data_loader.create_data_loader(training_dataset)
validation_set_loader = data_loader.create_data_loader(validation_dataset)
testing_set_loader = data_loader.create_data_loader(testing_dataset)

model = OB_05Model_Variant2()

initial_learning_rate = 0.0001
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=initial_learning_rate, weight_decay=5e-2)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=5)

training_config = training.TrainingConfig(
    models_output_dir=r"model_v2/",

    training_set_loader=training_set_loader,
    validation_set_loader=validation_set_loader,
    testing_set_loader=testing_set_loader,

    epochs=100,

    classes=data_loader.get_trainset().classes,
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler
)


# Initialize weights
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight)


model.apply(weights_init)

training_logger = training.train_model(training_config)

with open("model_v2/training_logger.pkl", "wb") as file:
    pickle.dump(training_logger, file)

torch.save(model.state_dict(), 'model_v2/model.pth')
