In [1]:
import sys
sys.path.append('/p/home/jusers/kotobi2/juwels/hida_project/')

In [2]:
import os.path as osp
import numpy as np
import torch
from torch_geometric.loader import DataLoader

from attribution_gnn1.QM9_SpecData import QM9_SpecData
from attribution_gnn1.split import save_split

from src.models import SpectraGNN, SpectraGAT, SpectraGraphNet

from training.trainer import GNNTrainer

In [3]:
model_name = 'spectragraphnet_50k_bl.pt'
num_epochs = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr =1e-3
milestones = np.arange(10, 100, 10).tolist()

# preparing the data

In [4]:
root = '/p/home/jusers/kotobi2/juwels/data_qm9/all_graph_data/qm9_spec_50k_0-8eV_bl.pt'
qm9_spec = QM9_SpecData(root=root,
             raw_dir='/p/home/jusers/kotobi2/juwels/data_qm9/raw/',
             spectra=[])#broadened_spectra_stk)

In [5]:
idxs = save_split(
    path='/p/home/jusers/kotobi2/juwels/hida_project/data/split_files2/qm9_split_50k_0-8eV_bl.npz',
    ndata=len(qm9_spec),
    ntrain=40000,
    nval=10000,
    ntest=0,
    save_split=True,
    shuffle=True, 
    print_nsample=True
)

In [6]:
#train, val and test data
train_qm9 = [qm9_spec[i] for i in idxs['train']]
val_qm9 = [qm9_spec[i] for i in idxs['val']]
#test_qm9 = qm9_spec[idxs['test']]

In [7]:
# data loaders 
train_loader = DataLoader(train_qm9, batch_size=100, shuffle=True)
val_loader = DataLoader(val_qm9, batch_size=100, shuffle=True)
#test_loader = DataLoader(test_qm9, batch_size=100)

In [17]:
len(qm9_spec)

50000

# some more metrics

In [8]:
def RSE_loss(prediction, target):
    dE = (300 - 270) / 100
    nom = torch.sum(dE*torch.pow((target-prediction), 2))
    denom = torch.sum(dE*target)
    return torch.sqrt(nom) / denom 

In [9]:
def RMSE(prediction, target):
    return torch.sqrt(torch.mean((target - prediction)**2))

# loading the model 

## spectragnn

In [17]:
spectragnn = SpectraGNN(
    gnn_name='gatv2',
    in_channels=[11, 128, 256, 512],
    out_channels=[128, 256, 512, 600],
    num_targets=100,
    num_layers=4,
    heads=3
).to(device)

# loading the saved model 
path_to_model = osp.join('./best_model', 
                         model_name)

if osp.exists(path_to_model):
    spectragnn.load_state_dict(torch.load(path_to_model))
else:
    print('model is not loaded')

model is not loaded


In [18]:
spectragnn

SpectraGNN(
  (interaction_layers): ModuleList(
    (0): GATv2Conv(11, 128, heads=3)
    (1): ReLU(inplace=True)
    (2): GATv2Conv(384, 256, heads=3)
    (3): ReLU(inplace=True)
    (4): GATv2Conv(768, 512, heads=3)
    (5): ReLU(inplace=True)
    (6): GATv2Conv(1536, 600, heads=1)
  )
  (dropout): Dropout(p=0.3, inplace=False)
  (out): Linear(in_features=600, out_features=100, bias=True)
)

## SpectraGAT

In [11]:
spectragat = SpectraGAT(
    node_features_dim=11,
    in_channels=[128, 128, 128, 128],
    out_channels=[128, 128, 128, 400],
    targets=100,
    n_layers=4,
    n_heads=3,
    gat_type = 'gatv2_custom',
    use_residuals=True,
    use_jk=True
).to(device)

# loading the saved model 
path_to_model = osp.join('./best_model', 
                         model_name)

if osp.exists(path_to_model):
    spectragat.load_state_dict(torch.load(path_to_model))
else:
    print('model is not loaded')

model is not loaded


In [12]:
spectragat

