In [1]:
import sys 
sys.path.append('../')
import os
import pickle
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from utils import *
from datas import data
from set_train import *
from models.ConformerResnet import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
stock_symbol, end_date, num_class, batch_size, init, fp16_training, num_epochs, lr = set_train()
trainloader, validloader, testloader, test_date, df = data(stock_symbol, num_class, end_date, batch_size, window=100)
torch.cuda.empty_cache()

100%|██████████| 2863/2863 [00:03<00:00, 854.45it/s] 
100%|██████████| 2863/2863 [00:06<00:00, 425.46it/s]


x_train_len: 2543, valid_len: 160, test_len: 160


## Init: Model, Criteria, Optimizer, Fp16, Load previous trained

In [2]:
if fp16_training:
    from accelerate import Accelerator
    accelerator = Accelerator()
    device = accelerator.device
    model = Conformer_Resnet(num_class)
else:
    model = Conformer_Resnet(num_class).to(device)
Model = model.model_type

# Check path
if os.path.exists(f'Temp//{Model}_{stock_symbol}_LastTrainInfo.pk'):
    if init:
        print("Init model")
        lr = lr
        last_epoch = 0
        min_val_loss = 10000
        loss_train = []
        loss_valid = []
    else:
        print('Load from last train epoch')
        model.load_state_dict(torch.load(f'Temp//{Model}_class{num_class}_{stock_symbol}_checkpoint_LastTrainModel.pt'))
        with open(f'Temp//{Model}_class{num_class}_{stock_symbol}_LastTrainInfo.pk', 'rb') as f:
            last_train_info = pickle.load(f)
        with open(f'Temp//{Model}_class{num_class}_{stock_symbol}_TrainValHistLoss.pk', 'rb') as f:
            loss_train_val = pickle.load(f)            
        lr = last_train_info['lr']
        last_epoch = last_train_info['epoch']
        min_val_loss = last_train_info['min val loss']
        loss_train = loss_train_val['train']
        loss_valid = loss_train_val['valid']
else:
        print("Init model")
        lr = lr
        last_epoch = 0
        min_val_loss = 10000.0
        loss_train = []
        loss_valid = []
        
print(
    f'Last train epoch: {last_epoch}  '
    f'Last train lr: {lr}   '
    f'Min val loss: {min_val_loss}'
    )

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.00001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)        

if fp16_training:
    print('Accelerate Prepare')    
    model, optimizer, trainloader, validloader, scheduler = \
    accelerator.prepare(model, optimizer, trainloader, validloader, scheduler)
        
for name, param in model.named_parameters():
    print(f"Parameter '{name}' is on device: {param.device}")
    break

Init model
Last train epoch: 0  Last train lr: 1e-05   Min val loss: 10000.0
Accelerate Prepare
Parameter 'conv1.weight' is on device: cuda:0


# Train

