In [9]:
import os

import torch
import mlflow
from mlflow.tracking import MlflowClient

from db import database as db
from utils.modelling import (
    reshape_and_split,
    normalize_dataset,
    move_to_device,
    vizualize_and_view_prediction,
    vizualize_dataset
)

In [None]:
# TODO: Fix environment variables when testing locally
ML_HOST = None # replace
ML_PORT = None # replace

ML_SERVER = f"http://{ML_HOST}:{ML_PORT}"
mlflow.set_tracking_uri(ML_SERVER)
mlflow.set_registry_uri(ML_SERVER)

client = MlflowClient()

In [None]:
for registrated_model in client.search_registered_models():
    print(f"Name: {registrated_model.name}")

In [None]:
MODEL_NAME = "lstm_model"
for rm in client.search_model_versions("name='{}'".format(MODEL_NAME)):
    print((rm.name, rm.version, rm.run_id))

In [None]:
device = torch.device("cpu")

In [None]:
MODEL_NAME = "lstm_model"
MODEL_VERSION = "1"

model_uri = f"models:/{MODEL_NAME}/{MODEL_VERSION}"
model = mlflow.pytorch.load_model(model_uri=model_uri, map_location=device)
model.double()
model.to(device)
model.eval()

### Load dataset

In [None]:
SPLIT_SIZE = 5
SPLIT_RATIO = 0.3
DATASET_NAME = "Simple Vibes"

dataset = db.load_time_series_as_numpy(DATASET_NAME)
dataset = reshape_and_split(dataset, split_ratio=SPLIT_RATIO, split_size=SPLIT_SIZE)
dataset = normalize_dataset(dataset)

x_train, y_train, x_test, y_test = move_to_device(data=dataset, device=device)

vizualize_dataset(x_train)

### Evaluate

In [None]:
future = 1000

predictions = model(x_test, future=future)

In [None]:
predictions = predictions.detach().numpy()
n_samples = x_test.size(1)

vizualize_and_view_prediction(predictions, n_samples, future)