# Import and Set

In [4]:
%load_ext autoreload
%autoreload 2

import sys 
sys.path.append('../')
import os
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from utils import *
from datas import *
from models.Decoder_only import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
stock_symbol = '5871.TW'
end_date = '2024-12-31'

num_class = 2
batch_size = 128
init = True
fp16_training = True
num_epochs = 200
lr = 0.0001

# Data
trainloader, validloader, testloader, test_date, df, src = data(stock_symbol=stock_symbol, batch_size=batch_size)
for x, y in trainloader:
    break
x.shape, y.shape, src.shape

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
x_train_len: 2842, valid_len: 149, test_len: 157


(torch.Size([128, 6, 10]), torch.Size([128, 2]), torch.Size([2842, 6, 100]))

## Init: Model, Criteria, Optimizer, Fp16, Previous Tarin Inofrmation

In [5]:
"""
Choose if fp16 and define model
pip install accelerate==0.2.0
"""
MODEL = "Decoder-only"
# Model
if fp16_training:
    print('Accelerating')
    from accelerate import Accelerator
    accelerator = Accelerator()
    device = accelerator.device
    model = TransformerDecoderOnly(num_class=num_class)
else:
    model = TransformerDecoderOnly(num_class=num_class).to(device)
        

# Settings
print("Init model")
lr = lr
last_epoch = 0
min_val_loss = float("inf")
if os.path.exists(f'Temp//{MODEL}_{stock_symbol}_LastTrainInfo.pk'):
    print('Load from last train epoch')
    with open(f'Temp//{MODEL}_class{num_class}_{stock_symbol}_LastTrainInfo.pk', 'rb') as f:
        last_train_info = pickle.load(f)
    lr = last_train_info['lr']
    last_epoch = last_train_info['epoch']
    min_val_loss = last_train_info['min val loss']
    model.load_state_dict(torch.load(f'Temp//{MODEL}_class{num_class}_{stock_symbol}_checkpoint_LastTrainModel.pt'))
    
print(f'Last train epoch: {last_epoch}  '
        f'Last train lr: {lr}   '
        f'Min val loss: {min_val_loss}')

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.00001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=len(trainloader)*1, gamma=0.9)        

if fp16_training:
    print('Accelerate Prepare')    
    model, optimizer, trainloader, validloader, scheduler = \
        accelerator.prepare(model, optimizer, trainloader, validloader, scheduler)
        


Accelerating
Init model
Last train epoch: 0  Last train lr: 0.0001   Min val loss: inf
Accelerate Prepare


## Train

In [8]:
src.shape, src.squeeze(2).unsqueeze(0).shape

(torch.Size([2842, 6, 100]), torch.Size([1, 2842, 6, 100]))

In [3]:
"""
--- Original shape when loading ---------
batch_x: (batch_size, d_model, seqlen) 
src: (total_length, d_model, seq_len)
--- Input shape of model ----------------
batch_x: (batch_size, seq_len, d_model) 
src: (total_length, seq_len, d_model)   
"""

# Code for fp16
src = src.squeeze(2).unsqueeze(0).to(device)
for epoch in tqdm(range(last_epoch, num_epochs)):
    # Training phase
    model.train()
    loss_train_e = 0
    for batch_x, batch_y in trainloader: 
        optimizer.zero_grad() 
        batch_x = batch_x.permute(0, 2, 1)
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        accelerator.backward(loss)
        optimizer.step()
        
        loss_train_e += loss.item()
        if epoch > 50:
            scheduler.step()
        
    loss_train_e /= len(trainloader)
    
    loss_valid_e = 0
    with torch.no_grad():
        model.eval()
        for batch_x_val, batch_y_val in validloader:
            batch_x_val = batch_x_val.permute(0, 2, 1)
            outputs_val = model(batch_x_val)
            loss = criterion(outputs_val, batch_y_val)
            loss_valid_e += loss.item()
        loss_valid_e /= len(validloader)
        
        # Save model for each epoch
        torch.save(model.state_dict(), f'Temp/{MODEL}_class{num_class}_{stock_symbol}_checkpoint_LastTrainModel.pt')
        if loss_valid_e < min_val_loss:
            min_val_loss = loss_valid_e
            
            # Save model for best eval result
            print(f'New best model found in epoch {epoch} with val loss: {min_val_loss}')
            torch.save(model.state_dict(), f'Model_Result/{MODEL}_class{num_class}_{stock_symbol}_best_model.pt')            
    
    # Save training information
    with open(f'Temp/{MODEL}_class{num_class}_{stock_symbol}_LastTrainInfo.pk', 'wb') as f:
        pickle.dump({'min val loss': min_val_loss, 'epoch': epoch, 'lr': optimizer.param_groups[0]['lr']}, f)
        
    """
    print(f'Epoch [{epoch}/{num_epochs}]',
        f'Training Loss: {loss_train_e:.10f}',
        f'Valid Loss: {loss_valid_e:.10f}')
    """

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 1/200 [00:04<15:53,  4.79s/it]

New best model found in epoch 0 with val loss: 1.9583598375320435


  1%|          | 2/200 [00:09<15:46,  4.78s/it]

New best model found in epoch 1 with val loss: 1.8926794528961182


  4%|▍         | 9/200 [00:41<14:43,  4.62s/it]

New best model found in epoch 8 with val loss: 1.8767502307891846


  5%|▌         | 10/200 [00:46<15:02,  4.75s/it]

New best model found in epoch 9 with val loss: 1.8494727611541748


  6%|▌         | 11/200 [00:51<15:05,  4.79s/it]

New best model found in epoch 10 with val loss: 1.8404202461242676


  6%|▋         | 13/200 [01:01<15:06,  4.85s/it]

