# 🪙Crypto 📈forecasting with Lightning⚡Flash

[Flash](https://lightning-flash.readthedocs.io/en/stable) makes complex AI recipes for over 15 tasks across 7 data domains accessible to all.

In a nutshell, Flash is the production grade research framework you always dreamed of but didn't have time to build.

In [None]:
# ! pip install -q lightning-flash[tabular]
! pip install -q 'https://github.com/PyTorchLightning/lightning-flash/archive/refs/heads/master.zip#egg=lightning-flash[tabular]'
! pip install -q mplfinance
! pip install -q --upgrade pandas --force-reinstall
! pip list | grep -E "lightning|torch|finance|crypto"

## Data exolorations & preparation

Checking the input data and pairing with Crypto names

In [None]:
%reload_ext autoreload
%autoreload 2

import pandas as pd

TAB_COLUMN_TYPES = {
    'Asset_ID': 'int8',
    'Count': 'int32',
    'row_id': 'int32',
    'Count': 'int32',
    'Open': 'float64',
    'High': 'float64',
    'Low': 'float64',
    'Close': 'float64',
    'Volume': 'float64',
    'VWAP': 'float64',
}
df_train = pd.read_csv("/kaggle/input/g-research-crypto-forecasting/train.csv", low_memory=False, dtype=TAB_COLUMN_TYPES)
display(df_train.head())

### Linking with coins

Show how naby data points we heva pear crypto

In [None]:
df_counts = df_train.groupby("Asset_ID").size()
df_counts.plot.bar(grid=True)

In [None]:
df_asset_details = pd.read_csv("/kaggle/input/g-research-crypto-forecasting/asset_details.csv")
display(df_asset_details.T)

In [None]:
mapping = dict(df_asset_details[['Asset_ID', 'Asset_Name']].values)
df_train["Asset_Name"] = df_train["Asset_ID"].map(mapping)

### Finantial timeseries

showing short tail of each coin with `mplfinance` package

In [None]:
import mplfinance as mpf

for n, dfg in df_train.groupby("Asset_Name"):
    dfg['timestamp'] = pd.to_datetime(dfg['timestamp'])
    dfg.set_index("timestamp", inplace=True)
    mpf.plot(
        dfg[-300:], # the dataframe containing the OHLC (Open, High, Low and Close) data
        type='candle', # use candlesticks 
        volume=True, # also show the volume
        mav=(3, 6, 9), # use three different moving averages
        figsize=(14, 2), # set the ratio of the figure
        style='yahoo',
        title=n,
    )

### Prune

limit the amout dues to limited computing resources

In [None]:
# df_train_small = df_train[df_train["Asset_ID"].isin([0, 1, 2])]
df_train_small = df_train[df_train["Asset_ID"] == 0].interpolate()

# Training with Flash Lightning

See the forecasting docs: https://lightning-flash.readthedocs.io/en/stable/reference/tabular_forecasting.html

In [None]:
import flash
import pandas as pd
import torch
from flash.tabular.forecasting import TabularForecaster, TabularForecastingData
from pytorch_forecasting.data import NaNLabelEncoder
from pytorch_forecasting.data.examples import generate_ar_data

### 1. Create the DataModule

Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html

In [None]:
datamodule = TabularForecastingData.from_data_frame(
    time_idx="timestamp",
    target="Target",
    # categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
    group_ids=["Asset_ID"],
    time_varying_unknown_reals=["Target"],
    allow_missing_timesteps=True,
    max_encoder_length=60,
    max_prediction_length=1,
    train_data_frame=df_train_small,
    val_data_frame=df_train_small,
    batch_size=512,
    num_workers=4,
)

### 2. Build the task

In [None]:
model = TabularForecaster(
    datamodule.parameters,
    backbone="deep_ar",
    # backbone_kwargs={"widths": [32, 512], "backcast_loss_ratio": 0.1},
)

### 3. Create the trainer and train the model

In [None]:
import torch
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import StochasticWeightAveraging

swa = StochasticWeightAveraging(swa_epoch_start=0.6)
logger = CSVLogger(save_dir='logs/')
trainer = flash.Trainer(
    max_epochs=10,
    logger=logger,
    gpus=torch.cuda.device_count(),
    gradient_clip_val=0.01,
)

# ==============================

trainer.tune(model, datamodule=datamodule, lr_find_kwargs=dict(min_lr=2e-5, max_lr=1, num_training=65),)
print(f"Learning Rate: {model.learning_rate}")

# ==============================

trainer.fit(model, datamodule=datamodule)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
metrics = metrics[[c for c in metrics.columns if "_step" not in c]]
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
g = sns.relplot(data=metrics, kind="line")
g.set(yscale="log")
plt.gcf().set_size_inches(15, 5)

### 4. Generate predictions

In [None]:
outputs = model.predict(df_train_small[-1000:])
print(outputs[0].keys())
predictions = [o['prediction'] for o in outputs]
print(predictions)

In [None]:
import gresearch_crypto
from pprint import pprint
from tqdm.auto import tqdm

env = gresearch_crypto.make_env()   # initialize the environment
iter_test = env.iter_test()

for i, (df_test, df_pred) in tqdm(enumerate(iter_test)):
    display(df_test.head())
    df_test_small = df_test[df_test["Asset_ID"] == 0]
    df_test_small['Target'] = [None]
    # TODO
#     out = model.predict(df_test_small)
#     preds = zip(df_test_small[row_id], out['prediction'])
#     pprint(preds)
    env.predict(df_pred)   # register your predictions
    break