In [1]:
import os
import sys
import getpass
from binance.spot import Spot

__file__ = %pwd
loc_list = os.path.abspath(__file__).split(os.sep)
HOME_LOC = os.path.join(os.sep, *loc_list[:-2])
sys.path.append(HOME_LOC)
os.chdir(HOME_LOC)
from CODE.Utils.encrypt import *
from CODE.Utils.encrypt import Encrypted_API_key
from CODE.Utils.encrypt import Encrypted_API_secret
from CODE.Utils.utils import *
from CODE.Utils.indicators import *
from CODE.Utils.normalize import *
from CODE.Runner import *
from CODE.RNN import *


/homes/David_Li/Mega/University_of_Adelaide/Works/Courses/4339_COMP_SCI_7318_Deep_Learning_Fundamentals/Assignment3
/homes/David_Li/Mega/University_of_Adelaide/Works/Courses/4339_COMP_SCI_7318_Deep_Learning_Fundamentals/Assignment3


In [2]:
file_paths = load_and_sort_files(os.path.join(HOME_LOC, "DATA", "RAW", "1m"))
train_X, train_Y, test_X, test_Y = split_data(
    file_paths, train_ratio=0.85, x_length=58, y_length=2
)

Processing Training Files: 100%|██████████| 3060/3060 [04:26<00:00, 11.49it/s]
Processing Testing Files: 100%|██████████| 540/540 [00:47<00:00, 11.29it/s]


In [3]:
print(train_X.shape, train_Y.shape, test_X.shape, test_Y.shape)

(3050, 47, 23) (3050, 2) (540, 47, 23) (540, 2)


In [4]:
num_channels = train_X.shape[-1]  # 通道数，对应于特征数
output_size = train_Y.shape[-1]  # 输出尺寸
batch_size = 36  # 批处理大小

# 定义卷积层配置
conv_layers_config = [
    (64, 3, 1),  # 第一个卷积层：64个过滤器，核大小为3，填充为1
    (128, 3, 1),  # 第二个卷积层：128个过滤器，核大小为3，填充为1
]

# RNN层的参数
hidden_size = 100  # RNN隐藏层大小
num_rnn_layers = 2  # RNN层数量

# 创建模型参数字典
model_p = {
    "conv_layers_config": conv_layers_config,
    "num_channels": num_channels,  # 从train_X.shape[-1]获取
    "output_size": output_size,  # 从train_Y.shape[-1]获取
    "data_len": train_X.shape[1],  # 时间序列长度
    "hidden_size": hidden_size,
    "num_rnn_layers": num_rnn_layers,
}

# 创建 GroupedTrainer 实例
trainer = GroupedTrainer(
    model=FlexibleRNN,  # 使用 FlexibleRNN 模型
    data=(train_X, test_X, train_Y, test_Y),  # 传入训练和测试数据
    batch_size=batch_size,  # 指定批处理大小
    model_p=model_p,  # 传入模型参数
    learning_rate=0.01,  # 学 习率
    weight_decay=0.0,  # 权重衰减（正则化参数）
)

# 设置损失函数为均方误差（适用于回归任务）
trainer.criterion = nn.MSELoss()

# 打印模型结构（可选）
print(trainer.model)

