In [2]:
import pandas as pd
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
ercot_2022_actual_load_df = pd.read_csv('downloaded/20220101-20221212 ERCOT Actual Load.csv')
ercot_2022_actual_load_df.head(), ercot_2022_actual_load_df.tail()

(                   Date      Load
 0  1/1/2022 12:00:00 AM  38145.79
 1   1/1/2022 1:00:00 AM  37158.13
 2   1/1/2022 2:00:00 AM  35966.24
 3   1/1/2022 3:00:00 AM  35148.96
 4   1/1/2022 4:00:00 AM  34610.33,
                         Date      Load
 8299   12/12/2022 7:00:00 PM  46775.92
 8300   12/12/2022 8:00:00 PM  46217.27
 8301   12/12/2022 9:00:00 PM  44998.13
 8302  12/12/2022 10:00:00 PM  42774.80
 8303  12/12/2022 11:00:00 PM  40435.58)

In [42]:

def train_test_split(df, test_size=0.2):
    train_size = int(len(df) * (1 - test_size))
    train_set = df[:train_size]
    test_set = df[train_size:]
    return train_set, test_set

def train_test_split_by_date(df):
    train_set = df[df['Date'] < '2022-10-01']
    test_set = df[df['Date'] >= '2022-10-01']
    return train_set, test_set
    

train_set, test_set = train_test_split_by_date(ercot_2022_actual_load_df)
train_set.shape, test_set.shape

((3169, 2), (5135, 2))

In [43]:
class TimeseriesDataset(torch.utils.data.Dataset):   
    def __init__(self, X, y, seq_len=1):
        self.X = X
        self.y = y
        self.seq_len = seq_len

    def __len__(self):
        return self.X.__len__() - (self.seq_len-1)

    def __getitem__(self, index):
        #print(index)
        a, b = self.X[index:index+self.seq_len], self.y[index+self.seq_len-1]
        #print(a.shape, b.shape)
        return(a, b)

In [44]:
def df_to_tensor(df):
    X = torch.tensor(df['Load'].values).float()
    return X
#print(test_set.values[1663])
x_time_series = df_to_tensor(train_set)
y_time_series = df_to_tensor(test_set)
x_time_series.shape, y_time_series.shape

(torch.Size([3169]), torch.Size([5135]))

In [45]:
seq_len = 73
batch_size = seq_len
train_dataset = TimeseriesDataset(x_time_series, y_time_series, seq_len=seq_len)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = False)


for nth_batch, (batch, _) in enumerate(train_dataset):
    print(f'Batch {nth_batch}:')
    print(f'batch shape: {batch.shape}')
    print(f'batch: {batch}')


