In [1]:
import random
import numpy as np
import time
import math
from tqdm.notebook import tqdm
import os
import sys
import psutil

import torch
import torch.nn as nn
import torch.utils.data as data_utils

# logging jupiterlab notebook 
import logging
nblog = open("nb.log", "a+")
sys.stdout.echo = nblog
sys.stderr.echo = nblog
get_ipython().log.handlers[0].stream = nblog
get_ipython().log.setLevel(logging.INFO)


# GPU infos
num_GPUs = torch.cuda.device_count()
for i in range(num_GPUs):
    info = torch.cuda.get_device_properties(i)
    mem_info = torch.cuda.mem_get_info(i)
    print(f"CUDA:{i} {info} [{mem_info[0]/ 1024 ** 2, mem_info[1]/ 1024 ** 2}]")

# variables
# operators = ['-', '+', '*', '/']
operators = ['+', '-', '*']
brackets = ['(', ')']
input_chars = [" "] + [str(d) for d in range(10)] + ["."] + operators + brackets
output_chars = [" "] + [str(d) for d in range(10)] + ["-"]
max_number = 999
max_digit = 5
max_bracket = 3
input_dim = 25

CUDA:0 _CudaDeviceProperties(name='NVIDIA GeForce GTX 1070', major=6, minor=1, total_memory=8105MB, multi_processor_count=15) [(8004.625, 8105.0625)]


In [2]:
# function for generating simple arithematic equations and answers
def generate_equation():
    # number of numbers will be in the equation
    num_digits = random.randint(2, max_digit)  # Choose a random number of digits for each operand

    # Generate a list of elements in equation
    equations = []
    for i in range(num_digits):
        equations.append(str(random.randint(0, max_number)))
        if i < num_digits - 1:
            equations.append(random.choice(operators))

    # Add brackets randomly
    num_brackets = random.randint(0, max_bracket)
    for _ in range(num_brackets):
        pos1 = random.randint(0, len(equations) - 1)
        while equations[pos1] in operators + brackets:
            pos1 += 1
        new_equations = equations[:pos1] + ['('] + equations[pos1:]
        
        pos2 = random.randint(pos1 + 2, len(new_equations))
        while 2 < pos2 < len(new_equations) and new_equations[pos2 - 1] in operators + brackets:
            pos2 += 1
        if pos2 == len(new_equations):
            continue
        new_equations = new_equations[:pos2] + [')'] + new_equations[pos2:]
        equations = new_equations

    # concatenate them into a single string
    final_equation = "".join(equations)
    return final_equation

# evaluate the equation and get the result
def evaluate_equation(equation):
    try:
        # result = f"{{:.6f}}".format(eval(equation)).zfill(input_dim)
        result = str(eval(equation))
        return result
    except ZeroDivisionError:
        return " " * input_dim

# function to generate a string equation with answer
def generate_eq():
    # Generate and evaluate a random equation
    equation = generate_equation()
    result = evaluate_equation(equation)
    return equation, result

In [3]:
# generate some examples
for _ in range(5):
    print(generate_eq())

('977-392-(789)-875', '-1079')
('821-249', '572')
('64-956-682*115+179', '-79143')
('986-750*188', '-140014')
('819-787', '32')


In [4]:
# run the generation function many times to get the maxium length of the equation and maximum range of the answer
num_trail = 100000
max_len_eq = 0
max_len_result = 0
max_result = float(-np.inf)
min_result = float(np.inf)
max_result_str = None
min_result_str = None
invalid_count = 0
for _ in range(num_trail):
    equation, result = generate_eq()
    if "!" not in result:
        fresult = float(result)
        if fresult > max_result:
            max_result = fresult
            max_result_str = result
        if fresult < min_result:
            min_result = fresult
            min_result_str = result
    else:
        invalid_count += 1
    len_eq, len_result = len(equation), len(result)
    max_len_eq = len_eq if len_eq > max_len_eq else max_len_eq
    max_len_result = len_result if len_result > max_len_result else max_len_result