New best model found in epoch 12 with val loss: 1.81757390499115


  7%|▋         | 14/200 [01:06<14:56,  4.82s/it]

New best model found in epoch 13 with val loss: 1.7975507974624634


  8%|▊         | 16/200 [01:15<14:37,  4.77s/it]

New best model found in epoch 15 with val loss: 1.7748100757598877


  9%|▉         | 18/200 [01:25<14:25,  4.75s/it]

New best model found in epoch 17 with val loss: 1.7527772188186646


 10%|▉         | 19/200 [01:29<14:21,  4.76s/it]

New best model found in epoch 18 with val loss: 1.7443220615386963


 10%|█         | 20/200 [01:34<14:15,  4.75s/it]

New best model found in epoch 19 with val loss: 1.7418633699417114


 10%|█         | 21/200 [01:39<14:12,  4.76s/it]

New best model found in epoch 20 with val loss: 1.7275147438049316


 12%|█▏        | 23/200 [01:48<13:59,  4.74s/it]

New best model found in epoch 22 with val loss: 1.7265808582305908


 12%|█▏        | 24/200 [01:53<14:08,  4.82s/it]

New best model found in epoch 23 with val loss: 1.7159842252731323


 12%|█▎        | 25/200 [01:58<14:10,  4.86s/it]

New best model found in epoch 24 with val loss: 1.6905359029769897


 14%|█▎        | 27/200 [02:08<13:40,  4.74s/it]

New best model found in epoch 26 with val loss: 1.6746641397476196


 14%|█▍        | 29/200 [02:17<13:16,  4.66s/it]

New best model found in epoch 28 with val loss: 1.6682322025299072


 16%|█▌        | 31/200 [02:26<13:10,  4.68s/it]

New best model found in epoch 30 with val loss: 1.6632614135742188


 16%|█▌        | 32/200 [02:31<13:25,  4.79s/it]

New best model found in epoch 31 with val loss: 1.6542341709136963


 16%|█▋        | 33/200 [02:37<13:49,  4.97s/it]

New best model found in epoch 32 with val loss: 1.6466714143753052


 17%|█▋        | 34/200 [02:41<13:37,  4.93s/it]

New best model found in epoch 33 with val loss: 1.6397109031677246


 18%|█▊        | 35/200 [02:46<13:17,  4.83s/it]

New best model found in epoch 34 with val loss: 1.6389164924621582


 18%|█▊        | 36/200 [02:51<13:15,  4.85s/it]

New best model found in epoch 35 with val loss: 1.6316556930541992


 18%|█▊        | 37/200 [02:56<13:16,  4.89s/it]

New best model found in epoch 36 with val loss: 1.6268718242645264


 19%|█▉        | 38/200 [03:01<13:03,  4.84s/it]

New best model found in epoch 37 with val loss: 1.615032434463501


 20%|█▉        | 39/200 [03:05<12:52,  4.80s/it]

New best model found in epoch 38 with val loss: 1.6125450134277344


 20%|██        | 41/200 [03:15<12:45,  4.81s/it]

New best model found in epoch 40 with val loss: 1.6066936254501343


 21%|██        | 42/200 [03:20<12:39,  4.81s/it]

New best model found in epoch 41 with val loss: 1.5868661403656006


 23%|██▎       | 46/200 [03:39<12:32,  4.89s/it]

New best model found in epoch 45 with val loss: 1.580669641494751


 24%|██▍       | 48/200 [03:48<11:57,  4.72s/it]

New best model found in epoch 47 with val loss: 1.5682976245880127


 25%|██▌       | 50/200 [03:59<12:23,  4.96s/it]

New best model found in epoch 49 with val loss: 1.566779613494873


 26%|██▌       | 52/200 [04:08<11:52,  4.81s/it]

New best model found in epoch 51 with val loss: 1.555202841758728


 26%|██▋       | 53/200 [04:13<12:15,  5.00s/it]

New best model found in epoch 52 with val loss: 1.5541694164276123


 27%|██▋       | 54/200 [04:19<12:20,  5.07s/it]

New best model found in epoch 53 with val loss: 1.5514241456985474


 28%|██▊       | 55/200 [04:23<11:59,  4.96s/it]

New best model found in epoch 54 with val loss: 1.5497925281524658


 29%|██▉       | 58/200 [04:39<12:10,  5.14s/it]

New best model found in epoch 57 with val loss: 1.549411654472351


 30%|██▉       | 59/200 [04:44<11:55,  5.08s/it]

New best model found in epoch 58 with val loss: 1.5414636135101318


 30%|███       | 61/200 [04:53<11:38,  5.02s/it]

New best model found in epoch 60 with val loss: 1.5363045930862427


 36%|███▌      | 72/200 [05:48<10:44,  5.04s/it]

New best model found in epoch 71 with val loss: 1.535282850265503


 37%|███▋      | 74/200 [05:57<10:06,  4.81s/it]

New best model found in epoch 73 with val loss: 1.5328164100646973


 42%|████▏     | 84/200 [06:46<09:30,  4.92s/it]

New best model found in epoch 83 with val loss: 1.532764196395874


 44%|████▍     | 88/200 [07:06<09:26,  5.05s/it]

New best model found in epoch 87 with val loss: 1.5315158367156982


 53%|█████▎    | 106/200 [08:34<07:37,  4.86s/it]

New best model found in epoch 105 with val loss: 1.529775857925415


 57%|█████▋    | 114/200 [09:13<07:16,  5.08s/it]

New best model found in epoch 113 with val loss: 1.522719144821167


 73%|███████▎  | 146/200 [11:53<04:25,  4.91s/it]

New best model found in epoch 145 with val loss: 1.5207993984222412


100%|██████████| 200/200 [16:22<00:00,  4.91s/it]
