In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd

def create_states(df, window_size=9):
    states = []
    for i in range(window_size, len(df)):
        state = df.iloc[i-window_size:i].values
        states.append(state)
    return np.array(states)

class ConvDQN(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ConvDQN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * window_size, 128)
        self.fc2 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # Change the shape to (batch_size, num_features, window_size)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten the output from the conv layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [2]:
# Assuming df_train is your DataFrame
df_train = pd.DataFrame(np.random.randn(100, 5))  # Example DataFrame, replace with actual data
window_size = 9
states = create_states(df_train, window_size)

# Example dimensions
input_dim = states.shape[2]  # number of features
output_dim = 3  # buy, sell, hold

# Initialize the model and move it to the device
model = ConvDQN(input_dim, output_dim)

# Example of using the model
# Assuming states is the input data
states_tensor = torch.tensor(states, dtype=torch.float32)
output = model(states_tensor)
print(output)

tensor([[ 0.0849, -0.0457, -0.0547],
        [ 0.1491, -0.0138, -0.0896],
        [ 0.1372, -0.0313, -0.1212],
        [ 0.0391, -0.0294, -0.0590],
        [ 0.0991, -0.0357, -0.1130],
        [ 0.1326, -0.0102, -0.0918],
        [ 0.0711, -0.0159, -0.1189],
        [ 0.0794,  0.0101, -0.0640],
        [ 0.0801, -0.0242, -0.0952],
        [ 0.1157, -0.0397, -0.1595],
        [ 0.0922, -0.0196, -0.0977],
        [ 0.1124, -0.0270, -0.0634],
        [ 0.1143, -0.0206, -0.0745],
        [ 0.1115, -0.0148, -0.0764],
        [ 0.1151, -0.0542, -0.0947],
        [ 0.1175, -0.0220, -0.0843],
        [ 0.1262,  0.0002, -0.0697],
        [ 0.1224, -0.0259, -0.0746],
        [ 0.1133, -0.0116, -0.1009],
        [ 0.1275,  0.0088, -0.1040],
        [ 0.0690, -0.0482, -0.0752],
        [ 0.1659, -0.0286, -0.0827],
        [ 0.1088, -0.0507, -0.1364],
        [ 0.1158, -0.0618, -0.0931],
        [ 0.1213,  0.0324, -0.1045],
        [ 0.1487, -0.0344, -0.1001],
        [ 0.0829, -0.0419, -0.0805],
 