# TS Transformers

> *注：本代码在编写过程中出现的Crossing Transformer即为文章中提出的TS-Transformer。在编写代码时未发现Crossing Transformer这一名称已在其他工作中使用，因此在文章撰写时改名为TS-Transformer。代码中未及时替换导致的误解作者表示抱歉。*

## 模型超参数

In [None]:
FILTER_LEN = 2000

SAMPLE_SIZE = 200

BATCH_SIZE = 64

NUM_EPOCHS = 10

LOOK_BACK = 45

K = 150

EARLYSTOP_PATIENCE = 100

## 数据处理

In [None]:
import torch
import numpy as np
import random 
import time
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader,Dataset
import torch.nn as nn
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from IPython.display import clear_output

####################################### Random Seed #######################################
seed = 3407
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)

####################################### Import Data #######################################
file_path = 'label.csv'
data_raw = pd.read_csv(file_path)

In [None]:
df = data_raw.drop('label', axis=1)
df = df.drop('Volume', axis=1)
df = df.dropna()
df.head()

############### Data Preprocessing #################
df[['High', 'Open', 'Low', 'Close', 'Avg']] = np.log(df[['High', 'Open', 'Low', 'Close', 'Avg']]) 
df.dropna(inplace=True)

df['year'] = pd.to_datetime(df['date']).dt.year
df['month'] = pd.to_datetime(df['date']).dt.month
df['day'] = pd.to_datetime(df['date']).dt.day

df['year'] = df['year'].astype(int)
df['month'] = df['month'].astype(int)
df['day'] = df['day'].astype(int)

df.drop('date', axis=1, inplace=True)

df = df[['year', 'month', 'day', 'stock'] + [col for col in df.columns if col not in ['year', 'month', 'day', 'stock']]]

df.head(20)

In [None]:
class StockDataset(Dataset):
    def __init__(self, data, lookback, predict_len=1):
        self.data = []
        self.targets = []
        self.stock_idx = []
        self.lookback = lookback
        self.stock_encoder = LabelEncoder()
        self.minmaxscale = MinMaxScaler()
        self.dta = data
        self.predict_len = predict_len
        self.dta['stock_enc'] = self.stock_encoder.fit_transform(self.dta['stock'])

    def __len__(self):
        return len(self.dta.groupby(by=["stock_enc"]))

    def __getitem__(self, idx):
        idx_stock_data = self.dta[self.dta["stock_enc"] == idx]

        start = np.random.randint(0, idx_stock_data.shape[0] - self.lookback)
        stock_name = str(idx_stock_data[['stock']][start:start+self.predict_len].values.item()) # 获取第idx个股票的股票编号
        index_input = torch.tensor([i for i in range(start, start+self.lookback)]) # 获得第idx个股票从start到start+lookback (input)的编号索引
        index_target = torch.tensor([i for i in range(start + self.lookback, start + self.lookback + self.predict_len)]) # target 即 input 后predict_len天的数据
        _input = torch.tensor(self.dta[self.dta['stock_enc'] == idx][['year','month' ,'day', 'High', 'Open', 'Low', 'Close', 'Avg']][start:start+self.lookback].values) # 获取第idx个股票从start到start+lookback的数据，其格式为torch.tensor
        
        # target.shape:(batch_size, predict_len, features)
        target = torch.tensor(self.dta[self.dta['stock_enc'] == idx][['year','month' ,'day', 'High', 'Open', 'Low', 'Close', 'Avg']][start+self.lookback:start+self.lookback+1].values) # 获取第idx个股票从start+lookback到start+lookback+prelen的数据，其格式为torch.tensor
        
        return index_input, index_target, _input, target, stock_name   

In [None]:
stock_unique = df['stock'].unique()
selected_stocks = random.sample(stock_unique.tolist(), SAMPLE_SIZE)
df_selected = df[df['stock'].isin(selected_stocks)]

df_selected = df_selected.groupby('stock').filter(lambda x: len(x) >= FILTER_LEN)

