In [1]:
from typing import Optional
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import warnings
import os.path as osp
from math import pi as PI
from ase.io import read
import torch
import torch.nn.functional as F
from torch.nn import Embedding, Sequential, Linear, ModuleList
import random
from sklearn.model_selection import train_test_split
from torch_scatter import scatter
from torch_geometric.data.makedirs import makedirs
from torch_geometric.data import download_url, extract_zip, Dataset
from torch_geometric.nn import radius_graph, MessagePassing

In [53]:
class SchNet(torch.nn.Module):
    r"""The continuous-filter convolutional neural network SchNet from the
    `"SchNet: A Continuous-filter Convolutional Neural Network for Modeling
    Quantum Interactions" <https://arxiv.org/abs/1706.08566>`_ paper that uses
    the interactions blocks of the form

    .. math::
        \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot
        h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),

    here :math:`h_{\mathbf{\Theta}}` denotes an MLP and
    :math:`\mathbf{e}_{j,i}` denotes the interatomic distances between atoms.

    .. note::

        For an example of using a pretrained SchNet variant, see
        `examples/qm9_pretrained_schnet.py
        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
        qm9_pretrained_schnet.py>`_.

    Args:
        hidden_channels (int, optional): Hidden embedding size.
            (default: :obj:`128`)
        num_filters (int, optional): The number of filters to use.
            (default: :obj:`128`)
        num_interactions (int, optional): The number of interaction blocks.
            (default: :obj:`6`)
        num_gaussians (int, optional): The number of gaussians :math:`\mu`.
            (default: :obj:`50`)
        cutoff (float, optional): Cutoff distance for interatomic interactions.
            (default: :obj:`10.0`)
        max_num_neighbors (int, optional): The maximum number of neighbors to
            collect for each node within the :attr:`cutoff` distance.
            (default: :obj:`32`)
        readout (string, optional): Whether to apply :obj:`"add"` or
            :obj:`"mean"` global aggregation. (default: :obj:`"add"`)
        dipole (bool, optional): If set to :obj:`True`, will use the magnitude
            of the dipole moment to make the final prediction, *e.g.*, for
            target 0 of :class:`torch_geometric.datasets.QM9`.
            (default: :obj:`False`)
        mean (float, optional): The mean of the property to predict.
            (default: :obj:`None`)
        std (float, optional): The standard deviation of the property to
            predict. (default: :obj:`None`)
        atomref (torch.Tensor, optional): The reference of single-atom
            properties.
            Expects a vector of shape :obj:`(max_atomic_number, )`.
    """

    url = 'http://www.quantum-machine.org/datasets/trained_schnet_models.zip'
    def __init__(self, hidden_channels: int = 128, num_filters: int = 128,
                 num_interactions: int = 10, num_gaussians: int = 50,
                 cutoff: float = 10.0, max_num_neighbors: int = 32,
                 readout: str = 'mean', dipole: bool = False,
                # mean: Optional[float] = None, std: Optional[float] = None,
                 atomref: Optional[torch.Tensor] = None
            ):
        super(SchNet, self).__init__()

        import ase

        self.hidden_channels = hidden_channels
        self.num_filters = num_filters
        self.num_interactions = num_interactions
        self.num_gaussians = num_gaussians
        self.cutoff = cutoff
        self.max_num_neighbors = max_num_neighbors
        self.readout = readout
        self.dipole = dipole
        self.readout = 'add' if self.dipole else self.readout
        self.mean = None
        self.std = None
        self.scale = None

        atomic_mass = torch.from_numpy(ase.data.atomic_masses)
        self.register_buffer('atomic_mass', atomic_mass)

        self.embedding = Embedding(100, hidden_channels)
        self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)

        self.interactions = ModuleList()
        for _ in range(num_interactions):
            block = InteractionBlock(hidden_channels, num_gaussians,
                                     num_filters, cutoff)
            self.interactions.append(block)

        self.lin1 = Linear(hidden_channels, hidden_channels // 2)
        self.act = ShiftedSoftplus()
        self.lin2 = Linear(hidden_channels // 2, 1)

        self.register_buffer('initial_atomref', atomref)
        self.atomref = None
        if atomref is not None:
            self.atomref = Embedding(100, 1)
            self.atomref.weight.data.copy_(atomref)

        self.reset_parameters()

    def reset_parameters(self):
        self.embedding.reset_parameters()
        for interaction in self.interactions:
            interaction.reset_parameters()
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        self.lin1.bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)
        if self.atomref is not None:
            self.atomref.weight.data.copy_(self.initial_atomref)


    def forward(self, data):
        z = data.z.long()
        pos = data.pos
        batch = data.batch

        """"""
        assert z.dim() == 1 and z.dtype == torch.long
        batch = torch.zeros_like(z) if batch is None else batch

        h = self.embedding(z)

        edge_index = radius_graph(pos, r=self.cutoff, batch=batch,
                                  max_num_neighbors=self.max_num_neighbors)
        row, col = edge_index
        edge_weight = (pos[row] - pos[col]).norm(dim=-1)
        edge_attr = self.distance_expansion(edge_weight)

        for interaction in self.interactions:
            h = h + interaction(h, edge_index, edge_weight, edge_attr)

        h = self.lin1(h)
        h = self.act(h)
        h = self.lin2(h)

        if self.dipole:
            # Get center of mass.
            mass = self.atomic_mass[z].view(-1, 1)
            c = scatter(mass * pos, batch, dim=0) / scatter(mass, batch, dim=0)
            h = h * (pos - c[batch])

        if not self.dipole and self.mean is not None and self.std is not None:
            h = h * self.std + self.mean

        if not self.dipole and self.atomref is not None:
            h = h + self.atomref(z)

        out = scatter(h, batch, dim=0, reduce=self.readout)

        if self.dipole:
            out = torch.norm(out, dim=-1, keepdim=True)

        if self.scale is not None:
            out = self.scale * out

        return out


    def __repr__(self):
        return (f'{self.__class__.__name__}('
                f'hidden_channels={self.hidden_channels}, '
                f'num_filters={self.num_filters}, '
                f'num_interactions={self.num_interactions}, '
                f'num_gaussians={self.num_gaussians}, '
                f'cutoff={self.cutoff})')



class InteractionBlock(torch.nn.Module):
    def __init__(self, hidden_channels, num_gaussians, num_filters, cutoff):
        super(InteractionBlock, self).__init__()
        self.mlp = Sequential(
            Linear(num_gaussians, num_filters),
            ShiftedSoftplus(),
            Linear(num_filters, num_filters),
        )
        self.conv = CFConv(hidden_channels, hidden_channels, num_filters,
                           self.mlp, cutoff)
        self.act = ShiftedSoftplus()
        self.lin = Linear(hidden_channels, hidden_channels)

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.mlp[0].weight)
        self.mlp[0].bias.data.fill_(0)
        torch.nn.init.xavier_uniform_(self.mlp[2].weight)
        self.mlp[0].bias.data.fill_(0)
        self.conv.reset_parameters()
        torch.nn.init.xavier_uniform_(self.lin.weight)
        self.lin.bias.data.fill_(0)

    def forward(self, x, edge_index, edge_weight, edge_attr):
        x = self.conv(x, edge_index, edge_weight, edge_attr)
        x = self.act(x)
        x = self.lin(x)
        return x


class CFConv(MessagePassing):
    def __init__(self, in_channels, out_channels, num_filters, nn, cutoff):
        super(CFConv, self).__init__(aggr='add')
        self.lin1 = Linear(in_channels, num_filters, bias=False)
        self.lin2 = Linear(num_filters, out_channels)
        self.nn = nn
        self.cutoff = cutoff

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin1.weight)
        torch.nn.init.xavier_uniform_(self.lin2.weight)
        self.lin2.bias.data.fill_(0)

    def forward(self, x, edge_index, edge_weight, edge_attr):
        C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0)
        W = self.nn(edge_attr) * C.view(-1, 1)

        x = self.lin1(x)
        x = self.propagate(edge_index, x=x, W=W)
        x = self.lin2(x)
        return x

    def message(self, x_j, W):
        return x_j * W


class GaussianSmearing(torch.nn.Module):
    def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
        super(GaussianSmearing, self).__init__()
        offset = torch.linspace(start, stop, num_gaussians)
        self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
        self.register_buffer('offset', offset)

    def forward(self, dist):
        dist = dist.view(-1, 1) - self.offset.view(1, -1)
        return torch.exp(self.coeff * torch.pow(dist, 2))


class ShiftedSoftplus(torch.nn.Module):
    def __init__(self):
        super(ShiftedSoftplus, self).__init__()
        self.shift = torch.log(torch.tensor(2.0)).item()

    def forward(self, x):
        return F.softplus(x) - self.shift

In [47]:
targets=pd.read_csv('targets.csv')
label=targets['energy_per_atom']

In [46]:
targets.head()

Unnamed: 0.1,Unnamed: 0,energy,energy_per_atom,fermi_level,homo,lumo,fold
0,0,-349.34764,-7.278076,1.263305,0.253529,2.448582,5
1,1,-349.353704,-7.278202,1.251832,0.229737,2.455558,2
2,2,-349.359991,-7.278333,1.253671,0.232639,2.468721,2
3,3,-349.352917,-7.278186,1.253987,0.232242,2.464204,3
4,4,-349.362675,-7.278389,1.253701,0.232852,2.466078,3


In [16]:
targets['fold'].value_counts()

7    435
5    435
3    435
1    435
6    435
4    435
2    435
0    435
Name: fold, dtype: int64

In [31]:
trains=np.where([targets['fold']!=5])[1]

In [48]:
structures=pd.read_csv('structures.csv')
structures

Unnamed: 0.1,Unnamed: 0,_id
0,0,Full Formula (W16 Se1 S30 O1)\nReduced Formula...
1,1,Full Formula (W16 Se1 S30 O1)\nReduced Formula...
2,2,Full Formula (W16 Se1 S30 O1)\nReduced Formula...
3,3,Full Formula (W16 Se1 S30 O1)\nReduced Formula...
4,4,Full Formula (W16 Se1 S30 O1)\nReduced Formula...
...,...,...
3475,3475,Full Formula (W36 Se70 S1 O1)\nReduced Formula...
3476,3476,Full Formula (W36 Se70 S1 O1)\nReduced Formula...
3477,3477,Full Formula (W36 Se70 S1 O1)\nReduced Formula...
3478,3478,Full Formula (W36 Se70 S1 O1)\nReduced Formula...


In [49]:
atoms_list=np.load('atoms_list.npy', allow_pickle=True)

In [50]:
atoms_list

