### CS182 project - Deliver ideas of CGCNN

In [1]:
### General Explanation

### Embedding Crystal Graph

In [2]:
# Note: Need pymatgen
# TODO:Explain about Crystal Graph

In [10]:
import os
import sys
import csv
import json
import torch
import torch.nn as nn
import random
import warnings
warnings.filterwarnings('ignore')
import functools
import numpy as np

from pymatgen.core.structure import Structure
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data.sampler import SubsetRandomSampler
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR

from data_utils import CIFData
from data_utils import AtomCustomJSONInitializer
from data_utils import AtomInitializer
from data_utils import GaussianDistance

In [2]:
# Let's convert salt (NaCl) to crystal graph.
# You can use print method to see the lattice and position of Na, Cl
# Atoms in the cell.
nacl = Structure.from_file('hw_data/cifs/1000041.cif')
print(nacl)

Full Formula (Na4 Cl4)
Reduced Formula: NaCl
abc   :   5.620000   5.620000   5.620000
angles:  90.000000  90.000000  90.000000
pbc   :       True       True       True
Sites (8)
  #  SP      a    b    c
---  ----  ---  ---  ---
  0  Na+   0    0    0
  1  Na+   0    0.5  0.5
  2  Na+   0.5  0    0.5
  3  Na+   0.5  0.5  0
  4  Cl-   0.5  0.5  0.5
  5  Cl-   0.5  0    0
  6  Cl-   0    0.5  0
  7  Cl-   0    0    0.5


In [3]:
# add visualization of the structure using the CIF file 

In [5]:
# First we will convert atoms to atomistic features using
# pre-difined atom embedding.
# atom_init.json is containing vector embedding of atoms,
# Where key (1, 2, 3, ..., 100) represent atomic number
# and values are embedding vectors.
# You can try different atom embedding too.

# Load embedding file.
element_embedding_file = 'hw_data/atom_init.json'
with open(element_embedding_file) as f:
    elem_embedding = json.load(f)
elem_embedding = {int(key): value for key, value
                  in elem_embedding.items()}

# Encode crsytal data to atomistic features.
# Atom_feature shd have shape of (# of atoms, len(embedding vector))
atom_fea = np.vstack([elem_embedding[nacl[i].specie.number]
                      for i in range(len(nacl))])
atom_fea = torch.Tensor(atom_fea)

assert atom_fea.shape == (8,92)

In [6]:
# Next, we will get neighbor information from each atoms in the cell.
# We will get help from pymatgen package.
# get_all_neighbor function of structure object returns atoms within
# the input radius. Note that here len(all_nbrs) is 8 since there are
# 8 atoms (4 Na+, 4 Cl-) in the cell. Each list contain the neighbor
# atom information considering periodicity.
# We will use 12 nearest neighbors after sorting with distance.

all_nbrs = nacl.get_all_neighbors(r = 8, include_index=True)
all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]

assert len(all_nbrs) == 8

nbr_fea_idx, nbr_fea = [], []
for nbr in all_nbrs:
    # Note: x[1] returns distance to neighbors.
    # Note: x[2] returns index of original structure object
    nbr_fea_idx.append(list(map(lambda x: x[2],
                                nbr[:12])))
    nbr_fea.append(list(map(lambda x: x[1],
                            nbr[:12])))

# nbr_fea_idx contain information of nearest neighbor atoms 
# from ith row (ith atom in the cell)
# For example, 0th atom (Na+ (0.0000, 0.0000, 0.0000)) is neighbored
# with 5th, 6th, 7th, etc...
# nbr_fea contain information of nearest neighbot distance.
nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
nbr_fea_idx = torch.LongTensor(nbr_fea_idx)

nbr_fea_idx_standard = torch.LongTensor([[5, 6, 7, 7, 6, 5, 2, 1, 2, 1, 3, 3],
                                         [4, 7, 6, 7, 6, 4, 2, 0, 3, 2, 3, 3],
                                         [4, 7, 5, 7, 5, 4, 0, 3, 1, 3, 3, 1],
                                         [4, 6, 5, 6, 5, 4, 2, 1, 1, 2, 0, 0],
                                         [3, 1, 2, 3, 2, 1, 6, 5, 6, 5, 7, 7],
                                         [2, 3, 0, 3, 0, 2, 4, 4, 6, 6, 4, 7],
                                         [1, 3, 0, 3, 0, 1, 4, 4, 5, 5, 5, 4],
                                         [2, 1, 0, 2, 1, 0, 4, 5, 5, 4, 6, 6]])

