In [1]:
import math_dataset 
from math_dataset import MathDatasetManager
import torch
import torch.optim as optim
from torch.utils import data
from math_dataset import (
    question_answer_to_position_batch_collate_fn
)
import model_process


import utils

%matplotlib notebook  

print("Torch Version", torch.__version__)

%load_ext autoreload
%autoreload 2

Torch Version 1.5.0+cu101


## Math Dataset Manager

This class is just a Numpy/Pytorch helper to manage all files in Math Dataset v1.0 and select different parts of it by categories or modules to generate a Pytorch dataset for training. Pytorch Datasets created doesn't mount all questions/answers in memory and use Pandas limited streaming features to bufferize data. It allows loading huge files quite fast while keeping memory print reasonable. It also caches lazy datasets and allows fast re-using previously created ones.

Here are the main features provided right now.

### Initialize Math Dataset Manager

In [2]:
mdsmgr = MathDatasetManager(
  "C:\\Users\\Jesús\\Documents\\PC2\\TorchDemo\\hs-math-nlp\\mathematics_dataset-v1.0\\mathematics_dataset-v1.0\\"
)

initialized MultiFilesMathDataset with categories ['algebra', 'arithmetic', 'calculus', 'comparison', 'measurement', 'numbers', 'polynomials', 'probability'] and types ['train-easy', 'train-medium', 'train-hard', 'interpolate', 'extrapolate']


### Check availables types (difficulties + interpolate + extrapolate)

In [3]:
print("types", list(mdsmgr.get_types()))

types ['train-easy', 'train-medium', 'train-hard', 'interpolate', 'extrapolate']


### Check availables problem categories

In [4]:
print("categories", list(mdsmgr.get_categories()))

categories ['algebra', 'arithmetic', 'calculus', 'comparison', 'measurement', 'numbers', 'polynomials', 'probability']


## Pytorch Initialization

In [5]:
seed = 1
torch.manual_seed(seed)
device = torch.device("cuda")
print("device", device)

device cuda


## Train on Algebra Linear_1d in Easy mode

### Create an experiment with a name and a unique ID

In [6]:
exp_name = "add_or_sub" # "math_ds_algebra_linear_1d_easy"
unique_id = "2021-07-25" # "2019-05-25_0900"

### Build Dataset for training

#### Train-easy dataset

In [7]:
ds = mdsmgr.build_dataset_from_module(
    'arithmetic', 'add_or_sub', 'train-hard'
)
print("train-easy dataset size", len(ds))

train-easy dataset size 666666


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  errors=errors,


#### Interpolate dataset

In [8]:
ds_interpolate = mdsmgr.build_dataset_from_module(
    'arithmetic', 'add_or_sub', 'interpolate'
)
print("interpolate dataset size", len(ds_interpolate))

interpolate dataset size 10000


### Create default Transformer model

Here we test the best model found in the paper: a multi-head self-attention transformer to give a default sample.


In [9]:
model = utils.build_transformer()

### Create basic optimizer

In [10]:
optimizer = optim.Adam(model.parameters(), lr=6e-6, betas=(0.9, 0.995), eps=1e-9)

### Create Pytorch dataloaders

In [11]:
# here we split data in 90/10% for train/validation and use interpolate for test
train_ds, val_ds = math_dataset.random_split_dataset(ds, split_rate=0.9)

# we provide the function question_answer_to_position_batch_collate_fn that collates
# all questions/answers into transformer format enhanced with char positioning
train_loader = data.DataLoader(
    train_ds, batch_size=128, shuffle=True, num_workers=12,
    collate_fn=question_answer_to_position_batch_collate_fn)

val_loader = data.DataLoader(
    val_ds, batch_size=128, shuffle=False, num_workers=12,
    collate_fn=question_answer_to_position_batch_collate_fn)

interpolate_loader = data.DataLoader(
    ds_interpolate, batch_size=128, shuffle=False, num_workers=12,
    collate_fn=question_answer_to_position_batch_collate_fn)


In [12]:
#import checkpoints


# build default transformer model
#model = utils.build_transformer()