array([Structure Summary
Lattice
    abc : 12.76292132 12.76292132 14.202402
 angles : 90.0 90.0 119.99999999999999
 volume : 2003.515085416219
      A : 12.76292132 0.0 7.815035371153756e-16
      B : -6.381460659999997 11.053014089622023 7.815035371153756e-16
      C : 0.0 0.0 14.202402
PeriodicSite: W (-0.0000, 1.8422, 3.5506) [0.0833, 0.1667, 0.2500]
PeriodicSite: W (-1.5954, 4.6054, 3.5506) [0.0833, 0.4167, 0.2500]
PeriodicSite: W (-3.1907, 7.3687, 3.5506) [0.0833, 0.6667, 0.2500]
PeriodicSite: W (-4.7861, 10.1319, 3.5506) [0.0833, 0.9167, 0.2500]
PeriodicSite: W (3.1907, 1.8422, 3.5506) [0.3333, 0.1667, 0.2500]
PeriodicSite: W (1.5954, 4.6054, 3.5506) [0.3333, 0.4167, 0.2500]
PeriodicSite: W (0.0000, 7.3687, 3.5506) [0.3333, 0.6667, 0.2500]
PeriodicSite: W (-1.5954, 10.1319, 3.5506) [0.3333, 0.9167, 0.2500]
PeriodicSite: W (6.3815, 1.8422, 3.5506) [0.5833, 0.1667, 0.2500]
PeriodicSite: W (4.7861, 4.6054, 3.5506) [0.5833, 0.4167, 0.2500]
PeriodicSite: W (3.1907, 7.3687, 3.5506) [0

In [51]:
from torch_geometric.data import Data
import torch
import ase
from pymatgen.io.ase import AseAtomsAdaptor

i=0
data_atoms = []
for atoms in tqdm(atoms_list):
    atoms=AseAtomsAdaptor.get_atoms(atoms)
    # set the atomic numbers, positions, and cell
    atom = torch.Tensor(atoms.get_atomic_numbers())
    positions = torch.Tensor(atoms.get_positions())
    natoms = positions.shape[0]
    
    # put the minimum data in torch geometric data object
    data = Data(
        pos=positions,
        z= atom,
       # natoms=natoms,
    )
    
    # calculate energy
    data.y = label[i]
    i=i+1
    data_atoms.append(data)

100%|██████████| 3480/3480 [00:02<00:00, 1299.88it/s]


In [9]:
from torch_geometric.data import DataLoader
train_dataset, test_dataset = train_test_split(data_atoms, test_size=0.2)
train_loader = DataLoader(train_dataset, batch_size=4)
test_loader = DataLoader(test_dataset, batch_size=4)

In [10]:
model=SchNet()
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
#device='cpu'
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 10
#
loss_func = torch.nn.L1Loss() #define loss
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=256)

In [None]:
tr_loss = []
ts_loss=[]#set model to training mode
for epoch in range(200):
    model.train() 
    valid_loss = 0

    train_loss=0

     #shuffle the training data each epoch
    for d in tqdm(train_loader): #go over each training point
        data = d.to(device)#send data to device
       
        out = model(data) 
        optimizer.zero_grad() #zero gradients
        #evaluate data point
        loss = loss_func(out, data.y) #L1 error loss
         #add loss value to aggregate loss
        loss.backward() #compute gradients
        optimizer.step() #apply optimization
        train_loss += loss.item()
    
    model.eval()     # Optional when not using Model Specific layer
    for d in tqdm(test_loader):
        data = d.to(device)
        target = model(data)
        loss = loss_func(target, data.y)
        valid_loss += loss.item()
        
    print('Epoch: {:03d}, Average loss: {:.5f}'.format(epoch, train_loss/len(train_loader)))
    tr_loss.append(train_loss/len(train_loader))
    print('Epoch: {:03d}, Average loss: {:.5f}'.format(epoch, valid_loss/len(test_loader)))
    ts_loss.append(valid_loss/len(test_loader))
    
    


Using a target size (torch.Size([4])) that is different to the input size (torch.Size([4, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.

100%|██████████| 696/696 [00:24<00:00, 28.61it/s]
100%|██████████| 174/174 [00:03<00:00, 50.49it/s]


Epoch: 000, Average loss: 134.42364
Epoch: 000, Average loss: 119.33974


100%|██████████| 696/696 [00:24<00:00, 28.93it/s]
100%|██████████| 174/174 [00:03<00:00, 45.87it/s]


Epoch: 001, Average loss: 127.82084
Epoch: 001, Average loss: 120.88367


100%|██████████| 696/696 [00:23<00:00, 29.93it/s]
100%|██████████| 174/174 [00:03<00:00, 47.35it/s]


Epoch: 002, Average loss: 125.83390
Epoch: 002, Average loss: 124.45187


100%|██████████| 696/696 [00:21<00:00, 32.52it/s]
100%|██████████| 174/174 [00:03<00:00, 51.09it/s]


Epoch: 003, Average loss: 123.77319
Epoch: 003, Average loss: 117.66858


100%|██████████| 696/696 [00:20<00:00, 33.15it/s]
100%|██████████| 174/174 [00:03<00:00, 49.40it/s]


Epoch: 004, Average loss: 123.38754
Epoch: 004, Average loss: 119.50837


100%|██████████| 696/696 [00:20<00:00, 34.15it/s]
100%|██████████| 174/174 [00:03<00:00, 57.12it/s]


Epoch: 005, Average loss: 122.62859
Epoch: 005, Average loss: 117.82337


100%|██████████| 696/696 [00:20<00:00, 33.23it/s]
100%|██████████| 174/174 [00:03<00:00, 49.00it/s]


Epoch: 006, Average loss: 122.55601
Epoch: 006, Average loss: 119.03779


100%|██████████| 696/696 [00:20<00:00, 33.15it/s]
100%|██████████| 174/174 [00:03<00:00, 55.88it/s]


Epoch: 007, Average loss: 122.01835
Epoch: 007, Average loss: 118.11694


100%|██████████| 696/696 [00:20<00:00, 33.37it/s]
100%|██████████| 174/174 [00:03<00:00, 50.60it/s]


Epoch: 008, Average loss: 122.09188
Epoch: 008, Average loss: 118.18667


100%|██████████| 696/696 [00:20<00:00, 33.36it/s]
100%|██████████| 174/174 [00:03<00:00, 48.11it/s]


Epoch: 009, Average loss: 121.83959
Epoch: 009, Average loss: 118.57265


100%|██████████| 696/696 [00:21<00:00, 32.84it/s]
100%|██████████| 174/174 [00:03<00:00, 49.82it/s]


Epoch: 010, Average loss: 121.87870
Epoch: 010, Average loss: 116.21666


100%|██████████| 696/696 [00:20<00:00, 34.37it/s]
100%|██████████| 174/174 [00:03<00:00, 53.99it/s]


Epoch: 011, Average loss: 121.56763
Epoch: 011, Average loss: 116.19798


100%|██████████| 696/696 [00:20<00:00, 33.19it/s]
100%|██████████| 174/174 [00:03<00:00, 55.29it/s]


Epoch: 012, Average loss: 121.50928
Epoch: 012, Average loss: 117.26267


100%|██████████| 696/696 [00:20<00:00, 33.15it/s]
100%|██████████| 174/174 [00:03<00:00, 49.78it/s]


Epoch: 013, Average loss: 121.34610
Epoch: 013, Average loss: 118.11378


100%|██████████| 696/696 [00:20<00:00, 33.18it/s]
100%|██████████| 174/174 [00:03<00:00, 52.74it/s]


Epoch: 014, Average loss: 121.20595
Epoch: 014, Average loss: 116.77405


100%|██████████| 696/696 [00:20<00:00, 33.57it/s]
100%|██████████| 174/174 [00:03<00:00, 47.64it/s]


Epoch: 015, Average loss: 121.18663
Epoch: 015, Average loss: 116.51326


100%|██████████| 696/696 [00:21<00:00, 33.01it/s]
100%|██████████| 174/174 [00:03<00:00, 50.25it/s]


Epoch: 016, Average loss: 121.03969
Epoch: 016, Average loss: 116.69373


100%|██████████| 696/696 [00:20<00:00, 33.18it/s]
100%|██████████| 174/174 [00:03<00:00, 53.65it/s]


Epoch: 017, Average loss: 121.06359
Epoch: 017, Average loss: 116.16187


100%|██████████| 696/696 [00:20<00:00, 34.09it/s]
100%|██████████| 174/174 [00:03<00:00, 48.93it/s]


Epoch: 018, Average loss: 120.91281
Epoch: 018, Average loss: 116.28699


100%|██████████| 696/696 [00:20<00:00, 33.28it/s]
100%|██████████| 174/174 [00:03<00:00, 56.20it/s]


Epoch: 019, Average loss: 120.90908
Epoch: 019, Average loss: 116.02501


100%|██████████| 696/696 [00:21<00:00, 33.06it/s]
100%|██████████| 174/174 [00:03<00:00, 48.90it/s]


Epoch: 020, Average loss: 120.80400
Epoch: 020, Average loss: 116.07510


100%|██████████| 696/696 [00:21<00:00, 33.03it/s]
100%|██████████| 174/174 [00:03<00:00, 50.69it/s]


Epoch: 021, Average loss: 120.62498
Epoch: 021, Average loss: 116.61733


100%|██████████| 696/696 [00:21<00:00, 33.02it/s]
100%|██████████| 174/174 [00:03<00:00, 48.47it/s]


Epoch: 022, Average loss: 120.67766
Epoch: 022, Average loss: 115.94030


100%|██████████| 696/696 [00:20<00:00, 33.40it/s]
100%|██████████| 174/174 [00:03<00:00, 47.22it/s]


Epoch: 023, Average loss: 121.04433
Epoch: 023, Average loss: 115.83962


100%|██████████| 696/696 [00:20<00:00, 33.36it/s]
100%|██████████| 174/174 [00:03<00:00, 49.44it/s]


Epoch: 024, Average loss: 120.55706
Epoch: 024, Average loss: 115.98043


100%|██████████| 696/696 [00:20<00:00, 33.97it/s]
100%|██████████| 174/174 [00:03<00:00, 52.68it/s]


Epoch: 025, Average loss: 120.49396
Epoch: 025, Average loss: 115.86176


100%|██████████| 696/696 [00:21<00:00, 32.90it/s]
100%|██████████| 174/174 [00:03<00:00, 55.57it/s]


Epoch: 026, Average loss: 120.48897
Epoch: 026, Average loss: 116.36128


100%|██████████| 696/696 [00:21<00:00, 33.04it/s]
100%|██████████| 174/174 [00:03<00:00, 46.60it/s]


Epoch: 027, Average loss: 120.51984
Epoch: 027, Average loss: 116.40148


100%|██████████| 696/696 [00:21<00:00, 32.96it/s]
100%|██████████| 174/174 [00:03<00:00, 51.00it/s]


Epoch: 028, Average loss: 127.13017
Epoch: 028, Average loss: 129.06945


100%|██████████| 696/696 [00:21<00:00, 33.11it/s]
100%|██████████| 174/174 [00:03<00:00, 47.48it/s]


Epoch: 029, Average loss: 132.64835
Epoch: 029, Average loss: 129.10023


100%|██████████| 696/696 [00:21<00:00, 32.75it/s]
100%|██████████| 174/174 [00:03<00:00, 48.03it/s]


Epoch: 030, Average loss: 132.66454
Epoch: 030, Average loss: 128.93214


100%|██████████| 696/696 [00:16<00:00, 41.25it/s]
100%|██████████| 174/174 [00:02<00:00, 58.29it/s]


Epoch: 031, Average loss: 132.58110
Epoch: 031, Average loss: 128.92663


100%|██████████| 696/696 [00:19<00:00, 36.25it/s]
100%|██████████| 174/174 [00:01<00:00, 124.33it/s]


Epoch: 032, Average loss: 132.64319
Epoch: 032, Average loss: 129.00436


100%|██████████| 696/696 [00:14<00:00, 47.69it/s]
100%|██████████| 174/174 [00:02<00:00, 86.43it/s] 


Epoch: 033, Average loss: 132.63599
Epoch: 033, Average loss: 128.96066


100%|██████████| 696/696 [00:14<00:00, 47.30it/s]
100%|██████████| 174/174 [00:02<00:00, 86.41it/s] 


Epoch: 034, Average loss: 132.59458
Epoch: 034, Average loss: 129.01102


100%|██████████| 696/696 [00:17<00:00, 40.11it/s]
100%|██████████| 174/174 [00:03<00:00, 51.51it/s]


Epoch: 035, Average loss: 132.59382
Epoch: 035, Average loss: 129.01058


100%|██████████| 696/696 [00:20<00:00, 33.49it/s]
100%|██████████| 174/174 [00:03<00:00, 54.93it/s]


Epoch: 036, Average loss: 132.59440
Epoch: 036, Average loss: 129.01036


100%|██████████| 696/696 [00:20<00:00, 33.40it/s]
100%|██████████| 174/174 [00:03<00:00, 49.67it/s]


Epoch: 037, Average loss: 132.59470
Epoch: 037, Average loss: 129.01013


100%|██████████| 696/696 [00:20<00:00, 33.61it/s]
100%|██████████| 174/174 [00:03<00:00, 54.21it/s]


Epoch: 038, Average loss: 132.59488
Epoch: 038, Average loss: 129.00992


100%|██████████| 696/696 [00:20<00:00, 33.62it/s]
100%|██████████| 174/174 [00:03<00:00, 48.35it/s]


Epoch: 039, Average loss: 132.59497
Epoch: 039, Average loss: 129.00972


 53%|█████▎    | 368/696 [00:11<00:09, 33.03it/s]

In [12]:
print(model)

SchNet(hidden_channels=128, num_filters=128, num_interactions=4, num_gaussians=50, cutoff=10.0)


# Cross-val

In [54]:
from torch_geometric.data import DataLoader
maes=[]
for a in range(7):
    trains=np.where([targets['fold']!=a])[1]
    tests=np.where([targets['fold']==a])[1]
    train_dataset=[data_atoms[i] for i in trains]
    test_dataset=[data_atoms[i] for i in tests]
    train_loader = DataLoader(train_dataset, batch_size=4)
    test_loader = DataLoader(test_dataset, batch_size=4)
    model=SchNet()
    device = 'cuda:2' if torch.cuda.is_available() else 'cpu'
    #device='cpu'
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    epochs = 10
    #
    loss_func = torch.nn.L1Loss() #define loss
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=256)
    tr_loss = []
    ts_loss=[]#set model to training mode
    for epoch in range(100):
        model.train() 
        valid_loss = 0

        train_loss=0

         #shuffle the training data each epoch
        for d in tqdm(train_loader): #go over each training point
            data = d.to(device)#send data to device

            out = model(data) 
            optimizer.zero_grad() #zero gradients
            #evaluate data point
            loss = loss_func(out, data.y) #L1 error loss
             #add loss value to aggregate loss
            loss.backward() #compute gradients
            optimizer.step() #apply optimization
            train_loss += loss.item()

        model.eval()     # Optional when not using Model Specific layer
        for d in tqdm(test_loader):
            data = d.to(device)
            target = model(data)
            loss = loss_func(target, data.y)
            valid_loss += loss.item()

        print('Epoch: {:03d}, Average loss: {:.5f}'.format(epoch, train_loss/len(train_loader)))
        tr_loss.append(train_loss/len(train_loader))
        
        print('Epoch: {:03d}, Average loss: {:.5f}'.format(epoch, valid_loss/len(test_loader)))
        ts_loss.append(valid_loss/len(test_loader))
    maes.append(valid_loss/len(test_loader))

    


Using a target size (torch.Size([4])) that is different to the input size (torch.Size([4, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.


Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.

100%|██████████| 762/762 [00:30<00:00, 24.65it/s]

Using a target size (torch.Size([3])) that is different to the input size (torch.Size([3, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.

100%|██████████| 109/109 [00:01<00:00, 59.84it/s]


Epoch: 000, Average loss: 0.41070
Epoch: 000, Average loss: 0.75743


100%|██████████| 762/762 [00:30<00:00, 24.84it/s]
100%|██████████| 109/109 [00:01<00:00, 54.84it/s]


Epoch: 001, Average loss: 0.13164
Epoch: 001, Average loss: 0.33875


100%|██████████| 762/762 [00:30<00:00, 24.68it/s]
100%|██████████| 109/109 [00:01<00:00, 55.27it/s]


Epoch: 002, Average loss: 0.11214
Epoch: 002, Average loss: 0.13221


100%|██████████| 762/762 [00:31<00:00, 24.49it/s]
100%|██████████| 109/109 [00:01<00:00, 56.32it/s]


Epoch: 003, Average loss: 0.07704
Epoch: 003, Average loss: 0.24454


100%|██████████| 762/762 [00:30<00:00, 24.83it/s]
100%|██████████| 109/109 [00:01<00:00, 54.85it/s]


Epoch: 004, Average loss: 0.07801
Epoch: 004, Average loss: 0.16103


100%|██████████| 762/762 [00:30<00:00, 24.72it/s]
100%|██████████| 109/109 [00:01<00:00, 55.12it/s]


Epoch: 005, Average loss: 0.07259
Epoch: 005, Average loss: 0.17880


100%|██████████| 762/762 [00:29<00:00, 25.81it/s]
100%|██████████| 109/109 [00:01<00:00, 76.16it/s]


Epoch: 006, Average loss: 0.06672
Epoch: 006, Average loss: 0.17376


100%|██████████| 762/762 [00:25<00:00, 29.56it/s]
100%|██████████| 109/109 [00:01<00:00, 76.75it/s]


Epoch: 007, Average loss: 0.05521
Epoch: 007, Average loss: 0.27626


100%|██████████| 762/762 [00:30<00:00, 25.08it/s]
100%|██████████| 109/109 [00:02<00:00, 51.99it/s]


Epoch: 008, Average loss: 0.05231
Epoch: 008, Average loss: 0.09857


100%|██████████| 762/762 [00:30<00:00, 25.20it/s]
100%|██████████| 109/109 [00:01<00:00, 56.11it/s]


Epoch: 012, Average loss: 0.04263
Epoch: 012, Average loss: 0.33979


100%|██████████| 762/762 [00:30<00:00, 25.38it/s]
100%|██████████| 109/109 [00:01<00:00, 57.65it/s]


Epoch: 013, Average loss: 0.04369
Epoch: 013, Average loss: 0.26611


100%|██████████| 762/762 [00:29<00:00, 25.45it/s]
100%|██████████| 109/109 [00:01<00:00, 58.76it/s]


Epoch: 014, Average loss: 0.04227
Epoch: 014, Average loss: 0.19421


100%|██████████| 762/762 [00:29<00:00, 25.48it/s]
100%|██████████| 109/109 [00:01<00:00, 56.26it/s]


Epoch: 015, Average loss: 0.03895
Epoch: 015, Average loss: 0.09937


100%|██████████| 762/762 [00:29<00:00, 25.47it/s]
100%|██████████| 109/109 [00:01<00:00, 55.56it/s]


Epoch: 016, Average loss: 0.03919
Epoch: 016, Average loss: 0.13360


100%|██████████| 762/762 [00:30<00:00, 25.38it/s]
100%|██████████| 109/109 [00:01<00:00, 54.83it/s]


Epoch: 017, Average loss: 0.03538
Epoch: 017, Average loss: 0.06488


100%|██████████| 762/762 [00:30<00:00, 25.37it/s]
100%|██████████| 109/109 [00:01<00:00, 55.07it/s]


Epoch: 018, Average loss: 0.03847
Epoch: 018, Average loss: 0.23016


100%|██████████| 762/762 [00:30<00:00, 24.82it/s]
100%|██████████| 109/109 [00:01<00:00, 55.34it/s]


Epoch: 021, Average loss: 0.03697
Epoch: 021, Average loss: 0.09610


100%|██████████| 762/762 [00:30<00:00, 24.98it/s]
100%|██████████| 109/109 [00:02<00:00, 54.28it/s]


Epoch: 022, Average loss: 0.03468
Epoch: 022, Average loss: 0.12203


100%|██████████| 762/762 [00:30<00:00, 24.87it/s]
100%|██████████| 109/109 [00:01<00:00, 55.52it/s]


Epoch: 023, Average loss: 0.03298
Epoch: 023, Average loss: 0.10802


100%|██████████| 762/762 [00:30<00:00, 24.63it/s]
100%|██████████| 109/109 [00:01<00:00, 59.24it/s]


Epoch: 024, Average loss: 0.03325
Epoch: 024, Average loss: 0.13699


100%|██████████| 762/762 [00:30<00:00, 24.91it/s]
100%|██████████| 109/109 [00:01<00:00, 61.27it/s]


Epoch: 025, Average loss: 0.03319
Epoch: 025, Average loss: 0.09834


100%|██████████| 762/762 [00:30<00:00, 24.85it/s]
100%|██████████| 109/109 [00:01<00:00, 56.29it/s]


Epoch: 026, Average loss: 0.03278
Epoch: 026, Average loss: 0.12171


100%|██████████| 762/762 [00:29<00:00, 26.25it/s]
100%|██████████| 109/109 [00:01<00:00, 73.20it/s]


Epoch: 027, Average loss: 0.03244
Epoch: 027, Average loss: 0.15437


100%|██████████| 762/762 [00:25<00:00, 29.97it/s]
100%|██████████| 109/109 [00:01<00:00, 77.83it/s]


Epoch: 028, Average loss: 0.03180
Epoch: 028, Average loss: 0.12911


100%|██████████| 762/762 [00:28<00:00, 27.07it/s]
100%|██████████| 109/109 [00:01<00:00, 55.17it/s]


Epoch: 029, Average loss: 527.10044
Epoch: 029, Average loss: 0.30803


100%|██████████| 762/762 [00:30<00:00, 24.97it/s]
100%|██████████| 109/109 [00:01<00:00, 54.82it/s]


Epoch: 030, Average loss: 0.17711
Epoch: 030, Average loss: 0.28400


100%|██████████| 762/762 [00:29<00:00, 25.76it/s]
100%|██████████| 109/109 [00:01<00:00, 58.78it/s]


Epoch: 031, Average loss: 0.11022
Epoch: 031, Average loss: 0.20743


100%|██████████| 762/762 [00:29<00:00, 25.79it/s]
100%|██████████| 109/109 [00:01<00:00, 55.99it/s]


Epoch: 032, Average loss: 0.08039
Epoch: 032, Average loss: 0.14100


100%|██████████| 762/762 [00:29<00:00, 25.81it/s]
100%|██████████| 109/109 [00:01<00:00, 56.26it/s]


Epoch: 033, Average loss: 0.07084
Epoch: 033, Average loss: 0.19696


 48%|████▊     | 364/762 [00:14<00:16, 23.48it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:29<00:00, 25.92it/s]
100%|██████████| 109/109 [00:01<00:00, 55.55it/s]


Epoch: 038, Average loss: 0.06653
Epoch: 038, Average loss: 0.17569


100%|██████████| 762/762 [00:29<00:00, 25.74it/s]
100%|██████████| 109/109 [00:01<00:00, 55.00it/s]


Epoch: 039, Average loss: 1889.41761
Epoch: 039, Average loss: 0.30646


100%|██████████| 762/762 [00:29<00:00, 25.80it/s]
100%|██████████| 109/109 [00:01<00:00, 54.89it/s]


Epoch: 040, Average loss: 0.14714
Epoch: 040, Average loss: 0.30650


100%|██████████| 762/762 [00:29<00:00, 25.67it/s]
100%|██████████| 109/109 [00:01<00:00, 56.94it/s]


Epoch: 041, Average loss: 0.14467
Epoch: 041, Average loss: 0.30655


100%|██████████| 762/762 [00:30<00:00, 24.76it/s]
100%|██████████| 109/109 [00:02<00:00, 48.21it/s]


Epoch: 042, Average loss: 0.13954
Epoch: 042, Average loss: 0.30648


100%|██████████| 762/762 [00:31<00:00, 23.83it/s]
100%|██████████| 109/109 [00:02<00:00, 46.16it/s]


Epoch: 043, Average loss: 0.13906
Epoch: 043, Average loss: 0.30752


100%|██████████| 762/762 [00:31<00:00, 24.13it/s]
100%|██████████| 109/109 [00:02<00:00, 44.51it/s]


Epoch: 044, Average loss: 0.13660
Epoch: 044, Average loss: 0.31072


100%|██████████| 762/762 [00:31<00:00, 24.11it/s]
100%|██████████| 109/109 [00:02<00:00, 42.37it/s]


Epoch: 045, Average loss: 0.13276
Epoch: 045, Average loss: 0.31497


100%|██████████| 762/762 [00:31<00:00, 24.03it/s]
100%|██████████| 109/109 [00:02<00:00, 47.90it/s]


Epoch: 046, Average loss: 0.12566
Epoch: 046, Average loss: 0.32819


100%|██████████| 762/762 [00:31<00:00, 24.17it/s]
100%|██████████| 109/109 [00:02<00:00, 49.65it/s]


Epoch: 047, Average loss: 0.12655
Epoch: 047, Average loss: 0.32652


100%|██████████| 762/762 [00:32<00:00, 23.73it/s]
100%|██████████| 109/109 [00:02<00:00, 48.96it/s]


Epoch: 048, Average loss: 0.13031
Epoch: 048, Average loss: 0.32627


100%|██████████| 762/762 [00:26<00:00, 29.10it/s]
100%|██████████| 109/109 [00:01<00:00, 68.29it/s]


Epoch: 049, Average loss: 0.13127
Epoch: 049, Average loss: 0.32749


100%|██████████| 762/762 [00:28<00:00, 26.59it/s]
100%|██████████| 109/109 [00:02<00:00, 48.30it/s]


Epoch: 050, Average loss: 0.12654
Epoch: 050, Average loss: 0.32930


100%|██████████| 762/762 [00:31<00:00, 24.14it/s]
100%|██████████| 109/109 [00:02<00:00, 46.89it/s]


Epoch: 051, Average loss: 0.12841
Epoch: 051, Average loss: 0.30765


100%|██████████| 762/762 [00:31<00:00, 23.97it/s]
100%|██████████| 109/109 [00:01<00:00, 56.42it/s]


Epoch: 052, Average loss: 0.12607
Epoch: 052, Average loss: 0.30861


100%|██████████| 762/762 [00:31<00:00, 24.28it/s]
100%|██████████| 109/109 [00:02<00:00, 49.78it/s]


Epoch: 053, Average loss: 0.12565
Epoch: 053, Average loss: 0.30691


100%|██████████| 762/762 [00:30<00:00, 24.69it/s]
100%|██████████| 109/109 [00:02<00:00, 48.61it/s]


Epoch: 054, Average loss: 0.12311
Epoch: 054, Average loss: 0.30665


100%|██████████| 762/762 [00:31<00:00, 24.42it/s]
100%|██████████| 109/109 [00:02<00:00, 46.47it/s]


Epoch: 055, Average loss: 0.12392
Epoch: 055, Average loss: 0.30817


100%|██████████| 762/762 [00:32<00:00, 23.23it/s]
100%|██████████| 109/109 [00:01<00:00, 56.89it/s]


Epoch: 056, Average loss: 0.12358
Epoch: 056, Average loss: 0.30687


100%|██████████| 762/762 [00:31<00:00, 24.25it/s]
100%|██████████| 109/109 [00:02<00:00, 47.06it/s]


Epoch: 057, Average loss: 0.12037
Epoch: 057, Average loss: 0.30650


100%|██████████| 762/762 [00:32<00:00, 23.63it/s]
100%|██████████| 109/109 [00:02<00:00, 48.50it/s]


Epoch: 058, Average loss: 0.11983
Epoch: 058, Average loss: 0.30653


100%|██████████| 762/762 [00:32<00:00, 23.62it/s]
100%|██████████| 109/109 [00:02<00:00, 46.37it/s]


Epoch: 059, Average loss: 0.11550
Epoch: 059, Average loss: 0.30687


100%|██████████| 762/762 [00:31<00:00, 23.97it/s]
100%|██████████| 109/109 [00:02<00:00, 48.62it/s]


Epoch: 060, Average loss: 0.11207
Epoch: 060, Average loss: 0.30723


100%|██████████| 762/762 [00:34<00:00, 22.22it/s]
100%|██████████| 109/109 [00:02<00:00, 48.96it/s]


Epoch: 061, Average loss: 0.11179
Epoch: 061, Average loss: 0.30658


100%|██████████| 762/762 [00:32<00:00, 23.73it/s]
100%|██████████| 109/109 [00:02<00:00, 49.54it/s]


Epoch: 062, Average loss: 0.11538
Epoch: 062, Average loss: 0.30867


100%|██████████| 762/762 [00:33<00:00, 22.81it/s]
100%|██████████| 109/109 [00:02<00:00, 47.32it/s]


Epoch: 063, Average loss: 0.11238
Epoch: 063, Average loss: 0.32580


100%|██████████| 762/762 [00:31<00:00, 24.11it/s]
100%|██████████| 109/109 [00:02<00:00, 42.54it/s]


Epoch: 064, Average loss: 0.11394
Epoch: 064, Average loss: 0.30743


100%|██████████| 762/762 [00:33<00:00, 22.79it/s]
100%|██████████| 109/109 [00:02<00:00, 44.65it/s]


Epoch: 065, Average loss: 0.11322
Epoch: 065, Average loss: 0.31148


100%|██████████| 762/762 [00:32<00:00, 23.72it/s]
100%|██████████| 109/109 [00:02<00:00, 47.38it/s]


Epoch: 066, Average loss: 0.11315
Epoch: 066, Average loss: 0.30655


100%|██████████| 762/762 [00:33<00:00, 22.42it/s]
100%|██████████| 109/109 [00:02<00:00, 49.58it/s]


Epoch: 067, Average loss: 0.11444
Epoch: 067, Average loss: 0.30720


100%|██████████| 762/762 [00:31<00:00, 24.27it/s]
100%|██████████| 109/109 [00:02<00:00, 48.89it/s]


Epoch: 068, Average loss: 0.11194
Epoch: 068, Average loss: 0.31141


100%|██████████| 762/762 [00:31<00:00, 24.04it/s]
100%|██████████| 109/109 [00:02<00:00, 48.53it/s]


Epoch: 069, Average loss: 0.10827
Epoch: 069, Average loss: 0.31438


100%|██████████| 762/762 [00:31<00:00, 23.98it/s]
100%|██████████| 109/109 [00:02<00:00, 42.89it/s]


Epoch: 070, Average loss: 0.11036
Epoch: 070, Average loss: 0.31004


100%|██████████| 762/762 [00:27<00:00, 27.49it/s]
100%|██████████| 109/109 [00:02<00:00, 49.15it/s]


Epoch: 071, Average loss: 0.10618
Epoch: 071, Average loss: 0.32447


100%|██████████| 762/762 [00:34<00:00, 22.16it/s]
100%|██████████| 109/109 [00:02<00:00, 44.25it/s]


Epoch: 072, Average loss: 0.11080
Epoch: 072, Average loss: 0.31059


100%|██████████| 762/762 [00:36<00:00, 21.01it/s]
100%|██████████| 109/109 [00:02<00:00, 43.62it/s]


Epoch: 073, Average loss: 0.11080
Epoch: 073, Average loss: 0.31021


100%|██████████| 109/109 [00:02<00:00, 45.82it/s]


Epoch: 076, Average loss: 0.10896
Epoch: 076, Average loss: 0.33713


100%|██████████| 762/762 [00:31<00:00, 24.05it/s]
100%|██████████| 109/109 [00:02<00:00, 43.59it/s]


Epoch: 077, Average loss: 0.11079
Epoch: 077, Average loss: 0.30731


100%|██████████| 762/762 [00:32<00:00, 23.74it/s]
100%|██████████| 109/109 [00:02<00:00, 43.93it/s]


Epoch: 078, Average loss: 0.10993
Epoch: 078, Average loss: 0.33188


100%|██████████| 762/762 [00:31<00:00, 23.84it/s]
100%|██████████| 109/109 [00:02<00:00, 48.78it/s]


Epoch: 079, Average loss: 0.10607
Epoch: 079, Average loss: 0.33138


100%|██████████| 762/762 [00:31<00:00, 24.07it/s]
100%|██████████| 109/109 [00:02<00:00, 44.67it/s]


Epoch: 080, Average loss: 0.10860
Epoch: 080, Average loss: 0.33378


100%|██████████| 762/762 [00:31<00:00, 23.99it/s]
100%|██████████| 109/109 [00:02<00:00, 47.00it/s]


Epoch: 081, Average loss: 0.10600
Epoch: 081, Average loss: 0.31717


100%|██████████| 762/762 [00:31<00:00, 24.08it/s]
100%|██████████| 109/109 [00:02<00:00, 42.78it/s]


Epoch: 082, Average loss: 0.10765
Epoch: 082, Average loss: 0.32481


100%|██████████| 762/762 [00:32<00:00, 23.80it/s]
100%|██████████| 109/109 [00:02<00:00, 49.32it/s]


Epoch: 083, Average loss: 0.10797
Epoch: 083, Average loss: 0.32457


100%|██████████| 762/762 [00:32<00:00, 23.75it/s]
100%|██████████| 109/109 [00:02<00:00, 47.52it/s]


Epoch: 084, Average loss: 0.10814
Epoch: 084, Average loss: 0.31033


100%|██████████| 762/762 [00:31<00:00, 23.85it/s]
100%|██████████| 109/109 [00:02<00:00, 46.89it/s]


Epoch: 085, Average loss: 0.10537
Epoch: 085, Average loss: 0.34790


100%|██████████| 762/762 [00:31<00:00, 23.97it/s]
100%|██████████| 109/109 [00:02<00:00, 44.43it/s]


Epoch: 086, Average loss: 0.11095
Epoch: 086, Average loss: 0.33449


100%|██████████| 762/762 [00:31<00:00, 24.07it/s]
100%|██████████| 109/109 [00:02<00:00, 45.52it/s]


Epoch: 087, Average loss: 0.11173
Epoch: 087, Average loss: 0.33319


100%|██████████| 762/762 [00:32<00:00, 23.74it/s]
100%|██████████| 109/109 [00:02<00:00, 48.20it/s]


Epoch: 088, Average loss: 0.10927
Epoch: 088, Average loss: 0.32081


 22%|██▏       | 167/762 [00:06<00:22, 26.29it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:31<00:00, 23.89it/s]
100%|██████████| 109/109 [00:02<00:00, 45.48it/s]


Epoch: 009, Average loss: 0.03851
Epoch: 009, Average loss: 0.15289


100%|██████████| 762/762 [00:31<00:00, 23.87it/s]
100%|██████████| 109/109 [00:02<00:00, 44.84it/s]


Epoch: 010, Average loss: 0.03817
Epoch: 010, Average loss: 0.11236


100%|██████████| 762/762 [00:31<00:00, 23.88it/s]
100%|██████████| 109/109 [00:02<00:00, 44.17it/s]


Epoch: 011, Average loss: 0.03538
Epoch: 011, Average loss: 0.14328


100%|██████████| 762/762 [00:26<00:00, 28.96it/s]
100%|██████████| 109/109 [00:01<00:00, 65.10it/s]


Epoch: 012, Average loss: 0.03486
Epoch: 012, Average loss: 0.06781


100%|██████████| 762/762 [00:26<00:00, 28.97it/s]
100%|██████████| 109/109 [00:02<00:00, 50.16it/s]


Epoch: 013, Average loss: 0.03424
Epoch: 013, Average loss: 0.09027


100%|██████████| 762/762 [00:31<00:00, 24.33it/s]
100%|██████████| 109/109 [00:02<00:00, 47.69it/s]


Epoch: 014, Average loss: 0.03012
Epoch: 014, Average loss: 0.13319


100%|██████████| 762/762 [00:31<00:00, 23.93it/s]
100%|██████████| 109/109 [00:02<00:00, 47.16it/s]


Epoch: 015, Average loss: 0.03120
Epoch: 015, Average loss: 0.13122


100%|██████████| 762/762 [00:31<00:00, 24.32it/s]
100%|██████████| 109/109 [00:02<00:00, 49.59it/s]


Epoch: 016, Average loss: 0.03202
Epoch: 016, Average loss: 0.08323


100%|██████████| 762/762 [00:31<00:00, 24.12it/s]
100%|██████████| 109/109 [00:02<00:00, 44.52it/s]


Epoch: 017, Average loss: 0.03396
Epoch: 017, Average loss: 0.08244


100%|██████████| 762/762 [00:31<00:00, 24.05it/s]
100%|██████████| 109/109 [00:02<00:00, 47.98it/s]


Epoch: 018, Average loss: 762.16973
Epoch: 018, Average loss: 2.20237


100%|██████████| 762/762 [00:31<00:00, 24.01it/s]
100%|██████████| 109/109 [00:02<00:00, 48.26it/s]


Epoch: 019, Average loss: 0.92607
Epoch: 019, Average loss: 0.29763


100%|██████████| 762/762 [00:31<00:00, 24.18it/s]
100%|██████████| 109/109 [00:02<00:00, 47.79it/s]


Epoch: 020, Average loss: 0.31278
Epoch: 020, Average loss: 0.29701


 65%|██████▍   | 493/762 [00:20<00:12, 20.87it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:31<00:00, 24.42it/s]
100%|██████████| 109/109 [00:02<00:00, 49.46it/s]


Epoch: 041, Average loss: 0.11526
Epoch: 041, Average loss: 0.33044


100%|██████████| 762/762 [00:31<00:00, 24.09it/s]
100%|██████████| 109/109 [00:02<00:00, 48.59it/s]


Epoch: 042, Average loss: 0.11740
Epoch: 042, Average loss: 0.36780


100%|██████████| 762/762 [00:31<00:00, 24.10it/s]
100%|██████████| 109/109 [00:02<00:00, 45.09it/s]


Epoch: 043, Average loss: 0.11076
Epoch: 043, Average loss: 0.36298


100%|██████████| 762/762 [00:31<00:00, 24.05it/s]
100%|██████████| 109/109 [00:02<00:00, 46.50it/s]


Epoch: 044, Average loss: 0.11237
Epoch: 044, Average loss: 0.33309


100%|██████████| 762/762 [00:31<00:00, 23.95it/s]
100%|██████████| 109/109 [00:02<00:00, 47.37it/s]


Epoch: 045, Average loss: 0.11257
Epoch: 045, Average loss: 0.37572


100%|██████████| 762/762 [00:31<00:00, 24.13it/s]
100%|██████████| 109/109 [00:01<00:00, 56.45it/s]


Epoch: 046, Average loss: 0.11377
Epoch: 046, Average loss: 0.33563


100%|██████████| 762/762 [00:31<00:00, 23.94it/s]
100%|██████████| 109/109 [00:02<00:00, 46.78it/s]


Epoch: 047, Average loss: 0.11055
Epoch: 047, Average loss: 0.32778


100%|██████████| 762/762 [00:31<00:00, 23.98it/s]
100%|██████████| 109/109 [00:02<00:00, 48.74it/s]


Epoch: 048, Average loss: 0.11562
Epoch: 048, Average loss: 0.32984


100%|██████████| 762/762 [00:31<00:00, 24.04it/s]
100%|██████████| 109/109 [00:02<00:00, 45.57it/s]


Epoch: 049, Average loss: 0.11291
Epoch: 049, Average loss: 0.37225


100%|██████████| 762/762 [00:31<00:00, 23.99it/s]
100%|██████████| 109/109 [00:02<00:00, 49.99it/s]


Epoch: 050, Average loss: 0.10835
Epoch: 050, Average loss: 0.33905


100%|██████████| 762/762 [00:31<00:00, 24.02it/s]
100%|██████████| 109/109 [00:02<00:00, 49.17it/s]


Epoch: 051, Average loss: 0.11510
Epoch: 051, Average loss: 0.37798


100%|██████████| 762/762 [00:31<00:00, 24.09it/s]
100%|██████████| 109/109 [00:02<00:00, 47.87it/s]


Epoch: 052, Average loss: 0.11306
Epoch: 052, Average loss: 0.36955


 62%|██████▏   | 475/762 [00:19<00:11, 25.56it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:31<00:00, 24.31it/s]
100%|██████████| 109/109 [00:02<00:00, 47.14it/s]


Epoch: 073, Average loss: 0.10735
Epoch: 073, Average loss: 0.32992


100%|██████████| 762/762 [00:31<00:00, 24.16it/s]
100%|██████████| 109/109 [00:02<00:00, 47.76it/s]


Epoch: 074, Average loss: 0.11073
Epoch: 074, Average loss: 0.36629


100%|██████████| 762/762 [00:31<00:00, 24.24it/s]
100%|██████████| 109/109 [00:02<00:00, 49.28it/s]


Epoch: 075, Average loss: 0.11355
Epoch: 075, Average loss: 0.37473


100%|██████████| 762/762 [00:31<00:00, 24.13it/s]
100%|██████████| 109/109 [00:02<00:00, 46.53it/s]


Epoch: 076, Average loss: 0.11477
Epoch: 076, Average loss: 0.33919


100%|██████████| 762/762 [00:26<00:00, 28.44it/s]
100%|██████████| 109/109 [00:01<00:00, 65.25it/s]


Epoch: 077, Average loss: 0.11157
Epoch: 077, Average loss: 0.37336


100%|██████████| 762/762 [00:26<00:00, 28.96it/s]
100%|██████████| 109/109 [00:02<00:00, 45.66it/s]


Epoch: 078, Average loss: 0.11121
Epoch: 078, Average loss: 0.37025


100%|██████████| 762/762 [00:31<00:00, 24.18it/s]
100%|██████████| 109/109 [00:02<00:00, 45.17it/s]


Epoch: 079, Average loss: 0.11091
Epoch: 079, Average loss: 0.36836


100%|██████████| 762/762 [00:31<00:00, 23.94it/s]
100%|██████████| 109/109 [00:02<00:00, 48.64it/s]


Epoch: 080, Average loss: 0.11228
Epoch: 080, Average loss: 0.35628


100%|██████████| 762/762 [00:31<00:00, 24.02it/s]
100%|██████████| 109/109 [00:02<00:00, 43.82it/s]


Epoch: 081, Average loss: 0.10978
Epoch: 081, Average loss: 0.36885


100%|██████████| 762/762 [00:31<00:00, 24.07it/s]
100%|██████████| 109/109 [00:02<00:00, 50.04it/s]


Epoch: 082, Average loss: 0.10905
Epoch: 082, Average loss: 0.34433


100%|██████████| 762/762 [00:31<00:00, 24.18it/s]
100%|██████████| 109/109 [00:02<00:00, 43.12it/s]


Epoch: 083, Average loss: 0.10737
Epoch: 083, Average loss: 0.36677


100%|██████████| 762/762 [00:31<00:00, 24.06it/s]
100%|██████████| 109/109 [00:02<00:00, 49.46it/s]


Epoch: 084, Average loss: 0.10833
Epoch: 084, Average loss: 0.32755


100%|██████████| 762/762 [00:31<00:00, 24.09it/s]
  0%|          | 0/109 [00:00<?, ?it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:31<00:00, 24.05it/s]
100%|██████████| 109/109 [00:02<00:00, 45.90it/s]


Epoch: 005, Average loss: 0.04827
Epoch: 005, Average loss: 0.12258


100%|██████████| 762/762 [00:31<00:00, 24.04it/s]
100%|██████████| 109/109 [00:02<00:00, 44.79it/s]


Epoch: 006, Average loss: 0.04683
Epoch: 006, Average loss: 0.10531


100%|██████████| 762/762 [00:31<00:00, 24.38it/s]
100%|██████████| 109/109 [00:02<00:00, 49.71it/s]


Epoch: 007, Average loss: 0.04164
Epoch: 007, Average loss: 0.18110


100%|██████████| 762/762 [00:31<00:00, 23.98it/s]
100%|██████████| 109/109 [00:02<00:00, 46.81it/s]


Epoch: 008, Average loss: 0.03578
Epoch: 008, Average loss: 0.29228


100%|██████████| 762/762 [00:31<00:00, 24.03it/s]
100%|██████████| 109/109 [00:02<00:00, 47.31it/s]


Epoch: 009, Average loss: 0.03663
Epoch: 009, Average loss: 0.14317


100%|██████████| 762/762 [00:31<00:00, 24.32it/s]
100%|██████████| 109/109 [00:02<00:00, 50.02it/s]


Epoch: 010, Average loss: 0.03355
Epoch: 010, Average loss: 0.08910


100%|██████████| 762/762 [00:31<00:00, 24.07it/s]
100%|██████████| 109/109 [00:02<00:00, 49.19it/s]


Epoch: 011, Average loss: 572.30415
Epoch: 011, Average loss: 5.19365


100%|██████████| 762/762 [00:31<00:00, 23.87it/s]
100%|██████████| 109/109 [00:02<00:00, 49.79it/s]


Epoch: 012, Average loss: 0.51665
Epoch: 012, Average loss: 0.49793


100%|██████████| 762/762 [00:31<00:00, 23.97it/s]
100%|██████████| 109/109 [00:02<00:00, 43.76it/s]


Epoch: 013, Average loss: 0.30034
Epoch: 013, Average loss: 0.34889


100%|██████████| 762/762 [00:31<00:00, 24.04it/s]
100%|██████████| 109/109 [00:02<00:00, 48.47it/s]


Epoch: 014, Average loss: 0.16901
Epoch: 014, Average loss: 0.33404


100%|██████████| 762/762 [00:31<00:00, 23.96it/s]
100%|██████████| 109/109 [00:02<00:00, 48.27it/s]


Epoch: 015, Average loss: 0.12312
Epoch: 015, Average loss: 0.28097


100%|██████████| 762/762 [00:31<00:00, 23.88it/s]
100%|██████████| 109/109 [00:02<00:00, 46.16it/s]


Epoch: 016, Average loss: 0.15766
Epoch: 016, Average loss: 0.34651


 86%|████████▌ | 654/762 [00:27<00:05, 21.55it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:31<00:00, 24.18it/s]
100%|██████████| 109/109 [00:02<00:00, 49.31it/s]


Epoch: 037, Average loss: 0.11049
Epoch: 037, Average loss: 0.30368


100%|██████████| 762/762 [00:31<00:00, 24.11it/s]
100%|██████████| 109/109 [00:02<00:00, 47.42it/s]


Epoch: 038, Average loss: 0.11241
Epoch: 038, Average loss: 0.30368


100%|██████████| 762/762 [00:31<00:00, 23.97it/s]
100%|██████████| 109/109 [00:02<00:00, 42.48it/s]


Epoch: 039, Average loss: 0.11138
Epoch: 039, Average loss: 0.31398


100%|██████████| 762/762 [00:31<00:00, 23.90it/s]
100%|██████████| 109/109 [00:02<00:00, 45.01it/s]


Epoch: 040, Average loss: 0.10687
Epoch: 040, Average loss: 0.30942


100%|██████████| 762/762 [00:26<00:00, 28.58it/s]
100%|██████████| 109/109 [00:01<00:00, 65.80it/s]


Epoch: 041, Average loss: 0.10502
Epoch: 041, Average loss: 0.30368


100%|██████████| 762/762 [00:27<00:00, 28.09it/s]
100%|██████████| 109/109 [00:02<00:00, 49.71it/s]


Epoch: 042, Average loss: 0.11242
Epoch: 042, Average loss: 0.30368


100%|██████████| 762/762 [00:30<00:00, 24.62it/s]
100%|██████████| 109/109 [00:02<00:00, 43.58it/s]


Epoch: 043, Average loss: 0.11045
Epoch: 043, Average loss: 0.31483


100%|██████████| 762/762 [00:31<00:00, 24.00it/s]
100%|██████████| 109/109 [00:02<00:00, 47.20it/s]


Epoch: 044, Average loss: 0.10553
Epoch: 044, Average loss: 0.33399


100%|██████████| 762/762 [00:32<00:00, 23.63it/s]
100%|██████████| 109/109 [00:02<00:00, 46.47it/s]


Epoch: 045, Average loss: 0.10248
Epoch: 045, Average loss: 0.30942


100%|██████████| 762/762 [00:31<00:00, 23.88it/s]
100%|██████████| 109/109 [00:02<00:00, 47.52it/s]


Epoch: 046, Average loss: 0.10984
Epoch: 046, Average loss: 0.30380


100%|██████████| 762/762 [00:32<00:00, 23.80it/s]
100%|██████████| 109/109 [00:02<00:00, 47.60it/s]


Epoch: 047, Average loss: 0.10804
Epoch: 047, Average loss: 0.30380


100%|██████████| 762/762 [00:33<00:00, 23.07it/s]
100%|██████████| 109/109 [00:02<00:00, 37.55it/s]


Epoch: 048, Average loss: 0.11234
Epoch: 048, Average loss: 0.31930


 70%|███████   | 536/762 [00:24<00:09, 24.40it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 109/109 [00:02<00:00, 48.43it/s]


Epoch: 068, Average loss: 0.10566
Epoch: 068, Average loss: 0.30368


100%|██████████| 762/762 [00:32<00:00, 23.70it/s]
100%|██████████| 109/109 [00:02<00:00, 46.81it/s]


Epoch: 069, Average loss: 0.10429
Epoch: 069, Average loss: 0.30621


100%|██████████| 762/762 [00:32<00:00, 23.76it/s]
100%|██████████| 109/109 [00:02<00:00, 47.88it/s]


Epoch: 070, Average loss: 0.10424
Epoch: 070, Average loss: 0.30368


100%|██████████| 762/762 [00:31<00:00, 23.87it/s]
100%|██████████| 109/109 [00:02<00:00, 43.94it/s]


Epoch: 071, Average loss: 0.10823
Epoch: 071, Average loss: 0.31436


100%|██████████| 762/762 [00:31<00:00, 23.93it/s]
100%|██████████| 109/109 [00:02<00:00, 46.30it/s]


Epoch: 072, Average loss: 0.10564
Epoch: 072, Average loss: 0.31276


100%|██████████| 762/762 [00:31<00:00, 23.82it/s]
100%|██████████| 109/109 [00:02<00:00, 46.50it/s]


Epoch: 073, Average loss: 0.10447
Epoch: 073, Average loss: 0.31220


100%|██████████| 762/762 [00:32<00:00, 23.65it/s]
100%|██████████| 109/109 [00:02<00:00, 47.42it/s]


Epoch: 074, Average loss: 0.10331
Epoch: 074, Average loss: 0.30368


100%|██████████| 762/762 [00:32<00:00, 23.77it/s]
100%|██████████| 109/109 [00:02<00:00, 47.87it/s]


Epoch: 075, Average loss: 0.10557
Epoch: 075, Average loss: 0.31890


100%|██████████| 762/762 [00:32<00:00, 23.65it/s]
100%|██████████| 109/109 [00:02<00:00, 46.25it/s]


Epoch: 076, Average loss: 0.10824
Epoch: 076, Average loss: 0.32671


100%|██████████| 762/762 [00:31<00:00, 23.86it/s]
100%|██████████| 109/109 [00:02<00:00, 45.25it/s]


Epoch: 077, Average loss: 0.10806
Epoch: 077, Average loss: 0.31878


100%|██████████| 762/762 [00:31<00:00, 23.91it/s]
100%|██████████| 109/109 [00:02<00:00, 43.55it/s]


Epoch: 078, Average loss: 0.10684
Epoch: 078, Average loss: 0.31987


100%|██████████| 762/762 [00:32<00:00, 23.58it/s]
100%|██████████| 109/109 [00:02<00:00, 49.11it/s]


Epoch: 079, Average loss: 0.10625
Epoch: 079, Average loss: 0.31099


100%|██████████| 762/762 [00:32<00:00, 23.69it/s]
100%|██████████| 109/109 [00:02<00:00, 48.85it/s]


Epoch: 080, Average loss: 0.11482
Epoch: 080, Average loss: 0.30940


 10%|▉         | 76/762 [00:03<00:30, 22.84it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)


Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.

100%|██████████| 762/762 [00:32<00:00, 23.53it/s]

Using a target size (torch.Size([3])) that is different to the input size (torch.Size([3, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.

100%|██████████| 109/109 [00:02<00:00, 48.83it/s]


Epoch: 000, Average loss: 0.35749
Epoch: 000, Average loss: 0.68610


100%|██████████| 762/762 [00:32<00:00, 23.49it/s]
100%|██████████| 109/109 [00:02<00:00, 47.51it/s]


Epoch: 001, Average loss: 0.16099
Epoch: 001, Average loss: 0.30857


100%|██████████| 762/762 [00:32<00:00, 23.52it/s]
100%|██████████| 109/109 [00:02<00:00, 43.71it/s]


Epoch: 002, Average loss: 0.10458
Epoch: 002, Average loss: 0.25907


100%|██████████| 762/762 [00:27<00:00, 27.79it/s]
100%|██████████| 109/109 [00:01<00:00, 68.72it/s]


Epoch: 003, Average loss: 0.11786
Epoch: 003, Average loss: 0.10373


100%|██████████| 762/762 [00:28<00:00, 26.64it/s]
100%|██████████| 109/109 [00:02<00:00, 47.08it/s]


Epoch: 004, Average loss: 0.08672
Epoch: 004, Average loss: 0.26308


100%|██████████| 762/762 [00:32<00:00, 23.11it/s]
100%|██████████| 109/109 [00:02<00:00, 44.16it/s]


Epoch: 005, Average loss: 0.06738
Epoch: 005, Average loss: 0.20685


100%|██████████| 762/762 [00:32<00:00, 23.81it/s]
100%|██████████| 109/109 [00:02<00:00, 48.18it/s]


Epoch: 006, Average loss: 0.06805
Epoch: 006, Average loss: 0.17360


100%|██████████| 762/762 [00:32<00:00, 23.50it/s]
100%|██████████| 109/109 [00:02<00:00, 43.76it/s]


Epoch: 007, Average loss: 0.07780
Epoch: 007, Average loss: 0.13118


100%|██████████| 762/762 [00:32<00:00, 23.71it/s]
100%|██████████| 109/109 [00:02<00:00, 48.19it/s]


Epoch: 008, Average loss: 0.05941
Epoch: 008, Average loss: 0.35054


100%|██████████| 762/762 [00:32<00:00, 23.41it/s]
100%|██████████| 109/109 [00:02<00:00, 48.88it/s]


Epoch: 009, Average loss: 0.05695
Epoch: 009, Average loss: 0.09533


100%|██████████| 762/762 [00:32<00:00, 23.40it/s]
100%|██████████| 109/109 [00:02<00:00, 49.30it/s]


Epoch: 010, Average loss: 0.06506
Epoch: 010, Average loss: 0.12578


100%|██████████| 762/762 [00:32<00:00, 23.61it/s]
100%|██████████| 109/109 [00:02<00:00, 47.05it/s]


Epoch: 011, Average loss: 0.05056
Epoch: 011, Average loss: 0.10020


 92%|█████████▏| 703/762 [00:29<00:02, 25.83it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:31<00:00, 23.85it/s]
100%|██████████| 109/109 [00:02<00:00, 47.43it/s]


Epoch: 032, Average loss: 0.03176
Epoch: 032, Average loss: 0.14064


100%|██████████| 762/762 [00:32<00:00, 23.77it/s]
100%|██████████| 109/109 [00:02<00:00, 48.18it/s]


Epoch: 033, Average loss: 1082.23499
Epoch: 033, Average loss: 0.50125


100%|██████████| 762/762 [00:31<00:00, 23.94it/s]
100%|██████████| 109/109 [00:02<00:00, 48.11it/s]


Epoch: 034, Average loss: 0.20582
Epoch: 034, Average loss: 0.69599


100%|██████████| 762/762 [00:32<00:00, 23.61it/s]
100%|██████████| 109/109 [00:02<00:00, 49.18it/s]


Epoch: 035, Average loss: 0.12854
Epoch: 035, Average loss: 0.15496


100%|██████████| 762/762 [00:32<00:00, 23.44it/s]
100%|██████████| 109/109 [00:02<00:00, 48.79it/s]


Epoch: 036, Average loss: 0.06809
Epoch: 036, Average loss: 0.25144


100%|██████████| 762/762 [00:32<00:00, 23.81it/s]
100%|██████████| 109/109 [00:02<00:00, 43.14it/s]


Epoch: 037, Average loss: 0.06801
Epoch: 037, Average loss: 0.11350


100%|██████████| 762/762 [00:31<00:00, 23.83it/s]
100%|██████████| 109/109 [00:02<00:00, 49.32it/s]


Epoch: 038, Average loss: 0.05411
Epoch: 038, Average loss: 0.18417


100%|██████████| 762/762 [00:32<00:00, 23.77it/s]
100%|██████████| 109/109 [00:02<00:00, 47.25it/s]


Epoch: 039, Average loss: 0.04599
Epoch: 039, Average loss: 0.09109


100%|██████████| 762/762 [00:32<00:00, 23.50it/s]
100%|██████████| 109/109 [00:02<00:00, 48.62it/s]


Epoch: 040, Average loss: 0.06294
Epoch: 040, Average loss: 0.13198


100%|██████████| 762/762 [00:32<00:00, 23.56it/s]
100%|██████████| 109/109 [00:02<00:00, 44.51it/s]


Epoch: 041, Average loss: 0.07062
Epoch: 041, Average loss: 0.30993


100%|██████████| 762/762 [00:32<00:00, 23.61it/s]
100%|██████████| 109/109 [00:02<00:00, 48.01it/s]


Epoch: 042, Average loss: 0.04870
Epoch: 042, Average loss: 0.15540


100%|██████████| 762/762 [00:32<00:00, 23.55it/s]
100%|██████████| 109/109 [00:02<00:00, 46.48it/s]


Epoch: 043, Average loss: 0.03920
Epoch: 043, Average loss: 0.07530


 49%|████▉     | 377/762 [00:16<00:17, 21.65it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.70it/s]
100%|██████████| 109/109 [00:02<00:00, 45.09it/s]


Epoch: 063, Average loss: 0.11568
Epoch: 063, Average loss: 0.29158


100%|██████████| 762/762 [00:31<00:00, 23.83it/s]
100%|██████████| 109/109 [00:02<00:00, 48.99it/s]


Epoch: 064, Average loss: 0.11768
Epoch: 064, Average loss: 0.29552


100%|██████████| 762/762 [00:32<00:00, 23.65it/s]
100%|██████████| 109/109 [00:02<00:00, 48.27it/s]


Epoch: 065, Average loss: 0.11640
Epoch: 065, Average loss: 0.29541


100%|██████████| 762/762 [00:27<00:00, 27.56it/s]
100%|██████████| 109/109 [00:01<00:00, 62.60it/s]


Epoch: 066, Average loss: 0.11102
Epoch: 066, Average loss: 0.29319


100%|██████████| 762/762 [00:29<00:00, 26.14it/s]
100%|██████████| 109/109 [00:02<00:00, 44.09it/s]


Epoch: 067, Average loss: 0.11319
Epoch: 067, Average loss: 0.29493


100%|██████████| 762/762 [00:32<00:00, 23.09it/s]
100%|██████████| 109/109 [00:02<00:00, 48.36it/s]


Epoch: 068, Average loss: 0.10844
Epoch: 068, Average loss: 0.29197


100%|██████████| 762/762 [00:32<00:00, 23.34it/s]
100%|██████████| 109/109 [00:02<00:00, 43.22it/s]


Epoch: 069, Average loss: 0.11047
Epoch: 069, Average loss: 0.29517


100%|██████████| 762/762 [00:32<00:00, 23.66it/s]
100%|██████████| 109/109 [00:02<00:00, 45.38it/s]


Epoch: 070, Average loss: 0.11097
Epoch: 070, Average loss: 0.29514


100%|██████████| 762/762 [00:32<00:00, 23.50it/s]
100%|██████████| 109/109 [00:02<00:00, 44.60it/s]


Epoch: 071, Average loss: 0.11171
Epoch: 071, Average loss: 0.29177


100%|██████████| 762/762 [00:32<00:00, 23.48it/s]
100%|██████████| 109/109 [00:02<00:00, 49.17it/s]


Epoch: 072, Average loss: 0.11176
Epoch: 072, Average loss: 0.29381


100%|██████████| 762/762 [00:32<00:00, 23.44it/s]
100%|██████████| 109/109 [00:02<00:00, 43.43it/s]


Epoch: 073, Average loss: 0.10899
Epoch: 073, Average loss: 0.29364


100%|██████████| 762/762 [00:32<00:00, 23.51it/s]
100%|██████████| 109/109 [00:02<00:00, 46.93it/s]


Epoch: 074, Average loss: 0.11604
Epoch: 074, Average loss: 0.29672


 48%|████▊     | 366/762 [00:15<00:15, 25.15it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.75it/s]
100%|██████████| 109/109 [00:02<00:00, 49.16it/s]


Epoch: 094, Average loss: 0.11209
Epoch: 094, Average loss: 0.29239


100%|██████████| 762/762 [00:32<00:00, 23.68it/s]
100%|██████████| 109/109 [00:02<00:00, 43.67it/s]


Epoch: 095, Average loss: 0.10939
Epoch: 095, Average loss: 0.29251


100%|██████████| 762/762 [00:31<00:00, 23.96it/s]
100%|██████████| 109/109 [00:02<00:00, 46.22it/s]


Epoch: 096, Average loss: 0.11206
Epoch: 096, Average loss: 0.29255


100%|██████████| 762/762 [00:32<00:00, 23.74it/s]
100%|██████████| 109/109 [00:02<00:00, 46.50it/s]


Epoch: 097, Average loss: 0.10905
Epoch: 097, Average loss: 0.29270


100%|██████████| 762/762 [00:32<00:00, 23.75it/s]
100%|██████████| 109/109 [00:02<00:00, 47.08it/s]


Epoch: 098, Average loss: 0.10887
Epoch: 098, Average loss: 0.29109


100%|██████████| 762/762 [00:32<00:00, 23.80it/s]
100%|██████████| 109/109 [00:02<00:00, 46.74it/s]


Epoch: 099, Average loss: 0.11073
Epoch: 099, Average loss: 0.29270



Using a target size (torch.Size([4])) that is different to the input size (torch.Size([4, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.


Using a target size (torch.Size([1])) that is different to the input size (torch.Size([1, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.

100%|██████████| 762/762 [00:31<00:00, 23.92it/s]

Using a target size (torch.Size([3])) that is different to the input size (torch.Size([3, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.

100%|██████████| 109/109 [00:02<00:00, 42.47it/s]


Epoch: 000, Average loss: 0.31523
Epoch: 000, Average loss: 0.31174


100%|██████████| 762/762 [00:32<00:00, 23.72it/s]
100%|██████████| 109/109 [00:02<00:00, 48.68it/s]


Epoch: 001, Average loss: 0.11442
Epoch: 001, Average loss: 0.35797


100%|██████████| 762/762 [00:32<00:00, 23.61it/s]
100%|██████████| 109/109 [00:02<00:00, 48.12it/s]


Epoch: 002, Average loss: 0.08030
Epoch: 002, Average loss: 0.19793


100%|██████████| 762/762 [00:31<00:00, 23.90it/s]
100%|██████████| 109/109 [00:02<00:00, 46.93it/s]


Epoch: 003, Average loss: 0.06145
Epoch: 003, Average loss: 0.21568


100%|██████████| 762/762 [00:31<00:00, 23.86it/s]
100%|██████████| 109/109 [00:02<00:00, 48.15it/s]


Epoch: 004, Average loss: 0.05996
Epoch: 004, Average loss: 0.25375


100%|██████████| 762/762 [00:32<00:00, 23.70it/s]
100%|██████████| 109/109 [00:02<00:00, 47.33it/s]


Epoch: 005, Average loss: 0.05783
Epoch: 005, Average loss: 0.07205


 43%|████▎     | 329/762 [00:13<00:16, 25.47it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.72it/s]
100%|██████████| 109/109 [00:02<00:00, 46.82it/s]


Epoch: 019, Average loss: 0.20650
Epoch: 019, Average loss: 0.32042


100%|██████████| 762/762 [00:32<00:00, 23.42it/s]
100%|██████████| 109/109 [00:02<00:00, 47.52it/s]


Epoch: 020, Average loss: 0.19906
Epoch: 020, Average loss: 0.32295


100%|██████████| 762/762 [00:32<00:00, 23.37it/s]
100%|██████████| 109/109 [00:02<00:00, 48.91it/s]


Epoch: 021, Average loss: 0.18713
Epoch: 021, Average loss: 0.33234


100%|██████████| 762/762 [00:32<00:00, 23.39it/s]
100%|██████████| 109/109 [00:02<00:00, 48.39it/s]


Epoch: 022, Average loss: 0.17698
Epoch: 022, Average loss: 0.36231


100%|██████████| 762/762 [00:32<00:00, 23.15it/s]
100%|██████████| 109/109 [00:02<00:00, 47.99it/s]


Epoch: 023, Average loss: 0.16580
Epoch: 023, Average loss: 0.34555


 24%|██▎       | 180/762 [00:07<00:21, 26.47it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.42it/s]
100%|██████████| 109/109 [00:02<00:00, 48.25it/s]


Epoch: 030, Average loss: 0.12031
Epoch: 030, Average loss: 0.33967


100%|██████████| 762/762 [00:32<00:00, 23.47it/s]
100%|██████████| 109/109 [00:02<00:00, 48.47it/s]


Epoch: 031, Average loss: 0.11658
Epoch: 031, Average loss: 0.33055


100%|██████████| 762/762 [00:32<00:00, 23.53it/s]
100%|██████████| 109/109 [00:02<00:00, 45.07it/s]


Epoch: 032, Average loss: 0.11516
Epoch: 032, Average loss: 0.33079


100%|██████████| 762/762 [00:32<00:00, 23.61it/s]
100%|██████████| 109/109 [00:02<00:00, 49.02it/s]


Epoch: 033, Average loss: 0.11227
Epoch: 033, Average loss: 0.36095


100%|██████████| 762/762 [00:32<00:00, 23.54it/s]
100%|██████████| 109/109 [00:02<00:00, 41.87it/s]


Epoch: 034, Average loss: 0.11232
Epoch: 034, Average loss: 0.38274


100%|██████████| 762/762 [00:32<00:00, 23.49it/s]
100%|██████████| 109/109 [00:02<00:00, 48.55it/s]


Epoch: 035, Average loss: 0.11176
Epoch: 035, Average loss: 0.33477


100%|██████████| 762/762 [00:32<00:00, 23.62it/s]
100%|██████████| 109/109 [00:02<00:00, 43.16it/s]


Epoch: 036, Average loss: 0.10360
Epoch: 036, Average loss: 0.34587


 60%|██████    | 459/762 [00:19<00:11, 27.54it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:30<00:00, 24.92it/s]
100%|██████████| 109/109 [00:02<00:00, 48.23it/s]


Epoch: 050, Average loss: 0.10770
Epoch: 050, Average loss: 0.33683


100%|██████████| 762/762 [00:32<00:00, 23.79it/s]
100%|██████████| 109/109 [00:02<00:00, 48.28it/s]


Epoch: 051, Average loss: 0.10987
Epoch: 051, Average loss: 0.33590


100%|██████████| 762/762 [00:32<00:00, 23.52it/s]
100%|██████████| 109/109 [00:02<00:00, 44.62it/s]


Epoch: 052, Average loss: 0.10525
Epoch: 052, Average loss: 0.36954


100%|██████████| 762/762 [00:32<00:00, 23.68it/s]
100%|██████████| 109/109 [00:02<00:00, 47.79it/s]


Epoch: 053, Average loss: 0.11128
Epoch: 053, Average loss: 0.35991


100%|██████████| 762/762 [00:32<00:00, 23.61it/s]
100%|██████████| 109/109 [00:02<00:00, 49.42it/s]


Epoch: 054, Average loss: 0.10297
Epoch: 054, Average loss: 0.35996


 11%|█▏        | 87/762 [00:03<00:32, 20.73it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.81it/s]
100%|██████████| 109/109 [00:02<00:00, 48.67it/s]


Epoch: 061, Average loss: 0.10935
Epoch: 061, Average loss: 0.35326


100%|██████████| 762/762 [00:31<00:00, 23.91it/s]
100%|██████████| 109/109 [00:02<00:00, 43.54it/s]


Epoch: 062, Average loss: 0.10690
Epoch: 062, Average loss: 0.35184


100%|██████████| 762/762 [00:32<00:00, 23.68it/s]
100%|██████████| 109/109 [00:02<00:00, 47.49it/s]


Epoch: 063, Average loss: 0.10624
Epoch: 063, Average loss: 0.33555


100%|██████████| 762/762 [00:32<00:00, 23.81it/s]
100%|██████████| 109/109 [00:02<00:00, 48.59it/s]


Epoch: 064, Average loss: 0.10918
Epoch: 064, Average loss: 0.33938


100%|██████████| 762/762 [00:31<00:00, 23.85it/s]
100%|██████████| 109/109 [00:02<00:00, 48.24it/s]


Epoch: 065, Average loss: 0.10560
Epoch: 065, Average loss: 0.35449


100%|██████████| 762/762 [00:31<00:00, 23.87it/s]
100%|██████████| 109/109 [00:02<00:00, 47.01it/s]


Epoch: 066, Average loss: 0.11009
Epoch: 066, Average loss: 0.35682


100%|██████████| 762/762 [00:31<00:00, 23.87it/s]
100%|██████████| 109/109 [00:02<00:00, 46.73it/s]


Epoch: 067, Average loss: 0.10830
Epoch: 067, Average loss: 0.33697


 42%|████▏     | 321/762 [00:13<00:19, 22.80it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.58it/s]
100%|██████████| 109/109 [00:02<00:00, 44.78it/s]


Epoch: 081, Average loss: 0.10731
Epoch: 081, Average loss: 0.34456


100%|██████████| 762/762 [00:31<00:00, 23.87it/s]
100%|██████████| 109/109 [00:02<00:00, 45.61it/s]


Epoch: 082, Average loss: 0.10355
Epoch: 082, Average loss: 0.33766


100%|██████████| 762/762 [00:32<00:00, 23.60it/s]
100%|██████████| 109/109 [00:02<00:00, 42.26it/s]


Epoch: 083, Average loss: 0.10534
Epoch: 083, Average loss: 0.34282


100%|██████████| 762/762 [00:32<00:00, 23.55it/s]
100%|██████████| 109/109 [00:02<00:00, 48.34it/s]


Epoch: 084, Average loss: 0.10735
Epoch: 084, Average loss: 0.33665


100%|██████████| 762/762 [00:32<00:00, 23.39it/s]
100%|██████████| 109/109 [00:02<00:00, 48.92it/s]


Epoch: 085, Average loss: 0.10612
Epoch: 085, Average loss: 0.36923


 11%|█▏        | 86/762 [00:03<00:31, 21.41it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:28<00:00, 27.02it/s]
100%|██████████| 109/109 [00:01<00:00, 55.61it/s]


Epoch: 092, Average loss: 0.10264
Epoch: 092, Average loss: 0.33791


100%|██████████| 762/762 [00:30<00:00, 24.86it/s]
100%|██████████| 109/109 [00:02<00:00, 45.30it/s]


Epoch: 093, Average loss: 0.10249
Epoch: 093, Average loss: 0.33373


100%|██████████| 762/762 [00:32<00:00, 23.66it/s]
100%|██████████| 109/109 [00:02<00:00, 47.48it/s]


Epoch: 094, Average loss: 0.10525
Epoch: 094, Average loss: 0.37104


100%|██████████| 762/762 [00:32<00:00, 23.62it/s]
100%|██████████| 109/109 [00:02<00:00, 47.57it/s]


Epoch: 095, Average loss: 0.10515
Epoch: 095, Average loss: 0.34021


100%|██████████| 762/762 [00:31<00:00, 23.87it/s]
100%|██████████| 109/109 [00:02<00:00, 48.17it/s]


Epoch: 096, Average loss: 0.10690
Epoch: 096, Average loss: 0.33983


100%|██████████| 762/762 [00:31<00:00, 23.82it/s]
100%|██████████| 109/109 [00:02<00:00, 46.21it/s]


Epoch: 097, Average loss: 0.11195
Epoch: 097, Average loss: 0.34485


100%|██████████| 762/762 [00:32<00:00, 23.62it/s]
100%|██████████| 109/109 [00:02<00:00, 47.64it/s]


Epoch: 098, Average loss: 0.10699
Epoch: 098, Average loss: 0.35832


 70%|███████   | 536/762 [00:22<00:09, 24.64it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:28<00:00, 26.47it/s]
100%|██████████| 109/109 [00:01<00:00, 62.02it/s]


Epoch: 012, Average loss: 0.04039
Epoch: 012, Average loss: 0.19664


100%|██████████| 762/762 [00:26<00:00, 28.69it/s]
100%|██████████| 109/109 [00:02<00:00, 49.09it/s]


Epoch: 013, Average loss: 0.03614
Epoch: 013, Average loss: 0.15727


100%|██████████| 762/762 [00:32<00:00, 23.57it/s]
100%|██████████| 109/109 [00:02<00:00, 43.21it/s]


Epoch: 014, Average loss: 0.03495
Epoch: 014, Average loss: 0.09424


100%|██████████| 762/762 [00:32<00:00, 23.59it/s]
100%|██████████| 109/109 [00:02<00:00, 48.29it/s]


Epoch: 015, Average loss: 0.03654
Epoch: 015, Average loss: 0.14905


100%|██████████| 762/762 [00:32<00:00, 23.70it/s]
100%|██████████| 109/109 [00:02<00:00, 49.52it/s]


Epoch: 016, Average loss: 0.03358
Epoch: 016, Average loss: 0.15251


 28%|██▊       | 216/762 [00:09<00:19, 27.91it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.42it/s]
100%|██████████| 109/109 [00:02<00:00, 44.92it/s]


Epoch: 023, Average loss: 0.09343
Epoch: 023, Average loss: 0.14523


100%|██████████| 762/762 [00:31<00:00, 23.83it/s]
100%|██████████| 109/109 [00:02<00:00, 44.44it/s]


Epoch: 024, Average loss: 0.08727
Epoch: 024, Average loss: 0.14017


100%|██████████| 762/762 [00:32<00:00, 23.44it/s]
100%|██████████| 109/109 [00:02<00:00, 48.45it/s]


Epoch: 025, Average loss: 0.06762
Epoch: 025, Average loss: 0.38868


100%|██████████| 762/762 [00:32<00:00, 23.53it/s]
100%|██████████| 109/109 [00:02<00:00, 48.96it/s]


Epoch: 026, Average loss: 0.06916
Epoch: 026, Average loss: 0.25916


100%|██████████| 762/762 [00:32<00:00, 23.66it/s]
100%|██████████| 109/109 [00:02<00:00, 48.22it/s]


Epoch: 027, Average loss: 0.08092
Epoch: 027, Average loss: 0.16978


100%|██████████| 762/762 [00:32<00:00, 23.52it/s]
100%|██████████| 109/109 [00:02<00:00, 43.82it/s]


Epoch: 028, Average loss: 0.05506
Epoch: 028, Average loss: 0.17676


100%|██████████| 762/762 [00:32<00:00, 23.56it/s]
100%|██████████| 109/109 [00:02<00:00, 46.80it/s]


Epoch: 029, Average loss: 0.05327
Epoch: 029, Average loss: 0.20858


 71%|███████   | 539/762 [00:23<00:11, 20.10it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.40it/s]
100%|██████████| 109/109 [00:02<00:00, 49.80it/s]


Epoch: 043, Average loss: 0.11212
Epoch: 043, Average loss: 0.30416


100%|██████████| 762/762 [00:32<00:00, 23.50it/s]
100%|██████████| 109/109 [00:02<00:00, 48.18it/s]


Epoch: 044, Average loss: 0.10796
Epoch: 044, Average loss: 0.30542


100%|██████████| 762/762 [00:32<00:00, 23.20it/s]
100%|██████████| 109/109 [00:02<00:00, 48.64it/s]


Epoch: 045, Average loss: 0.10997
Epoch: 045, Average loss: 0.30326


100%|██████████| 762/762 [00:32<00:00, 23.70it/s]
100%|██████████| 109/109 [00:02<00:00, 43.37it/s]


Epoch: 046, Average loss: 0.10927
Epoch: 046, Average loss: 0.30326


100%|██████████| 762/762 [00:32<00:00, 23.33it/s]
100%|██████████| 109/109 [00:02<00:00, 47.47it/s]


Epoch: 047, Average loss: 0.11259
Epoch: 047, Average loss: 0.31557


 16%|█▋        | 124/762 [00:05<00:26, 24.17it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:28<00:00, 26.31it/s]
100%|██████████| 109/109 [00:01<00:00, 63.06it/s]


Epoch: 054, Average loss: 0.11498
Epoch: 054, Average loss: 0.30421


100%|██████████| 762/762 [00:26<00:00, 28.32it/s]
100%|██████████| 109/109 [00:02<00:00, 45.89it/s]


Epoch: 055, Average loss: 0.11195
Epoch: 055, Average loss: 0.31678


100%|██████████| 762/762 [00:32<00:00, 23.54it/s]
100%|██████████| 109/109 [00:02<00:00, 44.86it/s]


Epoch: 056, Average loss: 0.11324
Epoch: 056, Average loss: 0.30550


100%|██████████| 762/762 [00:32<00:00, 23.46it/s]
100%|██████████| 109/109 [00:02<00:00, 49.77it/s]


Epoch: 057, Average loss: 0.11222
Epoch: 057, Average loss: 0.31793


100%|██████████| 762/762 [00:32<00:00, 23.77it/s]
100%|██████████| 109/109 [00:02<00:00, 48.66it/s]


Epoch: 058, Average loss: 0.10195
Epoch: 058, Average loss: 0.31760


100%|██████████| 762/762 [00:32<00:00, 23.75it/s]
100%|██████████| 109/109 [00:02<00:00, 44.80it/s]


Epoch: 059, Average loss: 0.11284
Epoch: 059, Average loss: 0.31751


100%|██████████| 762/762 [00:32<00:00, 23.47it/s]
100%|██████████| 109/109 [00:02<00:00, 45.13it/s]


Epoch: 060, Average loss: 0.11196
Epoch: 060, Average loss: 0.31754


 74%|███████▍  | 565/762 [00:24<00:07, 26.25it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.58it/s]
100%|██████████| 109/109 [00:02<00:00, 46.30it/s]


Epoch: 073, Average loss: 0.10798
Epoch: 073, Average loss: 0.31568


100%|██████████| 762/762 [00:32<00:00, 23.58it/s]
100%|██████████| 109/109 [00:02<00:00, 48.70it/s]


Epoch: 074, Average loss: 0.10695
Epoch: 074, Average loss: 0.30723


100%|██████████| 762/762 [00:28<00:00, 26.67it/s]
100%|██████████| 109/109 [00:01<00:00, 65.23it/s]


Epoch: 075, Average loss: 0.10720
Epoch: 075, Average loss: 0.30483


100%|██████████| 762/762 [00:29<00:00, 26.15it/s]
100%|██████████| 109/109 [00:01<00:00, 64.55it/s]


Epoch: 076, Average loss: 0.10543
Epoch: 076, Average loss: 0.31531


100%|██████████| 762/762 [00:30<00:00, 24.74it/s]
100%|██████████| 109/109 [00:02<00:00, 43.61it/s]


Epoch: 077, Average loss: 0.11003
Epoch: 077, Average loss: 0.31382


100%|██████████| 762/762 [00:32<00:00, 23.57it/s]
100%|██████████| 109/109 [00:02<00:00, 48.77it/s]


Epoch: 078, Average loss: 0.10687
Epoch: 078, Average loss: 0.31162


  3%|▎         | 21/762 [00:00<00:29, 25.10it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.53it/s]
100%|██████████| 109/109 [00:02<00:00, 48.96it/s]


Epoch: 085, Average loss: 0.10669
Epoch: 085, Average loss: 0.31171


100%|██████████| 762/762 [00:32<00:00, 23.48it/s]
100%|██████████| 109/109 [00:02<00:00, 45.14it/s]


Epoch: 086, Average loss: 0.10509
Epoch: 086, Average loss: 0.31146


100%|██████████| 762/762 [00:32<00:00, 23.29it/s]
100%|██████████| 109/109 [00:02<00:00, 49.28it/s]


Epoch: 087, Average loss: 0.10410
Epoch: 087, Average loss: 0.30476


100%|██████████| 762/762 [00:32<00:00, 23.50it/s]
100%|██████████| 109/109 [00:02<00:00, 47.72it/s]


Epoch: 088, Average loss: 0.10597
Epoch: 088, Average loss: 0.30348


100%|██████████| 762/762 [00:32<00:00, 23.46it/s]
100%|██████████| 109/109 [00:02<00:00, 43.73it/s]


Epoch: 089, Average loss: 0.10508
Epoch: 089, Average loss: 0.32257


100%|██████████| 762/762 [00:32<00:00, 23.63it/s]
100%|██████████| 109/109 [00:02<00:00, 46.28it/s]


Epoch: 090, Average loss: 0.11042
Epoch: 090, Average loss: 0.30458


100%|██████████| 762/762 [00:32<00:00, 23.36it/s]
100%|██████████| 109/109 [00:02<00:00, 46.36it/s]


Epoch: 091, Average loss: 0.10746
Epoch: 091, Average loss: 0.31144


 46%|████▌     | 347/762 [00:14<00:17, 23.75it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.43it/s]
100%|██████████| 109/109 [00:02<00:00, 49.16it/s]


Epoch: 004, Average loss: 0.06773
Epoch: 004, Average loss: 0.18886


100%|██████████| 762/762 [00:32<00:00, 23.28it/s]
100%|██████████| 109/109 [00:02<00:00, 47.84it/s]


Epoch: 005, Average loss: 0.06494
Epoch: 005, Average loss: 0.20151


100%|██████████| 762/762 [00:32<00:00, 23.60it/s]
100%|██████████| 109/109 [00:02<00:00, 48.64it/s]


Epoch: 006, Average loss: 0.05015
Epoch: 006, Average loss: 0.08905


100%|██████████| 762/762 [00:32<00:00, 23.59it/s]
100%|██████████| 109/109 [00:02<00:00, 46.44it/s]


Epoch: 007, Average loss: 0.05645
Epoch: 007, Average loss: 0.11456


100%|██████████| 762/762 [00:32<00:00, 23.42it/s]
100%|██████████| 109/109 [00:02<00:00, 46.69it/s]


Epoch: 008, Average loss: 0.04499
Epoch: 008, Average loss: 0.17990


100%|██████████| 762/762 [00:32<00:00, 23.42it/s]
100%|██████████| 109/109 [00:02<00:00, 48.64it/s]


Epoch: 009, Average loss: 0.04742
Epoch: 009, Average loss: 0.16177


100%|██████████| 762/762 [00:32<00:00, 23.61it/s]
100%|██████████| 109/109 [00:02<00:00, 47.16it/s]


Epoch: 010, Average loss: 0.03615
Epoch: 010, Average loss: 0.10501


100%|██████████| 762/762 [00:32<00:00, 23.17it/s]
100%|██████████| 109/109 [00:02<00:00, 48.66it/s]


Epoch: 011, Average loss: 0.03767
Epoch: 011, Average loss: 0.09073


100%|██████████| 762/762 [00:32<00:00, 23.44it/s]
100%|██████████| 109/109 [00:02<00:00, 49.09it/s]


Epoch: 012, Average loss: 0.03855
Epoch: 012, Average loss: 0.15463


100%|██████████| 762/762 [00:32<00:00, 23.75it/s]
100%|██████████| 109/109 [00:02<00:00, 44.57it/s]


Epoch: 013, Average loss: 0.03788
Epoch: 013, Average loss: 0.08303


100%|██████████| 762/762 [00:32<00:00, 23.68it/s]
100%|██████████| 109/109 [00:02<00:00, 50.18it/s]


Epoch: 014, Average loss: 0.03870
Epoch: 014, Average loss: 0.08555


100%|██████████| 762/762 [00:32<00:00, 23.56it/s]
100%|██████████| 109/109 [00:02<00:00, 43.95it/s]


Epoch: 015, Average loss: 0.03419
Epoch: 015, Average loss: 0.11145


 90%|█████████ | 689/762 [00:29<00:02, 30.08it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.35it/s]
100%|██████████| 109/109 [00:02<00:00, 45.28it/s]


Epoch: 035, Average loss: 0.14260
Epoch: 035, Average loss: 0.32972


100%|██████████| 762/762 [00:32<00:00, 23.37it/s]
100%|██████████| 109/109 [00:02<00:00, 45.40it/s]


Epoch: 036, Average loss: 0.13252
Epoch: 036, Average loss: 0.36392


100%|██████████| 762/762 [00:29<00:00, 26.21it/s]
100%|██████████| 109/109 [00:01<00:00, 68.01it/s]


Epoch: 037, Average loss: 0.12335
Epoch: 037, Average loss: 0.34101


100%|██████████| 762/762 [00:28<00:00, 26.74it/s]
100%|██████████| 109/109 [00:02<00:00, 50.74it/s]


Epoch: 038, Average loss: 0.12152
Epoch: 038, Average loss: 0.34281


100%|██████████| 762/762 [00:32<00:00, 23.54it/s]
100%|██████████| 109/109 [00:02<00:00, 43.57it/s]


Epoch: 039, Average loss: 0.11606
Epoch: 039, Average loss: 0.33307


100%|██████████| 762/762 [00:32<00:00, 23.54it/s]
100%|██████████| 109/109 [00:02<00:00, 43.87it/s]


Epoch: 040, Average loss: 0.11856
Epoch: 040, Average loss: 0.33876


100%|██████████| 762/762 [00:32<00:00, 23.55it/s]
100%|██████████| 109/109 [00:02<00:00, 48.27it/s]


Epoch: 041, Average loss: 0.11641
Epoch: 041, Average loss: 0.33764


100%|██████████| 762/762 [00:32<00:00, 23.16it/s]
100%|██████████| 109/109 [00:02<00:00, 47.10it/s]


Epoch: 042, Average loss: 0.11735
Epoch: 042, Average loss: 0.33340


100%|██████████| 762/762 [00:32<00:00, 23.34it/s]
100%|██████████| 109/109 [00:02<00:00, 47.52it/s]


Epoch: 043, Average loss: 0.11433
Epoch: 043, Average loss: 0.33543


100%|██████████| 762/762 [00:32<00:00, 23.44it/s]
100%|██████████| 109/109 [00:02<00:00, 48.56it/s]


Epoch: 044, Average loss: 0.11562
Epoch: 044, Average loss: 0.33085


100%|██████████| 762/762 [00:32<00:00, 23.38it/s]
100%|██████████| 109/109 [00:02<00:00, 46.38it/s]


Epoch: 045, Average loss: 0.11789
Epoch: 045, Average loss: 0.33801


100%|██████████| 762/762 [00:32<00:00, 23.27it/s]
100%|██████████| 109/109 [00:02<00:00, 48.72it/s]


Epoch: 046, Average loss: 0.11226
Epoch: 046, Average loss: 0.33116


 84%|████████▍ | 643/762 [00:27<00:05, 21.63it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.61it/s]
100%|██████████| 109/109 [00:01<00:00, 54.75it/s]


Epoch: 065, Average loss: 0.10676
Epoch: 065, Average loss: 0.35499


100%|██████████| 762/762 [00:32<00:00, 23.66it/s]
100%|██████████| 109/109 [00:02<00:00, 45.77it/s]


Epoch: 066, Average loss: 0.10411
Epoch: 066, Average loss: 0.36605


100%|██████████| 762/762 [00:32<00:00, 23.65it/s]
100%|██████████| 109/109 [00:02<00:00, 49.30it/s]


Epoch: 067, Average loss: 0.10845
Epoch: 067, Average loss: 0.36661


100%|██████████| 762/762 [00:32<00:00, 23.50it/s]
100%|██████████| 109/109 [00:02<00:00, 49.83it/s]


Epoch: 068, Average loss: 0.10813
Epoch: 068, Average loss: 0.35591


100%|██████████| 762/762 [00:32<00:00, 23.47it/s]
100%|██████████| 109/109 [00:02<00:00, 45.17it/s]


Epoch: 069, Average loss: 0.11001
Epoch: 069, Average loss: 0.35213


100%|██████████| 762/762 [00:32<00:00, 23.31it/s]
100%|██████████| 109/109 [00:02<00:00, 49.29it/s]


Epoch: 070, Average loss: 0.10888
Epoch: 070, Average loss: 0.34945


100%|██████████| 762/762 [00:32<00:00, 23.34it/s]
100%|██████████| 109/109 [00:02<00:00, 48.99it/s]


Epoch: 071, Average loss: 0.11409
Epoch: 071, Average loss: 0.36524


100%|██████████| 762/762 [00:32<00:00, 23.66it/s]
100%|██████████| 109/109 [00:02<00:00, 46.44it/s]


Epoch: 072, Average loss: 0.10828
Epoch: 072, Average loss: 0.33410


100%|██████████| 762/762 [00:32<00:00, 23.64it/s]
100%|██████████| 109/109 [00:02<00:00, 45.18it/s]


Epoch: 073, Average loss: 0.10912
Epoch: 073, Average loss: 0.34489


100%|██████████| 762/762 [00:32<00:00, 23.48it/s]
100%|██████████| 109/109 [00:02<00:00, 45.85it/s]


Epoch: 074, Average loss: 0.10860
Epoch: 074, Average loss: 0.33193


100%|██████████| 762/762 [00:31<00:00, 23.87it/s]
100%|██████████| 109/109 [00:02<00:00, 43.82it/s]


Epoch: 075, Average loss: 0.10518
Epoch: 075, Average loss: 0.34120


100%|██████████| 762/762 [00:32<00:00, 23.63it/s]
100%|██████████| 109/109 [00:02<00:00, 44.55it/s]


Epoch: 076, Average loss: 0.10792
Epoch: 076, Average loss: 0.34589


 94%|█████████▎| 713/762 [00:30<00:01, 25.17it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 762/762 [00:32<00:00, 23.58it/s]
100%|██████████| 109/109 [00:02<00:00, 49.25it/s]


Epoch: 096, Average loss: 0.10718
Epoch: 096, Average loss: 0.33190


100%|██████████| 762/762 [00:32<00:00, 23.41it/s]
100%|██████████| 109/109 [00:02<00:00, 49.17it/s]


Epoch: 097, Average loss: 0.11068
Epoch: 097, Average loss: 0.34512


100%|██████████| 762/762 [00:32<00:00, 23.50it/s]
100%|██████████| 109/109 [00:02<00:00, 46.77it/s]


Epoch: 098, Average loss: 0.10822
Epoch: 098, Average loss: 0.34916


100%|██████████| 762/762 [00:32<00:00, 23.39it/s]
100%|██████████| 109/109 [00:02<00:00, 46.93it/s]

Epoch: 099, Average loss: 0.10737
Epoch: 099, Average loss: 0.37682





In [55]:
maes

[0.32182802721311193,
 0.33955699766755376,
 0.3224277138675845,
 0.29270045722433186,
 0.33688947923242224,
 0.31148186747685347,
 0.3768232881920327]

In [56]:
np.mean(maes)

0.3288154044105558

In [58]:
np.std(maes)

0.02443561907944912