assert torch.equal(nbr_fea_idx, nbr_fea_idx_standard)

In [7]:
# Now we have two features, atomic feature and neighbor feature.
# Note that neighbor feature is discontionous information with respect to the distanc.
# Therefore we will expand neighbor feature using Gaussian Kernel (or Gaussian filter)
# https://en.wikipedia.org/wiki/Gaussian_filter
dmin = 0
dmax = 12
step = 0.2
var = step
filter_step = np.arange(dmin, dmax+step, step)

def expand(distances):
    # Ask student to do this
    return np.exp(-(distances[..., np.newaxis] - filter_step)**2 / var**2)

nbr_fea_gaussian = expand(nbr_fea)

In [8]:
assert np.array_equal(nbr_fea_gaussian, nbr_fea)

AssertionError: 

In [9]:
gdf = GaussianDistance(dmin=0, dmax=12, step=0.2)
nbr_fea = gdf.expand(nbr_fea)

In [10]:
# Will not be used
data_dir = './cgcnn_data/sample-regression'
test = CIFData(data_dir)
(atom_fea, nbr_fea, nbr_fea_idx), target, cif_id = test[-1]

In [11]:
# Will not be used
structures, _, _ = test[-1]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

### Build a Model

In [12]:
# TODO: Graphics of layers

In [2]:
from model import CrystalGraphConvNet

In [14]:
# Simple test
# Set seed using manual_seed
# TODO explanation about crystal_atom_idx
torch.manual_seed(123)
crystal_atom_idx = [torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7]).long()]
model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len)
model.forward(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

tensor([[0.6627]], grad_fn=<AddmmBackward0>)

### Training

In [2]:
from data_utils import collate_pool, get_train_val_test_loader
from train_utils import Normalizer, train, validate, save_checkpoint
from model import CrystalGraphConvNet
from random import sample

In [3]:
torch.manual_seed(123)

# set parameters
data_dir = './hw_data/perovskite_energy'
batch_size = 8
train_ratio, val_ratio, test_ratio = 0.8, 0.1, 0.1

# get dataset

dataset = CIFData(data_dir)
collate_fn = collate_pool
train_loader, val_loader, test_loader = get_train_val_test_loader(
    dataset=dataset,
    collate_fn=collate_fn,
    batch_size=batch_size,
    train_ratio=train_ratio,
    val_ratio=val_ratio,
    test_ratio=test_ratio,
    return_test=True)

In [4]:
# normalize target

if len(dataset) < 500:
    warnings.warn('Dataset has less than 500 data points. '
                    'Lower accuracy is expected. ')
    sample_data_list = [dataset[i] for i in range(len(dataset))]
else:
    sample_data_list = [dataset[i] for i in
                        sample(range(len(dataset)), 500)]
_, sample_target, _ = collate_pool(sample_data_list)
normalizer = Normalizer(sample_target)



In [5]:
# build model
structures, _, _ = dataset[0]
orig_atom_fea_len = structures[0].shape[-1]
nbr_fea_len = structures[1].shape[-1]

# number of hidden atom features in conv layers
atom_fea_len = 64
# number of hidden features after pooling
h_fea_len = 128
# number of conv layers
n_conv = 3
# number of hidden layers after pooling
n_h = 1

model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
                            atom_fea_len=atom_fea_len,
                            n_conv=n_conv,
                            h_fea_len=h_fea_len,
                            n_h=n_h)

In [6]:
# set hyperparameters
epochs = 15
criterion = nn.MSELoss()
lr = 0.01
momentum = 0.9
weight_decay = 0

optimizer = optim.SGD(model.parameters(), lr,
                              momentum=momentum,
                              weight_decay=weight_decay)

# optimizer = optim.Adam(model.parameters(), lr,
#                         weight_decay=weight_decay)
lr_milestones = [100]
scheduler = MultiStepLR(optimizer, milestones=lr_milestones,
                            gamma=0.1)