#model_exp_name = "linear_algebra" # "math_ds_algebra_linear_1d_easy"
#model_unique_id  = "2020-07-22" # "2019-05-25_0900"
#model_exp_name = 'math_ds_algebra_linear_1d_easy'
#model_unique_id = '2019-10-27_2300'
# restore best validation model from checkpoint
#_ = checkpoints.restore_checkpoint(".\\checkpoints\\checkpoint_b37504_e7.pth","", model=model)


In [13]:
model = model.to(device)


In [None]:
model_process.train(
    name = exp_name +"-" + unique_id,
    model = model,
    training_data= train_loader,
    validation_data = val_loader,
    interpolate_data=interpolate_loader,
    optimizer = optimizer,
    device = device,
    epochs=8,
    tb=None,
    log_interval=100)

~~~ Beginning Training ~~~~
Start epoch: 0, Start batch: 0, Max batch: None
[ Epoch: 0 / 8, Run Batch: 0 / None]
Batch: 0. Acc: 0.011730. Loss: 5.227549. Batch_acc: 0.011730. Batch_loss: 5.227549 
Batch: 1. Acc: 0.019573. Loss: 5.078589. Batch_acc: 0.027357. Batch_loss: 4.930756 
Batch: 2. Acc: 0.036892. Loss: 4.888388. Batch_acc: 0.071765. Batch_loss: 4.505412 
Batch: 3. Acc: 0.048684. Loss: 4.726118. Batch_acc: 0.083867. Batch_loss: 4.241954 
Batch: 4. Acc: 0.055477. Loss: 4.584731. Batch_acc: 0.083037. Batch_loss: 4.011135 
Batch: 5. Acc: 0.061344. Loss: 4.453219. Batch_acc: 0.090855. Batch_loss: 3.791699 
Batch: 6. Acc: 0.064137. Loss: 4.332433. Batch_acc: 0.080271. Batch_loss: 3.634551 
Batch: 7. Acc: 0.066735. Loss: 4.224350. Batch_acc: 0.084677. Batch_loss: 3.477855 
Batch: 8. Acc: 0.069298. Loss: 4.123463. Batch_acc: 0.089647. Batch_loss: 3.322553 
Batch: 9. Acc: 0.072976. Loss: 4.033534. Batch_acc: 0.106930. Batch_loss: 3.203272 
Batch: 10. Acc: 0.075828. Loss: 3.952317. Batch

Batch: 96. Acc: 0.130339. Loss: 2.711562. Batch_acc: 0.152632. Batch_loss: 2.449801 
Batch: 97. Acc: 0.130475. Loss: 2.708751. Batch_acc: 0.143503. Batch_loss: 2.441198 
Batch: 98. Acc: 0.130731. Loss: 2.706103. Batch_acc: 0.155280. Batch_loss: 2.451573 
Batch: 99. Acc: 0.130877. Loss: 2.703791. Batch_acc: 0.145821. Batch_loss: 2.468104 
Batch: 100. Acc: 0.131144. Loss: 2.701082. Batch_acc: 0.156983. Batch_loss: 2.438200 
Batch: 101. Acc: 0.131427. Loss: 2.698730. Batch_acc: 0.160279. Batch_loss: 2.459063 
Batch: 102. Acc: 0.131773. Loss: 2.696201. Batch_acc: 0.167245. Batch_loss: 2.436856 
Batch: 103. Acc: 0.132127. Loss: 2.693534. Batch_acc: 0.168682. Batch_loss: 2.418709 
Batch: 104. Acc: 0.132377. Loss: 2.690947. Batch_acc: 0.158537. Batch_loss: 2.419548 
Batch: 105. Acc: 0.132619. Loss: 2.688648. Batch_acc: 0.158107. Batch_loss: 2.446695 
Batch: 106. Acc: 0.132711. Loss: 2.686417. Batch_acc: 0.142525. Batch_loss: 2.447489 
Batch: 107. Acc: 0.132905. Loss: 2.684240. Batch_acc: 0.15