# 假设df是经过修改后包含'year'、'month'、'day'列的DataFrame
# 首先，我们需要将这些列重新组合为日期，以便确定唯一的日期
df_selected.loc[:, 'combined_date'] = pd.to_datetime(df[['year', 'month', 'day']])

# 获取唯一日期
unique_dates = df_selected['combined_date'].unique()
train_dates = unique_dates[:int(len(unique_dates) * 0.8)]
# print(train_dates)
test_dates = unique_dates[int(len(unique_dates) * 0.8):]

# 使用组合后的日期来切分训练集和测试集
train_df = df_selected[df_selected['combined_date'].isin(train_dates)].copy()
test_df = df_selected[df_selected['combined_date'].isin(test_dates)].copy()

# 删除'combined_date'列
train_df.drop('combined_date', axis=1, inplace=True)
test_df.drop('combined_date', axis=1, inplace=True)

train_df = train_df.groupby('stock').filter(lambda x: len(x) >= 150)
test_df = test_df.groupby('stock').filter(lambda x: len(x) >= 150)

print(f"train_df.shape:{train_df.shape}")

train_dataset = StockDataset(train_df, lookback=LOOK_BACK)  # 以30天为一个查看窗口
test_dataset = StockDataset(test_df, lookback=LOOK_BACK)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

In [None]:
df_selected['NewFeature'] = df_selected['Close'].notna().astype(int)
df_pivot = df_selected.pivot(index='stock', columns='combined_date', values='NewFeature')

plt.figure(figsize=(14, 10))
sns.heatmap(df_pivot, annot=False, fmt=".2f", cmap="YlGnBu")
plt.title('Stock Data Usage Heatmap')
plt.xlabel('Stock Code')
plt.ylabel('Date')
plt.show()

print(f"train_df.shape:{train_df.shape}")
print(f"train stock num:{train_df['stock'].nunique()}")
print(f"test_df.shape:{test_df.shape}")
print(f"test stock num:{test_df['stock'].nunique()}")

## 模型定义

In [None]:

# 通用SelfAttention模块
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
    
    def forward(self, values, keys, query):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        # Split the embedding into `self.heads` pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # Query-Key score matrix
        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)

        return self.fc_out(out)

# DatePositionalEncoding模块
def encode_cyclic_feature(value, max_value):
    value = value.float()  # Ensure float for computation
    sin_value = torch.sin(2 * np.pi * value / max_value)
    cos_value = torch.cos(2 * np.pi * value / max_value)
    return sin_value, cos_value

class FeedForward_silu(nn.Module):
    def __init__(self, embed_size, ff_hidden_size):
        super(FeedForward_silu, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_hidden_size)
        self.fc2 = nn.Linear(ff_hidden_size, embed_size)
        self.silu = nn.SiLU()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.silu(x)
        x = self.fc2(x)
        return x

class DatePositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(DatePositionalEncoding, self).__init__()
        self.d_model = d_model
        # Assuming you still want to encode year, month, and day
        self.linear = nn.Linear(5, d_model)  # Input features: normalized_year, sin_month, cos_month, sin_day, cos_day

    def forward(self, date_features):
        
        if isinstance(date_features, list):
            date_features = torch.tensor(date_features, dtype=torch.float32)
        # print(f"{date_features.size()}")
        # date_features is expected to be a tensor with shape [batch_size, 3] where columns are year, month, day
        batch_size = date_features.size(0)
        year = date_features[:, :,0]
        month = date_features[:,:, 1]
        day = date_features[:, :,2]
        
        normalized_year = (year - 2012) / (2023 - 2012)  # Example normalization
        
        # Assuming encode_cyclic_feature is a function that encodes month and day as cyclic features
        sin_month, cos_month = encode_cyclic_feature(month, 12)
        sin_day, cos_day = encode_cyclic_feature(day, 31)

        encoded_features = torch.stack([normalized_year, sin_month, 
                                        cos_month, sin_day, cos_day], dim=2)
        
        position_encodings = self.linear(encoded_features)  # Map to d_model dimensions
        return position_encodings
    