In [11]:
for epoch in range(epochs):
    best_mae_error = 1e10
    # train for one epoch
    # TO-Do, add tqdm in the train method, don't print too much here
    # To-Do, fix warning
    train(train_loader, model, criterion, optimizer, epoch, normalizer)

    # evaluate on validation set
    mae_error = validate(val_loader, model, criterion, normalizer)

    if mae_error != mae_error:
        print('Exit due to NaN')
        sys.exit(1)

    scheduler.step()

    # remember the best mae_eror and save checkpoint
    is_best = mae_error < best_mae_error
    best_mae_error = min(mae_error, best_mae_error)

    save_checkpoint({
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'best_mae_error': best_mae_error,
        'optimizer': optimizer.state_dict(),
        'normalizer': normalizer.state_dict(),
    }, is_best)




Epoch: [0][0/67]	Time 4.174 (4.174)	Data 4.149 (4.149)	Loss 0.2806 (0.2806)	MAE 33.661 (33.661)
Epoch: [0][10/67]	Time 0.173 (0.506)	Data 0.144 (0.481)	Loss 0.2260 (0.4284)	MAE 27.904 (31.323)
Epoch: [0][20/67]	Time 0.199 (0.329)	Data 0.168 (0.304)	Loss 0.9276 (0.4242)	MAE 40.531 (31.370)
Epoch: [0][30/67]	Time 0.099 (0.266)	Data 0.076 (0.240)	Loss 0.2496 (0.4043)	MAE 28.405 (30.997)
Epoch: [0][40/67]	Time 0.131 (0.239)	Data 0.103 (0.213)	Loss 0.0990 (0.4683)	MAE 21.543 (31.908)
Epoch: [0][50/67]	Time 0.173 (0.222)	Data 0.140 (0.196)	Loss 0.7011 (0.4656)	MAE 49.489 (32.808)
Epoch: [0][60/67]	Time 0.185 (0.217)	Data 0.159 (0.190)	Loss 0.6097 (0.5658)	MAE 55.410 (34.588)




Test: [0/9]	Time 3.988 (3.988)	Loss 0.2735 (0.2735)	MAE 34.290 (34.290)




 * MAE 62.041




Epoch: [1][0/67]	Time 4.171 (4.171)	Data 4.144 (4.144)	Loss 0.2162 (0.2162)	MAE 32.096 (32.096)
Epoch: [1][10/67]	Time 0.080 (0.494)	Data 0.052 (0.466)	Loss 0.2973 (0.4158)	MAE 40.910 (32.976)
Epoch: [1][20/67]	Time 0.167 (0.325)	Data 0.136 (0.297)	Loss 0.4507 (0.6693)	MAE 37.172 (40.992)
Epoch: [1][30/67]	Time 0.255 (0.270)	Data 0.205 (0.242)	Loss 3.6540 (0.7501)	MAE 86.633 (41.564)
Epoch: [1][40/67]	Time 0.123 (0.237)	Data 0.099 (0.210)	Loss 0.2269 (0.6288)	MAE 27.458 (38.487)
Epoch: [1][50/67]	Time 0.108 (0.222)	Data 0.081 (0.194)	Loss 0.2174 (0.5636)	MAE 32.236 (37.268)
Epoch: [1][60/67]	Time 0.179 (0.217)	Data 0.150 (0.188)	Loss 0.0719 (0.5981)	MAE 13.082 (37.436)




Test: [0/9]	Time 3.866 (3.866)	Loss 0.1769 (0.1769)	MAE 24.936 (24.936)
 * MAE 47.265




Epoch: [2][0/67]	Time 4.025 (4.025)	Data 3.999 (3.999)	Loss 0.1654 (0.1654)	MAE 22.428 (22.428)
Epoch: [2][10/67]	Time 0.187 (0.490)	Data 0.156 (0.464)	Loss 0.5146 (0.4998)	MAE 30.591 (32.601)
Epoch: [2][20/67]	Time 0.082 (0.326)	Data 0.060 (0.299)	Loss 0.1939 (0.6601)	MAE 30.182 (35.491)
Epoch: [2][30/67]	Time 0.191 (0.261)	Data 0.155 (0.235)	Loss 0.2796 (0.5313)	MAE 25.685 (34.110)
Epoch: [2][40/67]	Time 0.155 (0.233)	Data 0.130 (0.207)	Loss 0.0647 (0.5090)	MAE 17.004 (34.367)
Epoch: [2][50/67]	Time 0.128 (0.217)	Data 0.100 (0.191)	Loss 0.4045 (0.5317)	MAE 36.333 (34.877)
Epoch: [2][60/67]	Time 0.150 (0.208)	Data 0.126 (0.180)	Loss 0.2873 (0.5137)	MAE 34.638 (34.106)




