# Import and Set

## Init: Model, Criteria, Optimizer, Fp16, Previous Tarin Inofrmation

In [13]:
%load_ext autoreload
%autoreload 2

import sys 
sys.path.append('../')
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from utils import *
from datas import *
from models.Transformers import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
stock_symbol = '5871.TW'
end_date = '2024-12-31'

num_class = 2
batch_size = 128
init = True
fp16_training = True
num_epochs = 200
lr = 0.0001

# Data
trainloader, validloader, testloader, test_date, df, src = data(stock_symbol=stock_symbol, batch_size=batch_size)
for x, y in trainloader:
    break
x.shape, y.shape, src.shape


"""
Choose if fp16 and define model
pip install accelerate==0.2.0
"""
MODEL = "Transformer"
# Model
if fp16_training:
    print('Accelerating')
    from accelerate import Accelerator
    accelerator = Accelerator()
    device = accelerator.device
    model = Transformer(num_class=num_class)
else:
    model = Transformer(num_class=num_class).to(device)
        

# Settings
print("Init model")
lr = lr
last_epoch = 0
min_val_loss = float("inf")
if os.path.exists(f'Temp//{MODEL}_{stock_symbol}_LastTrainInfo.pk'):
    print('Load from last train epoch')
    with open(f'Temp//{MODEL}_class{num_class}_{stock_symbol}_LastTrainInfo.pk', 'rb') as f:
        last_train_info = pickle.load(f)
    lr = last_train_info['lr']
    last_epoch = last_train_info['epoch']
    min_val_loss = last_train_info['min val loss']
    model.load_state_dict(torch.load(f'Temp//{MODEL}_class{num_class}_{stock_symbol}_checkpoint_LastTrainModel.pt'))
    
print(f'Last train epoch: {last_epoch}  '
        f'Last train lr: {lr}   '
        f'Min val loss: {min_val_loss}')

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.00001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=len(trainloader)*1, gamma=0.9)        

if fp16_training:
    print('Accelerate Prepare')    
    model, optimizer, trainloader, validloader, scheduler = \
        accelerator.prepare(model, optimizer, trainloader, validloader, scheduler)
        


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
x_train_len: 2842, valid_len: 149, test_len: 157
Accelerating
Init model
Last train epoch: 0  Last train lr: 0.0001   Min val loss: inf
Accelerate Prepare


## Train

In [15]:
src.shape

torch.Size([2842, 6, 100])

In [16]:
"""
--- Original ---------
batch_x: (batch_size, d_model, seqlen) 
src: (total_length, d_model, seq_len)
--- Input of model ---
batch_x: (batch_size, seq_len, d_model) -> use src.permute()
src: (total_length, seq_len, d_model)   -> use batch.permute()
"""

src = src.permute(0, 2, 1).to(device)
for epoch in range(last_epoch, num_epochs):
    # Training
    model.train()
    loss_train_e = 0
    for batch_x, batch_y in tqdm(trainloader): 
        batch_x = batch_x.permute(0, 2, 1)    
        optimizer.zero_grad()       
        memory, outputs = model(src=src.to(device), tgt=batch_x, train=True)    
        # Loss
        loss = criterion(outputs, batch_y)
        accelerator.backward(loss)
        optimizer.step()
        loss_train_e += loss.item()
    
    # Train loss
    loss_train_e /= len(trainloader)
    
    # Scheduler 
    if epoch > 100:
        scheduler.step()
    
    # Validating
    loss_valid_e = 0
    with torch.no_grad():
        model.eval()
        for batch_x_val, batch_y_val in tqdm(validloader):
            batch_x_val = batch_x_val.permute(0, 2, 1)    
            memory, outputs_val = model(src, batch_x_val, False, memory)
                
            loss = criterion(outputs_val, batch_y_val)
            loss_valid_e += loss.item()
        loss_valid_e /= len(validloader)
            
        torch.save(model.state_dict(), f'Temp/{MODEL}_class{num_class}_{stock_symbol}_checkpoint_LastTrainModel.pt')
        if loss_valid_e < min_val_loss:
            min_val_loss = loss_valid_e
            print(f'New best model found in epoch {epoch} with val loss: {min_val_loss}')
            torch.save(model.state_dict(), f'Model_Result/{MODEL}_class{num_class}_{stock_symbol}_best_model.pt')  
    
    with open(f'Temp/{MODEL}_class{num_class}_{stock_symbol}_LastTrainInfo.pk', 'wb') as f:
        pickle.dump({'min val loss': min_val_loss, 'epoch': epoch, 'lr': optimizer.param_groups[0]['lr']}, f)        
        
    """print(
        f'Epoch [{epoch}/{num_epochs}]',
        f'Training Loss: {loss_train_e:.5f}',
        f'Valid Loss: {loss_valid_e:.5f}'
        )"""

  0%|          | 0/22 [00:00<?, ?it/s]

torch.Size([2842, 100, 6]) torch.Size([2842, 1, 6])
torch.Size([128, 10, 6]) torch.Size([128, 1, 6])


  0%|          | 0/22 [00:00<?, ?it/s]


RuntimeError: shape '[100, 128, 6]' is invalid for input of size 218265600