class TransformerBlock_silu(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden_size, dropout_rate):
        super(TransformerBlock_silu, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = FeedForward_silu(embed_size, ff_hidden_size)
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, value, key, query):
        attention = self.attention(value, key, query)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out
    
class CrossingTransformer(nn.Module):
    def __init__(self, embed_size, num_layers, heads, device, ff_hidden_size, dropout_rate, input_features):
        super(CrossingTransformer, self).__init__()
        self.device = device
        self.positional_encoding = DatePositionalEncoding(embed_size)
        self.layers = nn.ModuleList([TransformerBlock_silu(embed_size, heads, ff_hidden_size, dropout_rate) for _ in range(num_layers)])
        self.feature_embedding = nn.Linear(input_features, embed_size)
        self.output_layer = nn.Linear(embed_size, input_features)
        self.gate = nn.Linear(embed_size*2, embed_size)
        
    def forward(self, x):
        # data division
        dates = x[:,:,:3]  # dates.shape:(batch_size, lookback, 3)
        x = x[:,:, 3:]  # x.shape:(batch_size, lookback, features)
        # print(f"1 x.shape:{x.shape}")
        # feature embedding
        x = self.feature_embedding(x) # x.shape:(lookback, batch_size, embed_size)
        # positional encoding
        positions = self.positional_encoding(dates).to(self.device) # positions.shape:(lookback, batch_size, embed_size)
        x = x + positions
        # transformer blocks
        for layer in self.layers:
            x = layer(x, x, x)
        # print(f"2 x.shape:{x.shape}")
        ############################################################
        #                  SPACIAL  TRANSFORMER                    #
        ############################################################ 
        y = x.permute(1,0,2) # x.shape:(batch_size, lookback, features)
        # print(f"3 y.shape:{y.shape}")
                # transformer blocks
        for layer in self.layers:
            y = layer(y, y, y)

        y = y.permute(1,0,2) # x.shape:(lookback, batch_size, features)    

        # print(f"4 y.shape:{y.shape}")
        ############################################################
        #                  GATE  LAYER                             #
        ############################################################

        x_y = torch.cat((x,y),dim=-1)

        # print(f"x.shape:{x.shape}")
        # print(f"y.shape:{y.shape}")
        # print(f"x_y.shape:{x_y.shape}")
        
        gate_weight = torch.sigmoid(self.gate(x_y))
        output = gate_weight * x + (1-gate_weight) * y
        
        output = output[-1,:,:]
        output = self.output_layer(output)
        return output # return x.shape: (1, batch_size, features)
    

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)]
       
class VanillaTransformer(nn.Module):
    def __init__(self, embed_size, num_layers, heads, device, ff_hidden_size, dropout_rate, input_features):
        super(VanillaTransformer, self).__init__()
        self.device = device
        # self.positional_encoding = DatePositionalEncoding(embed_size)
        self.positional_encodnig_original = PositionalEncoding(embed_size)
        self.layers = nn.ModuleList([TransformerBlock_relu(embed_size, heads, ff_hidden_size, dropout_rate) for _ in range(num_layers)])
        self.feature_embedding = nn.Linear(input_features, embed_size)
        self.output_layer = nn.Linear(embed_size, input_features)
        self.embed_size = embed_size
        
    def forward(self, x):
        # data division
        # dates = x[:,:,:3]  # dates.shape:(lookback, batch_size, 3)
        x = x[:,:, 3:]  # x.shape:(lookback, batch_size, features)
        # feature embedding
        x = self.feature_embedding(x)*np.sqrt(self.embed_size) # x.shape:(lookback, batch_size, embed_size)
        # positional encoding
        pos = self.positional_encodnig_original(x).to(self.device) 
        x = x + pos

        # transformer blocks
        for layer in self.layers:
            x = layer(x, x, x)
        x = x[-1,:,:]
        x = self.output_layer(x)
        return x # return x.shape: (1, batch_size, features)

    