Test: [0/9]	Time 3.970 (3.970)	Loss 0.3331 (0.3331)	MAE 37.114 (37.114)
 * MAE 37.180




Epoch: [3][0/67]	Time 4.116 (4.116)	Data 4.090 (4.090)	Loss 0.1915 (0.1915)	MAE 25.482 (25.482)
Epoch: [3][10/67]	Time 0.118 (0.508)	Data 0.099 (0.481)	Loss 0.4429 (0.4854)	MAE 47.275 (35.168)
Epoch: [3][20/67]	Time 0.093 (0.329)	Data 0.074 (0.302)	Loss 0.1300 (0.4238)	MAE 21.316 (33.470)
Epoch: [3][30/67]	Time 0.117 (0.268)	Data 0.095 (0.241)	Loss 0.1396 (0.4238)	MAE 25.126 (33.110)
Epoch: [3][40/67]	Time 0.100 (0.245)	Data 0.078 (0.217)	Loss 0.4427 (0.4572)	MAE 49.124 (34.942)
Epoch: [3][50/67]	Time 0.124 (0.230)	Data 0.101 (0.202)	Loss 0.1866 (0.5403)	MAE 32.779 (36.064)
Epoch: [3][60/67]	Time 0.193 (0.217)	Data 0.164 (0.188)	Loss 0.3944 (0.5059)	MAE 36.211 (34.639)




Test: [0/9]	Time 3.803 (3.803)	Loss 0.1328 (0.1328)	MAE 22.599 (22.599)




 * MAE 41.498




Epoch: [4][0/67]	Time 3.975 (3.975)	Data 3.947 (3.947)	Loss 0.0808 (0.0808)	MAE 14.079 (14.079)
Epoch: [4][10/67]	Time 0.074 (0.497)	Data 0.053 (0.464)	Loss 1.3206 (0.7289)	MAE 81.152 (36.244)
Epoch: [4][20/67]	Time 0.157 (0.324)	Data 0.133 (0.294)	Loss 0.1289 (0.6021)	MAE 19.363 (36.151)
Epoch: [4][30/67]	Time 0.166 (0.266)	Data 0.139 (0.235)	Loss 0.0561 (0.5469)	MAE 16.138 (34.099)
Epoch: [4][40/67]	Time 0.177 (0.240)	Data 0.145 (0.209)	Loss 0.1343 (0.5142)	MAE 23.284 (34.241)
Epoch: [4][50/67]	Time 0.121 (0.223)	Data 0.089 (0.193)	Loss 0.2556 (0.5118)	MAE 28.796 (34.097)
Epoch: [4][60/67]	Time 0.150 (0.213)	Data 0.111 (0.183)	Loss 0.2790 (0.5388)	MAE 32.774 (35.748)




Test: [0/9]	Time 4.107 (4.107)	Loss 2.1809 (2.1809)	MAE 94.566 (94.566)
 * MAE 55.424




Epoch: [5][0/67]	Time 5.273 (5.273)	Data 5.123 (5.123)	Loss 5.4160 (5.4160)	MAE 131.260 (131.260)
Epoch: [5][10/67]	Time 0.374 (0.715)	Data 0.278 (0.622)	Loss 0.4092 (1.1857)	MAE 35.552 (50.309)
Epoch: [5][20/67]	Time 0.276 (0.497)	Data 0.206 (0.418)	Loss 0.0435 (0.9553)	MAE 15.285 (47.045)
Epoch: [5][30/67]	Time 0.154 (0.411)	Data 0.117 (0.338)	Loss 0.1168 (0.7195)	MAE 22.252 (40.874)
Epoch: [5][40/67]	Time 0.092 (0.347)	Data 0.070 (0.285)	Loss 0.1375 (0.6110)	MAE 25.502 (37.626)
Epoch: [5][50/67]	Time 0.192 (0.307)	Data 0.169 (0.251)	Loss 0.0366 (0.5385)	MAE 12.196 (35.206)
Epoch: [5][60/67]	Time 0.167 (0.284)	Data 0.143 (0.232)	Loss 0.6370 (0.5225)	MAE 34.554 (34.908)




