In [None]:
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint
from model import StocksPredictionModel
from stocks_dataset import StocksDataSet
from type_enums.ModelType import ModelType as mt
from type_enums.SplitType import SplitType as st

In [None]:
device = ("cuda" if torch.cuda.is_available() else "cpu")

# Saving the Model

In [None]:
def save_with_kwargs(model, epochs, lr, split_type, standardized = True, serial_num = "01"):
    """Save the model as .pth file

    Args:
        model (StocksPredictionModel class): model, which we want to save as the file
        epochs (integer): how many epochs was used during the training
        lr (float): what learning rate was used during the training
        split_type (SplitType enum): What split type was used during the training
        standardized (bool, optional): was the data standardized with scikitlearn tools during the training
        . Defaults to True.
        serial_num (str, optional): optional string that allows to differ training 
        approaches in file name. Defaults to "01".
    """
    torch.save([model.kwargs, model.state_dict()], 
               f"./Models Examples/{serial_num}_{epochs}epochs_{model.model_type.value}_{lr}lr_{split_type}_std{standardized}.pth")

# Training the model

In [None]:
def train_model(ds, model, learning_rate, epochs):
    """Train the model based on certain criterias

    Args:
        ds (StocksDataSet): data set on which our model will be trained
        model (StocksPredictionModel): model to train
        learning_rate (float): learning rate for the model
        epochs (integer): number of epochs
    """
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(params = model.parameters(), lr=learning_rate)
    
    
    for epoch in range(epochs):
        
        for X_train, X_test, y_train, y_test in ds:
            if X_train is not None:
                try:
                    X_train = torch.Tensor(X_train).squeeze(0)
                    model.train()
                    model = model.to(device)
                    X_train.to(device)
            
                    y_pred = model(X_train)
                
                
                    optimizer.zero_grad()
                    y_pred = y_pred.to(device)
                    y_train = y_train.to(device)
                    loss = loss_fn(y_pred, y_train)
                    loss = loss.to(device)
                
                    loss.backward()
                
                    optimizer.step()
                
                    model.eval()
                    
                
                    with torch.inference_mode():
                    
                        test_pred = model(X_test)
                        test_pred.to(device)
                        y_test = y_test.to(device)
                        test_loss = loss_fn(test_pred, y_test)
                        if epoch == epochs - 1 or epoch % 100 == 0:
                            print(f"Epoch number: {epoch}")
                            print(f"Test Loss is: {test_loss}")
                            print(f"Train Loss is: {loss}")
                except ValueError:
                    print("File skipped because it's too short!")
            print(f"{epoch} epoch finished")
            
    save_with_kwargs(model, epochs, learning_rate, ds.prep_type.value, ds.standardized)

# Training process testing

In [None]:
our_model = train_model(StocksDataSet("../Custom LSTM Model/Data/Stocks Data", 
                                      preparation_type=st.CustomSplit), 
                        StocksPredictionModel(input_size=16,
                                              num_classes = 30, 
                                              hidden_size=256, 
                                              num_layers=1, 
                                              modelType = mt.ComplexModel), 
                        learning_rate= 0.001, 
                        epochs = 10)