# Mamba in stock prediction

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import mean_squared_error,mean_absolute_error,r2_score
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from cnn_lstm_kan.mamba import Mamba, MambaConfig

ModuleNotFoundError: No module named 'pscan'

## Configurations

In [62]:
class args:
    seed=1
    epochs=100
    lr=0.01
    wd=1e-5
    hidden=16
    layer=2
    test_size=0.2
    cuda=torch.cuda.is_available()
    ts_code='600519'

In [63]:
def evaluation_metric(y_test,y_hat):
    MSE = mean_squared_error(y_test, y_hat)
    RMSE = MSE**0.5
    MAE = mean_absolute_error(y_test,y_hat)
    R2 = r2_score(y_test,y_hat)
    print('%.4f %.4f %.4f %.4f' % (MSE,RMSE,MAE,R2))

def set_seed(seed,cuda):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)

def dateinf(series, n_test):
    lt = len(series)
    print('Training start',series[0])
    print('Training end',series[lt-n_test-1])
    print('Testing start',series[lt-n_test])
    print('Testing end',series[lt-1])

In [64]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed(args.seed,torch.cuda.is_available())

## Define network

In [65]:
class Net(nn.Module):
    def __init__(self,in_dim,out_dim):
        super().__init__()
        self.config = MambaConfig(d_model=args.hidden, n_layers=args.layer)
        self.mamba = nn.Sequential(
            nn.Linear(in_dim,args.hidden),
            Mamba(self.config),
            nn.Linear(args.hidden,out_dim),
            nn.Tanh()
        )

    def forward(self,x):
        x = self.mamba(x)
        return x.flatten()

## Train and test

In [86]:
trainX.shape

(1362, 12)

In [89]:
# xt = torch.from_numpy(trainX).float().unsqueeze(0)
trainy.shape

(1362,)

In [96]:
def PredictWithData(trainX, trainy, testX):
    clf = Net(len(trainX[0]),1)
    opt = torch.optim.Adam(clf.parameters(),lr=args.lr,weight_decay=args.wd)
    xt = torch.from_numpy(trainX).float().unsqueeze(0)
    xv = torch.from_numpy(testX).float().unsqueeze(0)
    yt = torch.from_numpy(trainy).float()
    print(f'xt {xt.shape} xv {xv.shape} yt {yt.shape}')
    if args.cuda:
        clf = clf.cuda()
        xt = xt.cuda()
        xv = xv.cuda()
        yt = yt.cuda()

    for e in range(args.epochs):
        clf.train()
        z = clf(xt)
        loss = F.mse_loss(z,yt)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if(e%10==0):
            print('Epoch %d | Lossp: %.4f' % (e, loss.item()))

    clf.eval()
    mat = clf(xv)
    if args.cuda: mat = mat.cpu()
    yhat = mat.detach().numpy().flatten()
    return yhat

## Read Data

In [97]:
data = pd.read_csv(f'../data/{args.ts_code}.csv')
data

Unnamed: 0,Date,Code,Open,High,Low,Close,Preclose,Volume,Amount,Adjustflag,Turn,Tradestatus,Pctchg,Pettm,Pbmrq,Psttm,Pcfncfttm
0,2017-08-14,sh.600519,485.21,500.10,485.21,499.83,484.06,3933147,1.952354e+09,3,0.313099,1,3.257858,32.759372,8.302876,13.384504,25.786295
1,2017-08-15,sh.600519,500.11,501.10,495.01,495.97,499.83,2716322,1.350105e+09,3,0.216234,1,-0.772260,32.506384,8.238756,13.281141,25.587157
2,2017-08-16,sh.600519,498.00,498.80,493.00,496.49,495.97,1858722,9.214130e+08,3,0.147964,1,0.104843,32.540465,8.247394,13.295065,25.613984
3,2017-08-17,sh.600519,497.60,497.60,489.80,492.69,496.49,2584673,1.272341e+09,3,0.205754,1,-0.765370,32.291409,8.184271,13.193308,25.417941
4,2017-08-18,sh.600519,492.80,494.44,488.00,489.65,492.69,2385775,1.169733e+09,3,0.189920,1,-0.617023,32.092165,8.133772,13.111903,25.261107
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1697,2024-08-08,sh.600519,1411.00,1448.18,1409.00,1430.69,1415.50,2513082,3.592705e+09,3,0.200100,1,1.073100,23.040091,7.496835,11.398984,1892.855626
1698,2024-08-09,sh.600519,1460.03,1469.00,1436.80,1436.80,1430.69,3013849,4.368706e+09,3,0.239900,1,0.427100,22.435319,8.257577,11.071387,-408.037924
1699,2024-08-12,sh.600519,1430.00,1443.00,1426.58,1436.10,1436.80,1363203,1.956573e+09,3,0.108500,1,-0.048700,22.424389,8.253554,11.065993,-407.839130
1700,2024-08-13,sh.600519,1433.00,1435.00,1412.01,1423.01,1436.10,1630843,2.317038e+09,3,0.129800,1,-0.911500,22.219991,8.178323,10.965127,-404.121691