class FeedForward_relu(nn.Module):
    def __init__(self, embed_size, ff_hidden_size):
        super(FeedForward_relu, self).__init__()
        self.fc1 = nn.Linear(embed_size, ff_hidden_size)
        self.fc2 = nn.Linear(ff_hidden_size, embed_size)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
    
class TransformerBlock_relu(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden_size, dropout_rate):
        super(TransformerBlock_relu, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = FeedForward_relu(embed_size, ff_hidden_size)
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, value, key, query):
        attention = self.attention(value, key, query)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out


class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.counter = 0
        self.EpochSleep = 700

    def __call__(self, val_loss,epoch):
        score = val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience and epoch > self.EpochSleep:
                self.early_stop = True
            return (f'EarlyStopping counter: {self.counter} out of {self.patience}', self.best_score)
            
        else:
            self.best_score = score
            self.save_checkpoint(val_loss)
            self.counter = 0

    def save_checkpoint(self, val_loss):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        self.val_loss_min = val_loss




Utils

In [None]:
def Show_TrainInfo(index_input, index_target, _input, target, stock_name,epoch,verbose=True):
    print(f"===================================TRAIN INFO / EPOCH: {epoch}===================================")
    print(f"[Train Stock] ----- {stock_name[:]}")
    print(f"[Input Shape] ----- {index_input.shape}")
    print(f"[Stock Nums] ----- {index_input.shape[0]}")
    print(f"[LookBack Len] ----- {index_input.shape[-1]}")
#     print(f"[Start Date] ----- {_input[:,-1,:3].cpu().numpy().astype(int)}")

def Show_ValInfo(index_input, index_target, _input, target, stock_name,epoch,verbose=True):
    print(f"===================================VAL INFO / EPOCH: {epoch}===================================")
#     print(f"[Train Stock] ----- {stock_name[:]}")
    print(f"[Input Shape] ----- {index_input.shape}")
#     print(f"[Stock Nums] ----- {index_input.shape[-1]}")
#     print(f"[LookBack Len] ----- {index_input.shape[0]}")
#     print(f"[Start Date] ----- {_input[:,-1,:3].cpu().numpy().astype(int)}")
    
def record_data_directly(stock_names, date_tensors, train_info):
    for idx, stock_name in enumerate(stock_names):
        for date_tensor in date_tensors[idx]:
            year, month, day = date_tensor.numpy().astype(int)
            date_str = f"{year:04d}-{month:02d}-{day:02d}"
            key = (stock_name, date_str)
            
            # 更新train_info中的计数
            if key not in train_info:
                train_info[key] = 1
            else:
                train_info[key] += 1

## 模型训练