Batch 0:
batch shape: torch.Size([73])
batch: tensor([38145.7891, 37158.1289, 35966.2383, 35148.9609, 34610.3281, 34450.7305,
        34568.6211, 34632.4883, 35562.5000, 37846.8984, 39935.8984, 41790.8516,
        43130.2383, 43991.5117, 44247.7891, 44208.9414, 44333.1016, 44944.0391,
        46102.6211, 45950.9688, 45760.7188, 45528.6016, 44827.8008, 43890.7188,
        42904.2812, 42178.6211, 42060.8203, 42491.2617, 43433.8203, 44828.0586,
        46790.8281, 49240.6094, 51317.8789, 52436.7500, 52321.4414, 51451.2617,
        50290.2109, 48566.2188, 47052.3398, 46143.6914, 46202.1484, 48623.6797,
        51632.0703, 52375.0391, 52327.2891, 52032.5703, 50742.8398, 49477.8008,
        48526.7695, 48268.9492, 48373.2617, 49039.6211, 50314.0195, 52550.2383,
        55561.1602, 57599.7695, 57072.8008, 54390.7305, 51564.8008, 48765.6289,
        46376.4297, 44457.5781, 43011.8086, 42364.5508, 42559.4219, 44827.1914,
        48052.6094, 49029.6719, 49358.1016, 48763.4102, 47228.4492, 45556.

In [46]:
import torch
import torch.nn as nn
dims = seq_len

class Transformer(nn.Module):
  def __init__(self, input_dim, output_dim, hidden_dim, num_layers, num_heads):
    super(Transformer, self).__init__()

    self.encoder = nn.TransformerEncoder(
      nn.TransformerEncoderLayer(input_dim, num_heads, hidden_dim),
      num_layers
    )

    self.decoder = nn.TransformerDecoder(
      nn.TransformerDecoderLayer(output_dim, num_heads, hidden_dim),
      num_layers
    )

    self.output_layer = nn.Linear(hidden_dim, output_dim)

  def forward(self, src, trg):
    context = self.encoder(src)
    output = self.decoder(trg, context)
    return self.output_layer(output)

Transformer = Transformer(input_dim=dims, output_dim=dims, hidden_dim=seq_len, num_layers=4, num_heads=seq_len)

In [47]:
from tqdm import tqdm
import os
def train_transformer(model, train_loader, epochs):
  criterion = nn.MSELoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

  for epoch in range(epochs):
    for batch, _ in tqdm(train_loader):
      #print(batch.shape)
      #batch2 = batch.transpose(-2, -1)

      #print(batch.shape)
      output = model(batch, batch)
      loss = criterion(output, batch)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')


if os.path.exists('transformer_model.pt'):
  Transformer = torch.load('transformer_model.pt')
  
else:    
    train_transformer(Transformer, train_loader, epochs=100)
    torch.save(Transformer.state_dict(), 'transformer_model.pt')

  5%|▍         | 2/43 [00:00<00:02, 18.86it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 16%|█▋        | 7/43 [00:00<00:01, 20.10it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 30%|███       | 13/43 [00:00<00:01, 21.53it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 37%|███▋      | 16/43 [00:00<00:01, 23.25it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 51%|█████     | 22/43 [00:01<00:00, 22.31it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 60%|██████    | 26/43 [00:01<00:00, 23.96it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 74%|███████▍  | 32/43 [00:01<00:00, 21.90it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 88%|████████▊ | 38/43 [00:01<00:00, 21.64it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


100%|██████████| 43/43 [00:01<00:00, 21.87it/s]


torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([31, 73])
Epoch 1, Loss: 1708113920.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])


  7%|▋         | 3/43 [00:00<00:01, 20.88it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 21%|██        | 9/43 [00:00<00:01, 22.57it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 28%|██▊       | 12/43 [00:00<00:01, 21.90it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 35%|███▍      | 15/43 [00:00<00:01, 21.58it/s]

torch.Size([73, 73])


 42%|████▏     | 18/43 [00:00<00:01, 21.53it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 56%|█████▌    | 24/43 [00:01<00:00, 21.37it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 63%|██████▎   | 27/43 [00:01<00:00, 22.57it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 70%|██████▉   | 30/43 [00:01<00:00, 21.93it/s]

torch.Size([73, 73])


 77%|███████▋  | 33/43 [00:01<00:00, 21.22it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 91%|█████████ | 39/43 [00:01<00:00, 21.52it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


100%|██████████| 43/43 [00:01<00:00, 21.71it/s]


torch.Size([73, 73])
torch.Size([31, 73])
Epoch 2, Loss: 1694287232.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


  7%|▋         | 3/43 [00:00<00:01, 21.37it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 20.16it/s]

torch.Size([73, 73])


 21%|██        | 9/43 [00:00<00:01, 22.73it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 30%|███       | 13/43 [00:00<00:01, 26.83it/s]

torch.Size([73, 73])


 40%|███▉      | 17/43 [00:00<00:00, 29.42it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 49%|████▉     | 21/43 [00:00<00:00, 30.68it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 58%|█████▊    | 25/43 [00:00<00:00, 31.71it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 67%|██████▋   | 29/43 [00:00<00:00, 32.55it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 77%|███████▋  | 33/43 [00:01<00:00, 33.21it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 86%|████████▌ | 37/43 [00:01<00:00, 33.24it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 95%|█████████▌| 41/43 [00:01<00:00, 31.66it/s]

torch.Size([73, 73])


100%|██████████| 43/43 [00:01<00:00, 30.19it/s]


torch.Size([31, 73])
Epoch 3, Loss: 1669138560.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


  9%|▉         | 4/43 [00:00<00:01, 33.14it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 19%|█▊        | 8/43 [00:00<00:01, 30.99it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 28%|██▊       | 12/43 [00:00<00:01, 25.82it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 35%|███▍      | 15/43 [00:00<00:01, 23.82it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 42%|████▏     | 18/43 [00:00<00:01, 23.54it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 49%|████▉     | 21/43 [00:00<00:00, 22.43it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 56%|█████▌    | 24/43 [00:01<00:00, 22.20it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 63%|██████▎   | 27/43 [00:01<00:00, 21.59it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 70%|██████▉   | 30/43 [00:01<00:00, 21.61it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 77%|███████▋  | 33/43 [00:01<00:00, 23.22it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 86%|████████▌ | 37/43 [00:01<00:00, 26.20it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 93%|█████████▎| 40/43 [00:01<00:00, 25.49it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([31, 73])


100%|██████████| 43/43 [00:01<00:00, 24.19it/s]


Epoch 4, Loss: 1632581888.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


  7%|▋         | 3/43 [00:00<00:01, 21.30it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 20.96it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 21%|██        | 9/43 [00:00<00:01, 20.86it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 28%|██▊       | 12/43 [00:00<00:01, 20.76it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 37%|███▋      | 16/43 [00:00<00:01, 24.31it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 44%|████▍     | 19/43 [00:00<00:01, 22.70it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 51%|█████     | 22/43 [00:01<00:00, 21.87it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 58%|█████▊    | 25/43 [00:01<00:00, 21.25it/s]

torch.Size([73, 73])


 65%|██████▌   | 28/43 [00:01<00:00, 21.07it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 72%|███████▏  | 31/43 [00:01<00:00, 20.44it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 79%|███████▉  | 34/43 [00:01<00:00, 20.63it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 86%|████████▌ | 37/43 [00:01<00:00, 20.66it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 93%|█████████▎| 40/43 [00:01<00:00, 21.00it/s]

torch.Size([73, 73])


100%|██████████| 43/43 [00:02<00:00, 21.38it/s]


torch.Size([73, 73])
torch.Size([31, 73])
Epoch 5, Loss: 1584922624.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


  7%|▋         | 3/43 [00:00<00:01, 25.58it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 22.32it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 21%|██        | 9/43 [00:00<00:01, 21.20it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 28%|██▊       | 12/43 [00:00<00:01, 21.09it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 37%|███▋      | 16/43 [00:00<00:01, 24.72it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 47%|████▋     | 20/43 [00:00<00:00, 26.82it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 56%|█████▌    | 24/43 [00:00<00:00, 28.19it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 63%|██████▎   | 27/43 [00:01<00:00, 28.23it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 70%|██████▉   | 30/43 [00:01<00:00, 27.89it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 79%|███████▉  | 34/43 [00:01<00:00, 29.00it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 88%|████████▊ | 38/43 [00:01<00:00, 30.55it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 98%|█████████▊| 42/43 [00:01<00:00, 31.12it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([31, 73])


100%|██████████| 43/43 [00:01<00:00, 27.80it/s]


Epoch 6, Loss: 1527169920.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


  7%|▋         | 3/43 [00:00<00:01, 28.81it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 29.44it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 23%|██▎       | 10/43 [00:00<00:01, 31.38it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 33%|███▎      | 14/43 [00:00<00:00, 31.63it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 42%|████▏     | 18/43 [00:00<00:00, 31.91it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 51%|█████     | 22/43 [00:00<00:00, 31.93it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 60%|██████    | 26/43 [00:00<00:00, 31.07it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 70%|██████▉   | 30/43 [00:00<00:00, 31.43it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 79%|███████▉  | 34/43 [00:01<00:00, 31.72it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 88%|████████▊ | 38/43 [00:01<00:00, 30.95it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 98%|█████████▊| 42/43 [00:01<00:00, 25.77it/s]

torch.Size([31, 73])


100%|██████████| 43/43 [00:01<00:00, 29.19it/s]


Epoch 7, Loss: 1460333696.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])
torch.Size([73, 73])


  5%|▍         | 2/43 [00:00<00:02, 19.04it/s]

torch.Size([73, 73])
torch.Size([73, 73])


  9%|▉         | 4/43 [00:00<00:02, 19.07it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 18.77it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 19%|█▊        | 8/43 [00:00<00:01, 18.69it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 23%|██▎       | 10/43 [00:00<00:01, 18.56it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 28%|██▊       | 12/43 [00:00<00:01, 18.17it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 33%|███▎      | 14/43 [00:00<00:01, 18.23it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 37%|███▋      | 16/43 [00:00<00:01, 18.30it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 42%|████▏     | 18/43 [00:00<00:01, 18.57it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 47%|████▋     | 20/43 [00:01<00:01, 18.73it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 51%|█████     | 22/43 [00:01<00:01, 18.88it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 56%|█████▌    | 24/43 [00:01<00:01, 18.88it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 60%|██████    | 26/43 [00:01<00:00, 18.54it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 67%|██████▋   | 29/43 [00:01<00:00, 20.56it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 74%|███████▍  | 32/43 [00:01<00:00, 19.90it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 79%|███████▉  | 34/43 [00:01<00:00, 19.37it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 84%|████████▎ | 36/43 [00:01<00:00, 19.06it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 91%|█████████ | 39/43 [00:02<00:00, 21.36it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


100%|██████████| 43/43 [00:02<00:00, 20.05it/s]


torch.Size([73, 73])
torch.Size([31, 73])
Epoch 8, Loss: 1385633280.0000


  5%|▍         | 2/43 [00:00<00:02, 19.36it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 12%|█▏        | 5/43 [00:00<00:01, 22.68it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 19%|█▊        | 8/43 [00:00<00:01, 20.45it/s]

torch.Size([73, 73])


 26%|██▌       | 11/43 [00:00<00:01, 19.28it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 35%|███▍      | 15/43 [00:00<00:01, 18.78it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 44%|████▍     | 19/43 [00:00<00:01, 19.02it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 53%|█████▎    | 23/43 [00:01<00:01, 19.14it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 63%|██████▎   | 27/43 [00:01<00:00, 19.12it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 79%|███████▉  | 34/43 [00:01<00:00, 24.38it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 95%|█████████▌| 41/43 [00:01<00:00, 27.15it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


100%|██████████| 43/43 [00:01<00:00, 22.22it/s]


torch.Size([73, 73])
torch.Size([31, 73])
Epoch 9, Loss: 1304381824.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])


  7%|▋         | 3/43 [00:00<00:01, 29.52it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 24.11it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 21%|██        | 9/43 [00:00<00:01, 21.63it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 28%|██▊       | 12/43 [00:00<00:01, 23.42it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 35%|███▍      | 15/43 [00:00<00:01, 22.62it/s]

torch.Size([73, 73])


 42%|████▏     | 18/43 [00:00<00:01, 23.95it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 49%|████▉     | 21/43 [00:00<00:00, 23.73it/s]

torch.Size([73, 73])


 56%|█████▌    | 24/43 [00:01<00:00, 23.35it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 63%|██████▎   | 27/43 [00:01<00:00, 21.76it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 70%|██████▉   | 30/43 [00:01<00:00, 20.61it/s]

torch.Size([73, 73])


 77%|███████▋  | 33/43 [00:01<00:00, 20.08it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 91%|█████████ | 39/43 [00:01<00:00, 21.31it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


100%|██████████| 43/43 [00:01<00:00, 22.00it/s]


torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([31, 73])
Epoch 10, Loss: 1217974016.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])
torch.Size([73, 73])


  5%|▍         | 2/43 [00:00<00:02, 19.12it/s]

torch.Size([73, 73])
torch.Size([73, 73])


  9%|▉         | 4/43 [00:00<00:02, 19.27it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 16%|█▋        | 7/43 [00:00<00:01, 19.36it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 21%|██        | 9/43 [00:00<00:01, 19.30it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 26%|██▌       | 11/43 [00:00<00:01, 19.26it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 30%|███       | 13/43 [00:00<00:01, 19.09it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 35%|███▍      | 15/43 [00:00<00:01, 18.83it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 40%|███▉      | 17/43 [00:00<00:01, 18.71it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 44%|████▍     | 19/43 [00:00<00:01, 18.87it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 49%|████▉     | 21/43 [00:01<00:01, 18.34it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 53%|█████▎    | 23/43 [00:01<00:01, 18.29it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 58%|█████▊    | 25/43 [00:01<00:00, 18.26it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 63%|██████▎   | 27/43 [00:01<00:00, 18.31it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 67%|██████▋   | 29/43 [00:01<00:00, 18.27it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 72%|███████▏  | 31/43 [00:01<00:00, 18.67it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 79%|███████▉  | 34/43 [00:01<00:00, 19.82it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 84%|████████▎ | 36/43 [00:01<00:00, 19.70it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 91%|█████████ | 39/43 [00:02<00:00, 21.82it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


100%|██████████| 43/43 [00:02<00:00, 19.84it/s]


torch.Size([73, 73])
torch.Size([31, 73])
Epoch 11, Loss: 1127836800.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])


  5%|▍         | 2/43 [00:00<00:02, 18.02it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


  9%|▉         | 4/43 [00:00<00:02, 17.73it/s]

torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:02, 16.52it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 19%|█▊        | 8/43 [00:00<00:02, 17.13it/s]

torch.Size([73, 73])


 26%|██▌       | 11/43 [00:00<00:01, 21.13it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 35%|███▍      | 15/43 [00:00<00:01, 24.62it/s]

torch.Size([73, 73])


 42%|████▏     | 18/43 [00:00<00:00, 25.27it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 49%|████▉     | 21/43 [00:00<00:00, 22.43it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 56%|█████▌    | 24/43 [00:01<00:00, 21.10it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 63%|██████▎   | 27/43 [00:01<00:00, 20.37it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 70%|██████▉   | 30/43 [00:01<00:00, 19.74it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 77%|███████▋  | 33/43 [00:01<00:00, 20.87it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 84%|████████▎ | 36/43 [00:01<00:00, 20.92it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 91%|█████████ | 39/43 [00:01<00:00, 20.24it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


100%|██████████| 43/43 [00:02<00:00, 20.50it/s]


torch.Size([31, 73])
Epoch 12, Loss: 1035400000.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])


  5%|▍         | 2/43 [00:00<00:02, 18.50it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


  9%|▉         | 4/43 [00:00<00:02, 19.01it/s]

torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 19.01it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 19%|█▊        | 8/43 [00:00<00:01, 18.99it/s]

torch.Size([73, 73])


 26%|██▌       | 11/43 [00:00<00:01, 21.66it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 33%|███▎      | 14/43 [00:00<00:01, 21.86it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 40%|███▉      | 17/43 [00:00<00:01, 21.92it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 47%|████▋     | 20/43 [00:00<00:01, 22.10it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 53%|█████▎    | 23/43 [00:01<00:00, 20.80it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 60%|██████    | 26/43 [00:01<00:00, 20.69it/s]

torch.Size([73, 73])


 67%|██████▋   | 29/43 [00:01<00:00, 20.73it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 74%|███████▍  | 32/43 [00:01<00:00, 20.95it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 81%|████████▏ | 35/43 [00:01<00:00, 20.13it/s]

torch.Size([73, 73])


 88%|████████▊ | 38/43 [00:01<00:00, 20.21it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 95%|█████████▌| 41/43 [00:02<00:00, 19.59it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([31, 73])


100%|██████████| 43/43 [00:02<00:00, 20.41it/s]


Epoch 13, Loss: 942065536.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])
torch.Size([73, 73])


  7%|▋         | 3/43 [00:00<00:01, 21.19it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 23.91it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 21%|██        | 9/43 [00:00<00:01, 25.12it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 28%|██▊       | 12/43 [00:00<00:01, 23.10it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 35%|███▍      | 15/43 [00:00<00:01, 21.72it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 42%|████▏     | 18/43 [00:00<00:01, 23.36it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 49%|████▉     | 21/43 [00:00<00:00, 22.33it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 56%|█████▌    | 24/43 [00:01<00:00, 22.79it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 63%|██████▎   | 27/43 [00:01<00:00, 24.28it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 70%|██████▉   | 30/43 [00:01<00:00, 25.36it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 77%|███████▋  | 33/43 [00:01<00:00, 23.49it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 84%|████████▎ | 36/43 [00:01<00:00, 23.40it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 91%|█████████ | 39/43 [00:01<00:00, 24.20it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


100%|██████████| 43/43 [00:01<00:00, 24.17it/s]


torch.Size([31, 73])
Epoch 14, Loss: 849184832.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])


  7%|▋         | 3/43 [00:00<00:01, 25.92it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 26.76it/s]

torch.Size([73, 73])


 21%|██        | 9/43 [00:00<00:01, 27.33it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 28%|██▊       | 12/43 [00:00<00:01, 27.63it/s]

torch.Size([73, 73])


 35%|███▍      | 15/43 [00:00<00:01, 27.38it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 42%|████▏     | 18/43 [00:00<00:00, 27.60it/s]

torch.Size([73, 73])


 49%|████▉     | 21/43 [00:00<00:00, 28.22it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 56%|█████▌    | 24/43 [00:00<00:00, 28.30it/s]

torch.Size([73, 73])


 63%|██████▎   | 27/43 [00:00<00:00, 28.06it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 70%|██████▉   | 30/43 [00:01<00:00, 28.60it/s]

torch.Size([73, 73])


 79%|███████▉  | 34/43 [00:01<00:00, 29.37it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 98%|█████████▊| 42/43 [00:01<00:00, 30.00it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([31, 73])


100%|██████████| 43/43 [00:01<00:00, 28.63it/s]


Epoch 15, Loss: 758033408.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])


  5%|▍         | 2/43 [00:00<00:02, 18.82it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


  9%|▉         | 4/43 [00:00<00:02, 18.99it/s]

torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 18.79it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 19%|█▊        | 8/43 [00:00<00:01, 18.51it/s]

torch.Size([73, 73])


 23%|██▎       | 10/43 [00:00<00:01, 18.39it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 28%|██▊       | 12/43 [00:00<00:01, 18.52it/s]

torch.Size([73, 73])


 33%|███▎      | 14/43 [00:00<00:01, 18.52it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 37%|███▋      | 16/43 [00:00<00:01, 18.44it/s]

torch.Size([73, 73])


 42%|████▏     | 18/43 [00:00<00:01, 18.19it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 47%|████▋     | 20/43 [00:01<00:01, 18.12it/s]

torch.Size([73, 73])


 51%|█████     | 22/43 [00:01<00:01, 18.32it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 56%|█████▌    | 24/43 [00:01<00:01, 18.58it/s]

torch.Size([73, 73])


 60%|██████    | 26/43 [00:01<00:00, 18.88it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 65%|██████▌   | 28/43 [00:01<00:00, 19.01it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 72%|███████▏  | 31/43 [00:01<00:00, 19.74it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 77%|███████▋  | 33/43 [00:01<00:00, 19.35it/s]

torch.Size([73, 73])


 81%|████████▏ | 35/43 [00:01<00:00, 19.12it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 86%|████████▌ | 37/43 [00:01<00:00, 19.12it/s]

torch.Size([73, 73])


 91%|█████████ | 39/43 [00:02<00:00, 19.14it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 95%|█████████▌| 41/43 [00:02<00:00, 19.16it/s]

torch.Size([73, 73])


100%|██████████| 43/43 [00:02<00:00, 18.91it/s]


torch.Size([31, 73])
Epoch 16, Loss: 669780928.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])
torch.Size([73, 73])


  5%|▍         | 2/43 [00:00<00:02, 19.29it/s]

torch.Size([73, 73])
torch.Size([73, 73])


  9%|▉         | 4/43 [00:00<00:02, 19.23it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 18.51it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 19%|█▊        | 8/43 [00:00<00:01, 18.46it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 26%|██▌       | 11/43 [00:00<00:01, 21.39it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 33%|███▎      | 14/43 [00:00<00:01, 23.26it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 40%|███▉      | 17/43 [00:00<00:01, 21.55it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 47%|████▋     | 20/43 [00:00<00:01, 20.22it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 53%|█████▎    | 23/43 [00:01<00:01, 19.74it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 60%|██████    | 26/43 [00:01<00:00, 19.54it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 65%|██████▌   | 28/43 [00:01<00:00, 19.25it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 70%|██████▉   | 30/43 [00:01<00:00, 19.15it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 77%|███████▋  | 33/43 [00:01<00:00, 19.52it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 81%|████████▏ | 35/43 [00:01<00:00, 19.36it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 86%|████████▌ | 37/43 [00:01<00:00, 19.30it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 91%|█████████ | 39/43 [00:01<00:00, 19.07it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 95%|█████████▌| 41/43 [00:02<00:00, 18.87it/s]

torch.Size([73, 73])
torch.Size([73, 73])


100%|██████████| 43/43 [00:02<00:00, 19.69it/s]


torch.Size([31, 73])
Epoch 17, Loss: 587295232.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])


  5%|▍         | 2/43 [00:00<00:02, 18.12it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


  9%|▉         | 4/43 [00:00<00:02, 18.19it/s]

torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:02, 17.53it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 19%|█▊        | 8/43 [00:00<00:02, 16.93it/s]

torch.Size([73, 73])


 23%|██▎       | 10/43 [00:00<00:01, 16.86it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 30%|███       | 13/43 [00:00<00:01, 18.42it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 37%|███▋      | 16/43 [00:00<00:01, 19.51it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 44%|████▍     | 19/43 [00:01<00:01, 20.07it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 51%|█████     | 22/43 [00:01<00:01, 20.39it/s]

torch.Size([73, 73])


 58%|█████▊    | 25/43 [00:01<00:00, 21.33it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 65%|██████▌   | 28/43 [00:01<00:00, 21.61it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 72%|███████▏  | 31/43 [00:01<00:00, 21.84it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 79%|███████▉  | 34/43 [00:01<00:00, 21.95it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 86%|████████▌ | 37/43 [00:01<00:00, 22.04it/s]

torch.Size([73, 73])


 93%|█████████▎| 40/43 [00:01<00:00, 21.71it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([31, 73])


100%|██████████| 43/43 [00:02<00:00, 20.69it/s]


Epoch 18, Loss: 507048640.0000


  0%|          | 0/43 [00:00<?, ?it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


  7%|▋         | 3/43 [00:00<00:01, 22.77it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 14%|█▍        | 6/43 [00:00<00:01, 22.42it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 21%|██        | 9/43 [00:00<00:01, 18.63it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 28%|██▊       | 12/43 [00:00<00:01, 19.85it/s]

torch.Size([73, 73])
torch.Size([73, 73])


 35%|███▍      | 15/43 [00:00<00:01, 20.12it/s]

torch.Size([73, 73])
torch.Size([73, 73])
torch.Size([73, 73])


 40%|███▉      | 17/43 [00:00<00:01, 20.28it/s]

torch.Size([73, 73])





KeyboardInterrupt: 

In [None]:
test_set.head()

Unnamed: 0,Date,Load
6643,10/4/2022 8:00:00 PM,50943.27
6644,10/4/2022 9:00:00 PM,48196.03
6645,10/4/2022 10:00:00 PM,44963.52
6646,10/4/2022 11:00:00 PM,41779.37
6647,10/5/2022 12:00:00 AM,38623.86


In [None]:
def test_transformer(model, test_loader):
  with torch.no_grad():
    for batch, _ in test_loader:
      #batch = batch.permute(1, 0, 2)
      output = model(batch)
      print(f'Input: {batch}')
      print(f'Output: {output}')
test_transformer(Transformer, train_loader)

TypeError: 'collections.OrderedDict' object is not callable

In [None]:
import plotly.express as px 
def plot_predictions(model, test_loader):
    """Using Plotly, plot predictions of the model on the test set against the actual values for a single batch. The y-axis is 'Load'. The x-axis is the index of the test set."""
    with torch.no_grad():
        for batch, _ in test_loader:
            output = model(batch, batch)
            output = output.flatten()
            batch = batch.flatten()
            df = pd.DataFrame({'Load': output, 'Predicted Load': batch})
            df = df.reset_index()
            df = df.melt(id_vars='index', value_vars=['Load', 'Predicted Load'])
            fig = px.line(df, x='index', y='value', color='variable')
            fig.show()
            break
plot_predictions(Transformer, train_loader)


ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed

In [None]:
def summarize_transformer(model):
  print(model)
summarize_transformer(Transformer)

Transformer(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=73, out_features=73, bias=True)
        )
        (linear1): Linear(in_features=73, out_features=73, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=73, out_features=73, bias=True)
        (norm1): LayerNorm((73,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((73,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=73, out_features=73, bias=True)
        )
        (linear1): Linear(in_features=73, out_features=73, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
  