In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt
import numpy as np

In [5]:
from VQVAE import VQVAE, VQVAE_Dataset, Encoder, Decoder, VectorQuantizer

In [6]:
df = pd.read_csv("data/dataset_cifar10_v1.csv") # loading the dataset to pandas df

In [7]:
map = {"A":1.0,"B":2.0,"C":3.0,"D":4.0} # mapping the conv block type to numerical values

In [8]:
for column, dtype in df.dtypes.items(): # applying the mapping to the column and also converting to float32
    if dtype == 'object':
        df[column] = df[column].replace(map).astype('float32')

df = df.astype({col: 'float32' for col in df.select_dtypes('int64').columns})

  df[column] = df[column].replace(map).astype('float32')


In [9]:
df.shape

(1200, 25)

In [10]:
df.head()

Unnamed: 0,out_channel0,M,R1,R2,R3,R4,R5,convblock1,widenfact1,B1,...,B3,convblock4,widenfact4,B4,convblock5,widenfact5,B5,1_day_accuracy,1_day_accuracy_std,AVM
0,117.0,1.0,9.0,0.0,0.0,0.0,0.0,2.0,4.0,11.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.923597,0.057017,7.955268
1,122.0,1.0,3.0,0.0,0.0,0.0,0.0,4.0,4.0,2.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.922466,0.047991,7.772549
2,102.0,1.0,1.0,0.0,0.0,0.0,0.0,3.0,1.0,9.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.876552,0.110986,17.92613
3,32.0,3.0,3.0,1.0,7.0,0.0,0.0,2.0,2.0,5.0,...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.658484,0.166616,15.33696
4,38.0,3.0,8.0,4.0,11.0,0.0,0.0,4.0,1.0,12.0,...,9.0,0.0,0.0,0.0,0.0,0.0,0.0,0.874308,0.192419,12.943313


In [11]:
# mapping to assign labels to architecture data based on 1 day accuracy
intervals_mapping = {
    (0.6, 0.65): 0,
    (0.65, 0.7): 1,
    (0.7, 0.75): 2,
    (0.75, 0.8): 3,
    (0.8, 0.85): 4,
    (0.85, 0.9): 5,
    (0.9, 0.95): 6,
}

In [12]:
def map_accuracy_to_value(accuracy):
    for interval, value in intervals_mapping.items():
        if interval[0] <= accuracy <= interval[1]:
            return value
    return None  # Returns None if the accuracy doesn't fall within any defined intervals

df['accuracy_mapped'] = df['1_day_accuracy'].apply(map_accuracy_to_value)

In [13]:
data = df.iloc[:,:-4]
data.head()

Unnamed: 0,out_channel0,M,R1,R2,R3,R4,R5,convblock1,widenfact1,B1,...,B2,convblock3,widenfact3,B3,convblock4,widenfact4,B4,convblock5,widenfact5,B5
0,117.0,1.0,9.0,0.0,0.0,0.0,0.0,2.0,4.0,11.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,122.0,1.0,3.0,0.0,0.0,0.0,0.0,4.0,4.0,2.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,102.0,1.0,1.0,0.0,0.0,0.0,0.0,3.0,1.0,9.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,32.0,3.0,3.0,1.0,7.0,0.0,0.0,2.0,2.0,5.0,...,9.0,3.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
4,38.0,3.0,8.0,4.0,11.0,0.0,0.0,4.0,1.0,12.0,...,6.0,4.0,3.0,9.0,0.0,0.0,0.0,0.0,0.0,0.0


In [14]:
labels = df.iloc[:,-1]
labels.head()

0    6
1    6
2    5
3    1
4    5
Name: accuracy_mapped, dtype: int64

In [15]:
data = torch.tensor(data[data.columns].values,dtype=torch.float32)
labels = torch.tensor(labels.values,dtype=torch.int64)
data.shape, labels.shape

(torch.Size([1200, 22]), torch.Size([1200]))

In [16]:
labels

tensor([6, 6, 5,  ..., 6, 6, 2])

In [17]:
if torch.backends.mps.is_available():
  device = 'mps'
else:
  device = 'cpu'
device

'mps'

In [18]:
data = data.to(device)
labels = labels.to(device)

