<a href="https://colab.research.google.com/github/Sambosis/Historic_Crypto/blob/main/untitled56.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install necessary packages
# Uncomment the following lines if running in a new environment
# !pip install fluidstack -q
# !pip install pytorch_lightning tensorflow icecream tensorboardX rich wandb -q

# Standard Library Imports
import os
import io
import time
import random
from dataclasses import dataclass
import multiprocessing
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

# Third-Party Imports
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_curve
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar
from icecream import ic
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from PIL import Image
import requests
from rich.console import Console
from rich.table import Table
from rich.text import Text
from rich.box import ROUNDED
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback
import torch.distributed as dist
from scipy.ndimage import gaussian_filter1d

# Initialize Rich Console
console = Console()
num_cpus = multiprocessing.cpu_count()
print(f"Number of CPU cores available: {num_cpus}")

# Configuration Dataclass
@dataclass
class Config:
    # read the file "version" and increment the version number
    with open("version", "r") as f:
        version = int(f.read())
        VERSION_N = version + 1
        # print(f"Version number: {VERSION_N}")
        f.close()
    with open("version", "w") as f:
        f.write(str(VERSION_N))
        f.close()
    # VERSION_N: int = 87
    RECORDS_TO_LOAD: int = 1205040
    N_PAST: int = 3 * 12 * 3  # 1 week of 10-minute intervals
    N_FUTURE: int = 1 * 12 * 2  # 1 day of 10-minute intervals
    BATCH_SIZE: int = 2000
    HIDDEN_SIZE: int = 256
    NUM_LAYERS: int = 2
    NUM_EPOCHS: int = 150
    HOT_RESTART: bool = False
    TRAIN_FIRST: bool = True
    EPOCH_TO_RESTART: int = 50
    BATCH_FACTOR: int = 81
    DEBUG_FREQ: int = 180
    num_cpus = multiprocessing.cpu_count()
    NUM_WORKERS = (num_cpus // 4 - 4) if num_cpus > 16 else 4
    DEBUG_ON: bool = False
    DATA_URL: str = 'https://sambo.us-iad-1.linodeobjects.com/fillnan_combined_df.csv'
    DATA_FILE: str = './data/fill_nan_df.csv'
    MODEL_PATH: str = "/teamspace/studios/this_studio/models/TransformerModel351/model-351-epoch=148-val_loss=0.71.ckpt"
    MODEL_SAVE_PATH: str = f'./models/TransformerModel{VERSION_N}'
    DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    EPSILON: float = 1e-4

# Initialize Configuration
cfg = Config()

# Set Random Seed for Reproducibility
pl.seed_everything(40, workers=True)

# Print Device Information
print(f"Using device: {cfg.DEVICE}")
os.makedirs(cfg.MODEL_SAVE_PATH, exist_ok=True)

# Initialize IceCream Debugging
if cfg.DEBUG_ON:
    ic.enable()
else:
    ic.disable()

# Callback to Update Percentile Cutoff
class UpdatePercentileCutoffCallback(Callback):
    def __init__(self, reduction_threshold=1.9, reduction_factor=0.9):
        super().__init__()
        self.reduction_threshold = reduction_threshold
        self.reduction_factor = reduction_factor

    def on_validation_epoch_end(self, trainer, pl_module):
        # Skip during sanity check to prevent freezing
        if trainer.sanity_checking:
            return

        # Only the main process (rank 0) determines if reduction is needed
        if trainer.is_global_zero:
            avg_reward = trainer.callback_metrics.get('val/reward', 0)

            if avg_reward > self.reduction_threshold:
                old_perc_cutoff = pl_module.criterion.get_perc_cutoff()
                new_perc_cutoff = old_perc_cutoff * self.reduction_factor
                pl_module.criterion.set_perc_cutoff(new_perc_cutoff)
                pl_module.criterion.perc_cutoff_buffer.fill_(new_perc_cutoff)
                print(f"PercentileCutoffCallback: Reducing perc_cutoff from {old_perc_cutoff:.5f} to {new_perc_cutoff:.5f}")

                # Log reduction event to WandB
                pl_module.logger.experiment.log({
                    "percentile_cutoff_reduction": new_perc_cutoff,
                    "avg_reward": avg_reward
                })

# PositionalEncoding Class
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))  # (d_model/2,)

        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)

        self.register_buffer('pe', pe)

    def forward(self, x):
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return x

