In [1]:
import matplotlib.pyplot as plt
import time
import import_ipynb
import torch
import IPython
from IPython.core.display_functions import clear_output
#from orca.orca_state import device

# The train_model function is designed to train a PyTorch model using a specified loss function and optimizer.
# It supports GPU acceleration if available and saves the model's state whenever a new best loss is achieved.
# The function also provides real-time feedback on the training and test loss during the training process.

# Parameters:
# - model: A PyTorch model instance to be trained (torch.nn.Module).
# - loss_fn: The loss function used to compute the error between predictions and ground truth (e.g., torch.nn.MSELoss).
# - optimizer: The optimization algorithm used to update the model's parameters (e.g., torch.optim.Adam).
# - x_train: The training input data (PyTorch tensor).
# - y_train: The training target data (PyTorch tensor).
# - x_test: The test input data (PyTorch tensor).
# - y_test: The test target data (PyTorch tensor).
# - save_path: The file path where the model's state will be saved when a new best loss is achieved.



def train_model(model, loss_fn, optimizer, x_train, y_train, x_test, y_test, save_path):

    best_loss = float('inf')
    start_time = time.time()
    times, losses = [], []
    """plt.ion()
    fig, ax = plt.subplots(figsize=(10, 5))"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    model = model.to(device)
    x_train, y_train = x_train.to(device), y_train.to(device)
    x_test, y_test = x_test.to(device), y_test.to(device)


    while True:
        # Forward pass
        outputs = model(x_train)
        loss = loss_fn(outputs, y_train)
        test_loss = loss_fn(model(x_test), y_test)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #if scheduler:
         #   scheduler.step(loss)

        # Record time and loss
        #elapsed_time = time.time() - start_time
        #times.append(elapsed_time)
        #losses.append(loss.item())

        if loss.item() < best_loss:
            best_loss = loss.item()
            torch.save(model.state_dict(), save_path)
            print(f"New best loss: {best_loss:.4f}. Model saved to {save_path}.")

        """ax.clear()
        ax.plot(times, losses, label="Training Loss")
        ax.set_xlabel("Time (seconds)")
        ax.set_ylabel("Loss")
        ax.set_title("Loss vs Time")
        ax.legend()
        ax.grid()
        plt.pause(0.01)  # Pause briefly to update the plot"""

        print(f"Current Loss: {loss.item():.8f}, Test Loss: {test_loss.item():.8f}")


# Example Usage:
# model = MyModel()
# loss_fn = torch.nn.MSELoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# train_model(
#     model=model,
#     loss_fn=loss_fn,
#     optimizer=optimizer,
#     x_train=x_train,
#     y_train=y_train,
#     x_test=x_test,
#     y_test=y_test,
#     save_path="best_model.pth"
# )

In [None]:
i