In [3]:
for epoch in range(last_epoch, num_epochs):
    # Training phase
    model.train()
    loss_train_e = 0
    for batch_x, batch_y in tqdm(trainloader):
        if not fp16_training:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward() if not fp16_training else accelerator.backward(loss)     
        optimizer.step()
        scheduler.step()
        loss_train_e += loss.item()
    loss_train_e /= len(trainloader)
    loss_train.append(loss_train_e)
    
    # Scheduler 
    if epoch > 200:
        scheduler.step()
    
    loss_valid_e = 0
    with torch.no_grad():
        model.eval()
        for batch_x_val, batch_y_val in tqdm(validloader):
            batch_x_val = batch_x_val.to(device)
            batch_y_val = batch_y_val.to(device)
            outputs_val = model(batch_x_val)
            loss = criterion(outputs_val, batch_y_val)
            loss_valid_e += loss.item()
        loss_valid_e /= len(validloader)
        loss_valid.append(loss_valid_e)            
        torch.save(model.state_dict(), f'Temp/{Model}_class{num_class}_{stock_symbol}_checkpoint_LastTrainModel.pt')
        if loss_valid_e < min_val_loss:
            min_val_loss = loss_valid_e
            print(f'New best model found in epoch {epoch} with val loss: {min_val_loss}')
            torch.save(model.state_dict(), f'Model_Result/{Model}_class{num_class}_{stock_symbol}_best_model.pt')            
        if epoch % 50 == 0:
            pass
            
    with open(f'Temp/{Model}_class{num_class}_{stock_symbol}_TrainValHistLoss.pk', 'wb') as f:
        pickle.dump({'train': loss_train, 'valid': loss_valid}, f)
    with open(f'Temp/{Model}_class{num_class}_{stock_symbol}_LastTrainInfo.pk', 'wb') as f:
        pickle.dump({'min val loss': min_val_loss, 'epoch': epoch, 'lr': optimizer.param_groups[0]['lr']}, f)
    print(
        f'Epoch [{epoch}/{num_epochs}]',
        f'Training Loss: {loss_train_e:.10f}',
        f'Valid Loss: {loss_valid_e:.10f}'
        )


100%|██████████| 159/159 [00:28<00:00,  5.56it/s]
100%|██████████| 10/10 [00:00<00:00, 26.78it/s]


New best model found in epoch 0 with val loss: 3.83794686794281
Epoch [0/50] Training Loss: 3.7793970507 Valid Loss: 3.8379468679


100%|██████████| 159/159 [00:28<00:00,  5.60it/s]
100%|██████████| 10/10 [00:00<00:00, 28.32it/s]


New best model found in epoch 1 with val loss: 3.837934124469757
Epoch [1/50] Training Loss: 3.7053200976 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:24<00:00,  6.48it/s]
100%|██████████| 10/10 [00:00<00:00, 29.36it/s]


Epoch [2/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:24<00:00,  6.50it/s]
100%|██████████| 10/10 [00:00<00:00, 28.68it/s]


Epoch [3/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:24<00:00,  6.56it/s]
100%|██████████| 10/10 [00:00<00:00, 28.43it/s]


Epoch [4/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.74it/s]
100%|██████████| 10/10 [00:00<00:00, 28.99it/s]


Epoch [5/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.71it/s]
100%|██████████| 10/10 [00:00<00:00, 24.72it/s]


Epoch [6/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:25<00:00,  6.18it/s]
100%|██████████| 10/10 [00:00<00:00, 26.33it/s]


Epoch [7/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.01it/s]
100%|██████████| 10/10 [00:00<00:00, 29.10it/s]


Epoch [8/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:24<00:00,  6.59it/s]
100%|██████████| 10/10 [00:00<00:00, 21.72it/s]


Epoch [9/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.71it/s]
100%|██████████| 10/10 [00:00<00:00, 28.51it/s]


Epoch [10/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.85it/s]
100%|██████████| 10/10 [00:00<00:00, 28.77it/s]


Epoch [11/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.88it/s]
100%|██████████| 10/10 [00:00<00:00, 26.89it/s]


Epoch [12/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.76it/s]
100%|██████████| 10/10 [00:00<00:00, 28.89it/s]


Epoch [13/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.79it/s]
100%|██████████| 10/10 [00:00<00:00, 28.87it/s]


Epoch [14/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.76it/s]
100%|██████████| 10/10 [00:00<00:00, 28.38it/s]


Epoch [15/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  6.99it/s]
100%|██████████| 10/10 [00:00<00:00, 28.88it/s]


Epoch [16/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.09it/s]
100%|██████████| 10/10 [00:00<00:00, 29.06it/s]


Epoch [17/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.10it/s]
100%|██████████| 10/10 [00:00<00:00, 28.73it/s]


Epoch [18/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.04it/s]
100%|██████████| 10/10 [00:00<00:00, 28.89it/s]


