In [1]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import yfinance as yf
from pl_models import RNNModel, LSTMModel, GRUModel
from pl_config import stock_ticker, company_name, num_units, num_layers, dropout_prob, architecture, seq_length, start_date, end_date, num_epochs, learning_rate
import wandb


class StockPredictionModule(pl.LightningModule):
    def __init__(self, model, scaler, train_loader, val_loader, test_loader):
        super().__init__()
        self.model = model
        self.scaler = scaler
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.criterion = nn.MSELoss()

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
        return optimizer

    def training_step(self, batch, batch_idx):
        seqs, labels = batch
        y_pred = self(seqs)
        loss = self.criterion(y_pred, labels)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        seqs, labels = batch
        y_pred = self(seqs)
        loss = self.criterion(y_pred, labels)
        self.log("val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        seqs, labels = batch
        y_pred = self(seqs)
        loss = self.criterion(y_pred, labels)
        self.log("test_loss", loss)
        return loss

    def on_test_epoch_end(self):
        predictions, actuals = [], []
        for seqs, labels in self.test_loader:
            seqs, labels = seqs.to(self.device), labels.to(self.device)
            output = self(seqs)
            predictions.extend(output.view(-1).tolist())
            actuals.extend(labels.view(-1).tolist())
        
        predictions = np.array(predictions)
        actuals = np.array(actuals)
        predictions_original_scale = self.scaler.inverse_transform(predictions.reshape(-1, 1)).flatten()
        actuals_original_scale = self.scaler.inverse_transform(actuals.reshape(-1, 1)).flatten()
        
        plt.figure(figsize=(10, 5))
        plt.plot(actuals_original_scale, label='Actual Price', color='blue')
        plt.plot(predictions_original_scale, label='Predicted Price', color='red', linestyle='--')
        plt.title('Stock Price Prediction')
        plt.xlabel('Time')
        plt.ylabel('Stock Price')
        plt.legend()
        plt.tight_layout()
        
        wandb.log({"Predictions vs Actuals": wandb.Image(plt)})
        plt.close()


def create_sequences(data, seq_length):
    xs, ys = [], []
    for i in range(len(data)-seq_length-1):
        xs.append(data[i:(i+seq_length)])
        ys.append(data[i+seq_length])
    return np.array(xs), np.array(ys)


def load_data():
    stock_df = yf.download(stock_ticker, start=start_date, end=end_date)

    scaler = MinMaxScaler(feature_range=(-1, 1))
    scaled_data = scaler.fit_transform(stock_df[['Close']].values.reshape(-1, 1))
    X, y = create_sequences(scaled_data, seq_length)
    
    X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
    X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
    
    train_loader = DataLoader(TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train)), batch_size=64, shuffle=False)
    val_loader = DataLoader(TensorDataset(torch.FloatTensor(X_val), torch.FloatTensor(y_val)), batch_size=64, shuffle=False)
    test_loader = DataLoader(TensorDataset(torch.FloatTensor(X_test), torch.FloatTensor(y_test)), batch_size=64, shuffle=False)
    
    return train_loader, val_loader, test_loader, scaler


def main():
    # Define configuration for your model and experiment
    config = {
        "dataset": f"{company_name} closing prices",
        "architecture": architecture,
        "num_units": num_units,
        "num_layers": num_layers,
        "dropout": dropout_prob,
        "seq_length": seq_length,
        "start_date": start_date,
        "end_date": end_date,
        "epochs": num_epochs,
        "learning_rate": learning_rate
    }

    # Initialize WandbLogger with the config
    wandb_logger = WandbLogger(project="RNN_single_step_forecasts", log_model="all", config=config)
    
    # Load data
    train_loader, val_loader, test_loader, scaler = load_data()
    
    # Initialize model based on your selected architecture and parameters
    model_config = {
        "input_size": 1,            # Number of features
        "hidden_layer_size": num_units,    # Number of units in hidden layer
        "num_layers": num_layers,            # Number of layers
        "output_size": 1,           # Output size
        "dropout_prob": dropout_prob         # Dropout probability
    }
    if architecture == "RNN":
        model = RNNModel(**model_config)
    elif architecture == "LSTM":
        model = LSTMModel(**model_config)
    elif architecture == "GRU":
        model = GRUModel(**model_config)
    else:
        raise ValueError("Unsupported architecture specified")

    module = StockPredictionModule(model=model, scaler=scaler, train_loader=train_loader, val_loader=val_loader, test_loader=test_loader)

    # Create a Trainer and pass in the wandb_logger
    trainer = Trainer(max_epochs=num_epochs, logger=wandb_logger, gpus=1 if torch.cuda.is_available() else 0, enable_checkpointing=True)
    trainer.fit(module, train_dataloaders=train_loader, val_dataloaders=val_loader)
    trainer.test(dataloaders=test_loader, ckpt_path="best")

    wandb.finish()


if __name__ == "__main__":
    main()





ModuleNotFoundError: No module named 'models'