Batch: 192. Acc: 0.165023. Loss: 2.518801. Batch_acc: 0.233100. Batch_loss: 2.240252 
Batch: 193. Acc: 0.165337. Loss: 2.517332. Batch_acc: 0.224341. Batch_loss: 2.241188 
Batch: 194. Acc: 0.165625. Loss: 2.515889. Batch_acc: 0.220957. Batch_loss: 2.238852 
Batch: 195. Acc: 0.165971. Loss: 2.514502. Batch_acc: 0.234880. Batch_loss: 2.238497 
Batch: 196. Acc: 0.166341. Loss: 2.513063. Batch_acc: 0.239420. Batch_loss: 2.229047 
Batch: 197. Acc: 0.166680. Loss: 2.511684. Batch_acc: 0.233661. Batch_loss: 2.238823 
Batch: 198. Acc: 0.166969. Loss: 2.510322. Batch_acc: 0.224267. Batch_loss: 2.240872 
Batch: 199. Acc: 0.167245. Loss: 2.509036. Batch_acc: 0.223335. Batch_loss: 2.246944 
Batch: 200. Acc: 0.167565. Loss: 2.507583. Batch_acc: 0.231926. Batch_loss: 2.215740 
Batch: 201. Acc: 0.167877. Loss: 2.506270. Batch_acc: 0.230903. Batch_loss: 2.240991 
Batch: 202. Acc: 0.168119. Loss: 2.505044. Batch_acc: 0.217926. Batch_loss: 2.253050 
Batch: 203. Acc: 0.168430. Loss: 2.503738. Batch_acc: 

Batch: 280. Acc: 0.186494. Loss: 2.424831. Batch_acc: 0.243151. Batch_loss: 2.192437 
Batch: 281. Acc: 0.186665. Loss: 2.424046. Batch_acc: 0.234152. Batch_loss: 2.205135 
Batch: 282. Acc: 0.186832. Loss: 2.423288. Batch_acc: 0.233352. Batch_loss: 2.212170 
Batch: 283. Acc: 0.187016. Loss: 2.422621. Batch_acc: 0.239080. Batch_loss: 2.234367 
Batch: 284. Acc: 0.187215. Loss: 2.421842. Batch_acc: 0.243028. Batch_loss: 2.202984 
Batch: 285. Acc: 0.187401. Loss: 2.421088. Batch_acc: 0.241481. Batch_loss: 2.202091 
Batch: 286. Acc: 0.187581. Loss: 2.420349. Batch_acc: 0.239650. Batch_loss: 2.206257 
Batch: 287. Acc: 0.187758. Loss: 2.419659. Batch_acc: 0.238754. Batch_loss: 2.221364 
Batch: 288. Acc: 0.187961. Loss: 2.418887. Batch_acc: 0.245594. Batch_loss: 2.199306 
Batch: 289. Acc: 0.188189. Loss: 2.418107. Batch_acc: 0.253273. Batch_loss: 2.195477 
Batch: 290. Acc: 0.188399. Loss: 2.417289. Batch_acc: 0.247486. Batch_loss: 2.187320 
Batch: 291. Acc: 0.188566. Loss: 2.416506. Batch_acc: 

Batch: 376. Acc: 0.200684. Loss: 2.364721. Batch_acc: 0.247405. Batch_loss: 2.176139 
Batch: 377. Acc: 0.200823. Loss: 2.364215. Batch_acc: 0.253303. Batch_loss: 2.174044 
Batch: 378. Acc: 0.200938. Loss: 2.363721. Batch_acc: 0.243721. Batch_loss: 2.178491 
Batch: 379. Acc: 0.201015. Loss: 2.363249. Batch_acc: 0.230199. Batch_loss: 2.186127 
Batch: 380. Acc: 0.201114. Loss: 2.362784. Batch_acc: 0.238979. Batch_loss: 2.184648 
Batch: 381. Acc: 0.201219. Loss: 2.362330. Batch_acc: 0.240847. Batch_loss: 2.190408 
Batch: 382. Acc: 0.201287. Loss: 2.361905. Batch_acc: 0.227804. Batch_loss: 2.196945 
Batch: 383. Acc: 0.201424. Loss: 2.361430. Batch_acc: 0.253586. Batch_loss: 2.180168 
Batch: 384. Acc: 0.201530. Loss: 2.360897. Batch_acc: 0.242147. Batch_loss: 2.157825 
Batch: 385. Acc: 0.201651. Loss: 2.360455. Batch_acc: 0.247991. Batch_loss: 2.190718 
Batch: 386. Acc: 0.201767. Loss: 2.359987. Batch_acc: 0.246136. Batch_loss: 2.180077 
Batch: 387. Acc: 0.201871. Loss: 2.359529. Batch_acc: 