# Transformer-Based Model
class CryptoTransformer(nn.Module):
    def __init__(
        self,
        input_size,
        d_model=256,
        nhead=8,
        num_encoder_layers=2,
        num_decoder_layers=2,
        dim_feedforward=2048,
        dropout=0.3,
        activation="gelu",
        n_future=24,
        num_outputs=24,
        max_seq_length=5000
    ):
        super(CryptoTransformer, self).__init__()

        self.input_size = input_size
        self.d_model = d_model
        self.n_future = n_future
        self.num_outputs = num_outputs

        # Input linear layer
        self.input_fc = nn.Linear(input_size, d_model)

        # Positional Encoding
        self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_length)
        self.pos_decoder = PositionalEncoding(d_model, max_len=max_seq_length)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            batch_first=True  # Added for compatibility with batch_first=True
        )

        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            batch_first=True  # Added for compatibility with batch_first=True
        )

        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Output linear layer
        self.output_fc = nn.Linear(d_model, num_outputs)

        # Layer normalization
        self.layer_norm = nn.LayerNorm(d_model)

        # Dropout
        self.dropout = nn.Dropout(dropout)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask  # (sz, sz)

    def forward(self, src, tgt):
        """
        Args:
            src: (batch_size, n_past, num_features)
            tgt: (batch_size, n_future, num_features)
        Returns:
            out: (batch_size, n_future, num_outputs)
        """
        batch_size = src.size(0)

        # Input embedding
        src = self.input_fc(src) * np.sqrt(self.d_model)  # (batch_size, n_past, d_model)
        tgt = self.input_fc(tgt) * np.sqrt(self.d_model)  # (batch_size, n_future, d_model)

        # Add positional encoding
        src = self.pos_encoder(src)  # (batch_size, n_past, d_model)
        tgt = self.pos_decoder(tgt)  # (batch_size, n_future, d_model)

        # Create masks
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)  # (n_future, n_future)

        # Transformer forward pass
        memory = self.transformer_encoder(src)  # (batch_size, n_past, d_model)
        output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)  # (batch_size, n_future, d_model)

        # Final linear layer
        out = self.output_fc(output)  # (batch_size, n_future, num_outputs)

        return out  # (batch_size, n_future, num_outputs)