Test: [0/9]	Time 4.318 (4.318)	Loss 0.1747 (0.1747)	MAE 28.157 (28.157)




 * MAE 38.763




Epoch: [6][0/67]	Time 4.798 (4.798)	Data 4.771 (4.771)	Loss 0.4615 (0.4615)	MAE 41.225 (41.225)




Epoch: [6][10/67]	Time 0.105 (0.594)	Data 0.066 (0.543)	Loss 0.0779 (0.2944)	MAE 15.299 (30.862)
Epoch: [6][20/67]	Time 0.211 (0.426)	Data 0.160 (0.368)	Loss 0.3534 (0.3725)	MAE 30.058 (31.751)
Epoch: [6][30/67]	Time 0.157 (0.384)	Data 0.110 (0.318)	Loss 0.1812 (0.4184)	MAE 29.057 (34.832)
Epoch: [6][40/67]	Time 0.148 (0.345)	Data 0.121 (0.282)	Loss 0.1879 (0.4427)	MAE 26.660 (36.207)
Epoch: [6][50/67]	Time 0.161 (0.311)	Data 0.126 (0.254)	Loss 0.0827 (0.4288)	MAE 19.727 (35.172)
Epoch: [6][60/67]	Time 0.138 (0.293)	Data 0.098 (0.236)	Loss 0.1378 (0.4715)	MAE 22.450 (36.690)




Test: [0/9]	Time 6.022 (6.022)	Loss 0.4549 (0.4549)	MAE 38.987 (38.987)
 * MAE 32.877




Epoch: [7][0/67]	Time 5.104 (5.104)	Data 5.057 (5.057)	Loss 1.2032 (1.2032)	MAE 60.582 (60.582)




Epoch: [7][10/67]	Time 0.113 (0.642)	Data 0.087 (0.585)	Loss 0.1335 (0.5059)	MAE 24.466 (43.351)
Epoch: [7][20/67]	Time 0.309 (0.419)	Data 0.246 (0.373)	Loss 1.1074 (0.5037)	MAE 50.862 (40.960)
Epoch: [7][30/67]	Time 0.236 (0.352)	Data 0.196 (0.307)	Loss 0.2293 (0.5112)	MAE 30.798 (40.398)
Epoch: [7][40/67]	Time 0.175 (0.315)	Data 0.140 (0.270)	Loss 0.7153 (0.5501)	MAE 43.474 (40.422)
Epoch: [7][50/67]	Time 0.187 (0.287)	Data 0.141 (0.243)	Loss 0.1917 (0.5296)	MAE 23.188 (39.482)
Epoch: [7][60/67]	Time 0.196 (0.271)	Data 0.166 (0.228)	Loss 0.2351 (0.4847)	MAE 31.101 (37.045)




Test: [0/9]	Time 4.774 (4.774)	Loss 0.0182 (0.0182)	MAE 9.561 (9.561)




 * MAE 34.392




Epoch: [8][0/67]	Time 4.757 (4.757)	Data 4.725 (4.725)	Loss 0.3419 (0.3419)	MAE 38.151 (38.151)
Epoch: [8][10/67]	Time 0.090 (0.601)	Data 0.061 (0.561)	Loss 0.2959 (0.6291)	MAE 40.075 (36.898)
Epoch: [8][20/67]	Time 0.103 (0.424)	Data 0.078 (0.381)	Loss 0.2526 (0.5653)	MAE 36.961 (38.016)
Epoch: [8][30/67]	Time 0.311 (0.357)	Data 0.269 (0.316)	Loss 0.3743 (0.5305)	MAE 35.245 (37.780)
Epoch: [8][40/67]	Time 0.111 (0.314)	Data 0.085 (0.272)	Loss 0.0567 (0.4525)	MAE 16.721 (34.924)
Epoch: [8][50/67]	Time 0.255 (0.293)	Data 0.206 (0.251)	Loss 0.2837 (0.4567)	MAE 25.701 (34.398)
Epoch: [8][60/67]	Time 0.147 (0.275)	Data 0.118 (0.233)	Loss 0.1834 (0.4333)	MAE 27.639 (34.610)




Test: [0/9]	Time 4.260 (4.260)	Loss 1.2740 (1.2740)	MAE 63.013 (63.013)
 * MAE 41.071




