In [None]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import yfinance as yf
import matplotlib.pyplot as plt
import wandb
import os
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from pmdarima import auto_arima
from models import RNNModel, LSTMModel, GRUModel
from config import stock_ticker, architecture, seq_length, start_date, end_date, num_epochs, learning_rate, wandb_config, model_config
from sklearn.model_selection import train_test_split


class StockPredictionModule(pl.LightningModule):
    def __init__(self, model, scaler, train_loader, val_loader, test_loader, test_dates, arima_results):
        super().__init__()
        self.model = model
        self.scaler = scaler
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.test_dates = test_dates
        self.arima_results = arima_results
        self.criterion = nn.MSELoss()

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        seqs, labels = batch
        y_pred = self(seqs)
        loss = self.criterion(y_pred, labels)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        seqs, labels = batch
        y_pred = self(seqs)
        loss = self.criterion(y_pred, labels)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        seqs, labels = batch
        y_pred = self(seqs)
        loss = self.criterion(y_pred, labels)
        self.log("test_loss", loss)
        return loss

    def on_test_epoch_end(self):
        predictions, actuals = [], []
        for seqs, labels in self.test_loader:
            seqs, labels = seqs.to(self.device), labels.to(self.device)
            output = self(seqs)
            predictions.extend(output.view(-1).tolist())
            actuals.extend(labels.view(-1).tolist())
        
        predictions = np.array(predictions)
        actuals = np.array(actuals)
        predictions_original_scale = list(self.scaler.inverse_transform(predictions.reshape(-1, 1)).flatten())
        actuals_original_scale = list(self.scaler.inverse_transform(actuals.reshape(-1, 1)).flatten())
        baseline_original_scale = [actuals_original_scale[0]] + actuals_original_scale[:-1]
        arima_predictions = self.arima_results.predict(n_periods=len(actuals_original_scale))
        
        fig, ax = plt.subplots(figsize=(15, 7))
        ax.plot(self.test_dates, actuals_original_scale, label='Actual Price', color='black', linestyle='-', marker='o')
        ax.plot(self.test_dates, predictions_original_scale, label='Predicted Price', color='green', linestyle='-')
        ax.plot(self.test_dates, baseline_original_scale, label='Baseline', color='blue', linestyle='-', marker='o')
        ax.plot(self.test_dates, arima_predictions, label='ARIMA', color='orange', linestyle='-', marker='o')
        ax.set_title('Stock Price Prediction')
        ax.set_xlabel('Date')
        ax.set_ylabel('Stock Price')
        ax.legend()
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
        plt.tight_layout()
        plt.plot()
        plt.show()
        
        filename = "plot.png"
        fig.savefig(filename)
        wandb.log({"Stock Price Prediction": wandb.Image(filename)})
        os.remove(filename)
        plt.close(fig)

        actuals_original_scale = np.array(actuals_original_scale)
        predictions_original_scale = np.array(predictions_original_scale)
        baseline_original_scale = np.array(baseline_original_scale)

        model_mse = mean_squared_error(actuals_original_scale, predictions_original_scale)
        model_rmse = np.sqrt(model_mse)
        model_mae = mean_absolute_error(actuals_original_scale, predictions_original_scale)
        model_r2 = r2_score(actuals_original_scale, predictions_original_scale)
        model_mape = np.mean(np.abs((actuals_original_scale - predictions_original_scale) / actuals_original_scale))
        
        baseline_mse = mean_squared_error(actuals_original_scale, baseline_original_scale)
        baseline_rmse = np.sqrt(baseline_mse)
        baseline_mae = mean_absolute_error(actuals_original_scale, baseline_original_scale)
        baseline_r2 = r2_score(actuals_original_scale, baseline_original_scale)
        baseline_mape = np.mean(np.abs((actuals_original_scale - baseline_original_scale) / (actuals_original_scale + 1e-8)))

        arima_mse = mean_squared_error(actuals_original_scale, arima_predictions)
        arima_rmse = np.sqrt(arima_mse)
        arima_mae = mean_absolute_error(actuals_original_scale, arima_predictions)
        arima_r2 = r2_score(actuals_original_scale, arima_predictions)
        arima_mape = np.mean(np.abs((actuals_original_scale - arima_predictions) / actuals_original_scale))

        model_metrics = {
            "mse": model_mse,
            "rmse": model_rmse,
            "mae": model_mae,
            "mape": model_mape,
            "r2": model_r2,
        }

        baseline_metrics = {
            "mse": baseline_mse,
            "rmse": baseline_rmse,
            "mae": baseline_mae,
            "mape": baseline_mape,
            "r2": baseline_r2,
        }

        arima_metrics = {
            "mse": arima_mse,
            "rmse": arima_rmse,
            "mae": arima_mae,
            "mape": arima_mape,
            "r2": arima_r2,
        }

        model_baseline_performance_metrics = {
            "mse": round((baseline_mse / model_mse - 1) * 100, 2),
            "rmse": round((baseline_rmse / model_rmse - 1) * 100, 2),
            "mae": round((baseline_mae / model_mae - 1) * 100, 2),
            "mape": round((baseline_mape / model_mape - 1) * 100, 2),
            "r2": round((model_r2 / baseline_r2 - 1) * 100, 2),
        }

        model_arima_performance_metrics = {
            "mse": round((arima_mse / model_mse - 1) * 100, 2),
            "rmse": round((arima_rmse / model_rmse - 1) * 100, 2),
            "mae": round((arima_mae / model_mae - 1) * 100, 2),
            "mape": round((arima_mape / model_mape - 1) * 100, 2),
            "r2": round((model_r2 / arima_r2 - 1) * 100, 2),
        }

        metrics_table = wandb.Table(columns=["metric", "model", "baseline", "arima", "model-baseline performance comparison [%]", "model-arima performance comparison [%]"])

        for metric in model_metrics.keys():
            metrics_table.add_data(metric, model_metrics[metric], baseline_metrics[metric], arima_metrics[metric], 
                                   model_baseline_performance_metrics[metric], model_arima_performance_metrics[metric])

        wandb.log({"metrics": metrics_table})