Batch: 472. Acc: 0.210488. Loss: 2.324299. Batch_acc: 0.250000. Batch_loss: 2.161367 
Batch: 473. Acc: 0.210572. Loss: 2.324011. Batch_acc: 0.252263. Batch_loss: 2.181156 
Batch: 474. Acc: 0.210664. Loss: 2.323607. Batch_acc: 0.254169. Batch_loss: 2.132104 
Batch: 475. Acc: 0.210753. Loss: 2.323249. Batch_acc: 0.251973. Batch_loss: 2.156819 
Batch: 476. Acc: 0.210853. Loss: 2.322837. Batch_acc: 0.257919. Batch_loss: 2.130198 
Batch: 477. Acc: 0.210953. Loss: 2.322461. Batch_acc: 0.257336. Batch_loss: 2.146707 
Batch: 478. Acc: 0.211045. Loss: 2.322093. Batch_acc: 0.255014. Batch_loss: 2.146756 
Batch: 479. Acc: 0.211136. Loss: 2.321742. Batch_acc: 0.254902. Batch_loss: 2.153474 
Batch: 480. Acc: 0.211245. Loss: 2.321342. Batch_acc: 0.264294. Batch_loss: 2.126476 
Batch: 481. Acc: 0.211315. Loss: 2.321013. Batch_acc: 0.244571. Batch_loss: 2.163937 
Batch: 482. Acc: 0.211407. Loss: 2.320647. Batch_acc: 0.255575. Batch_loss: 2.145590 
Batch: 483. Acc: 0.211486. Loss: 2.320326. Batch_acc: 

Batch: 561. Acc: 0.218310. Loss: 2.294466. Batch_acc: 0.254296. Batch_loss: 2.126526 
Batch: 562. Acc: 0.218402. Loss: 2.294128. Batch_acc: 0.270930. Batch_loss: 2.101886 
Batch: 563. Acc: 0.218471. Loss: 2.293828. Batch_acc: 0.258560. Batch_loss: 2.120948 
Batch: 564. Acc: 0.218547. Loss: 2.293519. Batch_acc: 0.261547. Batch_loss: 2.118165 
Batch: 565. Acc: 0.218633. Loss: 2.293194. Batch_acc: 0.266138. Batch_loss: 2.112931 
Batch: 566. Acc: 0.218732. Loss: 2.292864. Batch_acc: 0.276084. Batch_loss: 2.102225 
Batch: 567. Acc: 0.218814. Loss: 2.292552. Batch_acc: 0.266588. Batch_loss: 2.110359 
Batch: 568. Acc: 0.218919. Loss: 2.292235. Batch_acc: 0.278286. Batch_loss: 2.113920 
Batch: 569. Acc: 0.218995. Loss: 2.291904. Batch_acc: 0.262881. Batch_loss: 2.100243 
Batch: 570. Acc: 0.219071. Loss: 2.291575. Batch_acc: 0.260966. Batch_loss: 2.110368 
Batch: 571. Acc: 0.219146. Loss: 2.291253. Batch_acc: 0.262029. Batch_loss: 2.106079 
Batch: 572. Acc: 0.219239. Loss: 2.290921. Batch_acc: 

Batch: 657. Acc: 0.225811. Loss: 2.267155. Batch_acc: 0.252018. Batch_loss: 2.129925 
Batch: 658. Acc: 0.225866. Loss: 2.266922. Batch_acc: 0.262727. Batch_loss: 2.110526 
Batch: 659. Acc: 0.225944. Loss: 2.266660. Batch_acc: 0.276427. Batch_loss: 2.097258 
Batch: 660. Acc: 0.226010. Loss: 2.266436. Batch_acc: 0.269099. Batch_loss: 2.119660 
Batch: 661. Acc: 0.226086. Loss: 2.266170. Batch_acc: 0.276278. Batch_loss: 2.090709 
Batch: 662. Acc: 0.226144. Loss: 2.265948. Batch_acc: 0.265838. Batch_loss: 2.114531 
Batch: 663. Acc: 0.226188. Loss: 2.265738. Batch_acc: 0.255828. Batch_loss: 2.124950 
Batch: 664. Acc: 0.226259. Loss: 2.265504. Batch_acc: 0.273623. Batch_loss: 2.108867 
Batch: 665. Acc: 0.226338. Loss: 2.265247. Batch_acc: 0.278868. Batch_loss: 2.094066 
Batch: 666. Acc: 0.226425. Loss: 2.264975. Batch_acc: 0.282755. Batch_loss: 2.088660 
Batch: 667. Acc: 0.226476. Loss: 2.264773. Batch_acc: 0.260672. Batch_loss: 2.131520 
Batch: 668. Acc: 0.226535. Loss: 2.264579. Batch_acc: 