Epoch [19/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.03it/s]
100%|██████████| 10/10 [00:00<00:00, 28.92it/s]


Epoch [20/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.04it/s]
100%|██████████| 10/10 [00:00<00:00, 29.14it/s]


Epoch [21/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.04it/s]
100%|██████████| 10/10 [00:00<00:00, 26.85it/s]


Epoch [22/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.08it/s]
100%|██████████| 10/10 [00:00<00:00, 29.01it/s]


Epoch [23/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  6.92it/s]
100%|██████████| 10/10 [00:00<00:00, 28.83it/s]


Epoch [24/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.05it/s]
100%|██████████| 10/10 [00:00<00:00, 27.54it/s]


Epoch [25/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.06it/s]
100%|██████████| 10/10 [00:00<00:00, 27.89it/s]


Epoch [26/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  6.98it/s]
100%|██████████| 10/10 [00:00<00:00, 28.85it/s]


Epoch [27/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.85it/s]
100%|██████████| 10/10 [00:00<00:00, 23.40it/s]


Epoch [28/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.89it/s]
100%|██████████| 10/10 [00:00<00:00, 28.81it/s]


Epoch [29/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:24<00:00,  6.62it/s]
100%|██████████| 10/10 [00:00<00:00, 28.87it/s]


Epoch [30/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.10it/s]
100%|██████████| 10/10 [00:00<00:00, 28.59it/s]


Epoch [31/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.08it/s]
100%|██████████| 10/10 [00:00<00:00, 24.58it/s]


Epoch [32/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  6.92it/s]
100%|██████████| 10/10 [00:00<00:00, 28.95it/s]


Epoch [33/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  7.04it/s]
100%|██████████| 10/10 [00:00<00:00, 29.07it/s]


Epoch [34/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:24<00:00,  6.50it/s]
100%|██████████| 10/10 [00:00<00:00, 20.90it/s]


Epoch [35/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:29<00:00,  5.40it/s]
100%|██████████| 10/10 [00:00<00:00, 17.85it/s]


Epoch [36/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:29<00:00,  5.33it/s]
100%|██████████| 10/10 [00:00<00:00, 18.56it/s]


Epoch [37/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:29<00:00,  5.38it/s]
100%|██████████| 10/10 [00:00<00:00, 20.46it/s]


Epoch [38/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:29<00:00,  5.35it/s]
100%|██████████| 10/10 [00:00<00:00, 26.19it/s]


Epoch [39/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:27<00:00,  5.85it/s]
100%|██████████| 10/10 [00:00<00:00, 18.82it/s]


Epoch [40/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:30<00:00,  5.13it/s]
100%|██████████| 10/10 [00:00<00:00, 25.97it/s]


Epoch [41/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:29<00:00,  5.39it/s]
100%|██████████| 10/10 [00:00<00:00, 19.66it/s]


Epoch [42/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:30<00:00,  5.28it/s]
100%|██████████| 10/10 [00:00<00:00, 17.08it/s]


Epoch [43/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:26<00:00,  6.03it/s]
100%|██████████| 10/10 [00:00<00:00, 22.23it/s]


Epoch [44/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:28<00:00,  5.54it/s]
100%|██████████| 10/10 [00:00<00:00, 22.67it/s]


Epoch [45/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:29<00:00,  5.32it/s]
100%|██████████| 10/10 [00:00<00:00, 18.99it/s]


Epoch [46/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:23<00:00,  6.90it/s]
100%|██████████| 10/10 [00:00<00:00, 29.22it/s]


Epoch [47/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  6.99it/s]
100%|██████████| 10/10 [00:00<00:00, 29.08it/s]


Epoch [48/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245


100%|██████████| 159/159 [00:22<00:00,  6.99it/s]
100%|██████████| 10/10 [00:00<00:00, 28.39it/s]


Epoch [49/50] Training Loss: 3.7053207101 Valid Loss: 3.8379341245