print(f"Max input string length: {max_len_eq}")
print(f"Max result string length: {max_len_result}")
print(f"Min result string: {min_result_str}")
print(f"Max result string: {max_result_str}")
print(f"Number of invalid input string: {invalid_count}/{num_trail} = {invalid_count/num_trail}")

Max input string length: 25
Max result string length: 15
Min result string: -646317433909
Max result string: 420998436312660
Number of invalid input string: 0/100000 = 0.0


In [5]:
# the embedding dimensions and mappings
input_embed_dim = len(input_chars)
output_embed_dim = len(output_chars)
print("input_chars", input_chars, "input_embed_dim", input_embed_dim)
print("output_chars", output_chars, "output_embed_dim", output_embed_dim)
input_embed_map = {e: np.eye(input_embed_dim)[i] for i, e in enumerate(input_chars)}
output_embed_map = {e: np.eye(output_embed_dim)[i] for i, e in enumerate(output_chars)}
input_embed_inverse_map = {i: k for i, k in enumerate(input_embed_map.keys())}
output_embed_inverse_map = {i: k for i, k in enumerate(output_embed_map.keys())}

print("==========[input_embed_map]============")
for k in input_embed_map.keys():
    print(k, input_embed_map[k])

print("==========[output_embed_map]============")
for k in output_embed_map.keys():
    print(k, output_embed_map[k])

print("==========[input_embed_inverse_map]============")
for k in input_embed_inverse_map.keys():
    print(k, input_embed_inverse_map[k])
    
print("==========[output_embed_inverse_map]============")
for k in output_embed_inverse_map.keys():
    print(k, output_embed_inverse_map[k])