Batch: 753. Acc: 0.231863. Loss: 2.245381. Batch_acc: 0.252971. Batch_loss: 2.125078 
Checkpointing on batch: 753. Accuracy: 0.2318634999168278. Loss per char: 2.2453808190020252. Time: 1627203181.4777737
Last question is tensor([ 2, 56, 73, 66, 85,  1, 74, 84,  1, 19, 25, 23, 20, 19, 15, 21, 24,  1,
        78, 74, 79, 86, 84,  1, 23, 19, 17, 23, 21, 18, 21, 22, 32,  3,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0], device='cuda:0')
Removing existing model file at checkpoints\add_or_sub-2021-07-25_latest_checkpoint.pth
Starting checkpoint save of checkpoints\add_or_sub-2021-07-25_latest_checkpoint.pth...
Final saved model size: 530790651
Batch: 754. Acc: 0.231917. Loss: 2.245205. Batch_acc: 0.272464. Batch_loss: 2.111274 
Batch: 755. Acc: 0.231973. Loss: 2.245018. Batch_acc: 0.274453. Batch_loss: 2.104357 
Batch: 756. Acc: 0.232041. Loss: 2.244793. Batch_acc: 0.282572. Batch_loss: 2.078081 
Batch: 757. Acc: 0.232077

Batch: 841. Acc: 0.236569. Loss: 2.229033. Batch_acc: 0.284642. Batch_loss: 2.083010 
Batch: 842. Acc: 0.236603. Loss: 2.228891. Batch_acc: 0.265130. Batch_loss: 2.108960 
Batch: 843. Acc: 0.236641. Loss: 2.228745. Batch_acc: 0.268986. Batch_loss: 2.104736 
Batch: 844. Acc: 0.236678. Loss: 2.228580. Batch_acc: 0.267468. Batch_loss: 2.090385 
Batch: 845. Acc: 0.236719. Loss: 2.228419. Batch_acc: 0.271828. Batch_loss: 2.090645 
Batch: 846. Acc: 0.236778. Loss: 2.228215. Batch_acc: 0.287623. Batch_loss: 2.053750 
Batch: 847. Acc: 0.236817. Loss: 2.228077. Batch_acc: 0.268823. Batch_loss: 2.114770 
Batch: 848. Acc: 0.236868. Loss: 2.227911. Batch_acc: 0.278819. Batch_loss: 2.089081 
Batch: 849. Acc: 0.236912. Loss: 2.227750. Batch_acc: 0.275540. Batch_loss: 2.088746 
Batch: 850. Acc: 0.236958. Loss: 2.227579. Batch_acc: 0.275723. Batch_loss: 2.081585 
Batch: 851. Acc: 0.236992. Loss: 2.227423. Batch_acc: 0.266944. Batch_loss: 2.090703 
Batch: 852. Acc: 0.237040. Loss: 2.227263. Batch_acc: 