def create_sequences(data, seq_length):
    xs = []
    ys = []
    for i in range(len(data)-seq_length-1):
        x = data[i:(i+seq_length)]
        y = data[i+seq_length]
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)


def load_data(stock_ticker, start_date, end_date, seq_length):

    stock_df = yf.download(stock_ticker, start=start_date, end=end_date)

    split_idx = int(len(stock_df) * 0.8) + seq_length
    train_df = stock_df.iloc[:split_idx]
    temp_df = stock_df.iloc[split_idx:]

    scaler = MinMaxScaler(feature_range=(-1, 1))
    train_normalized = scaler.fit_transform(train_df)
    X_train, y_train = create_sequences(train_normalized, seq_length)

    temp_normalized = scaler.transform(temp_df)
    X_temp, y_temp = create_sequences(temp_normalized, seq_length)
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.5, random_state=42
    )

    train_loader = torch.utils.data.DataLoader(dataset=torch.utils.data.TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train)), batch_size=64, shuffle=False)
    val_loader = torch.utils.data.DataLoader(dataset=torch.utils.data.TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(y_val)), batch_size=64, shuffle=False)
    test_loader = torch.utils.data.DataLoader(dataset=torch.utils.data.TensorDataset(torch.FloatTensor(X_test), torch.FloatTensor(y_test)), batch_size=64, shuffle=False)

    return train_loader, val_loader, test_loader, scaler


def main():

    torch.set_float32_matmul_precision("medium")
    wandb_logger = WandbLogger(project="RNN_single_step_forecasts", log_model="all", config=wandb_config)

    train_loader, val_loader, test_loader, scaler, test_dates, auto_arima_model = load_data()

    if architecture == "RNN":
        model = RNNModel(**model_config)
    elif architecture == "LSTM":
        model = LSTMModel(**model_config)
    elif architecture == "GRU":
        model = GRUModel(**model_config)
    else:
        raise ValueError("Unsupported architecture specified")

    module = StockPredictionModule(model=model, scaler=scaler, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, 
                                   test_dates=test_dates, arima_results=auto_arima_model)

    # Device agnostic initialization
    if torch.cuda.is_available():   # Check for GPU availability
        accelerator = "gpu"
        devices = 1
    elif hasattr(torch, 'has_mps') and torch.backends.mps.is_built():  # Check for MPS availability (Apple Silicon)
        accelerator = "mps"
        devices = 1
    else:
        accelerator = None  # Defaults to CPU
        devices = None  # Ignored for CPU

    trainer = Trainer(max_epochs=num_epochs, logger=wandb_logger, accelerator=accelerator, devices=devices, enable_checkpointing=True)
    trainer.fit(module, train_dataloaders=train_loader, val_dataloaders=val_loader)
    trainer.test(dataloaders=test_loader, ckpt_path="best")

    wandb.finish()


if __name__ == "__main__":
    main()