# Custom Balanced Loss Function
class BalancedCryptoLoss(nn.Module):
    def __init__(self, config):
        super(BalancedCryptoLoss, self).__init__()
        # Register perc_cutoff as a buffer for automatic synchronization
        self.register_buffer('perc_cutoff_buffer', torch.tensor(0.015))
        self.config = config
        self.mse_weight = 900.0
        self.mae_weight = 25.0
        self.max_diff_weight = 3.0
        self.balance_weight = 3.0
        self.direction_weight = 0.001
        self.mean_diff_weight = 15.0
        self.perc_diff_weight = 15.0
        self.within_1pct_reward_weight = 5.0
        self.reward_scaling = 5.0
        self.epsilon = config.EPSILON
        self.debug_freq = config.DEBUG_FREQ
        self.epoch = 0
        self.mean_mean_diff = 0.0
        self.reward = 0.0

    def directional_loss(self, preds, target):
        direction_pred = (preds[:, 1:] - preds[:, :-1]).sign()
        direction_true = (target[:, 1:] - target[:, :-1]).sign()

        # Convert signs to 0 and 1
        direction_pred = (direction_pred + 1) / 2
        direction_true = (direction_true + 1) / 2

        # Clamp values to prevent BCE from receiving exact 0 or 1
        direction_pred = torch.clamp(direction_pred, 1e-7, 1 - 1e-7)
        direction_true = torch.clamp(direction_true, 1e-7, 1 - 1e-7)

        return F.binary_cross_entropy(direction_pred, direction_true).mean() * self.direction_weight

    def mse_loss_component(self, y_pred, y_true):
        return F.mse_loss(y_pred, y_true) * self.mse_weight

    def mae_loss_component(self, y_pred, y_true):
        return F.l1_loss(y_pred, y_true) * self.mae_weight

    def percentage_diff_component(self, y_pred, y_true):
        perc_diff = torch.abs((y_pred - y_true) / (self.epsilon + y_true))
        self.mean_mean_diff = torch.mean(perc_diff).item()
        return (torch.mean(perc_diff) * self.perc_diff_weight) ** 2

    def max_diff_component(self, perc_diff):
        max_diffs, _ = torch.max(perc_diff, dim=1)
        return torch.mean(max_diffs) * self.max_diff_weight

    def imbalance_component(self, perc_diff):
        overpredict = torch.relu(perc_diff)
        underpredict = torch.relu(-perc_diff)
        imbalance = torch.abs(torch.mean(overpredict, dim=1) - torch.mean(underpredict, dim=1))
        return torch.mean(imbalance) * self.balance_weight

    def reward_component(self, y_pred, y_true):
        """
        Calculates the reward component based on the predicted and true values.

        Args:
            y_pred (torch.Tensor): The predicted values.
            y_true (torch.Tensor): The true values.

        Returns:
            torch.Tensor: The reward component calculated based on the percentage difference between
                          the predicted and true values.
        """
        percentage_diff = torch.abs((y_pred - y_true) / (self.epsilon + y_true))

        within_1pct = (percentage_diff <= self.perc_cutoff_buffer).float()
        within_1pct_ratio = torch.mean(within_1pct)
        return within_1pct_ratio * self.within_1pct_reward_weight

    def compute_all_losses(self, y_pred, y_true):
        mse_loss = self.mse_loss_component(y_pred, y_true)
        mae_loss = self.mae_loss_component(y_pred, y_true)
        perc_diff_loss = self.percentage_diff_component(y_pred, y_true)
        max_diff_loss = self.max_diff_component(torch.abs((y_pred - y_true) / (self.epsilon + y_true)))
        imbalance_loss = self.imbalance_component(torch.abs((y_pred - y_true) / (self.epsilon + y_true)))
        direction_loss = self.directional_loss(y_pred, y_true)
        reward = self.reward_component(y_pred, y_true)
        return mse_loss, mae_loss, perc_diff_loss, max_diff_loss, imbalance_loss, direction_loss, reward

    def forward(self, y_pred, y_true):
        self.epoch += 1

        # Compute Loss Components
        mse_loss, mae_loss, perc_diff_loss, max_diff_loss, imbalance_loss, direction_loss, reward = self.compute_all_losses(y_pred, y_true)
        self.reward = reward
        # Combine Loss Components
        final_loss = (mse_loss + mae_loss + perc_diff_loss + max_diff_loss +
                      imbalance_loss + direction_loss - (reward * self.reward_scaling))

        # Clamp Final Loss to prevent negative values
        final_loss = torch.clamp(final_loss, min=0.00001)
        perc_cutoff = self.get_perc_cutoff()
        return final_loss, mse_loss, mae_loss, perc_diff_loss, max_diff_loss, imbalance_loss, direction_loss, reward, perc_cutoff

    def get_reward(self):
        # Returns a float that is converted from a tensor
        return self.reward.item()

    def set_perc_cutoff(self, perc_cutoff):
        # Update the buffer in-place
        self.perc_cutoff_buffer.fill_(perc_cutoff)

    def get_last_mean_diff(self):
        return self.mean_mean_diff

    def get_perc_cutoff(self):
        return self.perc_cutoff_buffer.item()

# Custom Dataset
class CryptoDataset(Dataset):
    def __init__(self, data: pd.DataFrame, n_past: int, n_future: int):
        self.data = data
        self.n_past = n_past
        self.n_future = n_future

    def __len__(self):
        return len(self.data) - self.n_past - self.n_future + 1

    def __getitem__(self, idx):
        x = self.data.iloc[idx:idx + self.n_past].values  # (n_past, num_features)
        y = self.data.iloc[idx + self.n_past:idx + self.n_past + self.n_future].values  # (n_future, num_features)
        return torch.FloatTensor(x), torch.FloatTensor(y)

# Utility Functions
def get_random_sample(dataframe: pd.DataFrame):
    """
    Retrieve a random sample from the DataFrame.

    Args:
        dataframe (pd.DataFrame): DataFrame to sample from.

    Returns:
        tuple: (input_data, target_data)
    """
    random_index = random.randint(0, len(dataframe) - cfg.N_PAST - cfg.N_FUTURE)
    input_data = dataframe.iloc[random_index:random_index + cfg.N_PAST].values
    target_data = dataframe.iloc[random_index + cfg.N_PAST:random_index + cfg.N_PAST + cfg.N_FUTURE].values
    return torch.FloatTensor(input_data), torch.FloatTensor(target_data)

def prepare_input(input_data, device):
    """
    Prepare input tensor for the model.

    Args:
        input_data (torch.FloatTensor): Input data.

    Returns:
        torch.FloatTensor: Prepared input tensor.
    """
    return input_data.unsqueeze(0).to(device)

def convert_to_numpy(input_data, target, prediction):
    """
    Convert tensors to NumPy arrays.

    Args:
        input_data (torch.FloatTensor): Input data.
        target (torch.FloatTensor): Target data.
        prediction (torch.FloatTensor): Prediction data.

    Returns:
        tuple: (input_np, target_np, prediction_np)
    """
    return input_data.cpu().numpy(), target.cpu().numpy(), prediction.cpu().numpy()