In [None]:
def TrainModel_RandomSampling(TrainLoader, TestLoader, Model, Optimizer, Criterion, Device, 
                              NumEpochs, EarlyStopPatience, LearningRate, Sample_K,MODELNAME, VerboseFreq=5):
    
    device = Device
    model = Model.float().to(device)
    criterion = Criterion
    optimizer = Optimizer(model.parameters(), lr=LearningRate)

    train_loader = TrainLoader
    test_loader = TestLoader
    early_stopping = EarlyStopping(patience=EarlyStopPatience, verbose=True)
    num_epochs = NumEpochs
    k = Sample_K
    
    train_losses = []
    val_losses = []
    train_data_counts = {}
    train_info = {}

    for epoch in range(NumEpochs):
        
        print(f"{epoch}/{NumEpochs}")

        ####################################### Training #######################################
        model.train()
        
        total_loss = 0
        
        for j, (index_input, index_target, _input, target, stock_name) in enumerate(train_loader):
            # Show_TrainInfo(index_input, index_target, _input, target, stock_name, epoch)
            record_data_directly(stock_name, _input[:,:, :3], train_info)

            optimizer.zero_grad()

            src = _input.permute(1,0,2).float().to(device)[:-1,:,:]  # src.shape:(lookback, batch_size, features)

            # print(f"src.shape:{src.shape}")
            target = target.permute(1,0,2).float().to(device)[:,:,3:].view(-1,5) # target.shape:(predict_len, batch_size, features)
            # print(f"target.shape:{target.shape}")
            sampled_src = src[:1,:,:]
            # print(f"sampled_src.shape:{sampled_src.shape}")
            count_false = 0
            
            for i in range (src.shape[0]-1):
                prediction = model(sampled_src) # prediction.shape: (batch_size, features)
                if i<15:                            
                        prob_true_val = True
                else: 
                        v = k / (k + np.exp(epoch/k))  
                        prob_true_val = np.random.choice([True, False], p=[v, 1-v])

                if prob_true_val or count_false > 3: # use true value
                        sampled_src = torch.cat((sampled_src.detach(), src[i+1,:,:].unsqueeze(0).detach()))
                else: # use predicted value
                        count_false += 1
                        positional_encodings_new_val = src[i+1,:,:3].unsqueeze(0) #shape(1,13,3)                        
                        predicted_features = torch.cat((positional_encodings_new_val, prediction.unsqueeze(0)), dim=2)
                        sampled_src = torch.cat((sampled_src.detach(), predicted_features.detach()))

            loss = criterion(target,prediction)
            loss.backward()
            optimizer.step()

            total_loss += loss.detach().item()
            # print(f"len(target): {len(target)} / BATCH: {j} / EPOCH: {epoch} / ITERATION: {i}")
            # print(f"\r\nEPOCH {epoch+1} \nT_LOSS ----- [{round(total_loss/(j+1),3)}]")
        
        epoch_loss = total_loss / len(train_loader)
        if epoch > 0:
            train_losses.append(epoch_loss)
            
        if epoch % 50 == 0:
            torch.save(model.state_dict(),f'{MODELNAME}_epoch{epoch}.pth')
            


        ####################################### Validation #######################################

        model.eval()
        val_loss = 0

        # with torch.no_grad():
        #     for index_input, index_target, _input, target, stock_name in test_loader:
        #         # Show_ValInfo(index_input, index_target, _input, target, stock_name,epoch)

        #         src = _input.permute(1,0,2).float().to(device)[:-1,:,:]
        #         target = target.permute(1,0,2).float().to(device)[:,:,3:].view(-1,5)
                
        #         prediction = model(src)
        #         v_loss = criterion(target, prediction)
        #         val_loss += v_loss.item()
        # val_loss /= len(test_loader)
        # if epoch > 0:
        #     val_losses.append(val_loss)
        
        # print(f"\r\nEPOCH {epoch+1} \nT_LOSS ----- [{round(total_loss/(j+1),3)}] \nV_LOSS ----- [{round(val_loss,3)}]")


        # ####################################### Early Stopping #######################################
        # check = early_stopping(val_loss,epoch)
        # print(check)
        # if early_stopping.early_stop:
        #     print("Early stopping")
        #     break

        ####################################### Training Visualization #######################################
        if (epoch + 1) % VerboseFreq == 0:
            clear_output(wait=True)
            plt.figure(figsize=(20, 8))
            # plt.xlim(0,700)
            plt.plot(train_losses, label='Training Loss',color='blue',marker='o')
            plt.plot(val_losses, label='Validation Loss',color='red', linestyle='dashed')
            plt.legend()
            plt.xlabel('Epoch = '+str(epoch+1))
            plt.title('T Loss = '+str(round(epoch_loss,3)))
            # plt.annotate(check, (0,0), (600, 400), xycoords='axes fraction', textcoords='offset points', va='top')
            plt.show()
            plt.clf() 
            
    return model, train_losses, val_losses, train_info