Batch: 937. Acc: 0.240685. Loss: 2.214166. Batch_acc: 0.288059. Batch_loss: 2.067142 
Batch: 938. Acc: 0.240715. Loss: 2.214031. Batch_acc: 0.268732. Batch_loss: 2.089836 
Batch: 939. Acc: 0.240747. Loss: 2.213878. Batch_acc: 0.270715. Batch_loss: 2.072808 
Batch: 940. Acc: 0.240795. Loss: 2.213721. Batch_acc: 0.285880. Batch_loss: 2.064743 
Batch: 941. Acc: 0.240843. Loss: 2.213575. Batch_acc: 0.286290. Batch_loss: 2.075513 
Batch: 942. Acc: 0.240865. Loss: 2.213453. Batch_acc: 0.261822. Batch_loss: 2.098753 
Batch: 943. Acc: 0.240905. Loss: 2.213299. Batch_acc: 0.277716. Batch_loss: 2.071386 
Batch: 944. Acc: 0.240932. Loss: 2.213173. Batch_acc: 0.266007. Batch_loss: 2.092701 
Batch: 945. Acc: 0.240970. Loss: 2.213036. Batch_acc: 0.276826. Batch_loss: 2.084394 
Batch: 946. Acc: 0.241025. Loss: 2.212846. Batch_acc: 0.292325. Batch_loss: 2.036957 
Batch: 947. Acc: 0.241080. Loss: 2.212672. Batch_acc: 0.293811. Batch_loss: 2.047090 
Batch: 948. Acc: 0.241133. Loss: 2.212516. Batch_acc: 

Batch: 1025. Acc: 0.244046. Loss: 2.202260. Batch_acc: 0.284746. Batch_loss: 2.050782 
Batch: 1026. Acc: 0.244091. Loss: 2.202127. Batch_acc: 0.290780. Batch_loss: 2.061925 
Batch: 1027. Acc: 0.244126. Loss: 2.201990. Batch_acc: 0.280551. Batch_loss: 2.061507 
Batch: 1028. Acc: 0.244178. Loss: 2.201853. Batch_acc: 0.297564. Batch_loss: 2.059909 
Batch: 1029. Acc: 0.244192. Loss: 2.201769. Batch_acc: 0.259281. Batch_loss: 2.113453 
Batch: 1030. Acc: 0.244227. Loss: 2.201637. Batch_acc: 0.280692. Batch_loss: 2.065763 
Batch: 1031. Acc: 0.244257. Loss: 2.201533. Batch_acc: 0.275132. Batch_loss: 2.091914 
Batch: 1032. Acc: 0.244298. Loss: 2.201415. Batch_acc: 0.288427. Batch_loss: 2.075570 
Batch: 1033. Acc: 0.244352. Loss: 2.201295. Batch_acc: 0.300756. Batch_loss: 2.075631 
Batch: 1034. Acc: 0.244389. Loss: 2.201147. Batch_acc: 0.282286. Batch_loss: 2.049043 
Batch: 1035. Acc: 0.244433. Loss: 2.201017. Batch_acc: 0.292088. Batch_loss: 2.062116 
Batch: 1036. Acc: 0.244467. Loss: 2.200913.

Batch: 1120. Acc: 0.247177. Loss: 2.191231. Batch_acc: 0.283401. Batch_loss: 2.076801 
Batch: 1121. Acc: 0.247200. Loss: 2.191117. Batch_acc: 0.272675. Batch_loss: 2.065115 
Batch: 1122. Acc: 0.247239. Loss: 2.190985. Batch_acc: 0.289977. Batch_loss: 2.045282 
Batch: 1123. Acc: 0.247277. Loss: 2.190868. Batch_acc: 0.288988. Batch_loss: 2.063715 
Batch: 1124. Acc: 0.247305. Loss: 2.190754. Batch_acc: 0.279043. Batch_loss: 2.063913 
Batch: 1125. Acc: 0.247341. Loss: 2.190661. Batch_acc: 0.288416. Batch_loss: 2.083296 
Batch: 1126. Acc: 0.247372. Loss: 2.190561. Batch_acc: 0.281787. Batch_loss: 2.078097 
Batch: 1127. Acc: 0.247414. Loss: 2.190431. Batch_acc: 0.294658. Batch_loss: 2.044258 
Batch: 1128. Acc: 0.247446. Loss: 2.190323. Batch_acc: 0.283607. Batch_loss: 2.070449 
Batch: 1129. Acc: 0.247475. Loss: 2.190253. Batch_acc: 0.280069. Batch_loss: 2.110995 
Batch: 1130. Acc: 0.247518. Loss: 2.190085. Batch_acc: 0.295313. Batch_loss: 2.004538 
Batch: 1131. Acc: 0.247541. Loss: 2.189998.