Epoch: [9][0/67]	Time 4.075 (4.075)	Data 4.053 (4.053)	Loss 0.1521 (0.1521)	MAE 25.077 (25.077)
Epoch: [9][10/67]	Time 0.157 (0.494)	Data 0.127 (0.469)	Loss 0.3805 (0.3237)	MAE 36.517 (26.542)
Epoch: [9][20/67]	Time 0.137 (0.329)	Data 0.115 (0.303)	Loss 0.2524 (0.4506)	MAE 31.187 (32.501)
Epoch: [9][30/67]	Time 0.106 (0.281)	Data 0.081 (0.251)	Loss 0.2478 (0.4751)	MAE 31.255 (33.522)
Epoch: [9][40/67]	Time 0.211 (0.260)	Data 0.186 (0.229)	Loss 0.1590 (0.4346)	MAE 26.126 (32.183)
Epoch: [9][50/67]	Time 0.143 (0.244)	Data 0.115 (0.213)	Loss 0.2609 (0.4008)	MAE 35.301 (31.701)
Epoch: [9][60/67]	Time 0.139 (0.230)	Data 0.118 (0.199)	Loss 0.4291 (0.4189)	MAE 40.670 (32.093)




Test: [0/9]	Time 4.194 (4.194)	Loss 0.1188 (0.1188)	MAE 21.267 (21.267)
 * MAE 46.972




Epoch: [10][0/67]	Time 4.278 (4.278)	Data 4.251 (4.251)	Loss 0.2240 (0.2240)	MAE 30.050 (30.050)
Epoch: [10][10/67]	Time 0.117 (0.505)	Data 0.091 (0.480)	Loss 0.0448 (0.4201)	MAE 13.951 (30.798)
Epoch: [10][20/67]	Time 0.162 (0.343)	Data 0.128 (0.310)	Loss 0.5876 (0.5305)	MAE 39.796 (37.671)
Epoch: [10][30/67]	Time 0.100 (0.280)	Data 0.077 (0.248)	Loss 0.0818 (0.4823)	MAE 18.064 (35.066)
Epoch: [10][40/67]	Time 0.103 (0.247)	Data 0.080 (0.216)	Loss 0.0168 (0.4477)	MAE 7.298 (33.290)
Epoch: [10][50/67]	Time 0.096 (0.245)	Data 0.073 (0.207)	Loss 0.4013 (0.4633)	MAE 43.726 (34.126)
Epoch: [10][60/67]	Time 0.156 (0.225)	Data 0.129 (0.190)	Loss 0.1651 (0.4332)	MAE 29.955 (33.865)




Test: [0/9]	Time 5.161 (5.161)	Loss 0.2406 (0.2406)	MAE 30.573 (30.573)




 * MAE 37.595




Epoch: [11][0/67]	Time 4.663 (4.663)	Data 4.637 (4.637)	Loss 0.2169 (0.2169)	MAE 30.399 (30.399)




Epoch: [11][10/67]	Time 0.240 (0.544)	Data 0.198 (0.517)	Loss 2.1983 (0.3381)	MAE 85.936 (28.806)
Epoch: [11][20/67]	Time 0.158 (0.339)	Data 0.115 (0.314)	Loss 0.4609 (0.4252)	MAE 33.067 (33.389)
Epoch: [11][30/67]	Time 0.157 (0.274)	Data 0.130 (0.248)	Loss 0.0691 (0.4393)	MAE 19.011 (32.476)
Epoch: [11][40/67]	Time 0.330 (0.274)	Data 0.284 (0.238)	Loss 0.6025 (0.4568)	MAE 45.531 (33.119)
Epoch: [11][50/67]	Time 0.104 (0.261)	Data 0.079 (0.224)	Loss 0.0611 (0.4773)	MAE 13.241 (33.396)
Epoch: [11][60/67]	Time 0.266 (0.248)	Data 0.215 (0.208)	Loss 3.1125 (0.4996)	MAE 82.552 (33.632)




Test: [0/9]	Time 4.490 (4.490)	Loss 0.8158 (0.8158)	MAE 43.290 (43.290)
 * MAE 53.350




Epoch: [12][0/67]	Time 4.480 (4.480)	Data 4.460 (4.460)	Loss 0.1702 (0.1702)	MAE 26.113 (26.113)




