In [2]:
import os
import import_ipynb
import matplotlib.pyplot as plt
import time
import torch
import IPython
from IPython.core.display_functions import clear_output
import locations as l
#from orca.orca_state import device


def train_model(model, loss_fn, optimizer, x_train, y_train, x_test, y_test, model_name, batch_size=0):
    
    os.environ['TERM'] = 'xterm'
    best_loss = float('inf')


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    
    model = model.to(device)

    model_folder = l.locations.get_models_dir()
    save_path = os.path.join(model_folder, model_name)
    
    if batch_size != 0:
        dataset = torch.utils.data.TensorDataset(x_train, y_train)
        test_dataset = torch.utils.data.TensorDataset(x_test, y_test)
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    else:
        x_train, y_train = x_train.to(device), y_train.to(device)
        x_test, y_test = x_test.to(device), y_test.to(device)

    def r2_loss(y_pred, y_true):
        ss_total = torch.sum((y_true - torch.mean(y_true)) ** 2)
        ss_residual = torch.sum((y_true - y_pred) ** 2)
        r2 = 1 - (ss_residual / ss_total)
        return 1 - r2  # Loss is 1 - R²

    if batch_size != 0:
        while True:
            IPython.display.clear_output(wait=True)
            model.train()
            for batch_x, batch_y in dataloader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)

                # Forward pass
                outputs = model(batch_x)
                r2 = r2_loss(outputs, batch_y)
                loss = loss_fn(outputs, batch_y)

                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            model.eval()  # Set the model to evaluation mode
            test_losses = []
            test_r2_losses = []
            with torch.no_grad():  # Disable gradient computation
                for test_batch_x, test_batch_y in test_dataloader:
                    test_batch_x, test_batch_y = test_batch_x.to(device), test_batch_y.to(device)
                    test_outputs = model(test_batch_x)
                    test_loss = loss_fn(test_outputs, test_batch_y)
                    test_r2_loss = r2_loss(test_outputs, test_batch_y)
                    test_losses.append(test_loss.item())
                    test_r2_losses.append(test_r2_loss.item())
            
            # Compute the average test loss and R² loss
            average_test_loss = sum(test_losses) / len(test_losses)
            average_test_r2_loss = sum(test_r2_losses) / len(test_r2_losses)
            print(f"Average Test Loss: {average_test_loss:.8f}")
            print(f"Average Test R² Loss: {average_test_r2_loss:.8f}")

            if loss.detach().item() < best_loss:
                best_loss = loss.detach().item()
                torch.save(model.state_dict(), save_path)
                print(f"New best loss: {best_loss:.8f}. Model saved to {save_path}.")
            print(f"Current Loss: {loss.detach().item():.8f}")
            print(f"Current R2 Loss: {r2:.8f}")

    else:
        while True:
            IPython.display.clear_output(wait=True)
            # Forward pass
            model.train()
            outputs = model(x_train)
            loss = loss_fn(outputs, y_train)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            
            #os.system('clear')
            model.eval()
            with torch.no_grad():
                test_pred = model(x_test)
                test_loss = loss_fn(test_pred, y_test)
                r2_test_loss = r2_loss(test_pred, y_test)

            if loss.detach().item() < best_loss:
                best_loss = loss.item()
                torch.save(model.state_dict(), save_path)
                print(f"New best loss: {best_loss:.4f}. Model saved to {save_path}.")

            print(f"Current Loss: {loss.item():.8f}, Test Loss: {test_loss.item():.8f}")
            print(f"Current R2 Loss: {r2_loss(outputs, y_train).item():.8f}, Test R2 Loss: {r2_test_loss:.8f}")