In [98]:
if 'Close' in data.columns:
    close = data.pop('Close').values
ratechg = data['Pctchg'].apply(lambda x: 0.01 * x).values
data.drop(columns=['Preclose', 'Pctchg'], inplace=True)
dat = data.iloc[:, 2:].values
dat

array([[ 485.21    ,  500.1     ,  485.21    , ...,    8.302876,
          13.384504,   25.786295],
       [ 500.11    ,  501.1     ,  495.01    , ...,    8.238756,
          13.281141,   25.587157],
       [ 498.      ,  498.8     ,  493.      , ...,    8.247394,
          13.295065,   25.613984],
       ...,
       [1430.      , 1443.      , 1426.58    , ...,    8.253554,
          11.065993, -407.83913 ],
       [1433.      , 1435.      , 1412.01    , ...,    8.178323,
          10.965127, -404.121691],
       [1423.01    , 1424.9     , 1412.02    , ...,    8.122518,
          10.890306, -401.364141]])

## Split data

In [99]:
n_test=int(args.test_size*len(dat))
trainX, testX = dat[:-n_test, :], dat[-n_test:, :]
trainy = ratechg[:-n_test]
trainX.shape, testX.shape, trainy.shape

((1362, 12), (340, 12), (1362,))

## Start training and testing

In [100]:
predictions = PredictWithData(trainX, trainy, testX)
time = data['Date'][-n_test:]
data1 = close[-n_test:]
finalpredicted_stock_price = []
pred = close[-n_test - 1]
for i in range(n_test):
    pred = close[-n_test - 1 + i] * (1 + predictions[i])
    finalpredicted_stock_price.append(pred)

xt torch.Size([1, 1362, 12]) xv torch.Size([1, 340, 12]) yt torch.Size([1362])
Epoch 0 | Lossp: 0.0134
Epoch 10 | Lossp: 0.0043
Epoch 20 | Lossp: 0.0009
Epoch 30 | Lossp: 0.0004
Epoch 40 | Lossp: 0.0006
Epoch 50 | Lossp: 0.0006
Epoch 60 | Lossp: 0.0004
Epoch 70 | Lossp: 0.0004
Epoch 80 | Lossp: 0.0004
Epoch 90 | Lossp: 0.0004


## Plotting result

In [78]:
dateinf(data['Date'], n_test)
print('MSE RMSE MAE R2')
evaluation_metric(data1, finalpredicted_stock_price)
plt.figure(figsize=(10, 6))
plt.plot(time, data1, label='Stock Price')
plt.plot(time, finalpredicted_stock_price, label='Predicted Stock Price')
plt.title('Stock Price Prediction')
plt.xlabel('Time', fontsize=12, verticalalignment='top')
plt.ylabel('Close', fontsize=14, horizontalalignment='center')
plt.legend()
plt.show()

Training start 2017-08-14
Training end 2023-03-22
Testing start 2023-03-23
Testing end 2024-08-14
MSE RMSE MAE R2


ValueError: Input contains NaN.