### Hyperparameters

In [46]:
# hyperparameters (TODO: tuning)
x_dim = data.shape[1]
h_nodes = 256
scale = 2
num_layers = 3
embed_dim = 10
dropout = 0.2
num_embeddings = 50
commitment_cost = 0.25
divergence_cost = 0.1
epochs = 200
learning_rate = 1e-4
weight_decay = 1e-3
batch_size = 16

In [34]:
dataset = VQVAE_Dataset(data,labels) # dataset init

In [35]:
train_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [36]:
model = VQVAE(x_dim,embed_dim,dropout=dropout,num_embeddings=num_embeddings,
              commitment_cost=commitment_cost,h_nodes=h_nodes,scale=scale,
              num_layers=num_layers).to(device)

In [37]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate,weight_decay=weight_decay, amsgrad=False) # Works way better than SGD!

In [24]:
from torchsummary import summary
summary(Encoder(x_dim,embed_dim,dropout=dropout,h_nodes=h_nodes,scale=scale,num_layers=num_layers).to('cpu'),(x_dim,))
summary(VectorQuantizer(num_embeddings,embed_dim,commitment_cost).to('cpu'),(embed_dim,))
summary(Decoder(x_dim,embed_dim,h_nodes=h_nodes,dropout=dropout,scale=scale,num_layers=num_layers).to('cpu'),(embed_dim,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 256]           5,888
              ReLU-2                  [-1, 256]               0
            Linear-3                  [-1, 128]          32,896
              ReLU-4                  [-1, 128]               0
            Linear-5                   [-1, 64]           8,256
              ReLU-6                   [-1, 64]               0
            Linear-7                   [-1, 32]           2,080
              ReLU-8                   [-1, 32]               0
            Linear-9                   [-1, 10]             330
Total params: 49,450
Trainable params: 49,450
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.19
Estimated Total Size (MB): 0.20
---------------------------------------------

In [113]:
data_recon

tensor([[66.8076,  3.0430,  8.5017,  6.8811,  0.0000,  3.5551,  0.0000,  2.5475,
          2.4100,  6.5587,  2.0149,  1.9929,  5.1478,  1.5032,  1.5739,  3.9572,
          1.0727,  1.0492,  2.8168,  0.5429,  0.5501,  1.4889],
        [66.8076,  3.0430,  8.5017,  6.8811,  0.0000,  3.5551,  0.0000,  2.5475,
          2.4100,  6.5587,  2.0149,  1.9929,  5.1478,  1.5032,  1.5739,  3.9572,
          1.0727,  1.0492,  2.8168,  0.5429,  0.5501,  1.4889],
        [66.8076,  3.0430,  8.5017,  6.8811,  0.0000,  3.5551,  0.0000,  2.5475,
          2.4100,  6.5587,  2.0149,  1.9929,  5.1478,  1.5032,  1.5739,  3.9572,
          1.0727,  1.0492,  2.8168,  0.5429,  0.5501,  1.4889],
        [66.8076,  3.0430,  8.5017,  6.8811,  0.0000,  3.5551,  0.0000,  2.5475,
          2.4100,  6.5587,  2.0149,  1.9929,  5.1478,  1.5032,  1.5739,  3.9572,
          1.0727,  1.0492,  2.8168,  0.5429,  0.5501,  1.4889],
        [66.8076,  3.0430,  8.5017,  6.8811,  0.0000,  3.5551,  0.0000,  2.5475,
          2.410

In [38]:
model.train()
train_res_recon_error = []
train_res_perplexity = []
epoch_error = []
epoch_perplexity = []

for epoch in tqdm(range(epochs)):

  for i, batch in enumerate(train_loader):  # Iterate over all batches in the dataset
    batch_x,batch_y = batch
    batch_x = batch_x.to(device)
    batch_y = batch_y.to(device)
    optimizer.zero_grad()

    # noise = torch.rand(batch.size()).to(device)
    vq_loss, data_recon, perplexity = model(batch_x)
    recon_error = F.mse_loss(data_recon, batch_x)
    loss = recon_error + vq_loss
    loss.backward()

    # # Gradient checking
    # print("Encoder gradients:")
    # for name, param in model.encoder.named_parameters():
    #     if param.grad is None:
    #         print(f"{name}: No grad")

    # print("\nDecoder gradients:")
    # for name, param in model.decoder.named_parameters():
    #     if param.grad is None:
    #         print(f"{name}: No grad")

    optimizer.step()

    epoch_error.append(recon_error.item())
    epoch_perplexity.append(perplexity.item())

    # if (i+1) % 10 == 0:
    #   print('%d iterations of epochs %d/%d' % (i+1,epoch + 1, epochs))
    #   print('Average recon_error: %.3f' % np.mean(train_res_recon_error[-10:]))
    #   print('Average perplexity: %.3f' % np.mean(train_res_perplexity[-10:]))
    #   print()

  print('Epoch %d/%d' % (epoch + 1, epochs))
  print('Average epoch error: %.3f' % np.mean(epoch_error))
  print('Average epoch perplexity: %.3f' % np.mean(epoch_perplexity))
  print()

 93%|█████████▎| 1115/1200 [15:12<01:09,  1.22it/s]

Epoch 1115/1200
Average epoch error: 23.422
Average epoch perplexity: 2.896



 93%|█████████▎| 1116/1200 [15:13<01:09,  1.20it/s]

Epoch 1116/1200
Average epoch error: 23.423
Average epoch perplexity: 2.896



 93%|█████████▎| 1117/1200 [15:13<01:07,  1.22it/s]

Epoch 1117/1200
Average epoch error: 23.421
Average epoch perplexity: 2.896



 93%|█████████▎| 1118/1200 [15:14<01:07,  1.22it/s]

Epoch 1118/1200
Average epoch error: 23.417
Average epoch perplexity: 2.896



 93%|█████████▎| 1119/1200 [15:15<01:06,  1.22it/s]

Epoch 1119/1200
Average epoch error: 23.414
Average epoch perplexity: 2.895



 93%|█████████▎| 1120/1200 [15:16<01:04,  1.23it/s]

Epoch 1120/1200
Average epoch error: 23.411
Average epoch perplexity: 2.895



 93%|█████████▎| 1121/1200 [15:17<01:04,  1.23it/s]

Epoch 1121/1200
Average epoch error: 23.407
Average epoch perplexity: 2.895



 94%|█████████▎| 1122/1200 [15:18<01:04,  1.21it/s]

Epoch 1122/1200
Average epoch error: 23.404
Average epoch perplexity: 2.895



 94%|█████████▎| 1123/1200 [15:18<01:02,  1.23it/s]

Epoch 1123/1200
Average epoch error: 23.400
Average epoch perplexity: 2.895



 94%|█████████▎| 1124/1200 [15:19<01:02,  1.22it/s]

Epoch 1124/1200
Average epoch error: 23.396
Average epoch perplexity: 2.896



 94%|█████████▍| 1125/1200 [15:20<01:02,  1.21it/s]

Epoch 1125/1200
Average epoch error: 23.395
Average epoch perplexity: 2.895



 94%|█████████▍| 1126/1200 [15:21<01:00,  1.23it/s]

Epoch 1126/1200
Average epoch error: 23.423
Average epoch perplexity: 2.894



 94%|█████████▍| 1127/1200 [15:22<00:59,  1.23it/s]

Epoch 1127/1200
Average epoch error: 23.442
Average epoch perplexity: 2.893



 94%|█████████▍| 1128/1200 [15:22<00:58,  1.22it/s]

Epoch 1128/1200
Average epoch error: 23.462
Average epoch perplexity: 2.892



 94%|█████████▍| 1129/1200 [15:23<00:57,  1.24it/s]

Epoch 1129/1200
Average epoch error: 23.492
Average epoch perplexity: 2.891



 94%|█████████▍| 1130/1200 [15:24<00:56,  1.23it/s]

Epoch 1130/1200
Average epoch error: 23.528
Average epoch perplexity: 2.889



 94%|█████████▍| 1131/1200 [15:25<00:56,  1.22it/s]

Epoch 1131/1200
Average epoch error: 23.566
Average epoch perplexity: 2.887



 94%|█████████▍| 1132/1200 [15:26<00:54,  1.24it/s]

Epoch 1132/1200
Average epoch error: 23.603
Average epoch perplexity: 2.886



 94%|█████████▍| 1133/1200 [15:26<00:54,  1.23it/s]

Epoch 1133/1200
Average epoch error: 23.628
Average epoch perplexity: 2.885



 94%|█████████▍| 1134/1200 [15:27<00:54,  1.22it/s]

Epoch 1134/1200
Average epoch error: 23.631
Average epoch perplexity: 2.884



 95%|█████████▍| 1135/1200 [15:28<00:52,  1.23it/s]

Epoch 1135/1200
Average epoch error: 23.634
Average epoch perplexity: 2.883



 95%|█████████▍| 1136/1200 [15:29<00:51,  1.23it/s]

Epoch 1136/1200
Average epoch error: 23.639
Average epoch perplexity: 2.883



 95%|█████████▍| 1137/1200 [15:30<00:51,  1.23it/s]

Epoch 1137/1200
Average epoch error: 23.653
Average epoch perplexity: 2.882



 95%|█████████▍| 1138/1200 [15:31<00:49,  1.24it/s]

Epoch 1138/1200
Average epoch error: 23.686
Average epoch perplexity: 2.881



 95%|█████████▍| 1139/1200 [15:31<00:49,  1.23it/s]

Epoch 1139/1200
Average epoch error: 23.688
Average epoch perplexity: 2.880



 95%|█████████▌| 1140/1200 [15:32<00:49,  1.21it/s]

Epoch 1140/1200
Average epoch error: 23.687
Average epoch perplexity: 2.880



 95%|█████████▌| 1141/1200 [15:33<00:48,  1.22it/s]

Epoch 1141/1200
Average epoch error: 23.686
Average epoch perplexity: 2.880



 95%|█████████▌| 1142/1200 [15:34<00:48,  1.21it/s]

Epoch 1142/1200
Average epoch error: 23.691
Average epoch perplexity: 2.879



 95%|█████████▌| 1143/1200 [15:35<00:47,  1.20it/s]

Epoch 1143/1200
Average epoch error: 23.692
Average epoch perplexity: 2.879



 95%|█████████▌| 1144/1200 [15:36<00:45,  1.22it/s]

Epoch 1144/1200
Average epoch error: 23.693
Average epoch perplexity: 2.878



 95%|█████████▌| 1145/1200 [15:36<00:45,  1.22it/s]

Epoch 1145/1200
Average epoch error: 23.694
Average epoch perplexity: 2.877



 96%|█████████▌| 1146/1200 [15:37<00:44,  1.21it/s]

Epoch 1146/1200
Average epoch error: 23.696
Average epoch perplexity: 2.876



 96%|█████████▌| 1147/1200 [15:38<00:43,  1.23it/s]

Epoch 1147/1200
Average epoch error: 23.697
Average epoch perplexity: 2.875



 96%|█████████▌| 1148/1200 [15:39<00:42,  1.22it/s]

Epoch 1148/1200
Average epoch error: 23.700
Average epoch perplexity: 2.874



 96%|█████████▌| 1149/1200 [15:40<00:41,  1.22it/s]

Epoch 1149/1200
Average epoch error: 23.707
Average epoch perplexity: 2.873



 96%|█████████▌| 1150/1200 [15:40<00:40,  1.23it/s]

Epoch 1150/1200
Average epoch error: 23.713
Average epoch perplexity: 2.873



 96%|█████████▌| 1151/1200 [15:41<00:39,  1.23it/s]

Epoch 1151/1200
Average epoch error: 23.720
Average epoch perplexity: 2.872



 96%|█████████▌| 1152/1200 [15:42<00:39,  1.22it/s]

Epoch 1152/1200
Average epoch error: 23.724
Average epoch perplexity: 2.872



 96%|█████████▌| 1153/1200 [15:43<00:37,  1.24it/s]

Epoch 1153/1200
Average epoch error: 23.724
Average epoch perplexity: 2.872



 96%|█████████▌| 1154/1200 [15:44<00:37,  1.23it/s]

Epoch 1154/1200
Average epoch error: 23.721
Average epoch perplexity: 2.872



 96%|█████████▋| 1155/1200 [15:44<00:36,  1.22it/s]

Epoch 1155/1200
Average epoch error: 23.719
Average epoch perplexity: 2.872



 96%|█████████▋| 1156/1200 [15:45<00:35,  1.24it/s]

Epoch 1156/1200
Average epoch error: 23.716
Average epoch perplexity: 2.872



 96%|█████████▋| 1157/1200 [15:46<00:34,  1.23it/s]

Epoch 1157/1200
Average epoch error: 23.713
Average epoch perplexity: 2.872



 96%|█████████▋| 1158/1200 [15:47<00:34,  1.22it/s]

Epoch 1158/1200
Average epoch error: 23.711
Average epoch perplexity: 2.872



 97%|█████████▋| 1159/1200 [15:48<00:33,  1.23it/s]

Epoch 1159/1200
Average epoch error: 23.707
Average epoch perplexity: 2.872



 97%|█████████▋| 1160/1200 [15:49<00:32,  1.23it/s]

Epoch 1160/1200
Average epoch error: 23.701
Average epoch perplexity: 2.872



 97%|█████████▋| 1161/1200 [15:49<00:31,  1.22it/s]

Epoch 1161/1200
Average epoch error: 23.695
Average epoch perplexity: 2.873



 97%|█████████▋| 1162/1200 [15:50<00:30,  1.23it/s]

Epoch 1162/1200
Average epoch error: 23.691
Average epoch perplexity: 2.873



 97%|█████████▋| 1163/1200 [15:51<00:30,  1.21it/s]

Epoch 1163/1200
Average epoch error: 23.688
Average epoch perplexity: 2.873



 97%|█████████▋| 1164/1200 [15:52<00:29,  1.22it/s]

Epoch 1164/1200
Average epoch error: 23.683
Average epoch perplexity: 2.874



 97%|█████████▋| 1165/1200 [15:53<00:28,  1.22it/s]

Epoch 1165/1200
Average epoch error: 23.679
Average epoch perplexity: 2.874



 97%|█████████▋| 1166/1200 [15:53<00:27,  1.22it/s]

Epoch 1166/1200
Average epoch error: 23.676
Average epoch perplexity: 2.873



 97%|█████████▋| 1167/1200 [15:54<00:27,  1.21it/s]

Epoch 1167/1200
Average epoch error: 23.673
Average epoch perplexity: 2.873



 97%|█████████▋| 1168/1200 [15:55<00:25,  1.23it/s]

Epoch 1168/1200
Average epoch error: 23.668
Average epoch perplexity: 2.874



 97%|█████████▋| 1169/1200 [15:56<00:25,  1.22it/s]

Epoch 1169/1200
Average epoch error: 23.662
Average epoch perplexity: 2.874



 98%|█████████▊| 1170/1200 [15:57<00:24,  1.22it/s]

Epoch 1170/1200
Average epoch error: 23.655
Average epoch perplexity: 2.875



 98%|█████████▊| 1171/1200 [15:58<00:23,  1.23it/s]

Epoch 1171/1200
Average epoch error: 23.648
Average epoch perplexity: 2.875



 98%|█████████▊| 1172/1200 [15:58<00:22,  1.22it/s]

Epoch 1172/1200
Average epoch error: 23.643
Average epoch perplexity: 2.876



 98%|█████████▊| 1173/1200 [15:59<00:22,  1.22it/s]

Epoch 1173/1200
Average epoch error: 23.644
Average epoch perplexity: 2.875



 98%|█████████▊| 1174/1200 [16:00<00:21,  1.23it/s]

Epoch 1174/1200
Average epoch error: 23.644
Average epoch perplexity: 2.875



 98%|█████████▊| 1175/1200 [16:01<00:20,  1.21it/s]

Epoch 1175/1200
Average epoch error: 23.645
Average epoch perplexity: 2.875



 98%|█████████▊| 1176/1200 [16:02<00:19,  1.21it/s]

Epoch 1176/1200
Average epoch error: 23.643
Average epoch perplexity: 2.875



 98%|█████████▊| 1177/1200 [16:02<00:18,  1.22it/s]

Epoch 1177/1200
Average epoch error: 23.637
Average epoch perplexity: 2.875



 98%|█████████▊| 1178/1200 [16:03<00:18,  1.19it/s]

Epoch 1178/1200
Average epoch error: 23.631
Average epoch perplexity: 2.876



 98%|█████████▊| 1179/1200 [16:04<00:17,  1.22it/s]

Epoch 1179/1200
Average epoch error: 23.628
Average epoch perplexity: 2.876



 98%|█████████▊| 1180/1200 [16:05<00:16,  1.21it/s]

Epoch 1180/1200
Average epoch error: 23.627
Average epoch perplexity: 2.876



 98%|█████████▊| 1181/1200 [16:06<00:15,  1.20it/s]

Epoch 1181/1200
Average epoch error: 23.628
Average epoch perplexity: 2.875



 98%|█████████▊| 1182/1200 [16:07<00:14,  1.22it/s]

Epoch 1182/1200
Average epoch error: 23.626
Average epoch perplexity: 2.875



 99%|█████████▊| 1183/1200 [16:07<00:13,  1.22it/s]

Epoch 1183/1200
Average epoch error: 23.621
Average epoch perplexity: 2.875



 99%|█████████▊| 1184/1200 [16:08<00:13,  1.21it/s]

Epoch 1184/1200
Average epoch error: 23.616
Average epoch perplexity: 2.875



 99%|█████████▉| 1185/1200 [16:09<00:12,  1.23it/s]

Epoch 1185/1200
Average epoch error: 23.621
Average epoch perplexity: 2.875



 99%|█████████▉| 1186/1200 [16:10<00:11,  1.23it/s]

Epoch 1186/1200
Average epoch error: 23.621
Average epoch perplexity: 2.874



 99%|█████████▉| 1187/1200 [16:11<00:10,  1.22it/s]

Epoch 1187/1200
Average epoch error: 23.624
Average epoch perplexity: 2.874



 99%|█████████▉| 1188/1200 [16:11<00:09,  1.24it/s]

Epoch 1188/1200
Average epoch error: 23.624
Average epoch perplexity: 2.874



 99%|█████████▉| 1189/1200 [16:12<00:08,  1.23it/s]

Epoch 1189/1200
Average epoch error: 23.620
Average epoch perplexity: 2.874



 99%|█████████▉| 1190/1200 [16:13<00:08,  1.22it/s]

Epoch 1190/1200
Average epoch error: 23.617
Average epoch perplexity: 2.875



 99%|█████████▉| 1191/1200 [16:14<00:07,  1.24it/s]

Epoch 1191/1200
Average epoch error: 23.616
Average epoch perplexity: 2.875



 99%|█████████▉| 1192/1200 [16:15<00:06,  1.22it/s]

Epoch 1192/1200
Average epoch error: 23.619
Average epoch perplexity: 2.874



 99%|█████████▉| 1193/1200 [16:16<00:05,  1.22it/s]

Epoch 1193/1200
Average epoch error: 23.623
Average epoch perplexity: 2.873



100%|█████████▉| 1194/1200 [16:16<00:04,  1.23it/s]

Epoch 1194/1200
Average epoch error: 23.622
Average epoch perplexity: 2.873



100%|█████████▉| 1195/1200 [16:17<00:04,  1.22it/s]

Epoch 1195/1200
Average epoch error: 23.617
Average epoch perplexity: 2.873



100%|█████████▉| 1196/1200 [16:18<00:03,  1.22it/s]

Epoch 1196/1200
Average epoch error: 23.612
Average epoch perplexity: 2.873



100%|█████████▉| 1197/1200 [16:19<00:02,  1.23it/s]

Epoch 1197/1200
Average epoch error: 23.610
Average epoch perplexity: 2.873



100%|█████████▉| 1198/1200 [16:20<00:01,  1.21it/s]

Epoch 1198/1200
Average epoch error: 23.608
Average epoch perplexity: 2.873



100%|█████████▉| 1199/1200 [16:21<00:00,  1.21it/s]

Epoch 1199/1200
Average epoch error: 23.609
Average epoch perplexity: 2.872



100%|██████████| 1200/1200 [16:21<00:00,  1.22it/s]

Epoch 1200/1200
Average epoch error: 23.609
Average epoch perplexity: 2.872






## Hidden 256 with 3 layers results (1000 epochs) (0.25 commitment_loss)
100%|█████████▉| 999/1000 [13:46<00:00,  1.22it/s]
Epoch 999/1000
Average epoch error: 30.736
Average epoch perplexity: 2.855

100%|██████████| 1000/1000 [13:47<00:00,  1.21it/s]
Epoch 1000/1000
Average epoch error: 30.739
Average epoch perplexity: 2.855

## Hidden 256 with with 3 layers (1200 epochs) (0.35 commitment_loss)
100%|█████████▉| 1196/1200 [16:18<00:03,  1.22it/s]
Epoch 1196/1200
Average epoch error: 23.612
Average epoch perplexity: 2.873

100%|█████████▉| 1197/1200 [16:19<00:02,  1.23it/s]
Epoch 1197/1200
Average epoch error: 23.610
Average epoch perplexity: 2.873

100%|█████████▉| 1198/1200 [16:20<00:01,  1.21it/s]
Epoch 1198/1200
Average epoch error: 23.608
Average epoch perplexity: 2.873

100%|█████████▉| 1199/1200 [16:21<00:00,  1.21it/s]
Epoch 1199/1200
Average epoch error: 23.609
Average epoch perplexity: 2.872

100%|██████████| 1200/1200 [16:21<00:00,  1.22it/s]
Epoch 1200/1200
Average epoch error: 23.609
Average epoch perplexity: 2.872

## Test for different Commitment Cost [0.2 - 0.5 ] + 0.05 [200 epochs]

In [47]:
import pandas as pd

# Initialize results storage
result_loss = pd.DataFrame()
result_perplexity = pd.DataFrame()
best_commitment_cost = None
best_loss = float('inf')


# Range of commitment_cost values to test
commitment_costs = np.arange(0.2, 0.55, 0.05)

for commitment_cost in commitment_costs:
    print(f"Testing commitment_cost: {commitment_cost}")

    # Model and optimizer setup
    model = VQVAE(x_dim, embed_dim, dropout=dropout, num_embeddings=num_embeddings,
                  commitment_cost=commitment_cost, h_nodes=h_nodes, scale=scale,
                  num_layers=num_layers).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, amsgrad=False)


    # Training
    model.train()
    epoch_errors = []
    epoch_perplexities = []

    for epoch in tqdm(range(epochs)):
        batch_errors = []
        batch_perplexities = []
        for i, batch in enumerate(train_loader):
            batch_x, batch_y = batch
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            optimizer.zero_grad()

            vq_loss, data_recon, perplexity = model(batch_x)
            recon_error = nn.functional.mse_loss(data_recon, batch_x)
            loss = recon_error + vq_loss
            loss.backward()
            optimizer.step()

            batch_errors.append(recon_error.item())
            batch_perplexities.append(perplexity.item())

        epoch_errors.append(np.mean(batch_errors))
        epoch_perplexities.append(np.mean(batch_perplexities))
    
    # Save results
    result_loss[commitment_cost] = epoch_errors
    result_perplexity[commitment_cost] = epoch_perplexities
    
    # Update best commitment_cost based on average loss
    avg_loss = np.mean(epoch_errors)
    if avg_loss < best_loss:
        best_loss = avg_loss
        best_commitment_cost = commitment_cost


    print(f"Finished testing commitment_cost: {commitment_cost}. Avg loss: {avg_loss}, Avg perplexity: {np.mean(epoch_perplexities)}")
    

# Save results to disk
result_loss.to_csv('vqvae_loss_results.csv')
result_perplexity.to_csv('vqvae_perplexity_results.csv')
print(f"Results saved. Best commitment_cost: {best_commitment_cost} with loss: {best_loss}")

Testing commitment_cost: 0.2


  0%|          | 0/200 [00:00<?, ?it/s]

 10%|█         | 20/200 [00:16<02:25,  1.24it/s]


KeyboardInterrupt: 

## Test for different embed dimension [10 - 25] + 2 and num_embeddings = [25 - 75] + 5 --> using 0.25 commitment {200 epochs}

In [48]:
import pandas as pd

# Initialize results storage
best_params = {'embed_dim': None, 'num_embeddings': None}
best_loss = float('inf')
# Prepare to store results
results = {
    'embed_dim': [],
    'num_embeddings': [],
    'avg_loss': [],
    'avg_perplexity': []
}


# Range of commitment_cost values to test
embed_dims = np.arange(10, 26, 2)
num_embeddingss = np.arange(25, 76, 5)

for embed_dim in embed_dims:
    for num_embeddings in num_embeddingss:

        print(f"Training with embed_dim: {embed_dim}, num_embeddings: {num_embeddings}")

        # Model and optimizer setup
        model = VQVAE(x_dim, embed_dim, dropout=dropout, num_embeddings=num_embeddings,
                    commitment_cost=commitment_cost, h_nodes=h_nodes, scale=scale,
                    num_layers=num_layers).to(device)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay, amsgrad=False)


        # Training
        model.train()
        epoch_errors = []
        epoch_perplexities = []

        for epoch in range(epochs):
            for i, batch in enumerate(train_loader):
                batch_x, batch_y = batch
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)
                optimizer.zero_grad()

                vq_loss, data_recon, perplexity = model(batch_x)
                recon_error = nn.functional.mse_loss(data_recon, batch_x)
                loss = recon_error + vq_loss
                loss.backward()
                optimizer.step()

                epoch_errors.append(recon_error.item())
                epoch_perplexities.append(perplexity.item())
            
            print(f"Epoch {epoch + 1}/{epochs}: Avg loss: {np.mean(epoch_errors)}, Avg perplexity: {np.mean(epoch_perplexities)}")

        # Store results
        results['embed_dim'].append(embed_dim)
        results['num_embeddings'].append(num_embeddings)
        results['avg_loss'].append(np.mean(epoch_errors))
        results['avg_perplexity'].append(np.mean(epoch_perplexities))




        
        # Update best commitment_cost based on average loss
        avg_loss = np.mean(epoch_errors)
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_params['embed_dim'] = embed_dim
            best_params['num_embeddings'] = num_embeddings
    
        print(f"Average Loss for embed_dim {embed_dim} and num_embeddings {num_embeddings}: {avg_loss}, Avg perplexity: {np.mean(epoch_perplexities)}")
    
# Convert results to DataFrame
df_results = pd.DataFrame(results)
df_loss = df_results.pivot("embed_dim", "num_embeddings", "avg_loss")
df_perplexity = df_results.pivot("embed_dim", "num_embeddings", "avg_perplexity")

# Save results to CSV
df_loss.to_csv('hyperparam_loss_results.csv')
df_perplexity.to_csv('hyperparam_perplexity_results.csv')
print(f"Best parameters found: Embed_dim {best_params['embed_dim']}, Num_embeddings {best_params['num_embeddings']} with Loss: {best_loss}")

Training with embed_dim: 10, num_embeddings: 25
Epoch 1/200: Avg loss: 280.0791125488281, Avg perplexity: 2.7839516035715737
Epoch 2/200: Avg loss: 279.9283133951823, Avg perplexity: 2.0665637977917988
Epoch 3/200: Avg loss: 279.1691923014323, Avg perplexity: 1.7110425318611993
Epoch 4/200: Avg loss: 274.3326160430908, Avg perplexity: 1.5332818988958994
Epoch 5/200: Avg loss: 262.2686837768555, Avg perplexity: 1.4266255191167196
Epoch 6/200: Avg loss: 242.32931964450412, Avg perplexity: 1.3555212659305997
Epoch 7/200: Avg loss: 218.89858613513766, Avg perplexity: 1.3235923642203922
Epoch 8/200: Avg loss: 197.53814602533976, Avg perplexity: 1.3648380033175151
Epoch 9/200: Avg loss: 179.73925617076733, Avg perplexity: 1.4232386530770196
Epoch 10/200: Avg loss: 165.20167054494223, Avg perplexity: 1.4730454359054566
Epoch 11/200: Avg loss: 153.25631877783573, Avg perplexity: 1.5151683200489392
Epoch 12/200: Avg loss: 143.28238752788968, Avg perplexity: 1.5509638745254941
Epoch 13/200: Avg 

TypeError: DataFrame.pivot() takes 1 positional argument but 4 were given