Batch: 1215. Acc: 0.249886. Loss: 2.181677. Batch_acc: 0.272033. Batch_loss: 2.065346 
Batch: 1216. Acc: 0.249925. Loss: 2.181564. Batch_acc: 0.296045. Batch_loss: 2.046174 
Batch: 1217. Acc: 0.249956. Loss: 2.181464. Batch_acc: 0.286678. Batch_loss: 2.062533 
Batch: 1218. Acc: 0.249989. Loss: 2.181354. Batch_acc: 0.290379. Batch_loss: 2.045523 
Batch: 1219. Acc: 0.250010. Loss: 2.181256. Batch_acc: 0.276081. Batch_loss: 2.061839 
Batch: 1220. Acc: 0.250047. Loss: 2.181132. Batch_acc: 0.294785. Batch_loss: 2.032133 
Batch: 1221. Acc: 0.250070. Loss: 2.181055. Batch_acc: 0.277939. Batch_loss: 2.087189 
Batch: 1222. Acc: 0.250095. Loss: 2.180981. Batch_acc: 0.280783. Batch_loss: 2.090099 
Batch: 1223. Acc: 0.250125. Loss: 2.180918. Batch_acc: 0.286942. Batch_loss: 2.103530 
Batch: 1224. Acc: 0.250161. Loss: 2.180822. Batch_acc: 0.294529. Batch_loss: 2.062601 
Batch: 1225. Acc: 0.250198. Loss: 2.180723. Batch_acc: 0.295195. Batch_loss: 2.060216 
Batch: 1226. Acc: 0.250220. Loss: 2.180647.

Batch: 1303. Acc: 0.252205. Loss: 2.173587. Batch_acc: 0.281287. Batch_loss: 2.046254 
Batch: 1304. Acc: 0.252238. Loss: 2.173498. Batch_acc: 0.295612. Batch_loss: 2.057472 
Batch: 1305. Acc: 0.252267. Loss: 2.173399. Batch_acc: 0.288926. Batch_loss: 2.046735 
Batch: 1306. Acc: 0.252303. Loss: 2.173296. Batch_acc: 0.299652. Batch_loss: 2.037601 
Batch: 1307. Acc: 0.252326. Loss: 2.173203. Batch_acc: 0.283063. Batch_loss: 2.050318 


### Plotting Training from Tensorboard data

#### Restore best model for this experience

In [None]:

# build default transformer model
model = utils.build_transformer()

#model_exp_name = "linear_algebra" # "math_ds_algebra_linear_1d_easy"
#model_unique_id  = "2020-07-22" # "2019-05-25_0900"
#model_exp_name = 'math_ds_algebra_linear_1d_easy'
#model_unique_id = '2019-10-27_2300'
# restore best validation model from checkpoint
_ = checkpoints.restore_checkpoint(".\\checkpoints\\checkpoint_b37504_e7.pth","", model=model)


#### Loading tensorboard events

> As we can see, loss per char on validation dataset has a nice optimization curve but for interpolate, it's not the case. It's quite normal, interpolate contains more difficult and general cases. 

#### Accuracy Evolution during training

In [None]:
plt.rcParams['figure.figsize'] = [10, 6]

fig, ax = plt.subplots()

ax.plot(
    list(map(lambda l: l.step, valid_accuracy)),
    list(map(lambda l: l.value, valid_accuracy)),
    marker='+', label='Validation Accuracy'
)
ax.plot(
    list(map(lambda l: l.step, interpolate_accuracy)),
    list(map(lambda l: l.value, interpolate_accuracy)),
    marker='+', label='Interpolate Accuracy'
)
plt.title('Algebra/Linear_1d Accuracy')
ax.legend(loc='upper left', frameon=False)
plt.xticks(np.arange(0, 20, step=1.0))
plt.yticks(np.arange(0.3, 1.0, step=0.1))
plt.show()


> Accuracy for validation dataset is growing constantly up to 85% while for interpolate dataset, it doesn't change much. Interpolate dataset contains too complicated and generic problems compared to training set.

### Test Model

In [None]:
model_process.predict_single("Solve 5*w + 3 = -2 for w.", model, device, n_best=1)


In [None]:
model_process.predict_single("Solve 212 = 56*z - 12 for z.", model, device, n_best=1)


In [None]:
model_process.predict_single("Solve 2514*m = 2508*m - 24 for m.", model, device, n_best=1)