def gaussian_smoothing(data, window_size, sigma):
    """
    Compute the Gaussian smoothing of the data.

    Args:
        data (np.ndarray): Input data.
        window_size (int): Window size for Gaussian smoothing.
        sigma (float): Standard deviation of the Gaussian kernel.

    Returns:
        np.ndarray: Gaussian smoothed data.
    """
    # Generate Gaussian kernel
    x = np.linspace(-window_size // 2, window_size // 2, window_size)
    kernel = np.exp(-(x ** 2) / (2 * sigma ** 2))
    kernel /= kernel.sum()

    # Convolve data with Gaussian kernel
    return np.convolve(data, kernel, 'valid')

# def gaussian_smoothing(data, window_size, sigma):
#     """
#     Compute the Gaussian smoothing of the data.

#     Args:
#         data (np.ndarray): Input data.
#         window_size (int): Window size for Gaussian smoothing.
#         sigma (float): Standard deviation of the Gaussian kernel.

#     Returns:
#         np.ndarray: Gaussian smoothed data.
#     """
#     return gaussian_filter1d(data, sigma=sigma)

# Inverse Transformation Function
def inverse_transform_predictions(scaled_value, scaler, log_transform=True):
    """
    Inverse transform a scaled value back to its original scale.

    Args:
        scaled_value (np.ndarray or float): Scaled value(s).
        scaler (MinMaxScaler): Fitted scaler used during preprocessing.
        log_transform (bool): Indicates whether a log transform was applied.

    Returns:
        np.ndarray or float: Original scale value(s).
    """
    # Ensure scaled_value is a 2D array for inverse_transform
    scaled_array = np.array(scaled_value).reshape(-1, 1)
    inverse_scaled = scaler.inverse_transform(scaled_array).flatten()

    if log_transform:
        original = np.exp(inverse_scaled)
    else:
        original = inverse_scaled

    return original

# Visualization Function
def visualize_predictions(target_np, prediction_np, n_future, scalers, filtered_df, model_save_path):
    num_features = filtered_df.shape[1]
    max_cols = 4
    num_rows = (num_features - 1) // max_cols + 1
    num_cols = min(num_features, max_cols)

    plt.figure(figsize=(18 * num_cols / max_cols, 6 * num_rows))

    window_size = 7  # Adjust this value for smoothing

    for j in range(num_features):
        plt.subplot(num_rows, num_cols, j + 1)
        col_name = filtered_df.columns[j]

        # Extract past data from filtered_df
        past_scaled = filtered_df[col_name].values  # Shape: (n_past,)
        past_scaled = past_scaled[:-(n_future-1)]  # Only consider past data
        past_inverted = inverse_transform_predictions(past_scaled, scalers[col_name])

        # Directly extract the known future target data from filtered_df
        target_scaled = filtered_df[col_name].values # Shape: (n_future,)
        target_scaled = target_scaled[-(n_future+1):]  # Only consider future data
        target_inverted = inverse_transform_predictions(target_scaled, scalers[col_name])

        # Extract the predicted future data
        prediction_scaled = prediction_np[0, :, j]  # Shape: (n_future,)
        prediction_inverted = inverse_transform_predictions(prediction_scaled, scalers[col_name])

        last_xbtusd_price_scaled = filtered_df['XBTUSD_price'].iloc[-1]
        last_xbtusd_price = inverse_transform_predictions(last_xbtusd_price_scaled, scalers['XBTUSD_price'])

        # Adjust if column ends with 'XBT_price'
        if col_name.endswith('XBT_price'):
            past_inverted *= last_xbtusd_price
            target_inverted *= last_xbtusd_price
            prediction_inverted *= last_xbtusd_price

        # Combine past and future data
        # total_inverted = np.concatenate((past_inverted, target_inverted))
        total_predicted = np.concatenate((past_inverted, prediction_inverted))

        # Create time indices
        n_past = len(past_inverted)
        total_timesteps = n_past + n_future
        time_indices = range(total_timesteps)

        # Plot past data
        # print the lenth of the x and y axis
        # print(len(time_indices[-(len(past_inverted)):]), len(past_inverted[n_past-n_future:]))

        plt.plot(time_indices[n_past-n_future:(n_past+1)], past_inverted[-(n_future+1):], 'b', label='Past Data' if j == 0 else "")
        # plt.plot(time_indices[-(len(past_inverted))+n_future:], past_inverted[:], 'b', label='Past Data' if j == 0 else "")

        # Plot known target data
        plt.plot(time_indices[n_past-1:], target_inverted, 'g', alpha=0.7, label='Target Data' if j == 0 else "")

        # Plot prediction data
        plt.plot(time_indices[n_past:], prediction_inverted, 'r', alpha=0.7, label='Prediction Data' if j == 0 else "")

        # Optionally apply smoothing
        total_inverted_smooth = gaussian_smoothing(past_inverted, window_size, sigma=10)
        total_predicted_smooth = gaussian_smoothing(total_predicted, window_size, sigma=10)

        # Plot smoothed data
        # print the lenth of the x and y axis
        # plt.plot(time_indices[n_past:], total_inverted_smooth[-n_future:], 'g', linewidth=2, label='Target Smoothed' if j == 0 else "")

        # plt.plot(time_indices[n_past:], total_predicted_smooth[-n_future:], 'r', linewidth=2, label='Prediction Smoothed' if j == 0 else "")

        # plt.fill_between(range(-n_future), total_predicted_smooth[:-n_future], total_inverted_smooth[:-n_future], color='blue', alpha=0.1)
        # Ensure that both arrays have the same length for the fill_between operation
        min_length = min(len(total_predicted_smooth), len(total_inverted_smooth))

        # Adjust the indices to ensure matching lengths
        start_index = n_past - min_length
        end_index = n_past

        # Plot smoothed data
        # plt.plot(time_indices[n_past:], total_inverted_smooth[-n_future:], 'g', linewidth=2, label='Target Smoothed' if j == 0 else "")
        plt.plot(time_indices[n_past+window_size:], total_predicted_smooth[n_past+1:], 'r', linewidth=2, label='Prediction Smoothed' if j == 0 else "")

        # Fill between the smoothed prediction and smoothed total
        plt.fill_between(time_indices[-n_future:], total_predicted_smooth[-n_future:], target_inverted[-n_future:], color='blue', alpha=0.1)
        # Adjust plot settings
        plt.title(col_name)
        if j == 0:
            plt.legend(loc='upper right')

    plt.tight_layout()
    time_date = time.strftime("%Y%m%d-%H%M%S")
    image_path = os.path.join(model_save_path, f"{time_date}_predictions.png")
    plt.savefig(image_path)
    plt.close()

    return image_path

# Checkpoint Saving Function
def save_checkpoint(state, filename):
    """
    Save a training checkpoint.

    Args:
        state (dict): State dictionary containing model, optimizer, scheduler states, etc.
        filename (str): Path to save the checkpoint.
    """
    torch.save(state, filename)

# Checkpoint Loading Function
def load_checkpoint(model, optimizer, scheduler, model_path, device):
    print(f"Loading checkpoint from {model_path}")
    checkpoint = torch.load(model_path, map_location=device)
    print("Checkpoint keys:", checkpoint.keys())

    # Adjust the state_dict
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    elif 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    else:
        # If the checkpoint is the model's state_dict itself
        state_dict = checkpoint

    # Remove 'model.' prefix from the keys
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('model.'):
            name = k[6:]  # remove 'model.' prefix
        else:
            name = k
        new_state_dict[name] = v

    # Load the adjusted state_dict into the model
    missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)

    if missing_keys:
        print(f"Missing keys: {missing_keys}")
    if unexpected_keys:
        print(f"Unexpected keys: {unexpected_keys}")

    # Load optimizer and scheduler state dicts if available
    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    print("Checkpoint loaded successfully.")
    return checkpoint