input_chars [' ', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '+', '-', '*', '(', ')'] input_embed_dim 17
output_chars [' ', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'] output_embed_dim 12
  [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
0 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
1 [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
2 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
3 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
4 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
5 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
6 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
7 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
8 [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
9 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
. [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
+ [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
- [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
* [0. 0. 0. 0. 0. 0

In [6]:
# function to generate neural netowrk (nn) data for the Transformer
def generate_nn_data(num_sample, outout_raw_data_pair=False):
    data_pair = [generate_eq() for _ in range(num_sample)]
    input_data = []
    output_data = []
    for i in range(num_sample):
        i_str = data_pair[i][0]
        o_str = data_pair[i][1]
        
        i_vec = np.zeros((input_dim, input_embed_dim), dtype=np.float32)
        for i, char in enumerate(i_str):
            i_vec[i] = input_embed_map[char]
        input_data.append(i_vec)

        j_vec = np.zeros((input_dim, output_embed_dim), dtype=np.float32)
        for i, char in enumerate(o_str):
            j_vec[i] = output_embed_map[char]
        output_data.append(j_vec)

    input_data = np.array(input_data, dtype=np.float32)
    output_data = np.array(output_data, dtype=np.float32)
    if outout_raw_data_pair:
        return input_data, output_data, data_pair
    else:
        return input_data, output_data

time_start = time.time()
input_data, output_data, data_pair = generate_nn_data(2, outout_raw_data_pair=True)
time_spent = time.time() - time_start
print(f"raw input data: {data_pair}")
print(f"raw input size: input={input_data.shape}, output={output_data.shape}")
print(f"time_spent: {time_spent}s")

raw input data: [('687+598', '1285'), ('907*599', '543293')]
raw input size: input=(2, 25, 17), output=(2, 25, 12)
time_spent: 0.0006783008575439453s


In [7]:
# function to decode the nn data back to string
def decode_nn_data(output_nn_data, is_input, apply_float=False):
    decoded_output = []
    c_batchsize =  output_nn_data.shape[0]
    c_slen =  output_nn_data.shape[1]
    if is_input:
        embed_inverse_map = input_embed_inverse_map
    else:
        embed_inverse_map = output_embed_inverse_map
    for b in range(c_batchsize):
        e_output = [embed_inverse_map[np.argmax(output_nn_data[b][s])] for s in range(c_slen)]
        joint_e_output = "".join(e_output)
        if apply_float:
            try:
                decoded_output.append(float(joint_e_output))
            except:
                decoded_output.append(0.0)
        else:
            decoded_output.append(joint_e_output)
    return decoded_output
print(f"input : {decode_nn_data(input_data, is_input=True, apply_float=False)}")
print(f"output: {decode_nn_data(output_data, is_input=False, apply_float=False)}")

input : ['687+598                  ', '907*599                  ']
output: ['1285                     ', '543293                   ']


In [8]:
# our custom pytorch dataset class
class CustomDataset(data_utils.Dataset):
    def __init__(self, num_sample, random_seed=0):
        random.seed(0)
        self.num_sample = num_sample
        self.refresh_data()

    def refresh_data(self):
        print("refreshing dataset...")
        self.input_data, self.output_data = generate_nn_data(self.num_sample)

    def __len__(self):
        return self.num_sample

    def __getitem__(self, idx):
        return self.input_data[idx], self.output_data[idx]

train_dataset = CustomDataset(6)
train_dataloader = data_utils.DataLoader(train_dataset, batch_size=2)
for i in train_dataloader:
    print(f"input: {decode_nn_data(i[0].numpy(), is_input=True)}, output: {decode_nn_data(i[1].numpy(), is_input=False)}")
    print(f"input size: {i[0].shape}, output size: {i[1].shape}")

refreshing dataset...
input: ['776-41-988*497-940       ', '288+(773)+633            '], output: ['-491241                  ', '1694                     ']
input size: torch.Size([2, 25, 17]), output size: torch.Size([2, 25, 12])
input: ['920*338                  ', '453*266+824*((937))+95   '], output: ['310960                   ', '892681                   ']
input size: torch.Size([2, 25, 17]), output size: torch.Size([2, 25, 12])
input: ['939*227+822              ', '82-896                   '], output: ['213975                   ', '-814                     ']
input size: torch.Size([2, 25, 17]), output size: torch.Size([2, 25, 12])


In [9]:
# classes and functions for our basic Transformer model
class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

class MyActF(nn.Module):
    def __init__(self, mode="silu"):
        super().__init__()
        if mode == "leaky_relu":
            self.act_f = nn.LeakyReLU(negative_slope=0.01)
        elif mode == "relu":
            self.act_f = nn.ReLU()
        elif mode == "gelu":
            self.act_f = nn.GELU()
        elif mode in ["silu", "swish"]:
            self.act_f = nn.SiLU()
        elif mode == "hardswish":
            self.act_f = nn.Hardswish()
        elif mode == "SwiGLU":
            self.act_f = SwiGLU()
        else:
            assert False

    def forward(self, x):
        return self.act_f(x)


class BatchRenorm(torch.jit.ScriptModule):
    def __init__(
        self,
        num_features: int,
        eps: float = 1e-3,
        momentum: float = 0.01,
        affine: bool = True,
    ):
        super().__init__()
        self.register_buffer("running_mean", torch.zeros(num_features, dtype=torch.float))
        self.register_buffer("running_std", torch.ones(num_features, dtype=torch.float))
        self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
        self.weight = torch.nn.Parameter(torch.ones(num_features, dtype=torch.float))
        self.bias = torch.nn.Parameter(torch.zeros(num_features, dtype=torch.float))
        self.affine = affine
        self.eps = eps
        self.step = 0
        self.momentum = momentum

    def _check_input_dim(self, x: torch.Tensor) -> None:
        raise NotImplementedError()  # pragma: no cover

    @property
    def rmax(self) -> torch.Tensor:
        return (2 / 35000 * self.num_batches_tracked + 25 / 35).clamp_(1.0, 3.0)

    @property
    def dmax(self) -> torch.Tensor:
        return (5 / 20000 * self.num_batches_tracked - 25 / 20).clamp_(0.0, 5.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        self._check_input_dim(x)
        if x.dim() > 2:
            x = x.transpose(1, -1)
        if self.training:
            dims = [i for i in range(x.dim() - 1)]
            batch_mean = x.mean(dims)
            batch_std = x.std(dims, unbiased=False) + self.eps
            r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_(1 / self.rmax, self.rmax)
            d = ((batch_mean.detach() - self.running_mean.view_as(batch_mean)) / self.running_std.view_as(batch_std)).clamp_(-self.dmax, self.dmax)
            x = (x - batch_mean) / batch_std * r + d
            self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)
            self.running_std += self.momentum * (batch_std.detach() - self.running_std)
            self.num_batches_tracked += 1
        else:
            x = (x - self.running_mean) / self.running_std
        if self.affine:
            x = self.weight * x + self.bias
        if x.dim() > 2:
            x = x.transpose(1, -1)
        return x


class BatchRenorm1d(BatchRenorm):
    def _check_input_dim(self, x: torch.Tensor) -> None:
        if x.dim() not in [2, 3]:
            raise ValueError("expected 2D or 3D input (got {x.dim()}D input)")


class BatchRenorm2d(BatchRenorm):
    def _check_input_dim(self, x: torch.Tensor) -> None:
        if x.dim() != 4:
            raise ValueError("expected 4D input (got {x.dim()}D input)")


class BatchRenorm3d(BatchRenorm):
    def _check_input_dim(self, x: torch.Tensor) -> None:
        if x.dim() != 5:
            raise ValueError("expected 5D input (got {x.dim()}D input)")

class MyNorm(nn.Module):
    def __init__(self, input_dim, mode="BatchNorm"):
        super().__init__()
        self.mode = mode
        if mode == "LayerNorm":
            self.norm = nn.LayerNorm(input_dim)
        elif mode == "BatchNorm":
            self.norm = nn.BatchNorm1d(input_dim)
        elif mode == "BatchRenorm":
            self.norm = BatchRenorm1d(input_dim)
        elif mode == "GroupNorm":
            self.num_groups = 8
            self.remainder_dim = None
            self.rounded_input_dim = None
            if self.num_groups > input_dim:
                self.num_groups = 1
                self.remainder_dim = 0
                self.rounded_input_dim = input_dim
            else:
                self.remainder_dim = input_dim % self.num_groups
                self.rounded_input_dim = input_dim - self.remainder_dim
            self.norm = nn.GroupNorm(self.num_groups, self.rounded_input_dim)
            if self.remainder_dim > 0:
                self.remainder_norm = nn.LayerNorm(self.remainder_dim)
        else:
            assert False

    def forward(self, x):
        if self.mode == "GroupNorm":
            if self.remainder_dim == 0:
                self.norm(x)
            else:
                split_x1, split_x2 = torch.split(x, [self.rounded_input_dim, self.remainder_dim])
                norm_x1 = self.norm(split_x1)
                norm_x2 = self.remainder_norm(split_x2)
                return torch.cat([norm_x1, norm_x2], -1)
        else:
            return self.norm(x)

class PositionalEncoding(nn.Module):

    def __init__(self, token_dim: int, dropout: float = 0.0, max_len: int = 5000, learnable: bool = False):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.learnable = learnable
        if self.learnable:
            self.pe = nn.Parameter(torch.normal(mean=0, std=0.001, size=(1, max_len, token_dim)))
        else:
            position = torch.arange(max_len).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, token_dim, 2) * (-math.log(10000.0) / token_dim))
            pe = torch.zeros(1, max_len, token_dim)
            pe[0, :, 0::2] = torch.sin(position * div_term)
            pe[0, :, 1::2] = torch.cos(position * div_term)
            self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class MyTransformerEncoder(nn.Module):
    def __init__(self,
                 input_token_dim=64,
                 output_token_dim=16,
                 num_token=20,
                 nhead=16,
                 dim_feedforward=256,
                 dropout=0.1,
                 activation='gelu',
                 nlayers=3,
                 positional_encoding=True,
                 proj_norm_mode="LayerNorm",
                 ):
        super(MyTransformerEncoder, self).__init__()
        self.input_token_dim = input_token_dim
        self.output_token_dim = output_token_dim
        self.num_token = num_token
        self.positional_encoding = positional_encoding
        if input_token_dim % 2 == 0:
            self.corrected_input_token_dim = input_token_dim
        else:
            self.corrected_input_token_dim = input_token_dim + 1
        self.transformer_pre_projection = nn.Sequential(
            nn.Linear(self.input_token_dim, self.corrected_input_token_dim, bias=False),
            MyNorm(self.corrected_input_token_dim, mode=proj_norm_mode),
            MyActF(activation),
        )
        if positional_encoding:
            self.pos_encoder = PositionalEncoding(token_dim=self.corrected_input_token_dim, dropout=0.0, max_len=num_token, learnable=False)
        else:
            self.pos_encoder = None
        encoder_layer = nn.TransformerEncoderLayer(d_model=self.corrected_input_token_dim,
                                                   nhead=nhead,
                                                   dim_feedforward=dim_feedforward,
                                                   dropout=dropout,
                                                   activation=activation,
                                                   batch_first=True,
                                                   norm_first=True,)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, nlayers)
        self.transformer_final_projection = nn.Sequential(
            nn.Linear(self.corrected_input_token_dim, self.output_token_dim, bias=True),
        )
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, num_token, input_token_dim]
        output:
            y: Tensor, shape [batch_size, num_token, output_token_dim]
        """
        bs, t, d = x.size(0), x.size(1), x.size(2)
        assert d == self.input_token_dim
        
        x_t = x.view(bs * t, d).contiguous()
        proj_x = self.transformer_pre_projection(x_t)  # [batch_size, num_token, input_token_dim]
        proj_x_t = proj_x.view(bs, t, self.corrected_input_token_dim).contiguous()
        
        # proj_x_t = x
        
        if self.positional_encoding:
            endcoded_x = self.pos_encoder(proj_x_t)
        else:
            endcoded_x = proj_x_t
        transformed_x = self.transformer_encoder(endcoded_x)
        transformed_x_t = transformed_x.view(bs * t, self.corrected_input_token_dim).contiguous()
        proj_transformed_x = self.transformer_final_projection(transformed_x_t).view(bs, t, self.output_token_dim).contiguous()
        return proj_transformed_x

In [10]:
# some helper functions

# a simple average meter class for monitoring averages
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count if self.count != 0 else 0
        
# function to get memory usage
def mem():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 ** 3.

In [11]:
# training variables
num_epoch = 1000
batch_size = 400
test_batch_size = batch_size * 4
num_of_sample_per_epoch = batch_size * 1000
num_of_test_sample = 10000
lr = 1e-5
lr_factor = 0.1
lr_patience = 3
lr_threshold=lr * 0.1
warmup_epoch = 5
warmup_factor = 5
wd = 0.0001
GPUs = list(range(num_GPUs))

# datasets
train_dataset = CustomDataset(num_of_sample_per_epoch)
train_dataloader = data_utils.DataLoader(train_dataset, batch_size=batch_size)
test_dataset = CustomDataset(num_of_test_sample, random_seed=1234567890)
test_dataloader = data_utils.DataLoader(test_dataset, batch_size=test_batch_size)

# model
model = MyTransformerEncoder(
    input_token_dim=input_embed_dim,
    output_token_dim=output_embed_dim,
    num_token=input_dim,
    nhead=6,
    dim_feedforward=1024,
    dropout=0.1,
    activation='gelu',
    nlayers=64,
    positional_encoding=True,
    proj_norm_mode="LayerNorm", # LayerNorm BatchNorm
)
model = torch.nn.DataParallel(model, device_ids=GPUs).cuda()

# optimizer
train_parameters = model.parameters()
optimizer = torch.optim.AdamW(train_parameters, lr=lr, weight_decay=wd)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=lr_factor,
    patience=lr_patience,
    threshold=lr_threshold,
    threshold_mode='rel')

# loss
loss_function = nn.CrossEntropyLoss(reduction="mean").cuda()
# loss_function = nn.MSELoss(reduction="mean")
# class CustomLoss(nn.Module):
#     def __init__(self):
#         super(CustomLoss, self).__init__()
#         self.mseloss = nn.MSELoss(reduction='mean')

#     def forward(self, output, target):
#         batch_size = output.size(0)
#         n_token = output.size(1)
#         heatmaps_pred = output.reshape((batch_size, n_token, -1)).split(1, 1)
#         heatmaps_gt = target.reshape((batch_size, n_token, -1)).split(1, 1)
#         loss = 0

#         for idx in range(n_token):
#             heatmap_pred = heatmaps_pred[idx].squeeze()
#             heatmap_gt = heatmaps_gt[idx].squeeze()
#             loss += 0.5 * self.mseloss(heatmap_pred, heatmap_gt)

#         return loss / n_token
# loss_function = CustomLoss().cuda()

# helper functions
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
    
def set_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

refreshing dataset...
refreshing dataset...




In [None]:
# monitoring variables
begin_epoch = 0
avg_loss = AverageMeter()
avg_loss_test = AverageMeter()
avg_diff_ratio = AverageMeter()
avg_diff_ratio_test = AverageMeter()
best_test_monitor = float(np.inf)
best_epoch = 0
patient_count = 0

# checkpoint paths
checkpoint_dir = "checkpoint"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pth")
checkpoint_best_path = os.path.join(checkpoint_dir, "checkpoint_best.pth")
if os.path.exists(checkpoint_path):
    print("=> loading checkpoint '{}'".format(checkpoint_path))
    checkpoint = torch.load(checkpoint_path)
    begin_epoch = checkpoint['epoch']
    best_epoch = checkpoint['best_epoch']
    best_test_monitor = checkpoint['best_test_monitor']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, begin_epoch))

current_lr = lr
# set_lr(optimizer, 1e-4)
skip_first_epoch = False
for epoch in range(begin_epoch, num_epoch):
    current_lr = get_lr(optimizer)
    print(f"Epoch {epoch}|lr:{current_lr}|Memory: {mem():.2f} GB...")
    
    if not skip_first_epoch or epoch > 0:
        # training
        model.train()
        for train_data in tqdm(train_dataloader):
            model_pred = model(train_data[0])
            train_loss = loss_function(model_pred.view(-1, output_embed_dim), train_data[1].view(-1, output_embed_dim).cuda(non_blocking=True))

            # loss and accuracy monitor
            avg_loss.update(train_loss.cpu().detach().numpy().item())
            gt_result = np.array(decode_nn_data(train_data[1].numpy(), is_input=False, apply_float=True), dtype=np.float32)
            pred_result = np.array(decode_nn_data(model_pred.cpu().detach().numpy(), is_input=False, apply_float=True), dtype=np.float32)
            diff_ratio = np.mean(np.absolute(gt_result - pred_result) / (np.absolute(gt_result) + 1e-3))
            avg_diff_ratio.update(diff_ratio)

            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

    # testing
    model.eval()
    with torch.no_grad():
        for test_data in tqdm(test_dataloader):
            model_pred = model(test_data[0])
            test_loss = loss_function(model_pred.view(-1, output_embed_dim), test_data[1].view(-1, output_embed_dim).to(model_pred.device))
            
            # loss and accuracy monitor
            avg_loss_test.update(test_loss.cpu().detach().numpy().item())
            gt_result = np.array(decode_nn_data(test_data[1].numpy(), is_input=False, apply_float=True), dtype=np.float32)
            pred_result = np.array(decode_nn_data(model_pred.cpu().detach().numpy(), is_input=False, apply_float=True), dtype=np.float32)
            diff_ratio = np.mean(np.absolute(gt_result - pred_result) / (np.absolute(gt_result) + 1e-3))
            avg_diff_ratio_test.update(diff_ratio)
    
            # display some example
            test_input_data_numpy = test_data[0].numpy()
            current_batch_size = test_input_data_numpy.shape[0]
            sample_idx = np.random.choice(current_batch_size, 2)
            sample_input = decode_nn_data(test_input_data_numpy[sample_idx], is_input=True)
            sample_raw_output = decode_nn_data(test_data[1].numpy()[sample_idx], is_input=False, apply_float=False)
            sample_output = decode_nn_data(test_data[1].numpy()[sample_idx], is_input=False, apply_float=True)
            sample_raw_pred = decode_nn_data(model_pred.cpu().detach().numpy()[sample_idx], is_input=False, apply_float=False)
            sample_pred = decode_nn_data(model_pred.cpu().detach().numpy()[sample_idx], is_input=False, apply_float=True)
            print("sample_input", sample_input)
            print("sample_raw_output", sample_raw_output, "sample_raw_pred", sample_raw_pred)
            print("sample_output", sample_output, "sample_pred", sample_pred)

    # update best_test_monitor
    # test_monitor = avg_diff_ratio_test.val
    test_monitor = avg_loss_test.val
    if test_monitor != float("inf") and test_monitor != float('nan') and test_monitor < best_test_monitor:
        best_test_monitor = test_monitor
        best_epoch = epoch + 1
        patient_count = 0
        
        # save best checkpoint
        torch.save({
            'epoch': epoch + 1,
            'best_epoch': best_epoch,
            'best_test_monitor': best_test_monitor,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, checkpoint_best_path)
    else:
        patient_count += 1

    print(f"[Epoch:{epoch}]|avg_loss:{avg_loss.val:.6f}|lr:{current_lr}|avg_diff_ratio:{avg_diff_ratio.val:.3f}|avg_loss_test:{avg_loss_test.val:.6f}|avg_diff_ratio_test:{avg_diff_ratio_test.val:.3f}|"
          f"best_test_monitor:{best_test_monitor:.3f}|best_epoch:{best_epoch}|patient_count:{patient_count}/{lr_patience}|Memory: {mem():.2f} GB")
    
    # update learning rate if needed
    if epoch < warmup_epoch:
        print(f"warming up! increase learning rate from {current_lr:.8f} to {current_lr * warmup_factor:.8f}")
        set_lr(optimizer, current_lr * warmup_factor)
        patient_count = 0
    else:
        if patient_count >= lr_patience and current_lr > lr_threshold:
            print(f"reach patient threshold! reducing learning rate from {current_lr:.8f} to {current_lr * lr_factor:.8f}")
            set_lr(optimizer, current_lr * lr_factor)
            patient_count = 0
    
    # save checkpoint
    torch.save({
        'epoch': epoch + 1,
        'best_epoch': best_epoch,
        'best_test_monitor': best_test_monitor,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, checkpoint_path)
    
    # reset variables
    avg_loss.reset()
    avg_loss_test.reset()
    avg_diff_ratio.reset()
    avg_diff_ratio_test.reset()
    train_dataloader.dataset.refresh_data()
print("Finished!")

Epoch 0|lr:1e-05|Memory: 2.75 GB...


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

sample_input ['196+709                  ', '(652*836)-691            ']
sample_raw_output ['905                      ', '544381                   '] sample_raw_pred ['1110000000001160001100000', '-12000000000110000-100000']
sample_output [905.0, 544381.0] sample_pred [1.11000000000116e+24, 0.0]
sample_input ['2-118                    ', '512*105+(659*423)*274    ']
sample_raw_output ['-116                     ', '76433178                 '] sample_raw_pred ['-12000000000-10000-100000', '1100000000000100001100000']
sample_output [-116.0, 76433178.0] sample_pred [0.0, 1.1000000000001e+24]
sample_input ['649+343                  ', '951*167-538              ']
sample_raw_output ['992                      ', '158279                   '] sample_raw_pred ['1110000000001160001100000', '112000100000110000-100000']
sample_output [992.0, 158279.0] sample_pred [1.11000000000116e+24, 0.0]
sample_input ['312+(950-609)+78         ', '17*43+660+517            ']
sample_raw_output ['731               

  0%|          | 0/1000 [00:00<?, ?it/s]