In [None]:
import nutils
from models.model_mlp_v2 import MtNet
import copy
import sutils
import common as cm
from models.model_mlp_v2 import MtNet
import numpy as np
import pandas as pd
import SharedArray as sa
import torch as th
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    GradientAccumulationScheduler,
    RichProgressBar,
)
from pytorch_lightning.loggers import wandb as wandb_logger
import pytorch_lightning as pl
from functools import lru_cache
from rich.progress import track as tqdm
from torch.utils import data
import os
from argparse import ArgumentParser
import wandb
from datetime import datetime
from mlutils.data import DataLoaderY

In [None]:
args = ArgumentParser()
args.add_argument("--fold", "-f", type=int, default=3)
args.add_argument("--gpu", "-g", type=int, default=0)
args.add_argument(
    "--batch_size", "-b", type=int, default=100, help="batch size multiplier"
)
args.add_argument("--num_stocks", "-n", type=int, default=100)
args.add_argument("--save", "-s", type=int, default=0)
args.add_argument("--split", "-sp", type=int, default=0)
args.add_argument("--cur", type=int, default=0)
args.add_argument("--fut", type=int, default=60)
args = args.parse_args([])

In [None]:
def get_data(code):
    x = sa.attach(f"factor_{code}")
    y = sa.attach(f"label_{code}")
    z = sa.attach(f"timestamp_{code}")
    return x, y, z

In [None]:
timestamp = datetime.now().strftime("%y%m%d%H%M")
stk_list = cm.SELECTED_CODES
seed = 2022
if args.num_stocks < 100:
    pl.seed_everything(seed)
    idx = np.random.choice(len(stk_list), size=args.num_stocks, replace=False)
    stk_list = [stk_list[i] for i in idx]
print(f"selected stk_list:\n{stk_list}")

## Model
labels = [0, 1, 3]
tower_sizes = [[128, 128, 1], [32, 1], [32, 1]]
loss_weights = [1, 0.4, 0.4]

assert len(loss_weights) == len(tower_sizes)
model_param = {
    "input_size": 101,
    "hidden_sizes": [64, 128],
    "tower_sizes": tower_sizes,
    "act": "leakyrelu",
    "dropout": 0.4,
    "lr": 1e-4,
    "loss_fn": "mse",
    "weight_decay": 1e-3, 
    "loss_weights": loss_weights,
}

model = MtNet(**model_param)

## Logger details

experiment_name = (
    f"f{args.fold}_n{len(stk_list)}_b{args.batch_size}_{args.cur}-{args.fut}_{[["ret","mean","var","rv"][i] for i in labels]}"
)
datamodule = sutils.MTDataModule(
    codes=stk_list,
    labels_idx=labels,
    fold=args.fold,
    split=args.split,
    batch_size=args.batch_size * 4000,
    cur=args.cur,
    fut=args.fut,
)
logger = wandb_logger.WandbLogger(
    project="MultiTask" if args.num_stocks == 100 else "MultiTask_small",
    name=experiment_name,
)
logger.experiment.config.update(
    {
        "num_stocks": len(stk_list),
        "tmstamp": timestamp,
        "fold": args.fold,
        "release": False,
        "batch_size": args.batch_size,
        "lables": labels,
    }
)
logger.experiment.config.update(model_param)
## Training details
earlystop_callback = EarlyStopping(
    monitor="valid/ic_0",
    patience=15,
    mode="max",
)
checkpoint_callback = ModelCheckpoint(
    monitor="valid/ic_0",
    mode="max",
    dirpath=f"./checkpoints/{timestamp}/{args.fold}/",
    filename="model_{valid/ic_0:.3f}",
    save_top_k=1,
)
gas_callback = GradientAccumulationScheduler(scheduling={0: 10, 5: 5, 10: 1})


trainer = pl.Trainer(
    devices=[args.gpu],
    callbacks=[
        earlystop_callback,
        checkpoint_callback,
        RichProgressBar(),
        gas_callback,
        test_callback,
    ],
    logger=logger,
    max_epochs=200,
    precision="16-mixed",
    # detect_anomaly=True,
    
)

trainer.fit(model, datamodule)
trainer.test(model, datamodule, ckpt_path="best")