Epoch: [12][10/67]	Time 0.265 (0.569)	Data 0.204 (0.526)	Loss 0.1067 (0.4938)	MAE 20.978 (34.577)
Epoch: [12][20/67]	Time 0.067 (0.379)	Data 0.046 (0.341)	Loss 0.2766 (0.5203)	MAE 37.055 (37.242)
Epoch: [12][30/67]	Time 0.168 (0.299)	Data 0.146 (0.265)	Loss 0.0984 (0.4311)	MAE 22.558 (34.585)
Epoch: [12][40/67]	Time 0.207 (0.259)	Data 0.183 (0.228)	Loss 0.0509 (0.3635)	MAE 14.234 (30.946)
Epoch: [12][50/67]	Time 0.101 (0.241)	Data 0.075 (0.210)	Loss 0.2263 (0.3702)	MAE 27.671 (29.787)
Epoch: [12][60/67]	Time 0.083 (0.227)	Data 0.062 (0.197)	Loss 0.5622 (0.3549)	MAE 47.992 (30.060)




Test: [0/9]	Time 4.087 (4.087)	Loss 0.5369 (0.5369)	MAE 38.020 (38.020)
 * MAE 41.242




Epoch: [13][0/67]	Time 4.110 (4.110)	Data 4.082 (4.082)	Loss 0.0350 (0.0350)	MAE 10.244 (10.244)
Epoch: [13][10/67]	Time 0.079 (0.500)	Data 0.057 (0.471)	Loss 0.0458 (0.4286)	MAE 11.100 (31.277)
Epoch: [13][20/67]	Time 0.101 (0.339)	Data 0.069 (0.308)	Loss 0.1457 (0.4509)	MAE 26.722 (30.901)
Epoch: [13][30/67]	Time 0.314 (0.297)	Data 0.250 (0.257)	Loss 2.0181 (0.4448)	MAE 73.606 (31.790)
Epoch: [13][40/67]	Time 0.163 (0.267)	Data 0.137 (0.229)	Loss 0.0811 (0.4779)	MAE 16.309 (31.435)
Epoch: [13][50/67]	Time 0.189 (0.250)	Data 0.163 (0.210)	Loss 0.6603 (0.5050)	MAE 38.515 (33.613)
Epoch: [13][60/67]	Time 0.147 (0.240)	Data 0.123 (0.198)	Loss 0.0574 (0.4996)	MAE 15.993 (33.068)




Test: [0/9]	Time 4.926 (4.926)	Loss 0.0698 (0.0698)	MAE 17.816 (17.816)




 * MAE 33.741




Epoch: [14][0/67]	Time 4.046 (4.046)	Data 4.021 (4.021)	Loss 0.1527 (0.1527)	MAE 28.092 (28.092)
Epoch: [14][10/67]	Time 0.107 (0.499)	Data 0.081 (0.472)	Loss 0.2020 (0.2569)	MAE 31.608 (28.263)
Epoch: [14][20/67]	Time 0.205 (0.340)	Data 0.170 (0.311)	Loss 0.1723 (0.2780)	MAE 22.488 (29.299)
Epoch: [14][30/67]	Time 0.117 (0.287)	Data 0.097 (0.257)	Loss 0.1051 (0.2559)	MAE 18.838 (26.860)
Epoch: [14][40/67]	Time 0.197 (0.256)	Data 0.163 (0.226)	Loss 0.1952 (0.3270)	MAE 29.434 (27.938)
Epoch: [14][50/67]	Time 0.186 (0.237)	Data 0.154 (0.206)	Loss 0.1825 (0.3276)	MAE 29.538 (28.590)
Epoch: [14][60/67]	Time 0.126 (0.225)	Data 0.104 (0.195)	Loss 0.2848 (0.3781)	MAE 33.501 (30.566)




Test: [0/9]	Time 4.526 (4.526)	Loss 0.3072 (0.3072)	MAE 37.472 (37.472)




 * MAE 42.685


In [12]:
# test best model
print('---------Evaluate Model on Test Set---------------')
best_checkpoint = torch.load('model_best.pth.tar')
model.load_state_dict(best_checkpoint['state_dict'])
validate(test_loader, model, criterion, normalizer, test=True)

---------Evaluate Model on Test Set---------------




Test: [0/9]	Time 4.027 (4.027)	Loss 0.7605 (0.7605)	MAE 48.910 (48.910)
 ** MAE 54.624


tensor(54.6244)

In [None]:
# hw task, visualize the training loss