FlexibleRNN(
  (conv_layers): ModuleList(
    (0): Conv1d(23, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  )
  (rnn): RNN(128, 100, num_layers=2, batch_first=True)
  (fc1): Linear(in_features=1100, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=2, bias=True)
)


In [5]:
def run_various_models(
    train_data, conv_configs, rnn_configs, num_epochs, iter_n, save_dir, silent=False
):
    train_X, test_X, train_Y, test_Y = train_data
    num_channels = train_X.shape[-1]  # 通道数，对应于特征数
    output_size = train_Y.shape[-1]  # 输出尺寸
    batch_size = 36  # 批处理大小

    for conv_config in conv_configs:
        for hidden_size, num_rnn_layers in rnn_configs:
            model_p = {
                "conv_layers_config": conv_config,
                "num_channels": num_channels,
                "output_size": output_size,
                "data_len": train_X.shape[1],
                "hidden_size": hidden_size,
                "num_rnn_layers": num_rnn_layers,
            }

            # 创建模型和训练器
            trainer = GroupedTrainer(
                model=FlexibleRNN,
                data=(train_X, test_X, train_Y, test_Y),
                batch_size=batch_size,
                model_p=model_p,
                learning_rate=0.01,
                weight_decay=0.0,
            )
            trainer.criterion = nn.MSELoss()

            # 训练模型
            trainer.train(
                num_epochs,
                iter_n=iter_n,
                silent=silent,
            )

            # 保存结果到CSV
            file_name = f"{save_dir}/conv_{len(conv_config)}_rnn_{num_rnn_layers}_hidden_{hidden_size}.csv"
            trainer.save_metrics_to_csv(file_name)

In [6]:
# 定义卷积层配置
conv_configs = [
    [(32, 3, 1), (64, 3, 1)],  # 两层配置
    [(32, 3, 1), (64, 3, 1), (128, 3, 1)],  # 三层配置
    [(32, 3, 1), (64, 3, 1), (128, 3, 1), (192, 3, 1)],  # 更大的核和填充
    [(32, 3, 1), (64, 3, 1), (128, 3, 1), (192, 3, 1), (256, 3, 1)],  #
]


# RNN层的参数
rnn_configs = [
    (50, 4),  # 较小的隐藏层
    (50, 8),  # 更大的隐藏层
]


# 调用函数
run_various_models(
    train_data=(train_X, test_X, train_Y, test_Y),
    conv_configs=conv_configs,
    rnn_configs=rnn_configs,
    num_epochs=7200,
    iter_n=10,
    silent=True,
    save_dir="/Project/David_Li/Works/Courses/4339_COMP_SCI_7318_Deep_Learning_Fundamentals/Assignment3/DATA/Result",
)

total_loop_n: 6048, len(self.group_size): 84
group_index: 0


  return (1 + overall_return) ** (365 / num_days) - 1


group_index: 1
group_index: 2
group_index: 3
group_index: 4
group_index: 5
group_index: 6
group_index: 7
group_index: 8
group_index: 9
group_index: 10
group_index: 11
group_index: 12
group_index: 13
group_index: 14
group_index: 15
group_index: 16
group_index: 17
group_index: 18
group_index: 19
group_index: 20
group_index: 21
group_index: 22
group_index: 23
group_index: 24
group_index: 25
group_index: 26
group_index: 27
group_index: 28
group_index: 29
group_index: 30
group_index: 31
group_index: 32
group_index: 33
group_index: 34
group_index: 35
group_index: 36
group_index: 37
group_index: 38
group_index: 39
group_index: 40
group_index: 41
group_index: 42
group_index: 43
group_index: 44
group_index: 45
group_index: 46
group_index: 47
group_index: 48
group_index: 49
group_index: 50
group_index: 51
group_index: 52
group_index: 53
group_index: 54
group_index: 55
group_index: 56
group_index: 57
group_index: 58
group_index: 59
group_index: 60
group_index: 61
group_index: 62
group_index: 63
g

  return (1 + overall_return) ** (365 / num_days) - 1


group_index: 1
group_index: 2
group_index: 3
group_index: 4
group_index: 5
group_index: 6
group_index: 7
group_index: 8
group_index: 9
group_index: 10
group_index: 11
group_index: 12
group_index: 13
group_index: 14
group_index: 15
group_index: 16
group_index: 17
group_index: 18
group_index: 19
group_index: 20
group_index: 21
group_index: 22
group_index: 23
group_index: 24
group_index: 25
group_index: 26
group_index: 27
group_index: 28
group_index: 29
group_index: 30
group_index: 31
group_index: 32
group_index: 33
group_index: 34
group_index: 35
group_index: 36
group_index: 37
group_index: 38
group_index: 39
group_index: 40
group_index: 41
group_index: 42
group_index: 43
group_index: 44
group_index: 45
group_index: 46
group_index: 47
group_index: 48
group_index: 49
group_index: 50
group_index: 51
group_index: 52
group_index: 53
group_index: 54
group_index: 55
group_index: 56
group_index: 57
group_index: 58
group_index: 59
group_index: 60
group_index: 61
group_index: 62
group_index: 63
g

  return (1 + overall_return) ** (365 / num_days) - 1


group_index: 1
group_index: 2
group_index: 3
group_index: 4
group_index: 5
group_index: 6
group_index: 7
group_index: 8
group_index: 9
group_index: 10
group_index: 11
group_index: 12
group_index: 13
group_index: 14
group_index: 15
group_index: 16
group_index: 17
group_index: 18
group_index: 19
group_index: 20
group_index: 21
group_index: 22
group_index: 23
group_index: 24
group_index: 25
group_index: 26
group_index: 27
group_index: 28
group_index: 29
group_index: 30
group_index: 31
group_index: 32
group_index: 33
group_index: 34
group_index: 35
group_index: 36
group_index: 37
group_index: 38
group_index: 39
group_index: 40
group_index: 41
group_index: 42
group_index: 43
group_index: 44
group_index: 45
group_index: 46
group_index: 47
group_index: 48
group_index: 49
group_index: 50
group_index: 51
group_index: 52
group_index: 53
group_index: 54
group_index: 55
group_index: 56
group_index: 57
group_index: 58
group_index: 59
group_index: 60
group_index: 61
group_index: 62
group_index: 63
g

  return (1 + overall_return) ** (365 / num_days) - 1


group_index: 1
group_index: 2
group_index: 3
group_index: 4
group_index: 5
group_index: 6
group_index: 7
group_index: 8
group_index: 9
group_index: 10
group_index: 11
group_index: 12
group_index: 13
group_index: 14
group_index: 15
group_index: 16
group_index: 17
group_index: 18
group_index: 19
group_index: 20
group_index: 21
group_index: 22
group_index: 23
group_index: 24
group_index: 25
group_index: 26
group_index: 27
group_index: 28
group_index: 29
group_index: 30
group_index: 31
group_index: 32
group_index: 33
group_index: 34
group_index: 35
group_index: 36
group_index: 37
group_index: 38
group_index: 39
group_index: 40
group_index: 41
group_index: 42
group_index: 43
group_index: 44
group_index: 45
group_index: 46
group_index: 47
group_index: 48
group_index: 49
group_index: 50
group_index: 51
group_index: 52
group_index: 53
group_index: 54
group_index: 55
group_index: 56
group_index: 57
group_index: 58
group_index: 59
group_index: 60
group_index: 61
group_index: 62
group_index: 63
g

  return (1 + overall_return) ** (365 / num_days) - 1


group_index: 1
group_index: 2
group_index: 3
group_index: 4
group_index: 5
group_index: 6
group_index: 7
group_index: 8
group_index: 9
group_index: 10
group_index: 11
group_index: 12
group_index: 13
group_index: 14
group_index: 15
group_index: 16
group_index: 17
group_index: 18
group_index: 19
group_index: 20
group_index: 21
group_index: 22
group_index: 23
group_index: 24
group_index: 25
group_index: 26
group_index: 27
group_index: 28
group_index: 29
group_index: 30
group_index: 31
group_index: 32
group_index: 33
group_index: 34
group_index: 35
group_index: 36
group_index: 37
group_index: 38
group_index: 39
group_index: 40
group_index: 41
group_index: 42
group_index: 43
group_index: 44
group_index: 45
group_index: 46
group_index: 47
group_index: 48
group_index: 49
group_index: 50
group_index: 51
group_index: 52
group_index: 53
group_index: 54
group_index: 55
group_index: 56
group_index: 57
group_index: 58
group_index: 59
group_index: 60
group_index: 61
group_index: 62
group_index: 63
g

  return (1 + overall_return) ** (365 / num_days) - 1


group_index: 1
group_index: 2
group_index: 3
group_index: 4
group_index: 5
group_index: 6
group_index: 7
group_index: 8
group_index: 9
group_index: 10
group_index: 11
group_index: 12
group_index: 13
group_index: 14
group_index: 15
group_index: 16
group_index: 17
group_index: 18
group_index: 19
group_index: 20
group_index: 21
group_index: 22
group_index: 23
group_index: 24
group_index: 25
group_index: 26
group_index: 27
group_index: 28
group_index: 29
group_index: 30
group_index: 31
group_index: 32
group_index: 33
group_index: 34
group_index: 35
group_index: 36
group_index: 37
group_index: 38
group_index: 39
group_index: 40
group_index: 41
group_index: 42
group_index: 43
group_index: 44
group_index: 45
group_index: 46
group_index: 47
group_index: 48
group_index: 49
group_index: 50
group_index: 51
group_index: 52
group_index: 53
group_index: 54
group_index: 55
group_index: 56
group_index: 57
group_index: 58
group_index: 59
group_index: 60
group_index: 61
group_index: 62
group_index: 63
g

  return (1 + overall_return) ** (365 / num_days) - 1


group_index: 1
group_index: 2
group_index: 3
group_index: 4
group_index: 5
group_index: 6
group_index: 7
group_index: 8
group_index: 9
group_index: 10
group_index: 11
group_index: 12
group_index: 13
group_index: 14
group_index: 15
group_index: 16
group_index: 17
group_index: 18
group_index: 19
group_index: 20
group_index: 21
group_index: 22
group_index: 23
group_index: 24
group_index: 25
group_index: 26
group_index: 27
group_index: 28
group_index: 29
group_index: 30
group_index: 31
group_index: 32
group_index: 33
group_index: 34
group_index: 35
group_index: 36
group_index: 37
group_index: 38
group_index: 39
group_index: 40
group_index: 41
group_index: 42
group_index: 43
group_index: 44
group_index: 45
group_index: 46
group_index: 47
group_index: 48
group_index: 49
group_index: 50
group_index: 51
group_index: 52
group_index: 53
group_index: 54
group_index: 55
group_index: 56
group_index: 57
group_index: 58
group_index: 59
group_index: 60
group_index: 61
group_index: 62
group_index: 63
g

  return (1 + overall_return) ** (365 / num_days) - 1


group_index: 1
group_index: 2
group_index: 3
group_index: 4
group_index: 5
group_index: 6
group_index: 7
group_index: 8
group_index: 9
group_index: 10
group_index: 11
group_index: 12
group_index: 13
group_index: 14
group_index: 15
group_index: 16
group_index: 17
group_index: 18
group_index: 19
group_index: 20
group_index: 21
group_index: 22
group_index: 23
group_index: 24
group_index: 25
group_index: 26
group_index: 27
group_index: 28
group_index: 29
group_index: 30
group_index: 31
group_index: 32
group_index: 33
group_index: 34
group_index: 35
group_index: 36
group_index: 37
group_index: 38
group_index: 39
group_index: 40
group_index: 41
group_index: 42
group_index: 43
group_index: 44
group_index: 45
group_index: 46
group_index: 47
group_index: 48
group_index: 49
group_index: 50
group_index: 51
group_index: 52
group_index: 53
group_index: 54
group_index: 55
group_index: 56
group_index: 57
group_index: 58
group_index: 59
group_index: 60
group_index: 61
group_index: 62
group_index: 63
g