# Data Loading and Preprocessing
def load_and_preprocess_data(file_path: str, download_url: str = None):
    """
    Load and preprocess data from a CSV file. If the file does not exist, download it.

    Args:
        file_path (str): Path to the CSV file.
        download_url (str, optional): URL to download the CSV file. Defaults to None.

    Returns:
        pd.DataFrame: Preprocessed DataFrame.
        dict: Dictionary of scalers used for each column.
    """
    ic("Starting data loading and preprocessing...")
    start_time = time.time()

    # Check if the file exists
    if not os.path.exists(file_path):
        ic(f"File {file_path} does not exist.")
        os.makedirs(os.path.dirname(file_path), exist_ok=True)

        if download_url:
            ic(f"Downloading file from {download_url}...")
            try:
                response = requests.get(download_url, stream=True)
                response.raise_for_status()
                with open(file_path, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                ic(f"File downloaded and saved to {file_path}")
            except requests.exceptions.RequestException as e:
                ic(f"Failed to download the file: {e}")
                raise
        else:
            ic("Download URL not provided. Cannot download the file.")
            raise FileNotFoundError(f"The file {file_path} does not exist and no download URL was provided.")

    # Load the DataFrame
    df = pd.read_csv(file_path, parse_dates=['timestamp'])
    df.set_index('timestamp', inplace=True)
    df = df.tail(cfg.RECORDS_TO_LOAD)
    scalers = {}
    start_time_preprocess = time.time()

    for col in df.columns:
        # Ensure no non-positive values before log transform
        if (df[col] <= 0).any():
            raise ValueError(f"Column {col} contains non-positive values, cannot apply log transform.")

        # Apply natural logarithm transformation
        df[col] = np.log(df[col])

        # Initialize and fit MinMaxScaler
        scaler = MinMaxScaler()
        df[col] = scaler.fit_transform(df[[col]])

        # Save the scaler
        scalers[col] = scaler

    ic(f"Data preprocessing completed in {time.time() - start_time_preprocess:.2f} seconds")
    ic(f"DataFrame shape: {df.shape}")

    return df, scalers

# Lightning Wrapper
class LightningWrapper(pl.LightningModule):
    def __init__(self, model, criterion, optimizer, scheduler, num_epochs: int, scaler_dict: dict, val_data: pd.DataFrame):
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.num_epochs = num_epochs
        self.scaler_dict = scaler_dict
        self.val_data = val_data  # For making predictions during logging
        self.validation_rewards = []  # Initialize validation rewards list

    def forward(self, src, tgt):
        return self.model(src, tgt)

    def training_step(self, batch, batch_idx):
        batch_X, batch_y = batch
        # Shift target sequence to the right and prepend zeros
        tgt_input = torch.zeros_like(batch_y)
        tgt_input[:, 1:, :] = batch_y[:, :-1, :]
        y_pred = self.model(batch_X, tgt_input)
        final_loss, mse_loss, mae_loss, perc_diff_loss, max_diff_loss, imbalance_loss, direction_loss, reward, perc_cutoff = self.criterion(y_pred, batch_y)

        # Log all loss components
        self.log('train/final_loss', final_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log('train/mse_loss', mse_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('train/mae_loss', mae_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('train/perc_diff_loss', perc_diff_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('train/max_diff_loss', max_diff_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('train/imbalance_loss', imbalance_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('train/direction_loss', direction_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('train/reward', reward, on_step=False, on_epoch=True, sync_dist=True)

        return final_loss

    def validation_step(self, batch, batch_idx):
        batch_X, batch_y = batch
        # Shift target sequence to the right and prepend zeros
        tgt_input = torch.zeros_like(batch_y)
        tgt_input[:, 1:, :] = batch_y[:, :-1, :]
        y_pred = self.model(batch_X, tgt_input)
        final_loss, mse_loss, mae_loss, perc_diff_loss, max_diff_loss, imbalance_loss, direction_loss, reward, percentile_cutoff = self.criterion(y_pred, batch_y)

        # Log all loss components
        self.log('val_loss', final_loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log('val/mse_loss', mse_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('val/mae_loss', mae_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('val/perc_diff_loss', perc_diff_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('val/max_diff_loss', max_diff_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('val/imbalance_loss', imbalance_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('val/direction_loss', direction_loss, on_step=False, on_epoch=True, sync_dist=True)
        self.log('val/reward', reward, on_step=False, on_epoch=True, sync_dist=True)
        self.log('val/perc_cutoff', percentile_cutoff, on_step=False, on_epoch=True, sync_dist=True)
        self.validation_rewards.append(reward.item())

        return final_loss

    def configure_optimizers(self):
        return {
            'optimizer': self.optimizer,
            'lr_scheduler': {
                'scheduler': self.scheduler,
                'monitor': 'val_loss'
            }
        }

    def on_validation_epoch_end(self):
        # Skip ALL logic during sanity check
        if hasattr(self.trainer, 'running_sanity_check') and self.trainer.running_sanity_check:
            self.print("Skipping ALL on_validation_epoch_end logic during sanity check.")
            return  # Exit the method early

        if self.global_rank == 0:
            try:
                plot_path = self.generate_and_log_plots()
                if plot_path:
                    img = Image.open(plot_path)
                    self.logger.experiment.log({
                        "Validation/Prediction_vs_Target": wandb.Image(img),
                        "global_step": self.global_step
                    })
                    # os.remove(plot_path)
            except Exception as e:
                self.print(f"Error in generate_and_log_plots: {e}")

    def on_train_epoch_end(self):
        # Only the main process should perform logging
        if self.global_rank == 0:
            # Log learning rate
            optimizer = self.optimizers()
            lr = optimizer.param_groups[0]['lr']
            self.logger.experiment.log({'learning_rate': lr, 'epoch': self.current_epoch})

    # def on_after_backward(self):
    #     # Only the main process should perform logging
    #     if self.global_rank == 0:
    #         total_norm = 0.0
    #         for p in self.model.parameters():
    #             if p.grad is not None:
    #                 param_norm = p.grad.detach().data.norm(2)
    #                 total_norm += param_norm.item() ** 2
    #         total_norm = total_norm ** 0.5
    #         self.logger.experiment.log({'Gradients/grad_total_norm': total_norm, 'step': self.global_step})
    def on_after_backward(self):
        # Only the main process should perform logging
        if self.global_rank == 0:
            total_norm = 0.0
            clip_value = 50.0  # Your gradient clipping value

            for p in self.model.parameters():
                if p.grad is not None:
                    param_norm = p.grad.detach().data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** 0.5

            # Log total gradient norm
            self.logger.experiment.log({'Gradients/grad_total_norm': total_norm, 'step': self.global_step})

            # Log whether the gradients were clipped
            clipped = total_norm > clip_value
            # convert to float to plot in wandb
            clipped = float(clipped)
            self.logger.experiment.log({'Gradients/clipped': clipped, 'step': self.global_step})

    def generate_and_log_plots(self):
        """
        Generate prediction vs target plots and save them to a temporary file.
        Returns the path to the saved image.
        """
        # Make predictions on a random sample from validation data
        sample = get_random_sample(self.val_data)
        input_data, target = sample
        input_tensor = prepare_input(input_data, self.device)
        tgt_input = torch.zeros_like(target).unsqueeze(0).to(self.device)
        prediction = self.model(input_tensor, tgt_input)
        _, target_np, prediction_np = convert_to_numpy(input_tensor, target, prediction)

        # Prepare DataFrame for plotting
        start_idx = random.randint(0, len(self.val_data) - cfg.N_PAST - cfg.N_FUTURE)
        past_df = self.val_data.iloc[start_idx:start_idx + cfg.N_PAST]

        # Get future data to form the target data
        future_df = self.val_data.iloc[start_idx + cfg.N_PAST: start_idx + cfg.N_PAST + cfg.N_FUTURE]

        # Ensure past_df and future_df have correct lengths
        if len(past_df) < cfg.N_PAST or len(future_df) < cfg.N_FUTURE:
            print("Not enough data for plotting.")
            return None

        # Create filtered_df for plotting: combining past and future data
        filtered_df = pd.concat([past_df, future_df])

        # Generate and save plot
        image_path = visualize_predictions(target_np, prediction_np, cfg.N_FUTURE, self.scaler_dict, filtered_df, cfg.MODEL_SAVE_PATH)

        return image_path

    def set_perc_cutoff(self, perc_cutoff):
        self.criterion.set_perc_cutoff(perc_cutoff)

    def get_perc_cutoff(self):
        return self.criterion.get_perc_cutoff()

# Main Execution Block
if __name__ == "__main__":
    torch.set_float32_matmul_precision("medium")
    # Load and preprocess data
    df, scalers = load_and_preprocess_data(cfg.DATA_FILE, cfg.DATA_URL)
    NUM_FEATURES = df.shape[1]

    # Initialize the Wandb logger and name your Wandb project
    logger = WandbLogger(project='my-awesome-project', log_model=True)  # Set log_model to True

    # Log hyperparameters to Wandb
    logger.log_hyperparams({
        "batch_size": cfg.BATCH_SIZE,
        "hidden_size": cfg.HIDDEN_SIZE,
        "num_layers": cfg.NUM_LAYERS,
        "num_epochs": cfg.NUM_EPOCHS,
        "learning_rate": 4e-5,
        "weight_decay": 5e-5
    })

    # Split data into training and validation
    train_size = int(0.8 * len(df))
    train_data = df.iloc[:train_size]
    val_data = df.iloc[train_size:]

    # Create Datasets
    train_dataset = CryptoDataset(train_data, cfg.N_PAST, cfg.N_FUTURE)
    val_dataset = CryptoDataset(val_data, cfg.N_PAST, cfg.N_FUTURE)

    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.BATCH_SIZE,
        shuffle=False,
        num_workers=cfg.NUM_WORKERS,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=cfg.BATCH_SIZE,
        shuffle=False,
        num_workers=cfg.NUM_WORKERS,
        pin_memory=True
    )

    # Initialize Transformer Model
    model = CryptoTransformer(
        input_size=NUM_FEATURES,
        d_model=cfg.HIDDEN_SIZE,
        nhead=8,
        num_encoder_layers=cfg.NUM_LAYERS,
        num_decoder_layers=cfg.NUM_LAYERS,
        dim_feedforward=2048,
        dropout=0.2,
        activation="gelu",
        n_future=cfg.N_FUTURE,
        num_outputs=NUM_FEATURES,
        max_seq_length=cfg.N_PAST + cfg.N_FUTURE
    ).to(cfg.DEVICE)

    # Initialize Loss Function
    criterion = BalancedCryptoLoss(cfg)

    # Initialize Optimizer and Scheduler
    optimizer = optim.AdamW(model.parameters(), lr=4e-5, weight_decay=1e-4)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=4, min_lr=4e-6)

        # Handle Hot Restart# Create the CosineAnnealingLR scheduler
    # Create the CosineAnnealingWarmRestarts scheduler
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=12,           # First restart after 50 epochs
        T_mult=2,         # Double the cycle length after each restart
        eta_min=1e-6,     # Minimum learning rate
        last_epoch=-1,    # Start from the beginning1
    )    # ... (rest of your code)

    # Handle Hot Restart
    if cfg.HOT_RESTART:
        try:
            # Load the checkpoint into LightningWrapper
            wrapped_model = LightningWrapper.load_from_checkpoint(
                checkpoint_path=cfg.MODEL_PATH,
                model=model,  # Pass your model instance
                criterion=criterion,  # Pass your criterion instance
                optimizer=optimizer,  # Pass your optimizer instance
                scheduler=scheduler,  # Pass your scheduler instance
                num_epochs=cfg.NUM_EPOCHS,
                scaler_dict=scalers,
                val_data=val_data
            )

            # Access the underlying CryptoTransformer model if needed
            model = wrapped_model.model

        except FileNotFoundError:
            print(f"No checkpoint found at {cfg.MODEL_PATH}. Starting fresh.")
            wrapped_model = LightningWrapper(
                model=model,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                num_epochs=cfg.NUM_EPOCHS,
                scaler_dict=scalers,
                val_data=val_data
            )
    # ... (rest of your code)
    if cfg.TRAIN_FIRST:
        # Initialize Lightning Wrapper
        wrapped_model = LightningWrapper(
            model=model,
            criterion=criterion,  # Pass the initialized criterion
            optimizer=optimizer,
            scheduler=scheduler,
            num_epochs=cfg.NUM_EPOCHS,
            scaler_dict=scalers,
            val_data=val_data
        )

        # Initialize Callbacks
        early_stopping_callback = EarlyStopping(
            monitor='val_loss',  # Ensure this matches the logged metric
            patience=75,
            mode='min'
        )

        checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',  # Ensure this matches the logged metric
            dirpath=cfg.MODEL_SAVE_PATH,
            filename=f'model-{cfg.VERSION_N}-{{epoch:02d}}-{{val_loss:.2f}}',
            save_top_k=9,
            mode='min',
            save_weights_only=False
        )

        # Initialize Progress Bar Callback
        progress_bar = RichProgressBar(refresh_rate=2)  # Set your desired refresh rate

        # Initialize Percentile Cutoff Callback
        perc_cutoff_callback = UpdatePercentileCutoffCallback(
            reduction_threshold=0.8, # Set your desired reduction threshold of the reward
            reduction_factor=0.95
        )

        # Initialize Trainer with Wandb logger and all callbacks
        trainer = pl.Trainer(
            max_epochs=cfg.NUM_EPOCHS,
            logger=logger,  # Use Wandb logger here
            accelerator='gpu' if torch.cuda.is_available() else 'cpu',
            devices=torch.cuda.device_count() if torch.cuda.is_available() else 1,
            strategy='ddp_find_unused_parameters_true' if torch.cuda.device_count() > 1 else 'ddp_notebook',  # Distributed Data Parallel
            callbacks=[progress_bar, checkpoint_callback, early_stopping_callback, perc_cutoff_callback],
            enable_progress_bar=True,
            log_every_n_steps=10,
            # precision=16,  # Optional: Use mixed precision for faster training
            gradient_clip_val=50.0,  # Optional: Gradient clipping
        )

        # Start Training
        trainer.fit(wrapped_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    else:
        print("Skipping training as TRAIN_FIRST is set to False.")

    # Finalizing WandB
    wandb.finish()

    # Clean up CUDA cache
    try:
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"Failed to empty CUDA cache: {e}")
        pass

    # Terminate the script
    raise Exception("Training completed and script terminated.")

INFO:lightning_fabric.utilities.seed:Seed set to 40


Number of CPU cores available: 8
Using device: cuda


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msambosis[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


RuntimeError: Lightning can't create new processes if CUDA is already initialized. Did you manually call `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any other way? Please remove any such calls, or change the selected strategy. You will have to restart the Python kernel.