def Training_HeatMap(train_info):
    train_info_list = [(stock_code, date, count) for (stock_code, date), count in train_info.items()]
    train_info_dt = pd.DataFrame(train_info_list, columns=['Stock', 'Date', 'Count'])
    train_info_dt_pivot = train_info_dt.pivot(index='Date', columns='Stock', values='Count').fillna(0)
    plt.figure(figsize=(8, 10))
    sns.heatmap(train_info_dt_pivot, annot=False, fmt=".2f", cmap="YlGnBu")
    plt.title('Stock Data Usage Heatmap')
    plt.xlabel('Stock Code')
    plt.ylabel('Date')
    plt.show()





In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"DEVICE: {device}")
model_CrossingTransformer = CrossingTransformer(embed_size=256, num_layers=6, heads=8, device=device, ff_hidden_size=512, dropout_rate=0.1, input_features=5).float().to(device)
model_VanillaTransformer = VanillaTransformer(embed_size=256, num_layers=12, heads=8, device=device, ff_hidden_size=512, dropout_rate=0.1, input_features=5).float().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam

In [None]:
time_start = time.time()
model_CrossingTransformer, train_losses_cstrans, val_losses_cstrans, train_info_cstrans = TrainModel_RandomSampling(TrainLoader=train_loader, TestLoader=test_loader, 
                                                                        Model=model_CrossingTransformer, Optimizer=optimizer, Criterion=criterion, Device=device, 
                                                                        NumEpochs=NUM_EPOCHS,EarlyStopPatience=EARLYSTOP_PATIENCE, 
                                                                        LearningRate=0.0001, Sample_K=K, MODELNAME='CrT',VerboseFreq=5)
time_end = time.time()
print('Time Cost:', time_end - time_start, 's')

In [None]:
time_start = time.time()
model_VanillaTransformer, train_losses_vtrans, val_losses_vtrans, train_info_vtrans = TrainModel_RandomSampling(TrainLoader=train_loader, TestLoader=test_loader,
                                                                        Model=model_VanillaTransformer, Optimizer=optimizer, Criterion=criterion, Device=device, 
                                                                        NumEpochs=NUM_EPOCHS, EarlyStopPatience=EARLYSTOP_PATIENCE, 
                                                                        LearningRate=0.0001, Sample_K=K, MODELNAME='VnlT',VerboseFreq=5)
time_end = time.time()
print('Time Cost:', time_end - time_start, 's')

绘制训练热图、损失图

In [None]:
train_info_list = [(stock_code, date, count) for (stock_code, date), count in train_info_cstrans.items()]
train_info_dt = pd.DataFrame(train_info_list, columns=['Stock', 'Date', 'Count'])
train_info_dt_pivot = train_info_dt.pivot(index='Date', columns='Stock', values='Count').fillna(0)


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
plt.figure(figsize=(4, 5),dpi=3000) 
plt.style.use('seaborn-darkgrid')
sns.heatmap(train_info_dt_pivot, annot=False, fmt=".2f", cmap="YlGnBu")
plt.title('Training Heatmap')
plt.xlabel('Stock Code')
# no x labels
plt.xticks([])
plt.ylabel('Date')
plt.show()

In [None]:
# import pandas as pd
# import matplotlib.pyplot as plt

# plt.figure(figsize=(10, 6))
# plt.plot(losses, label='Training Loss')
# plt.title('Loss During Training')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.legend()
# plt.annotate(f"K={K} / NumEpochs={NUM_EPOCHS} / SampleSize={SAMPLE_SIZE} / BatchSize={BATCH_SIZE} / LookBack={LOOK_BACK} / \nStock: {stock_name[:]} ", (0,0), (0, -40), xycoords='axes fraction', textcoords='offset points', va='top')
# plt.grid(True)
# plt.show()

## 模型测试

TS-Transformer

In [None]:
from tqdm import tqdm

model_CrossingTransformer.eval()
# test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
predictions = []
actuals = []

