# Import the trained model

In [None]:
import pickle
import mlflow
import os
import pandas as pd
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from utils.utilities import load_checkpoint, save_checkpoint
from utils.metrics import IntegratedEvaluator
import pyarrow.parquet as pq
from omegaconf import OmegaConf
from datasets.dataOps import create_ood_datasets, create_datasets
from hydra.utils import instantiate
from torch.utils.data import DataLoader, Subset

In [None]:
RUN_ID = "5c6f56f74c4146f4b2aedc6a9546816f" # Put here your running id for the S2S_{msd} experiment on the D_{ood-GP} partition

In [None]:
mlflow.set_tracking_uri("file:./mlruns")
client = mlflow.client.MlflowClient()
dico = client.get_run(RUN_ID).to_dictionary()
print(dico["data"]["tags"]["exp_name"])

In [None]:
cfg = OmegaConf.load(os.path.join(dico["info"]["artifact_uri"].removeprefix("file://"), "config_exp.yaml"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
torch.manual_seed(42)

data = {}
for array in ["static_data", "before_ts", "after_ts", "target_ts", "mask_target", "cat_dicos"]:
    with open(f"{cfg.raw_data_folder + array}.pkl", "rb") as f:
        data[array] = pickle.load(f)
table = pq.read_table(cfg.raw_data_folder + cfg.info_ts_file)
ids = table.to_pandas().index.to_list()
list_unic_cat = [len(dico.keys()) for dico in data["cat_dicos"].values()]

train_dataset, val_dataset, test_dataset = create_datasets(ids=ids,
                                                           static_data=data["static_data"],
                                                           before_ts=data["before_ts"],
                                                           after_ts=data["after_ts"],
                                                           target_ts=data["target_ts"],
                                                           mask_target=data["mask_target"],
                                                           train_size=cfg.training.train_size,
                                                           val_size=cfg.training.val_size,
                                                           raw_data_folder=cfg.raw_data_folder,
                                                           means_and_stds_path=cfg.means_and_stds_path,
                                                           )

encoder = instantiate(cfg.model.encoder,
                        list_unic_cat=list_unic_cat).to(device)
decoder = instantiate(cfg.model.decoder).to(device)

optimizer = instantiate(cfg.training.optimizer,
                        params = list(encoder.parameters()) + list(decoder.parameters()))

In [None]:
# Charge trained model
checkpoint = torch.load(dico["info"]["artifact_uri"].removeprefix("file://") + "/best_model/best_model.pth")

load_checkpoint(checkpoint,
                encoder,
                decoder,
                optimizer)

In [None]:
# Freeze encoder and decoder
for param in encoder.parameters():
    param.requires_grad = False

for param in decoder.parameters():
    param.requires_grad = False

In [None]:
# Prepare modified data 

selected_id = ["011a0bab-5f91-4784-a7b5-eb895dcdbcda"]
indices_to_use = [
    i for i, item in enumerate(train_dataset)
    if item['id'] in selected_id
]

# Create a Subset of the dataset using only the objective observations
subset_dataset = Subset(train_dataset, indices_to_use)
loader = DataLoader(subset_dataset, batch_size=1, shuffle=True)

In [None]:
indices_to_use

In [None]:
for batch in loader:
    print(batch["static_data_num"])

In [None]:
with open("data/work_data/means_and_stds.pkl", "rb") as f:
    means_and_stds = pickle.load(f)

In [None]:
class Trainer(object):
    def __init__(self,
                 exp_name,
                 encoder,
                 decoder,
                 learning_rate,
                 num_epochs,
                 train_dataloader,
                 checkpoints_path,
                 device):

        # Core components
        self.exp_name = exp_name
        self.encoder = encoder.to(device)
        self.decoder = decoder.to(device)
        self.device = device

        # Data
        self.train_dataloader = train_dataloader

        # Paths
        self.checkpoints_path = checkpoints_path

        # Training options
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs

        # Internal trackers
        self.best_val_loss = float("inf")
        self.train_losses = []
        self.val_losses = []

    def train_epoch(self,
                    epoch,
                    num_epochs):
            
        self.encoder.train()
        self.decoder.train()
        total_epoch_loss = 0.0

        for batch in tqdm(self.train_dataloader):

            new_data = {}    
            for array in ["before_ts", "after_ts"]:
                with open(f"data/work_data/RCP_85/{array}.pkl", "rb") as f:
                    new_data[array] = pickle.load(f)

            total_batch_loss = 0.0
            outputs = []
            static_data_cat = batch["static_data_cat"].to(self.device)
            before_ts = torch.tensor((new_data["before_ts"] - means_and_stds["before_ts_mean"]) / means_and_stds["before_ts_std"], dtype=torch.float32).to(self.device) 
            after_ts = torch.tensor((new_data["after_ts"] - means_and_stds["after_ts_mean"]) / means_and_stds["after_ts_std"], dtype=torch.float32).to(self.device)
            target_ts = batch["target_ts"].to(self.device)
            mask_target = batch["mask_target"].to(self.device)

            mask = torch.zeros_like(self.static_data_num)
            mask[:, [3,4]] = 1.0

            self.optimizer.zero_grad()
            latent, x_t = self.encoder(self.static_data_num, static_data_cat, before_ts)
            x = torch.cat([x_t.unsqueeze(1), target_ts[:, :-1, :]], dim=1)
            h_0 = latent  # h_0
            outputs, _ = self.decoder(x, h_0, after_ts, ar=False)

            last_idx = mask_target.sum(dim=1) - 1          # (B,)

            total_batch_loss = -outputs[0, last_idx, 3]

            total_batch_loss.backward()
            self.static_data_num.grad *= mask
            self.optimizer.step()
            print(self.static_data_num)

            total_epoch_loss += total_batch_loss.item()

        print(f"TRAIN : Epoch [{epoch+1}/{num_epochs}], Loss: {total_epoch_loss:.4f}")

        ckpt_path = f"{self.checkpoints_path}/checkpoint.pth"
        checkpoint = {
            "epoch": epoch+1,
            "state_encoder_dict": self.encoder.state_dict(),
            "state_decoder_dict": self.decoder.state_dict(),
            }
        save_checkpoint(checkpoint, filename=ckpt_path)

        return total_batch_loss

    def train_loop(self, loader):
        batch = next(iter(loader))
        self.static_data_num = batch["static_data_num"].to(self.device)
        self.static_data_num = self.static_data_num.clone().detach().requires_grad_(True)
        print(self.static_data_num)
        self.optimizer = torch.optim.Adam(params=[self.static_data_num], lr=self.learning_rate)
        for epoch in tqdm(range(self.num_epochs)):
            agr_yield = self.train_epoch(epoch, self.num_epochs)
        return self.static_data_num, agr_yield

In [None]:
os.makedirs(name="agronomic/", exist_ok=True)

In [None]:
trainer = Trainer(exp_name="Agronomic_inverse_problem",
                  encoder=encoder,
                  decoder=decoder,
                  learning_rate=0.001,
                  num_epochs=100,
                  train_dataloader=loader,
                  checkpoints_path="agronomic/",
                  device=device,)

In [None]:
vector, agr_yield = trainer.train_loop(loader)

In [None]:
vector[0][-1].item()

In [None]:
(agr_yield * -1 * means_and_stds["target_ts_std"][3]) + means_and_stds["target_ts_mean"][3]

In [None]:
import numpy as np
theta = np.arctan2(vector[0][-2].item(), vector[0][-1].item())

# wrap angle to [0, 2Ï€)
theta = np.mod(theta, 2 * np.pi)
# recover day of year
day = theta * (365 / (2 * np.pi))

print(day)

In [None]:
import datetime
datetime.datetime.strptime('2035 350', '%Y %j')