SpectraGAT(
  (pre_layer): LinearLayer(
    (linear): Linear(in_features=11, out_features=128, bias=False)
    (_activation): ReLU(inplace=True)
  )
  (res_block): Residual_block(
    (res_layers): Sequential(
      (0): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
      (1): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
      (2): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
      (3): LinearLayer(
        (linear): Linear(in_features=128, out_features=128, bias=False)
        (_activation): ReLU(inplace=True)
      )
    )
  )
  (gat_layers): ModuleList(
    (0): GATv2LayerCus(
      (lin_r): LinearLayer(
        (linear): Linear(in_features=128, out_features=384, bias=False)
        (_activation): Identity()
      )
  

# SpectraGraphNet

In [None]:
spectragraphnet = SpectraGraphNet(
                 node_dim = 14,
                 edge_dim = 5,
                 hidden_channels = 512,
                 out_channels = 200,
                 gat_hidd = 512,
                 gat_out = 100,
                 n_layers = 3,
                 n_targets = 100).to(device)

# loading the saved model 
path_to_model = osp.join('./best_model', 
                         'spectragraphnet_50k_bl_pd.pt')

if osp.exists(path_to_model):
    spectragraphnet.load_state_dict(torch.load(path_to_model))
else:
    print('model is not loaded')

NameError: name 'SpectraGraphNet' is not defined

In [13]:
spectragraphnet

SpectraGraphNet(
  (graphnets): ModuleList(
    (0): GraphNetwork(
      (gatencoder): GATEncoder(
        (gats): ModuleList(
          (0): GATv2Conv(12, 64, heads=3)
          (1): ReLU(inplace=True)
          (2): GATv2Conv(192, 64, heads=3)
          (3): ReLU(inplace=True)
          (4): GATv2Conv(192, 64, heads=3)
          (5): ReLU(inplace=True)
          (6): GATv2Conv(192, 20, heads=1)
        )
      )
      (node_model): NodeModel(
        (mlp): Sequential(
          (0): Linear(in_features=42, out_features=64, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=64, out_features=50, bias=True)
          (3): ReLU(inplace=True)
          (4): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        )
      )
      (edge_model): EdgeModel(
        (mlp): Sequential(
          (0): Linear(in_features=125, out_features=64, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=64, out_features=50, bias=True)
          (3)

# Training with the trainer class

In [14]:
optimizer = torch.optim.AdamW(spectragraphnet.parameters(), lr=lr)
loss_fn = torch.nn.L1Loss()
loss_fn2 = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                 milestones=milestones,
                                                 gamma=0.8)

In [15]:
trainer = GNNTrainer(model=spectragraphnet, 
                     model_name="spectragraphnet_50k_bl",
                     device=device,
                     metric_path="./metrics")

In [16]:
trainer.train_val(train_loader, val_loader, optimizer,
                  RMSE, scheduler, num_epochs, write_every=1, train_graphnet=True)

  1%|          | 1/100 [01:20<2:12:11, 80.12s/it]

time = 1.33 mins mins
epoch 0 | average train loss = 19.55  and average validation loss = 16.39  |learning rate = 0.00100


  2%|▏         | 2/100 [02:19<1:50:50, 67.86s/it]

time = 2.32 mins mins
epoch 1 | average train loss = 12.47  and average validation loss = 9.61  |learning rate = 0.00100


  3%|▎         | 3/100 [03:20<1:44:25, 64.59s/it]

time = 3.34 mins mins
epoch 2 | average train loss = 9.07  and average validation loss = 8.70  |learning rate = 0.00100


  4%|▍         | 4/100 [04:23<1:42:36, 64.13s/it]

time = 4.39 mins mins
epoch 3 | average train loss = 8.84  and average validation loss = 8.66  |learning rate = 0.00100


  5%|▌         | 5/100 [05:26<1:40:37, 63.55s/it]

time = 5.43 mins mins
epoch 4 | average train loss = 8.84  and average validation loss = 8.67  |learning rate = 0.00100


  6%|▌         | 6/100 [06:27<1:38:15, 62.72s/it]

time = 6.45 mins mins
epoch 5 | average train loss = 8.83  and average validation loss = 8.66  |learning rate = 0.00100


  7%|▋         | 7/100 [07:28<1:36:27, 62.23s/it]

time = 7.47 mins mins
epoch 6 | average train loss = 8.83  and average validation loss = 8.65  |learning rate = 0.00100


  8%|▊         | 8/100 [08:29<1:34:58, 61.94s/it]

time = 8.49 mins mins
epoch 7 | average train loss = 8.80  and average validation loss = 8.40  |learning rate = 0.00100


  9%|▉         | 9/100 [09:31<1:33:38, 61.74s/it]

time = 9.52 mins mins
epoch 8 | average train loss = 8.21  and average validation loss = 7.78  |learning rate = 0.00100


 10%|█         | 10/100 [10:32<1:32:41, 61.79s/it]

time = 10.55 mins mins
epoch 9 | average train loss = 7.90  and average validation loss = 7.59  |learning rate = 0.00064


 11%|█         | 11/100 [11:35<1:32:13, 62.18s/it]

time = 11.60 mins mins
epoch 10 | average train loss = 7.75  and average validation loss = 7.48  |learning rate = 0.00080


 12%|█▏        | 12/100 [12:37<1:30:48, 61.92s/it]

time = 12.62 mins mins
epoch 11 | average train loss = 7.65  and average validation loss = 7.34  |learning rate = 0.00080


 13%|█▎        | 13/100 [13:39<1:30:01, 62.09s/it]

time = 13.66 mins mins
epoch 12 | average train loss = 7.54  and average validation loss = 7.22  |learning rate = 0.00080


 14%|█▍        | 14/100 [14:41<1:28:56, 62.05s/it]

time = 14.70 mins mins
epoch 13 | average train loss = 7.46  and average validation loss = 7.20  |learning rate = 0.00080


 15%|█▌        | 15/100 [15:43<1:27:48, 61.99s/it]

time = 15.73 mins mins
epoch 14 | average train loss = 7.40  and average validation loss = 7.07  |learning rate = 0.00080


 16%|█▌        | 16/100 [16:45<1:26:51, 62.04s/it]

time = 16.76 mins mins
epoch 15 | average train loss = 7.33  and average validation loss = 6.89  |learning rate = 0.00080


 17%|█▋        | 17/100 [17:46<1:25:28, 61.79s/it]

time = 17.78 mins mins
epoch 16 | average train loss = 7.12  and average validation loss = 6.72  |learning rate = 0.00080


 18%|█▊        | 18/100 [18:48<1:24:13, 61.63s/it]

time = 18.80 mins mins
epoch 17 | average train loss = 7.03  and average validation loss = 6.74  |learning rate = 0.00080


 19%|█▉        | 19/100 [19:49<1:23:03, 61.53s/it]

time = 19.82 mins mins
epoch 18 | average train loss = 6.93  and average validation loss = 6.56  |learning rate = 0.00080


 20%|██        | 20/100 [20:51<1:22:07, 61.60s/it]

time = 20.85 mins mins
epoch 19 | average train loss = 6.83  and average validation loss = 6.56  |learning rate = 0.00051


 21%|██        | 21/100 [21:53<1:21:13, 61.69s/it]

time = 21.89 mins mins
epoch 20 | average train loss = 6.76  and average validation loss = 6.35  |learning rate = 0.00064


 22%|██▏       | 22/100 [22:56<1:20:43, 62.10s/it]

time = 22.94 mins mins
epoch 21 | average train loss = 6.70  and average validation loss = 6.33  |learning rate = 0.00064


 23%|██▎       | 23/100 [23:58<1:19:51, 62.22s/it]

time = 23.98 mins mins
epoch 22 | average train loss = 6.65  and average validation loss = 6.24  |learning rate = 0.00064


 24%|██▍       | 24/100 [25:02<1:19:32, 62.80s/it]

time = 25.05 mins mins
epoch 23 | average train loss = 6.62  and average validation loss = 6.20  |learning rate = 0.00064


 25%|██▌       | 25/100 [26:07<1:19:02, 63.23s/it]

time = 26.12 mins mins
epoch 24 | average train loss = 6.58  and average validation loss = 6.15  |learning rate = 0.00064


 26%|██▌       | 26/100 [27:11<1:18:17, 63.49s/it]

time = 27.19 mins mins
epoch 25 | average train loss = 6.53  and average validation loss = 6.09  |learning rate = 0.00064


 27%|██▋       | 27/100 [28:15<1:17:30, 63.70s/it]

time = 28.26 mins mins
epoch 26 | average train loss = 6.50  and average validation loss = 6.10  |learning rate = 0.00064


 28%|██▊       | 28/100 [29:17<1:16:02, 63.37s/it]

time = 29.30 mins mins
epoch 27 | average train loss = 6.45  and average validation loss = 6.09  |learning rate = 0.00064


 29%|██▉       | 29/100 [30:22<1:15:19, 63.66s/it]

time = 30.37 mins mins
epoch 28 | average train loss = 6.43  and average validation loss = 5.99  |learning rate = 0.00064


 30%|███       | 30/100 [31:26<1:14:28, 63.84s/it]

time = 31.44 mins mins
epoch 29 | average train loss = 6.40  and average validation loss = 5.93  |learning rate = 0.00041


 31%|███       | 31/100 [32:30<1:13:31, 63.94s/it]

time = 32.51 mins mins
epoch 30 | average train loss = 6.33  and average validation loss = 5.89  |learning rate = 0.00051


 32%|███▏      | 32/100 [33:34<1:12:15, 63.76s/it]

time = 33.57 mins mins
epoch 31 | average train loss = 6.27  and average validation loss = 5.87  |learning rate = 0.00051


 33%|███▎      | 33/100 [34:37<1:11:11, 63.75s/it]

time = 34.63 mins mins
epoch 32 | average train loss = 6.23  and average validation loss = 5.77  |learning rate = 0.00051


 34%|███▍      | 34/100 [35:40<1:09:46, 63.43s/it]

time = 35.68 mins mins
epoch 33 | average train loss = 6.19  and average validation loss = 5.72  |learning rate = 0.00051


 35%|███▌      | 35/100 [36:42<1:08:10, 62.92s/it]

time = 36.70 mins mins
epoch 34 | average train loss = 6.17  and average validation loss = 5.66  |learning rate = 0.00051


 36%|███▌      | 36/100 [37:46<1:07:30, 63.29s/it]

time = 37.77 mins mins
epoch 35 | average train loss = 6.13  and average validation loss = 5.72  |learning rate = 0.00051


 37%|███▋      | 37/100 [38:50<1:06:45, 63.58s/it]

time = 38.84 mins mins
epoch 36 | average train loss = 6.10  and average validation loss = 5.61  |learning rate = 0.00051


 38%|███▊      | 38/100 [39:53<1:05:21, 63.25s/it]

time = 39.89 mins mins
epoch 37 | average train loss = 6.06  and average validation loss = 5.56  |learning rate = 0.00051


 39%|███▉      | 39/100 [40:55<1:03:55, 62.88s/it]

time = 40.92 mins mins
epoch 38 | average train loss = 6.04  and average validation loss = 5.54  |learning rate = 0.00051


 40%|████      | 40/100 [41:59<1:03:18, 63.31s/it]

time = 41.99 mins mins
epoch 39 | average train loss = 6.00  and average validation loss = 5.48  |learning rate = 0.00033


 41%|████      | 41/100 [43:03<1:02:29, 63.55s/it]

time = 43.06 mins mins
epoch 40 | average train loss = 5.97  and average validation loss = 5.49  |learning rate = 0.00041


 42%|████▏     | 42/100 [44:06<1:01:11, 63.31s/it]

time = 44.11 mins mins
epoch 41 | average train loss = 5.96  and average validation loss = 5.46  |learning rate = 0.00041


 43%|████▎     | 43/100 [45:09<1:00:01, 63.19s/it]

time = 45.15 mins mins
epoch 42 | average train loss = 5.93  and average validation loss = 5.42  |learning rate = 0.00041


 44%|████▍     | 44/100 [46:12<58:55, 63.13s/it]  

time = 46.20 mins mins
epoch 43 | average train loss = 5.93  and average validation loss = 5.42  |learning rate = 0.00041


 46%|████▌     | 46/100 [48:17<56:32, 62.82s/it]

time = 48.29 mins mins
epoch 45 | average train loss = 5.88  and average validation loss = 5.39  |learning rate = 0.00041


 47%|████▋     | 47/100 [49:18<55:04, 62.35s/it]

time = 49.31 mins mins
epoch 46 | average train loss = 5.89  and average validation loss = 5.37  |learning rate = 0.00041


 48%|████▊     | 48/100 [50:21<54:00, 62.32s/it]

time = 50.35 mins mins
epoch 47 | average train loss = 5.87  and average validation loss = 5.44  |learning rate = 0.00041


 49%|████▉     | 49/100 [51:24<53:16, 62.67s/it]

time = 51.41 mins mins
epoch 48 | average train loss = 5.86  and average validation loss = 5.36  |learning rate = 0.00041


 50%|█████     | 50/100 [52:26<52:04, 62.48s/it]

time = 52.44 mins mins
epoch 49 | average train loss = 5.83  and average validation loss = 5.32  |learning rate = 0.00026


 51%|█████     | 51/100 [53:28<50:46, 62.18s/it]

time = 53.47 mins mins
epoch 50 | average train loss = 5.80  and average validation loss = 5.28  |learning rate = 0.00033


 52%|█████▏    | 52/100 [54:33<50:37, 63.27s/it]

time = 54.57 mins mins
epoch 51 | average train loss = 5.81  and average validation loss = 5.28  |learning rate = 0.00033


 53%|█████▎    | 53/100 [55:41<50:29, 64.46s/it]

time = 55.69 mins mins
epoch 52 | average train loss = 5.79  and average validation loss = 5.31  |learning rate = 0.00033


 54%|█████▍    | 54/100 [56:44<49:06, 64.04s/it]

time = 56.74 mins mins
epoch 53 | average train loss = 5.78  and average validation loss = 5.28  |learning rate = 0.00033


 55%|█████▌    | 55/100 [57:46<47:32, 63.40s/it]

time = 57.77 mins mins
epoch 54 | average train loss = 5.77  and average validation loss = 5.24  |learning rate = 0.00033


 56%|█████▌    | 56/100 [58:49<46:28, 63.37s/it]

time = 58.82 mins mins
epoch 55 | average train loss = 5.77  and average validation loss = 5.25  |learning rate = 0.00033


 57%|█████▋    | 57/100 [59:51<45:10, 63.04s/it]

time = 59.86 mins mins
epoch 56 | average train loss = 5.75  and average validation loss = 5.31  |learning rate = 0.00033


 58%|█████▊    | 58/100 [1:00:53<43:55, 62.74s/it]

time = 60.90 mins mins
epoch 57 | average train loss = 5.75  and average validation loss = 5.30  |learning rate = 0.00033


 59%|█████▉    | 59/100 [1:01:56<42:50, 62.69s/it]

time = 61.94 mins mins
epoch 58 | average train loss = 5.73  and average validation loss = 5.18  |learning rate = 0.00033


 60%|██████    | 60/100 [1:03:00<42:08, 63.21s/it]

time = 63.01 mins mins
epoch 59 | average train loss = 5.72  and average validation loss = 5.19  |learning rate = 0.00021


 61%|██████    | 61/100 [1:04:05<41:28, 63.80s/it]

time = 64.10 mins mins
epoch 60 | average train loss = 5.70  and average validation loss = 5.19  |learning rate = 0.00026


 62%|██████▏   | 62/100 [1:05:10<40:31, 63.99s/it]

time = 65.17 mins mins
epoch 61 | average train loss = 5.68  and average validation loss = 5.20  |learning rate = 0.00026


 63%|██████▎   | 63/100 [1:06:12<39:04, 63.35s/it]

time = 66.20 mins mins
epoch 62 | average train loss = 5.69  and average validation loss = 5.15  |learning rate = 0.00026


 64%|██████▍   | 64/100 [1:07:14<37:44, 62.91s/it]

time = 67.23 mins mins
epoch 63 | average train loss = 5.68  and average validation loss = 5.18  |learning rate = 0.00026


 65%|██████▌   | 65/100 [1:08:15<36:26, 62.47s/it]

time = 68.26 mins mins
epoch 64 | average train loss = 5.66  and average validation loss = 5.15  |learning rate = 0.00026


 66%|██████▌   | 66/100 [1:09:17<35:19, 62.32s/it]

time = 69.29 mins mins
epoch 65 | average train loss = 5.66  and average validation loss = 5.09  |learning rate = 0.00026


 67%|██████▋   | 67/100 [1:10:19<34:10, 62.15s/it]

time = 70.32 mins mins
epoch 66 | average train loss = 5.64  and average validation loss = 5.10  |learning rate = 0.00026


 68%|██████▊   | 68/100 [1:11:20<33:04, 62.03s/it]

time = 71.35 mins mins
epoch 67 | average train loss = 5.63  and average validation loss = 5.09  |learning rate = 0.00026


 69%|██████▉   | 69/100 [1:12:22<31:58, 61.87s/it]

time = 72.38 mins mins
epoch 68 | average train loss = 5.63  and average validation loss = 5.12  |learning rate = 0.00026


 70%|███████   | 70/100 [1:13:23<30:52, 61.74s/it]

time = 73.40 mins mins
epoch 69 | average train loss = 5.62  and average validation loss = 5.14  |learning rate = 0.00017


 71%|███████   | 71/100 [1:14:28<30:12, 62.48s/it]

time = 74.47 mins mins
epoch 70 | average train loss = 5.60  and average validation loss = 5.06  |learning rate = 0.00021


 72%|███████▏  | 72/100 [1:15:32<29:23, 62.97s/it]

time = 75.54 mins mins
epoch 71 | average train loss = 5.60  and average validation loss = 5.08  |learning rate = 0.00021


 73%|███████▎  | 73/100 [1:16:34<28:10, 62.63s/it]

time = 76.57 mins mins
epoch 72 | average train loss = 5.59  and average validation loss = 5.04  |learning rate = 0.00021


 74%|███████▍  | 74/100 [1:17:35<27:01, 62.36s/it]

time = 77.60 mins mins
epoch 73 | average train loss = 5.58  and average validation loss = 5.04  |learning rate = 0.00021


 75%|███████▌  | 75/100 [1:18:37<25:56, 62.25s/it]

time = 78.63 mins mins
epoch 74 | average train loss = 5.57  and average validation loss = 5.07  |learning rate = 0.00021


 76%|███████▌  | 76/100 [1:19:38<24:45, 61.90s/it]

time = 79.65 mins mins
epoch 75 | average train loss = 5.58  and average validation loss = 5.04  |learning rate = 0.00021


 77%|███████▋  | 77/100 [1:20:39<23:38, 61.66s/it]

time = 80.67 mins mins
epoch 76 | average train loss = 5.57  and average validation loss = 5.06  |learning rate = 0.00021


 78%|███████▊  | 78/100 [1:21:40<22:31, 61.43s/it]

time = 81.68 mins mins
epoch 77 | average train loss = 5.56  and average validation loss = 5.02  |learning rate = 0.00021


 79%|███████▉  | 79/100 [1:22:41<21:24, 61.18s/it]

time = 82.69 mins mins
epoch 78 | average train loss = 5.56  and average validation loss = 5.00  |learning rate = 0.00021


 80%|████████  | 80/100 [1:23:45<20:37, 61.90s/it]

time = 83.75 mins mins
epoch 79 | average train loss = 5.55  and average validation loss = 5.13  |learning rate = 0.00013


 81%|████████  | 81/100 [1:24:48<19:43, 62.29s/it]

time = 84.80 mins mins
epoch 80 | average train loss = 5.54  and average validation loss = 5.00  |learning rate = 0.00017


 82%|████████▏ | 82/100 [1:25:50<18:39, 62.20s/it]

time = 85.84 mins mins
epoch 81 | average train loss = 5.53  and average validation loss = 4.99  |learning rate = 0.00017


 83%|████████▎ | 83/100 [1:26:50<17:26, 61.53s/it]

time = 86.84 mins mins
epoch 82 | average train loss = 5.51  and average validation loss = 4.98  |learning rate = 0.00017


 84%|████████▍ | 84/100 [1:27:49<16:15, 60.97s/it]

time = 87.83 mins mins
epoch 83 | average train loss = 5.53  and average validation loss = 4.97  |learning rate = 0.00017


 85%|████████▌ | 85/100 [1:28:49<15:08, 60.54s/it]

time = 88.82 mins mins
epoch 84 | average train loss = 5.50  and average validation loss = 4.96  |learning rate = 0.00017


 86%|████████▌ | 86/100 [1:29:50<14:08, 60.60s/it]

time = 89.84 mins mins
epoch 85 | average train loss = 5.51  and average validation loss = 4.99  |learning rate = 0.00017


 87%|████████▋ | 87/100 [1:30:49<13:03, 60.26s/it]

time = 90.83 mins mins
epoch 86 | average train loss = 5.51  and average validation loss = 4.97  |learning rate = 0.00017


 88%|████████▊ | 88/100 [1:31:49<12:01, 60.13s/it]

time = 91.82 mins mins
epoch 87 | average train loss = 5.49  and average validation loss = 4.98  |learning rate = 0.00017


 89%|████████▉ | 89/100 [1:32:49<10:59, 59.96s/it]

time = 92.82 mins mins
epoch 88 | average train loss = 5.50  and average validation loss = 4.96  |learning rate = 0.00017


 90%|█████████ | 90/100 [1:33:48<09:58, 59.82s/it]

time = 93.81 mins mins
epoch 89 | average train loss = 5.49  and average validation loss = 4.95  |learning rate = 0.00011


 91%|█████████ | 91/100 [1:34:48<08:57, 59.72s/it]

time = 94.80 mins mins
epoch 90 | average train loss = 5.47  and average validation loss = 4.93  |learning rate = 0.00013


 92%|█████████▏| 92/100 [1:35:47<07:57, 59.64s/it]

time = 95.79 mins mins
epoch 91 | average train loss = 5.47  and average validation loss = 4.98  |learning rate = 0.00013


 93%|█████████▎| 93/100 [1:36:46<06:57, 59.59s/it]

time = 96.78 mins mins
epoch 92 | average train loss = 5.46  and average validation loss = 4.97  |learning rate = 0.00013


 94%|█████████▍| 94/100 [1:37:46<05:57, 59.56s/it]

time = 97.77 mins mins
epoch 93 | average train loss = 5.46  and average validation loss = 4.98  |learning rate = 0.00013


 95%|█████████▌| 95/100 [1:38:45<04:57, 59.53s/it]

time = 98.76 mins mins
epoch 94 | average train loss = 5.46  and average validation loss = 4.95  |learning rate = 0.00013


 96%|█████████▌| 96/100 [1:39:45<03:58, 59.51s/it]

time = 99.76 mins mins
epoch 95 | average train loss = 5.45  and average validation loss = 4.92  |learning rate = 0.00013


 97%|█████████▋| 97/100 [1:40:45<02:59, 59.69s/it]

time = 100.76 mins mins
epoch 96 | average train loss = 5.46  and average validation loss = 4.92  |learning rate = 0.00013


 98%|█████████▊| 98/100 [1:41:45<01:59, 59.92s/it]

time = 101.76 mins mins
epoch 97 | average train loss = 5.44  and average validation loss = 4.91  |learning rate = 0.00013


 99%|█████████▉| 99/100 [1:42:45<00:59, 59.78s/it]

time = 102.76 mins mins
epoch 98 | average train loss = 5.43  and average validation loss = 4.92  |learning rate = 0.00013


100%|██████████| 100/100 [1:43:44<00:00, 62.25s/it]

time = 103.75 mins mins
epoch 99 | average train loss = 5.43  and average validation loss = 4.92  |learning rate = 0.00013