with torch.no_grad():

    for epoch in tqdm(range(50)):
        for index_input, index_target, _input, target, stock_name in test_loader:

            src = _input.permute(1,0,2).float().to(device)[:-1,:,:]
            target = target.permute(1,0,2).float().to(device)[:,:,3:].view(-1,5)
            # print(f"***target shape:{torch.exp(target)}***")     
    
            prediction = model_CrossingTransformer(src)
            # print(f"***prediction shape:{torch.exp(prediction)}***")

            predictions.append(prediction.cpu().numpy())
            actuals.append(target.cpu().numpy())

            # print((torch.exp(target)-torch.exp(prediction))/torch.exp(target))

predictions = np.array(predictions)
actuals = np.array(actuals)

mse = mean_squared_error(actuals.reshape(-1,5), predictions.reshape(-1,5))
mae = mean_absolute_error(actuals.reshape(-1,5), predictions.reshape(-1,5))
rmse = np.sqrt(mse)

print(f"MSE: {mse}")
print(f"MAE: {mae}")
print(f"RMSE: {rmse}")


VanTrans

In [None]:

model_VanillaTransformer.eval()
# test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
predictions = []
actuals = []

with torch.no_grad():

    for epoch in tqdm(range(50)):
        for index_input, index_target, _input, target, stock_name in test_loader:

            src = _input.permute(1,0,2).float().to(device)[:-1,:,:]
            target = target.permute(1,0,2).float().to(device)[:,:,3:].view(-1,5)
            # print(f"***target shape:{torch.exp(target)}***")     

            prediction = model_VanillaTransformer(src)
            # print(f"***prediction shape:{torch.exp(prediction)}***")

            predictions.append(prediction.cpu().numpy())
            actuals.append(target.cpu().numpy())

            # print((torch.exp(target)-torch.exp(prediction))/torch.exp(target))
predictions = np.array(predictions)
actuals = np.array(actuals)
mse = mean_squared_error(actuals.reshape(-1,5), predictions.reshape(-1,5))
mae = mean_absolute_error(actuals.reshape(-1,5), predictions.reshape(-1,5))
rmse = np.sqrt(mse)

print(f"MSE: {mse}")
print(f"MAE: {mae}")
print(f"RMSE: {rmse}")


In [None]:

# model_CrossingTransformer.eval()
# # test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
# predictions = []
# actuals = []

# # with torch.no_grad():

#     for epoch in range(3):
#         for index_input, index_target, _input, target, stock_name in test_loader:

#             src = _input.permute(1,0,2).float().to(device)[:-1,:,:]
#             target = target.permute(1,0,2).float().to(device)[:,:,3:].view(-1,5)
#             print(f"***target shape:{torch.exp(target)}***")     

#             prediction = model_CrossingTransformer(src)
#             print(f"***prediction shape:{torch.exp(prediction)}***")

#             predictions.append(prediction.cpu().numpy())
#             actuals.append(target.cpu().numpy())

#             print(torch.exp(target)-torch.exp(prediction))



In [None]:
# for index_input, index_target, _input, target, stock_name in train_loader:
#     # Show_ValInfo(index_input, index_target, _input, target, stock_name,epoch)
#     # 对验证数据进行相同的数据处理
#     src = _input.permute(1,0,2).float().to(device)[:-1,:,:]
#     target = target.permute(1,0,2).float().to(device)[:,:,3:].view(-1,6)
#     # 进行预测
#     prediction = model(src)
#     print(prediction)

In [None]:
# plt.plot(predictions[:100,0,0], label='Actuals')
# # plt.plot(predictions[0][:,:], label='Predictions')
# plt.xlabel('Time')
# plt.ylabel('Value')
# plt.legend()
# plt.show()


## 可视化

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import scienceplots

plt.figure(figsize=(5, 2.65), dpi=2000)
plt.style.use(['ieee','science'])
plt.plot(train_losses_cstrans, label='TS-Transformer')
plt.plot(train_losses_vtrans, label='Vanilla Transformer')
plt.title('Training Losses')
plt.xlabel('Epoch')
plt.ylim(0, 3)
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()