# Retrieve Graph

#### Node Dictionaries
+ track_dict
+ playlist_dict

#### Edge Dictionaries
+ shared_album_edges
+ shared_artist_edges
+ shared_genre_edges
+ cosine_similarity_edges
+ contains_edges

In [1]:
from py2neo import Graph, Node, Relationship
import numpy as np
from collections import defaultdict

uri = "..."
username = "..."
password = "..."
graph = Graph(uri, auth=(username, password))

# Retrieve track nodes and their properties
track_query = '''
MATCH (t:Track)
RETURN t.acousticness, t.album_id, t.artist_ids, t.danceability, t.duration_ms, t.energy, t.explicit,
       t.genre, t.id, t.key, t.liveness, t.loudness, t.mode, t.popularity, t.speechiness, t.tempo, t.valence
'''

track_result = graph.run(track_query).data()

# Create dictionary for tracks
track_dict = {row['t.id']: {
    'acousticness': row['t.acousticness'],
    'album_id': row['t.album_id'],
    'artist_ids': row['t.artist_ids'],
    'danceability': row['t.danceability'],
    'duration_ms': row['t.duration_ms'],
    'energy': row['t.energy'],
    'explicit': row['t.explicit'],
    'genre': row['t.genre'],
    'key': row['t.key'],
    'liveness': row['t.liveness'],
    'loudness': row['t.loudness'],
    'mode': row['t.mode'],
    'popularity': row['t.popularity'],
    'speechiness': row['t.speechiness'],
    'tempo': row['t.tempo'],
    'valence': row['t.valence'],
} for row in track_result}

# Retrieve playlist nodes and their properties
playlist_query = '''
MATCH (p:Playlist)
RETURN p.playlist_id, p.tracklist
'''

playlist_result = graph.run(playlist_query).data()

# Create dictionary playlists
playlist_dict = {row['p.playlist_id']: row['p.tracklist'] for row in playlist_result}

# Create defaultdicts for relationships
shared_album_edges = defaultdict(list)
shared_artist_edges = defaultdict(list)
shared_genre_edges = defaultdict(list)
cosine_similarity_edges = defaultdict(list)
contains_edges = defaultdict(list)

# Retrieve SHARED_ALBUM relationships
shared_album_query = '''
MATCH (t1:Track)-[:SHARED_ALBUM]->(t2:Track)
RETURN t1.id, t2.id
'''
shared_album_result = graph.run(shared_album_query).data()

for row in shared_album_result:
    track1_id, track2_id = row['t1.id'], row['t2.id']
    shared_album_edges[track1_id].append(track2_id)
    shared_album_edges[track2_id].append(track1_id)

# Retrieve SHARED_ARTIST relationships
shared_artist_query = '''
MATCH (t1:Track)-[:SHARED_ARTIST]->(t2:Track)
RETURN t1.id, t2.id
'''
shared_artist_result = graph.run(shared_artist_query).data()

for row in shared_artist_result:
    track1_id, track2_id = row['t1.id'], row['t2.id']
    shared_artist_edges[track1_id].append(track2_id)
    shared_artist_edges[track2_id].append(track1_id)

# Retrieve SHARED_GENRE relationships
shared_genre_query = '''
MATCH (t1:Track)-[:SHARED_GENRE]->(t2:Track)
RETURN t1.id, t2.id
'''
shared_genre_result = graph.run(shared_genre_query).data()

for row in shared_genre_result:
    track1_id, track2_id = row['t1.id'], row['t2.id']
    shared_genre_edges[track1_id].append(track2_id)
    shared_genre_edges[track2_id].append(track1_id)
    
# Retrieve COSINE_SIMILARITY relationships
cosine_similarity_query = '''
MATCH (t1:Track)-[r:COSINE_SIMILARITY]->(t2:Track)
RETURN t1.id, t2.id, r.value
'''
cosine_similarity_result = graph.run(cosine_similarity_query).data()

for row in cosine_similarity_result:
    track1_id, track2_id, similarity = row['t1.id'], row['t2.id'], row['r.value']
    cosine_similarity_edges[track1_id].append(track2_id)
    cosine_similarity_edges[track2_id].append(track1_id)

# Retrieve CONTAINS relationships
contains_query = '''
MATCH (p:Playlist)-[:CONTAINS]->(t:Track)
RETURN p.playlist_id, t.id
'''
contains_result = graph.run(contains_query).data()

for row in contains_result:
    playlist_id, track_id = row['p.playlist_id'], row['t.id']
    contains_edges[playlist_id].append(track_id)

# Creating Adjacency Matrices

In [2]:
from scipy.sparse import lil_matrix

# Get unique track IDs
unique_track_ids = list(track_dict.keys())

# Create a track ID to index mapping
track_id_to_index = {track_id: index for index, track_id in enumerate(unique_track_ids)}

#print(track_id_to_index)

# Get unique playlist IDs
unique_playlist_ids = list(playlist_dict.keys())

# Create a playlist ID to index mapping
playlist_id_to_index = {playlist_id: index for index, playlist_id in enumerate(unique_playlist_ids)}

#print(playlist_id_to_index)

# Create adjacency matrices for each relationship
n_tracks = len(unique_track_ids)
shared_album_adj = lil_matrix((n_tracks, n_tracks), dtype=np.float32)
shared_artist_adj = lil_matrix((n_tracks, n_tracks), dtype=np.float32)
shared_genre_adj = lil_matrix((n_tracks, n_tracks), dtype=np.float32)
cosine_similarity_adj = lil_matrix((n_tracks, n_tracks), dtype=np.float32)

# Fill in the shared_album_adj matrix
for track_id, connected_track_ids in shared_album_edges.items():
    track_index = track_id_to_index[track_id]
    for connected_track_id in connected_track_ids:
        connected_track_index = track_id_to_index[connected_track_id]
        shared_album_adj[track_index, connected_track_index] = .25

# Fill in the shared_artist_adj matrix
for track_id, connected_track_ids in shared_artist_edges.items():
    track_index = track_id_to_index[track_id]
    for connected_track_id in connected_track_ids:
        connected_track_index = track_id_to_index[connected_track_id]
        shared_artist_adj[track_index, connected_track_index] = .25

# Fill in the shared_genre_adj matrix
for track_id, connected_track_ids in shared_genre_edges.items():
    track_index = track_id_to_index[track_id]
    for connected_track_id in connected_track_ids:
        connected_track_index = track_id_to_index[connected_track_id]
        shared_genre_adj[track_index, connected_track_index] = .5

# Fill in the cosine_similarity_adj matrix
for track_id, connected_track_tuples in cosine_similarity_edges.items():
    track_index = track_id_to_index[track_id]
    for connected_track_id in connected_track_tuples:
        connected_track_index = track_id_to_index[connected_track_id]
        cosine_similarity_adj[track_index, connected_track_index] = .95


# KGAT Model Implementation: PyTorch

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math

class KGATLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(KGATLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        #self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.weight = nn.Parameter(torch.randn(in_features, out_features) * 0.1)  # multiply by 0.1 to increase values
        self.reset_parameters()

#    def reset_parameters(self):
#        gain = nn.init.calculate_gain('relu') * math.sqrt(3)
#        nn.init.uniform_(self.weight, -10, 10)
#        self.weight.data.mul_(gain)

    def reset_parameters(self):
        #print(self.weight)
        #nn.Parameter(torch.randn(11, out_features) * 0.1)  # multiply by 0.1 to increase values
        nn.init.kaiming_uniform_(self.weight, mode='fan_in', nonlinearity='relu')
        #print(self.weight)

    def forward(self, adjacency_matrix, input_features):
        output_features = torch.mm(input_features, self.weight)
        return torch.mm(adjacency_matrix, output_features)

class KGAT(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_layers):
        super(KGAT, self).__init__()
        self.input_bn = nn.BatchNorm1d(hidden_features)
        self.hidden_bns = nn.ModuleList([nn.BatchNorm1d(hidden_features) for _ in range(num_layers - 1)])
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.num_layers = num_layers

        # Define the input layer
        self.input_layer = KGATLayer(in_features, hidden_features)

        # Define the hidden layers
        self.hidden_layers = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.hidden_layers.append(KGATLayer(hidden_features, hidden_features))

        # Define the output layer
        self.output_layer = KGATLayer(hidden_features, out_features)

    def forward(self, adjacency_matrix, input_features):
        x = self.input_layer(adjacency_matrix, input_features)
        x = x.unsqueeze(2)  # Add an extra dimension for Batch Normalization
        x = F.relu(self.input_bn(x))
        x = x.squeeze(2)  # Remove the extra dimension
        for i, layer in enumerate(self.hidden_layers):
            x = layer(adjacency_matrix, x)
            x = x.unsqueeze(2)  # Add an extra dimension for Batch Normalization
            x = F.relu(self.hidden_bns[i](x))
            x = x.squeeze(2)  # Remove the extra dimension

        # Pass through the output layer
        x = self.output_layer(adjacency_matrix, x)

        return x        

hidden_features = 64  # Number of hidden features in the KGAT model
out_features = 1  # Number of output features in the KGAT model
num_layers = 2  # Number of layers in the KGAT model

numeric_keys = [
    'acousticness', 'danceability', 'duration_ms', 'energy', 'explicit',
    'liveness', 'loudness', 'popularity', 'speechiness', 'tempo', 'valence'
]

model = KGAT(num_layers=num_layers, in_features=len(numeric_keys), hidden_features=hidden_features, out_features=out_features)



# Prepare Model

In [11]:
import random
from sklearn.model_selection import train_test_split
from tqdm import tqdm

contains_edges_flat = [(playlist_id, track_id) for playlist_id, track_ids in contains_edges.items() for track_id in track_ids]
train_edges, val_edges = train_test_split(contains_edges_flat, test_size=0.2, random_state=42)

from torch.utils.data import Dataset, DataLoader

class MusicDataset(Dataset):
    def __init__(self, edges, playlist_id_to_index, track_id_to_index):
        self.edges = edges
        self.playlist_id_to_index = playlist_id_to_index
        self.track_id_to_index = track_id_to_index

    def __len__(self):
        return len(self.edges)

    def __getitem__(self, idx):
        playlist_id, track_id = self.edges[idx]
        playlist_index = self.playlist_id_to_index[playlist_id]
        track_index = self.track_id_to_index[track_id]
        return playlist_index, track_index

train_dataset = MusicDataset(train_edges, playlist_id_to_index, track_id_to_index)
val_dataset = MusicDataset(val_edges, playlist_id_to_index, track_id_to_index)

batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

loss_function = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss with logits for link prediction
#optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Adam optimizer with a learning rate of 'lr'
optimizer = torch.optim.SGD(model.parameters(), lr=0.0002, momentum=0.90)
num_epochs = 50

# Convert adjacency matrices to PyTorch tensors
shared_album_adj_tensor = torch.FloatTensor(shared_album_adj.toarray())
shared_artist_adj_tensor = torch.FloatTensor(shared_artist_adj.toarray())
shared_genre_adj_tensor = torch.FloatTensor(shared_genre_adj.toarray())
cosine_similarity_adj_tensor = torch.FloatTensor(cosine_similarity_adj.toarray())

# Function to handle tensor conversion for either array or string values
def process_feature(feature):
    if isinstance(feature, (list, tuple, np.ndarray)):
        return np.array(feature, dtype=np.float32)
    else:
        return np.float32(feature)

# Convert track features to PyTorch tensor
track_features = torch.FloatTensor([[process_feature(track_dict[track_id][key]) for key in numeric_keys] for track_id in unique_track_ids])

# Set device to use
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move model and data to the selected device
model = model.to(device)
shared_album_adj_tensor = shared_album_adj_tensor.to(device)
shared_artist_adj_tensor = shared_artist_adj_tensor.to(device)
shared_genre_adj_tensor = shared_genre_adj_tensor.to(device)
cosine_similarity_adj_tensor = cosine_similarity_adj_tensor.to(device)
track_features = track_features.to(device)

# Train KGAT Model

vanishing gradient problem in training this model.

Weights are initialized, but after passing each batch through the training loop the weights are equal to zero for the input layer, hidden layer, and output layers.

In [6]:
# Training loop
for epoch in tqdm(range(num_epochs)):
    model.train()
    total_loss = 0
    
    count = 0
    
    for batch_idx, (playlist_indices, track_indices) in enumerate(train_loader):
        playlist_indices = playlist_indices.to(device)
        track_indices = track_indices.to(device)

        optimizer.zero_grad()
        
        # Obtain embeddings from the KGAT model
        track_embeddings_shared_album = model(shared_album_adj_tensor, track_features)
        track_embeddings_shared_artist = model(shared_artist_adj_tensor, track_features)
        track_embeddings_shared_genre = model(shared_genre_adj_tensor, track_features)
        track_embeddings_cosine_similarity = model(cosine_similarity_adj_tensor, track_features)

        # Combine the embeddings from different relationships
        track_embeddings = track_embeddings_shared_album + track_embeddings_shared_artist + track_embeddings_shared_genre + track_embeddings_cosine_similarity

        # Compute the predicted probability of a track belonging to a playlist
        predicted_probs = torch.sigmoid(torch.sum(track_embeddings[track_indices] * track_embeddings[playlist_indices], dim=1))

        # Compute the ground truth for the current batch
        ground_truth = torch.tensor([1.0] * len(track_indices), device=device)

        # Compute the loss
        loss = loss_function(predicted_probs, ground_truth)

        # Perform backpropagation and optimization
        loss.backward()
        
        if ((count % 100) == 0) or ((count % 100) == 1) or ((count % 100) == 2):
            for name, param in model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    print(f'Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Gradient of {name}: {param.grad.mean()}')
        count += 1
        
        optimizer.step()

        total_loss += loss.item()

    # Compute the average loss for this epoch
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}')

    # Evaluate the model on the validation set
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for playlist_indices, track_indices in val_loader:
            playlist_indices = playlist_indices.to(device)
            track_indices = track_indices.to(device)

            # Obtain embeddings from the KGAT model
            track_embeddings_shared_album = model(shared_album_adj_tensor, track_features)
            track_embeddings_shared_artist = model(shared_artist_adj_tensor, track_features)
            track_embeddings_shared_genre = model(shared_genre_adj_tensor, track_features)
            track_embeddings_cosine_similarity = model(cosine_similarity_adj_tensor, track_features)

            # Combine the embeddings from different relationships
            track_embeddings = track_embeddings_shared_album + track_embeddings_shared_artist + track_embeddings_shared_genre + track_embeddings_cosine_similarity

            # Compute the predicted probability of a track belonging to a playlist
            predicted_probs = torch.sigmoid(torch.sum(track_embeddings[track_indices] * track_embeddings[playlist_indices], dim=1))

            # Compute the ground truth for the current batch
            ground_truth = torch.tensor([1.0] * len(track_indices), device=device)

            # Compute the loss
            loss = loss_function(predicted_probs, ground_truth)

            val_loss += loss.item()

    # Compute the average validation loss for this epoch
    avg_val_loss = val_loss / len(val_loader)
    print(f'Validation Loss: {avg_val_loss:.4f}')


  0%|                                                    | 0/50 [00:00<?, ?it/s]

Epoch [1/50], Batch [1/184], Gradient of input_bn.weight: -3.808236215263605e-08
Epoch [1/50], Batch [1/184], Gradient of input_bn.bias: -0.000841515779029578
Epoch [1/50], Batch [1/184], Gradient of hidden_bns.0.weight: 0.00025379256112501025
Epoch [1/50], Batch [1/184], Gradient of hidden_bns.0.bias: 0.0005878241499885917
Epoch [1/50], Batch [1/184], Gradient of input_layer.weight: 8.912380508263595e-06
Epoch [1/50], Batch [1/184], Gradient of hidden_layers.0.weight: -0.0007725466275587678
Epoch [1/50], Batch [1/184], Gradient of output_layer.weight: 0.0027055959217250347
Epoch [1/50], Batch [2/184], Gradient of input_bn.weight: -1.4295801520347595e-07
Epoch [1/50], Batch [2/184], Gradient of input_bn.bias: -0.00036191369872540236
Epoch [1/50], Batch [2/184], Gradient of hidden_bns.0.weight: 0.0002734675072133541
Epoch [1/50], Batch [2/184], Gradient of hidden_bns.0.bias: 0.0007899461779743433
Epoch [1/50], Batch [2/184], Gradient of input_layer.weight: -1.3507545190805104e-05
Epoch 

  2%|▊                                        | 1/50 [09:24<7:41:05, 564.59s/it]

Validation Loss: 0.3339
Epoch [2/50], Batch [1/184], Gradient of input_bn.weight: 1.3348221727937926e-06
Epoch [2/50], Batch [1/184], Gradient of input_bn.bias: -1.364789204671979e-05
Epoch [2/50], Batch [1/184], Gradient of hidden_bns.0.weight: 4.159519448876381e-05
Epoch [2/50], Batch [1/184], Gradient of hidden_bns.0.bias: -0.00016994297038763762
Epoch [2/50], Batch [1/184], Gradient of input_layer.weight: 4.793313337358995e-07
Epoch [2/50], Batch [1/184], Gradient of hidden_layers.0.weight: -3.384886804269627e-05
Epoch [2/50], Batch [1/184], Gradient of output_layer.weight: 0.0012177529279142618
Epoch [2/50], Batch [2/184], Gradient of input_bn.weight: 7.86027612775797e-07
Epoch [2/50], Batch [2/184], Gradient of input_bn.bias: -5.620004958473146e-05
Epoch [2/50], Batch [2/184], Gradient of hidden_bns.0.weight: 3.4782078728312626e-05
Epoch [2/50], Batch [2/184], Gradient of hidden_bns.0.bias: 0.00019568981952033937
Epoch [2/50], Batch [2/184], Gradient of input_layer.weight: -1.852

  4%|█▋                                       | 2/50 [18:47<7:30:41, 563.37s/it]

Validation Loss: 0.3324
Epoch [3/50], Batch [1/184], Gradient of input_bn.weight: 9.706027412903495e-07
Epoch [3/50], Batch [1/184], Gradient of input_bn.bias: -4.602269473252818e-05
Epoch [3/50], Batch [1/184], Gradient of hidden_bns.0.weight: 2.1219471818767488e-05
Epoch [3/50], Batch [1/184], Gradient of hidden_bns.0.bias: 0.00011594851093832403
Epoch [3/50], Batch [1/184], Gradient of input_layer.weight: 5.5219800287886756e-08
Epoch [3/50], Batch [1/184], Gradient of hidden_layers.0.weight: -3.638220005086623e-05
Epoch [3/50], Batch [1/184], Gradient of output_layer.weight: 0.0007321105804294348
Epoch [3/50], Batch [2/184], Gradient of input_bn.weight: 7.354992703767493e-07
Epoch [3/50], Batch [2/184], Gradient of input_bn.bias: -3.553083661245182e-05
Epoch [3/50], Batch [2/184], Gradient of hidden_bns.0.weight: 3.1379819120047614e-05
Epoch [3/50], Batch [2/184], Gradient of hidden_bns.0.bias: 0.00023245456395670772
Epoch [3/50], Batch [2/184], Gradient of input_layer.weight: 1.247

  6%|██▍                                      | 3/50 [28:10<7:21:26, 563.55s/it]

Validation Loss: 0.3314
Epoch [4/50], Batch [1/184], Gradient of input_bn.weight: 4.712364898296073e-07
Epoch [4/50], Batch [1/184], Gradient of input_bn.bias: -5.8763248489412945e-06
Epoch [4/50], Batch [1/184], Gradient of hidden_bns.0.weight: 5.682967639586423e-06
Epoch [4/50], Batch [1/184], Gradient of hidden_bns.0.bias: 2.104491795762442e-05
Epoch [4/50], Batch [1/184], Gradient of input_layer.weight: -4.1117484528285786e-08
Epoch [4/50], Batch [1/184], Gradient of hidden_layers.0.weight: -0.00011322428326820955
Epoch [4/50], Batch [1/184], Gradient of output_layer.weight: 0.00018000754062086344
Epoch [4/50], Batch [2/184], Gradient of input_bn.weight: 6.750451575499028e-07
Epoch [4/50], Batch [2/184], Gradient of input_bn.bias: -8.80432162375655e-06
Epoch [4/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.944655923580285e-05
Epoch [4/50], Batch [2/184], Gradient of hidden_bns.0.bias: 0.00015291596355382353
Epoch [4/50], Batch [2/184], Gradient of input_layer.weight: 1.827

  8%|███▎                                     | 4/50 [37:34<7:12:08, 563.66s/it]

Validation Loss: 0.3308
Epoch [5/50], Batch [1/184], Gradient of input_bn.weight: 5.742112989537418e-07
Epoch [5/50], Batch [1/184], Gradient of input_bn.bias: 1.4839635696262121e-05
Epoch [5/50], Batch [1/184], Gradient of hidden_bns.0.weight: -3.905783160007559e-06
Epoch [5/50], Batch [1/184], Gradient of hidden_bns.0.bias: -1.4182267477735877e-05
Epoch [5/50], Batch [1/184], Gradient of input_layer.weight: 6.927241003040763e-08
Epoch [5/50], Batch [1/184], Gradient of hidden_layers.0.weight: -0.00012684731336776167
Epoch [5/50], Batch [1/184], Gradient of output_layer.weight: 0.00016547575069125742
Epoch [5/50], Batch [2/184], Gradient of input_bn.weight: 6.10751158092171e-07
Epoch [5/50], Batch [2/184], Gradient of input_bn.bias: 2.978740303660743e-05
Epoch [5/50], Batch [2/184], Gradient of hidden_bns.0.weight: -5.093953404866625e-06
Epoch [5/50], Batch [2/184], Gradient of hidden_bns.0.bias: -0.00023410373250953853
Epoch [5/50], Batch [2/184], Gradient of input_layer.weight: -2.2

 10%|████                                     | 5/50 [46:58<7:02:42, 563.61s/it]

Validation Loss: 0.3306
Epoch [6/50], Batch [1/184], Gradient of input_bn.weight: 4.995090421289206e-07
Epoch [6/50], Batch [1/184], Gradient of input_bn.bias: 2.9018956411164254e-06
Epoch [6/50], Batch [1/184], Gradient of hidden_bns.0.weight: 2.035294528468512e-07
Epoch [6/50], Batch [1/184], Gradient of hidden_bns.0.bias: -3.0357077775988728e-06
Epoch [6/50], Batch [1/184], Gradient of input_layer.weight: -5.36750803803443e-07
Epoch [6/50], Batch [1/184], Gradient of hidden_layers.0.weight: -0.0002927831083070487
Epoch [6/50], Batch [1/184], Gradient of output_layer.weight: 0.00016322790179401636
Epoch [6/50], Batch [2/184], Gradient of input_bn.weight: 4.3645741243381053e-07
Epoch [6/50], Batch [2/184], Gradient of input_bn.bias: 1.9440822143224068e-05
Epoch [6/50], Batch [2/184], Gradient of hidden_bns.0.weight: 4.5044548642181326e-06
Epoch [6/50], Batch [2/184], Gradient of hidden_bns.0.bias: 6.98606891091913e-06
Epoch [6/50], Batch [2/184], Gradient of input_layer.weight: 5.8448

 12%|████▉                                    | 6/50 [56:21<6:53:20, 563.65s/it]

Validation Loss: 0.3304
Epoch [7/50], Batch [1/184], Gradient of input_bn.weight: 2.633037183841225e-07
Epoch [7/50], Batch [1/184], Gradient of input_bn.bias: 2.1415569790406153e-05
Epoch [7/50], Batch [1/184], Gradient of hidden_bns.0.weight: 1.4121360436547548e-07
Epoch [7/50], Batch [1/184], Gradient of hidden_bns.0.bias: 5.383425741456449e-05
Epoch [7/50], Batch [1/184], Gradient of input_layer.weight: 5.565748395497394e-08
Epoch [7/50], Batch [1/184], Gradient of hidden_layers.0.weight: -2.5462755729677156e-05
Epoch [7/50], Batch [1/184], Gradient of output_layer.weight: 0.00021921374718658626
Epoch [7/50], Batch [2/184], Gradient of input_bn.weight: 2.534998202463612e-07
Epoch [7/50], Batch [2/184], Gradient of input_bn.bias: 4.8592992243357e-06
Epoch [7/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.2151454029663e-06
Epoch [7/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.770736707840115e-05
Epoch [7/50], Batch [2/184], Gradient of input_layer.weight: 3.8307966576

 14%|█████▍                                 | 7/50 [1:05:44<6:43:46, 563.42s/it]

Validation Loss: 0.3303
Epoch [8/50], Batch [1/184], Gradient of input_bn.weight: 3.5825905797537416e-07
Epoch [8/50], Batch [1/184], Gradient of input_bn.bias: 2.5694658688735217e-05
Epoch [8/50], Batch [1/184], Gradient of hidden_bns.0.weight: -6.993291208345909e-06
Epoch [8/50], Batch [1/184], Gradient of hidden_bns.0.bias: 5.323992809280753e-05
Epoch [8/50], Batch [1/184], Gradient of input_layer.weight: -9.338127426872234e-08
Epoch [8/50], Batch [1/184], Gradient of hidden_layers.0.weight: 7.054176967358217e-05
Epoch [8/50], Batch [1/184], Gradient of output_layer.weight: 0.00024572928668931127
Epoch [8/50], Batch [2/184], Gradient of input_bn.weight: 1.4151737559586763e-09
Epoch [8/50], Batch [2/184], Gradient of input_bn.bias: 7.94402149040252e-06
Epoch [8/50], Batch [2/184], Gradient of hidden_bns.0.weight: -6.380903869285248e-06
Epoch [8/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.365713615086861e-05
Epoch [8/50], Batch [2/184], Gradient of input_layer.weight: -6.4770

 16%|██████▏                                | 8/50 [1:15:08<6:34:26, 563.49s/it]

Validation Loss: 0.3302
Epoch [9/50], Batch [1/184], Gradient of input_bn.weight: -2.117303665727377e-09
Epoch [9/50], Batch [1/184], Gradient of input_bn.bias: 1.5258105122484267e-05
Epoch [9/50], Batch [1/184], Gradient of hidden_bns.0.weight: -9.598570613889024e-07
Epoch [9/50], Batch [1/184], Gradient of hidden_bns.0.bias: 4.21161312260665e-05
Epoch [9/50], Batch [1/184], Gradient of input_layer.weight: -2.4802136522339424e-07
Epoch [9/50], Batch [1/184], Gradient of hidden_layers.0.weight: 0.00012793665518984199
Epoch [9/50], Batch [1/184], Gradient of output_layer.weight: 0.00019976455951109529
Epoch [9/50], Batch [2/184], Gradient of input_bn.weight: 6.78028300171718e-08
Epoch [9/50], Batch [2/184], Gradient of input_bn.bias: 1.0252535503241234e-05
Epoch [9/50], Batch [2/184], Gradient of hidden_bns.0.weight: 3.3480200727353804e-06
Epoch [9/50], Batch [2/184], Gradient of hidden_bns.0.bias: 5.013913687434979e-05
Epoch [9/50], Batch [2/184], Gradient of input_layer.weight: 3.3181

 18%|███████                                | 9/50 [1:24:31<6:24:55, 563.31s/it]

Validation Loss: 0.3301
Epoch [10/50], Batch [1/184], Gradient of input_bn.weight: -1.935950422193855e-08
Epoch [10/50], Batch [1/184], Gradient of input_bn.bias: 8.752641406317707e-06
Epoch [10/50], Batch [1/184], Gradient of hidden_bns.0.weight: 3.1177394248516066e-06
Epoch [10/50], Batch [1/184], Gradient of hidden_bns.0.bias: 2.008497904171236e-05
Epoch [10/50], Batch [1/184], Gradient of input_layer.weight: -1.5065877434494723e-08
Epoch [10/50], Batch [1/184], Gradient of hidden_layers.0.weight: 6.107502849772573e-05
Epoch [10/50], Batch [1/184], Gradient of output_layer.weight: 5.9377020079409704e-05
Epoch [10/50], Batch [2/184], Gradient of input_bn.weight: 3.4781805879902095e-07
Epoch [10/50], Batch [2/184], Gradient of input_bn.bias: 1.4870195627736393e-05
Epoch [10/50], Batch [2/184], Gradient of hidden_bns.0.weight: 4.081247880094452e-06
Epoch [10/50], Batch [2/184], Gradient of hidden_bns.0.bias: 3.90399873140268e-06
Epoch [10/50], Batch [2/184], Gradient of input_layer.wei

 20%|███████▌                              | 10/50 [1:33:54<6:15:32, 563.32s/it]

Validation Loss: 0.3300
Epoch [11/50], Batch [1/184], Gradient of input_bn.weight: -4.986142130292137e-07
Epoch [11/50], Batch [1/184], Gradient of input_bn.bias: -5.171541488380171e-07
Epoch [11/50], Batch [1/184], Gradient of hidden_bns.0.weight: -6.353145181492437e-06
Epoch [11/50], Batch [1/184], Gradient of hidden_bns.0.bias: 4.061673462274484e-05
Epoch [11/50], Batch [1/184], Gradient of input_layer.weight: -1.5084635833773063e-06
Epoch [11/50], Batch [1/184], Gradient of hidden_layers.0.weight: 1.008653543976834e-05
Epoch [11/50], Batch [1/184], Gradient of output_layer.weight: 0.00019026042718905956
Epoch [11/50], Batch [2/184], Gradient of input_bn.weight: 5.041874828748405e-08
Epoch [11/50], Batch [2/184], Gradient of input_bn.bias: 1.3271036550577264e-05
Epoch [11/50], Batch [2/184], Gradient of hidden_bns.0.weight: 5.845718078489881e-06
Epoch [11/50], Batch [2/184], Gradient of hidden_bns.0.bias: 5.540752681554295e-05
Epoch [11/50], Batch [2/184], Gradient of input_layer.we

 22%|████████▎                             | 11/50 [1:43:18<6:06:09, 563.32s/it]

Validation Loss: 0.3300
Epoch [12/50], Batch [1/184], Gradient of input_bn.weight: 4.093410552741261e-07
Epoch [12/50], Batch [1/184], Gradient of input_bn.bias: 9.36205469770357e-06
Epoch [12/50], Batch [1/184], Gradient of hidden_bns.0.weight: 1.7085994841181673e-07
Epoch [12/50], Batch [1/184], Gradient of hidden_bns.0.bias: -1.0985291737597436e-05
Epoch [12/50], Batch [1/184], Gradient of input_layer.weight: 3.533287795676188e-08
Epoch [12/50], Batch [1/184], Gradient of hidden_layers.0.weight: -9.64964692684589e-06
Epoch [12/50], Batch [1/184], Gradient of output_layer.weight: 0.00013395554560702294
Epoch [12/50], Batch [2/184], Gradient of input_bn.weight: 9.136274456977844e-07
Epoch [12/50], Batch [2/184], Gradient of input_bn.bias: 8.504845027346164e-06
Epoch [12/50], Batch [2/184], Gradient of hidden_bns.0.weight: 2.0749612303916365e-06
Epoch [12/50], Batch [2/184], Gradient of hidden_bns.0.bias: -0.00022173258184920996
Epoch [12/50], Batch [2/184], Gradient of input_layer.wei

 24%|█████████                             | 12/50 [1:52:40<5:56:37, 563.09s/it]

Validation Loss: 0.3299
Epoch [13/50], Batch [1/184], Gradient of input_bn.weight: 2.906399458879605e-07
Epoch [13/50], Batch [1/184], Gradient of input_bn.bias: 1.0701709470595233e-05
Epoch [13/50], Batch [1/184], Gradient of hidden_bns.0.weight: 2.5283879949711263e-06
Epoch [13/50], Batch [1/184], Gradient of hidden_bns.0.bias: -7.990040467120707e-05
Epoch [13/50], Batch [1/184], Gradient of input_layer.weight: 5.047905915489537e-08
Epoch [13/50], Batch [1/184], Gradient of hidden_layers.0.weight: -7.73692227085121e-05
Epoch [13/50], Batch [1/184], Gradient of output_layer.weight: 6.557972665177658e-05
Epoch [13/50], Batch [2/184], Gradient of input_bn.weight: 1.432490535080433e-07
Epoch [13/50], Batch [2/184], Gradient of input_bn.bias: 6.580229637620505e-06
Epoch [13/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.4891797945892904e-06
Epoch [13/50], Batch [2/184], Gradient of hidden_bns.0.bias: 3.1558760383632034e-05
Epoch [13/50], Batch [2/184], Gradient of input_layer.weig

 26%|█████████▉                            | 13/50 [2:02:04<5:47:26, 563.41s/it]

Validation Loss: 0.3299
Epoch [14/50], Batch [1/184], Gradient of input_bn.weight: 5.605306796496734e-07
Epoch [14/50], Batch [1/184], Gradient of input_bn.bias: 1.1135524800920393e-05
Epoch [14/50], Batch [1/184], Gradient of hidden_bns.0.weight: 3.108847522526048e-06
Epoch [14/50], Batch [1/184], Gradient of hidden_bns.0.bias: 4.9143734941026196e-05
Epoch [14/50], Batch [1/184], Gradient of input_layer.weight: -1.5226584437755264e-08
Epoch [14/50], Batch [1/184], Gradient of hidden_layers.0.weight: -5.0072329031536356e-05
Epoch [14/50], Batch [1/184], Gradient of output_layer.weight: 0.00021510380611289293
Epoch [14/50], Batch [2/184], Gradient of input_bn.weight: 4.965295374859124e-07
Epoch [14/50], Batch [2/184], Gradient of input_bn.bias: -3.34552169078961e-07
Epoch [14/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.392164449498523e-06
Epoch [14/50], Batch [2/184], Gradient of hidden_bns.0.bias: 4.5780740038026124e-05
Epoch [14/50], Batch [2/184], Gradient of input_layer.w

 28%|██████████▋                           | 14/50 [2:11:28<5:38:00, 563.35s/it]

Validation Loss: 0.3298
Epoch [15/50], Batch [1/184], Gradient of input_bn.weight: 1.9514118321239948e-07
Epoch [15/50], Batch [1/184], Gradient of input_bn.bias: 1.3385997590376064e-05
Epoch [15/50], Batch [1/184], Gradient of hidden_bns.0.weight: 2.9490784072550014e-06
Epoch [15/50], Batch [1/184], Gradient of hidden_bns.0.bias: 5.794630487798713e-05
Epoch [15/50], Batch [1/184], Gradient of input_layer.weight: -1.660568038630572e-08
Epoch [15/50], Batch [1/184], Gradient of hidden_layers.0.weight: 2.3540853362646885e-05
Epoch [15/50], Batch [1/184], Gradient of output_layer.weight: 0.00018321526295039803
Epoch [15/50], Batch [2/184], Gradient of input_bn.weight: 2.745837264228612e-07
Epoch [15/50], Batch [2/184], Gradient of input_bn.bias: 7.959901267895475e-06
Epoch [15/50], Batch [2/184], Gradient of hidden_bns.0.weight: -3.3363448892487213e-07
Epoch [15/50], Batch [2/184], Gradient of hidden_bns.0.bias: 2.2611606254940853e-05
Epoch [15/50], Batch [2/184], Gradient of input_layer.

 30%|███████████▍                          | 15/50 [2:20:51<5:28:35, 563.29s/it]

Validation Loss: 0.3298
Epoch [16/50], Batch [1/184], Gradient of input_bn.weight: 1.8689024727791548e-07
Epoch [16/50], Batch [1/184], Gradient of input_bn.bias: 4.771240128320642e-06
Epoch [16/50], Batch [1/184], Gradient of hidden_bns.0.weight: -2.028746621363098e-06
Epoch [16/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.3767948985332623e-05
Epoch [16/50], Batch [1/184], Gradient of input_layer.weight: -2.4059174208446166e-08
Epoch [16/50], Batch [1/184], Gradient of hidden_layers.0.weight: 7.368755905190483e-05
Epoch [16/50], Batch [1/184], Gradient of output_layer.weight: 7.148238364607096e-05
Epoch [16/50], Batch [2/184], Gradient of input_bn.weight: 9.347422746941447e-08
Epoch [16/50], Batch [2/184], Gradient of input_bn.bias: -1.0037797437689733e-06
Epoch [16/50], Batch [2/184], Gradient of hidden_bns.0.weight: 6.607126579183387e-07
Epoch [16/50], Batch [2/184], Gradient of hidden_bns.0.bias: 9.144989235210232e-06
Epoch [16/50], Batch [2/184], Gradient of input_layer.we

 32%|████████████▏                         | 16/50 [2:30:14<5:19:17, 563.45s/it]

Validation Loss: 0.3298
Epoch [17/50], Batch [1/184], Gradient of input_bn.weight: 3.467721398919821e-07
Epoch [17/50], Batch [1/184], Gradient of input_bn.bias: 6.983500497881323e-06
Epoch [17/50], Batch [1/184], Gradient of hidden_bns.0.weight: -1.972919562831521e-06
Epoch [17/50], Batch [1/184], Gradient of hidden_bns.0.bias: 6.311391189228743e-06
Epoch [17/50], Batch [1/184], Gradient of input_layer.weight: -1.0659809390745068e-07
Epoch [17/50], Batch [1/184], Gradient of hidden_layers.0.weight: 6.113076960900798e-05
Epoch [17/50], Batch [1/184], Gradient of output_layer.weight: 0.00012269265425857157
Epoch [17/50], Batch [2/184], Gradient of input_bn.weight: -1.0262738214805722e-07
Epoch [17/50], Batch [2/184], Gradient of input_bn.bias: 7.467942850780673e-06
Epoch [17/50], Batch [2/184], Gradient of hidden_bns.0.weight: -1.8523064682085533e-06
Epoch [17/50], Batch [2/184], Gradient of hidden_bns.0.bias: -2.5294972147094086e-05
Epoch [17/50], Batch [2/184], Gradient of input_layer

 34%|████████████▉                         | 17/50 [2:39:38<5:09:50, 563.36s/it]

Validation Loss: 0.3297
Epoch [18/50], Batch [1/184], Gradient of input_bn.weight: 1.879748197097797e-07
Epoch [18/50], Batch [1/184], Gradient of input_bn.bias: 1.585221070854459e-06
Epoch [18/50], Batch [1/184], Gradient of hidden_bns.0.weight: -2.7316743853589287e-06
Epoch [18/50], Batch [1/184], Gradient of hidden_bns.0.bias: 7.072448170220014e-06
Epoch [18/50], Batch [1/184], Gradient of input_layer.weight: 6.950268272021276e-08
Epoch [18/50], Batch [1/184], Gradient of hidden_layers.0.weight: 1.8488412024453282e-05
Epoch [18/50], Batch [1/184], Gradient of output_layer.weight: 6.985709478612989e-05
Epoch [18/50], Batch [2/184], Gradient of input_bn.weight: 5.23546987096779e-07
Epoch [18/50], Batch [2/184], Gradient of input_bn.bias: 2.8899066819576547e-05
Epoch [18/50], Batch [2/184], Gradient of hidden_bns.0.weight: -2.206837962148711e-06
Epoch [18/50], Batch [2/184], Gradient of hidden_bns.0.bias: 3.7116624298505485e-06
Epoch [18/50], Batch [2/184], Gradient of input_layer.weig

 36%|█████████████▋                        | 18/50 [2:49:01<5:00:25, 563.31s/it]

Validation Loss: 0.3297
Epoch [19/50], Batch [1/184], Gradient of input_bn.weight: 2.5410645321244374e-07
Epoch [19/50], Batch [1/184], Gradient of input_bn.bias: 6.735758688591886e-06
Epoch [19/50], Batch [1/184], Gradient of hidden_bns.0.weight: -4.352547875896562e-06
Epoch [19/50], Batch [1/184], Gradient of hidden_bns.0.bias: -2.2760204956284724e-05
Epoch [19/50], Batch [1/184], Gradient of input_layer.weight: -3.731013009655726e-08
Epoch [19/50], Batch [1/184], Gradient of hidden_layers.0.weight: -3.2826246751938015e-05
Epoch [19/50], Batch [1/184], Gradient of output_layer.weight: 7.938381895655766e-05
Epoch [19/50], Batch [2/184], Gradient of input_bn.weight: 1.5938894648570567e-07
Epoch [19/50], Batch [2/184], Gradient of input_bn.bias: 7.05244474374922e-07
Epoch [19/50], Batch [2/184], Gradient of hidden_bns.0.weight: -1.3630588000523858e-06
Epoch [19/50], Batch [2/184], Gradient of hidden_bns.0.bias: 8.296373380289879e-06
Epoch [19/50], Batch [2/184], Gradient of input_layer.

 38%|██████████████▍                       | 19/50 [2:58:25<4:51:09, 563.52s/it]

Validation Loss: 0.3297
Epoch [20/50], Batch [1/184], Gradient of input_bn.weight: 7.309472493943758e-08
Epoch [20/50], Batch [1/184], Gradient of input_bn.bias: 3.1860358831181657e-06
Epoch [20/50], Batch [1/184], Gradient of hidden_bns.0.weight: 1.1410738807171583e-06
Epoch [20/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.3237175153335556e-05
Epoch [20/50], Batch [1/184], Gradient of input_layer.weight: -1.6094436006142132e-08
Epoch [20/50], Batch [1/184], Gradient of hidden_layers.0.weight: -1.730658914311789e-05
Epoch [20/50], Batch [1/184], Gradient of output_layer.weight: 5.497224992723204e-05
Epoch [20/50], Batch [2/184], Gradient of input_bn.weight: -3.152963472530246e-07
Epoch [20/50], Batch [2/184], Gradient of input_bn.bias: 3.7145382520975545e-06
Epoch [20/50], Batch [2/184], Gradient of hidden_bns.0.weight: -1.7571483112988062e-06
Epoch [20/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.511331538495142e-05
Epoch [20/50], Batch [2/184], Gradient of input_layer

 40%|███████████████▏                      | 20/50 [3:07:48<4:41:43, 563.46s/it]

Validation Loss: 0.3297
Epoch [21/50], Batch [1/184], Gradient of input_bn.weight: 2.0041989046148956e-07
Epoch [21/50], Batch [1/184], Gradient of input_bn.bias: 4.904994511889527e-06
Epoch [21/50], Batch [1/184], Gradient of hidden_bns.0.weight: 1.321946911048144e-06
Epoch [21/50], Batch [1/184], Gradient of hidden_bns.0.bias: 4.295600228942931e-05
Epoch [21/50], Batch [1/184], Gradient of input_layer.weight: 9.117091082089246e-09
Epoch [21/50], Batch [1/184], Gradient of hidden_layers.0.weight: 4.653566065826453e-05
Epoch [21/50], Batch [1/184], Gradient of output_layer.weight: 0.00013104041863698512
Epoch [21/50], Batch [2/184], Gradient of input_bn.weight: 1.570333552081138e-07
Epoch [21/50], Batch [2/184], Gradient of input_bn.bias: 1.9884988432750106e-05
Epoch [21/50], Batch [2/184], Gradient of hidden_bns.0.weight: 2.0570741980918683e-06
Epoch [21/50], Batch [2/184], Gradient of hidden_bns.0.bias: 2.111992216669023e-05
Epoch [21/50], Batch [2/184], Gradient of input_layer.weigh

 42%|███████████████▉                      | 21/50 [3:17:12<4:32:23, 563.58s/it]

Validation Loss: 0.3296
Epoch [22/50], Batch [1/184], Gradient of input_bn.weight: 3.4699496609391645e-08
Epoch [22/50], Batch [1/184], Gradient of input_bn.bias: 7.237132422233117e-07
Epoch [22/50], Batch [1/184], Gradient of hidden_bns.0.weight: 1.4879542504786514e-06
Epoch [22/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.1723149327735882e-05
Epoch [22/50], Batch [1/184], Gradient of input_layer.weight: -2.9845885762824764e-08
Epoch [22/50], Batch [1/184], Gradient of hidden_layers.0.weight: 1.469076505600242e-05
Epoch [22/50], Batch [1/184], Gradient of output_layer.weight: 4.0547984099248424e-05
Epoch [22/50], Batch [2/184], Gradient of input_bn.weight: 6.334539648378268e-08
Epoch [22/50], Batch [2/184], Gradient of input_bn.bias: 4.1435359889874235e-06
Epoch [22/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.807247826945968e-07
Epoch [22/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.918058842420578e-05
Epoch [22/50], Batch [2/184], Gradient of input_layer.we

 44%|████████████████▋                     | 22/50 [3:26:36<4:23:00, 563.59s/it]

Validation Loss: 0.3296
Epoch [23/50], Batch [1/184], Gradient of input_bn.weight: 3.767149792111013e-08
Epoch [23/50], Batch [1/184], Gradient of input_bn.bias: 4.620546860678587e-06
Epoch [23/50], Batch [1/184], Gradient of hidden_bns.0.weight: -1.8998207451659255e-07
Epoch [23/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.1746020391001366e-05
Epoch [23/50], Batch [1/184], Gradient of input_layer.weight: -2.143772981355596e-08
Epoch [23/50], Batch [1/184], Gradient of hidden_layers.0.weight: -1.1329760127409827e-05
Epoch [23/50], Batch [1/184], Gradient of output_layer.weight: 4.970489317202009e-05
Epoch [23/50], Batch [2/184], Gradient of input_bn.weight: 2.009392119362019e-07
Epoch [23/50], Batch [2/184], Gradient of input_bn.bias: 4.90136926600826e-06
Epoch [23/50], Batch [2/184], Gradient of hidden_bns.0.weight: -1.3606327229354065e-06
Epoch [23/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.632029307074845e-05
Epoch [23/50], Batch [2/184], Gradient of input_layer.we

 46%|█████████████████▍                    | 23/50 [3:35:59<4:13:33, 563.48s/it]

Validation Loss: 0.3296
Epoch [24/50], Batch [1/184], Gradient of input_bn.weight: 2.5767667466425337e-07
Epoch [24/50], Batch [1/184], Gradient of input_bn.bias: 1.272796725970693e-05
Epoch [24/50], Batch [1/184], Gradient of hidden_bns.0.weight: 2.1976591142447433e-06
Epoch [24/50], Batch [1/184], Gradient of hidden_bns.0.bias: -2.097174001391977e-05
Epoch [24/50], Batch [1/184], Gradient of input_layer.weight: -5.150340953719024e-08
Epoch [24/50], Batch [1/184], Gradient of hidden_layers.0.weight: -8.56227234180551e-06
Epoch [24/50], Batch [1/184], Gradient of output_layer.weight: 6.629393465118483e-05
Epoch [24/50], Batch [2/184], Gradient of input_bn.weight: 1.94378117157612e-07
Epoch [24/50], Batch [2/184], Gradient of input_bn.bias: 8.182996680261567e-06
Epoch [24/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.685410097707063e-06
Epoch [24/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.6391715689678676e-05
Epoch [24/50], Batch [2/184], Gradient of input_layer.weigh

 48%|██████████████████▏                   | 24/50 [3:45:22<4:04:11, 563.52s/it]

Validation Loss: 0.3296
Epoch [25/50], Batch [1/184], Gradient of input_bn.weight: 2.8074737201677635e-07
Epoch [25/50], Batch [1/184], Gradient of input_bn.bias: 1.3536549886339344e-05
Epoch [25/50], Batch [1/184], Gradient of hidden_bns.0.weight: 4.8002916628320236e-06
Epoch [25/50], Batch [1/184], Gradient of hidden_bns.0.bias: -4.8763773520477116e-05
Epoch [25/50], Batch [1/184], Gradient of input_layer.weight: 2.447138847117003e-08
Epoch [25/50], Batch [1/184], Gradient of hidden_layers.0.weight: 5.585956751019694e-05
Epoch [25/50], Batch [1/184], Gradient of output_layer.weight: 9.002200386021286e-05
Epoch [25/50], Batch [2/184], Gradient of input_bn.weight: 4.2825649870792404e-07
Epoch [25/50], Batch [2/184], Gradient of input_bn.bias: 1.6923229850362986e-05
Epoch [25/50], Batch [2/184], Gradient of hidden_bns.0.weight: -1.7133406799985096e-06
Epoch [25/50], Batch [2/184], Gradient of hidden_bns.0.bias: -4.910056304652244e-06
Epoch [25/50], Batch [2/184], Gradient of input_layer

 50%|███████████████████                   | 25/50 [3:54:46<3:54:48, 563.56s/it]

Validation Loss: 0.3296
Epoch [26/50], Batch [1/184], Gradient of input_bn.weight: 2.6584643819660414e-07
Epoch [26/50], Batch [1/184], Gradient of input_bn.bias: 1.0631611985445488e-05
Epoch [26/50], Batch [1/184], Gradient of hidden_bns.0.weight: 4.527333203441231e-06
Epoch [26/50], Batch [1/184], Gradient of hidden_bns.0.bias: 2.5710782210808247e-05
Epoch [26/50], Batch [1/184], Gradient of input_layer.weight: -8.026272979577698e-08
Epoch [26/50], Batch [1/184], Gradient of hidden_layers.0.weight: 1.2261548363312613e-05
Epoch [26/50], Batch [1/184], Gradient of output_layer.weight: 0.00012742448598146439
Epoch [26/50], Batch [2/184], Gradient of input_bn.weight: 5.62963577976916e-08
Epoch [26/50], Batch [2/184], Gradient of input_bn.bias: 6.024185950082028e-07
Epoch [26/50], Batch [2/184], Gradient of hidden_bns.0.weight: 2.1758007733296836e-06
Epoch [26/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.7777076209313236e-05
Epoch [26/50], Batch [2/184], Gradient of input_layer.we

 52%|███████████████████▊                  | 26/50 [4:04:09<3:45:21, 563.42s/it]

Validation Loss: 0.3295
Epoch [27/50], Batch [1/184], Gradient of input_bn.weight: 3.210002432751935e-07
Epoch [27/50], Batch [1/184], Gradient of input_bn.bias: 9.04152329894714e-06
Epoch [27/50], Batch [1/184], Gradient of hidden_bns.0.weight: 2.5283443392254412e-06
Epoch [27/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.0203961210208945e-05
Epoch [27/50], Batch [1/184], Gradient of input_layer.weight: -4.776862638777857e-08
Epoch [27/50], Batch [1/184], Gradient of hidden_layers.0.weight: -2.6237948986818083e-05
Epoch [27/50], Batch [1/184], Gradient of output_layer.weight: 9.304793638875708e-05
Epoch [27/50], Batch [2/184], Gradient of input_bn.weight: 3.649961399787571e-08
Epoch [27/50], Batch [2/184], Gradient of input_bn.bias: 3.328320644868654e-06
Epoch [27/50], Batch [2/184], Gradient of hidden_bns.0.weight: 2.355042397539364e-06
Epoch [27/50], Batch [2/184], Gradient of hidden_bns.0.bias: 2.3452341338270344e-05
Epoch [27/50], Batch [2/184], Gradient of input_layer.weig

 54%|████████████████████▌                 | 27/50 [4:13:33<3:35:59, 563.45s/it]

Validation Loss: 0.3295
Epoch [28/50], Batch [1/184], Gradient of input_bn.weight: 3.5845050661009736e-07
Epoch [28/50], Batch [1/184], Gradient of input_bn.bias: 1.1012662980647292e-05
Epoch [28/50], Batch [1/184], Gradient of hidden_bns.0.weight: -5.8832379181694705e-06
Epoch [28/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.8807550077326596e-05
Epoch [28/50], Batch [1/184], Gradient of input_layer.weight: -8.183002364603453e-08
Epoch [28/50], Batch [1/184], Gradient of hidden_layers.0.weight: -1.9194494598195888e-05
Epoch [28/50], Batch [1/184], Gradient of output_layer.weight: 0.00011670928506646305
Epoch [28/50], Batch [2/184], Gradient of input_bn.weight: 2.692195266718045e-07
Epoch [28/50], Batch [2/184], Gradient of input_bn.bias: 6.037171260686591e-06
Epoch [28/50], Batch [2/184], Gradient of hidden_bns.0.weight: 4.208936843497213e-07
Epoch [28/50], Batch [2/184], Gradient of hidden_bns.0.bias: 2.650740134413354e-05
Epoch [28/50], Batch [2/184], Gradient of input_layer.

 56%|█████████████████████▎                | 28/50 [4:22:56<3:26:37, 563.54s/it]

Validation Loss: 0.3295
Epoch [29/50], Batch [1/184], Gradient of input_bn.weight: 2.0245977339072851e-07
Epoch [29/50], Batch [1/184], Gradient of input_bn.bias: 1.1470768185972702e-05
Epoch [29/50], Batch [1/184], Gradient of hidden_bns.0.weight: 3.918770289601525e-06
Epoch [29/50], Batch [1/184], Gradient of hidden_bns.0.bias: -3.215148899471387e-05
Epoch [29/50], Batch [1/184], Gradient of input_layer.weight: 2.9111879573662236e-09
Epoch [29/50], Batch [1/184], Gradient of hidden_layers.0.weight: 4.035429810755886e-06
Epoch [29/50], Batch [1/184], Gradient of output_layer.weight: 7.575067138532177e-05
Epoch [29/50], Batch [2/184], Gradient of input_bn.weight: 9.157565727946348e-08
Epoch [29/50], Batch [2/184], Gradient of input_bn.bias: 5.200608029554132e-06
Epoch [29/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.955830839506234e-06
Epoch [29/50], Batch [2/184], Gradient of hidden_bns.0.bias: 2.27810705837328e-05
Epoch [29/50], Batch [2/184], Gradient of input_layer.weight

 58%|██████████████████████                | 29/50 [4:32:20<3:17:16, 563.65s/it]

Validation Loss: 0.3295
Epoch [30/50], Batch [1/184], Gradient of input_bn.weight: -1.5941623132675886e-07
Epoch [30/50], Batch [1/184], Gradient of input_bn.bias: 2.8415679480531253e-06
Epoch [30/50], Batch [1/184], Gradient of hidden_bns.0.weight: -1.34810215968173e-06
Epoch [30/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.3001495972275734e-05
Epoch [30/50], Batch [1/184], Gradient of input_layer.weight: -7.664335726076388e-07
Epoch [30/50], Batch [1/184], Gradient of hidden_layers.0.weight: -3.915307024726644e-05
Epoch [30/50], Batch [1/184], Gradient of output_layer.weight: 9.660443902248517e-05
Epoch [30/50], Batch [2/184], Gradient of input_bn.weight: 9.124346433964092e-08
Epoch [30/50], Batch [2/184], Gradient of input_bn.bias: 5.049309947935399e-06
Epoch [30/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.8426371752866544e-06
Epoch [30/50], Batch [2/184], Gradient of hidden_bns.0.bias: 2.2126923795440234e-05
Epoch [30/50], Batch [2/184], Gradient of input_layer.w

 60%|██████████████████████▊               | 30/50 [4:41:44<3:07:50, 563.53s/it]

Validation Loss: 0.3295
Epoch [31/50], Batch [1/184], Gradient of input_bn.weight: -4.960111255059019e-08
Epoch [31/50], Batch [1/184], Gradient of input_bn.bias: 2.1913747332291678e-06
Epoch [31/50], Batch [1/184], Gradient of hidden_bns.0.weight: -2.597056663944386e-06
Epoch [31/50], Batch [1/184], Gradient of hidden_bns.0.bias: -4.186911610304378e-06
Epoch [31/50], Batch [1/184], Gradient of input_layer.weight: -6.97664518156671e-07
Epoch [31/50], Batch [1/184], Gradient of hidden_layers.0.weight: -4.6753753849770874e-05
Epoch [31/50], Batch [1/184], Gradient of output_layer.weight: 7.141631067497656e-05
Epoch [31/50], Batch [2/184], Gradient of input_bn.weight: 3.336649569973815e-07
Epoch [31/50], Batch [2/184], Gradient of input_bn.bias: 6.3716506701894104e-06
Epoch [31/50], Batch [2/184], Gradient of hidden_bns.0.weight: 2.4822861632856075e-07
Epoch [31/50], Batch [2/184], Gradient of hidden_bns.0.bias: -1.2873890227638185e-05
Epoch [31/50], Batch [2/184], Gradient of input_layer

 62%|███████████████████████▌              | 31/50 [4:51:08<2:58:30, 563.70s/it]

Validation Loss: 0.3295
Epoch [32/50], Batch [1/184], Gradient of input_bn.weight: 7.099430376911187e-09
Epoch [32/50], Batch [1/184], Gradient of input_bn.bias: -1.285547114093788e-06
Epoch [32/50], Batch [1/184], Gradient of hidden_bns.0.weight: 9.513244094705442e-07
Epoch [32/50], Batch [1/184], Gradient of hidden_bns.0.bias: 8.913717465475202e-06
Epoch [32/50], Batch [1/184], Gradient of input_layer.weight: -9.048114257836914e-09
Epoch [32/50], Batch [1/184], Gradient of hidden_layers.0.weight: 1.1128757932965527e-06
Epoch [32/50], Batch [1/184], Gradient of output_layer.weight: 2.4464639864163473e-05
Epoch [32/50], Batch [2/184], Gradient of input_bn.weight: 1.1879978956130799e-07
Epoch [32/50], Batch [2/184], Gradient of input_bn.bias: 1.0979510079778265e-05
Epoch [32/50], Batch [2/184], Gradient of hidden_bns.0.weight: -3.1591639526595827e-06
Epoch [32/50], Batch [2/184], Gradient of hidden_bns.0.bias: 2.392439273535274e-05
Epoch [32/50], Batch [2/184], Gradient of input_layer.w

 64%|████████████████████████▎             | 32/50 [5:00:31<2:49:06, 563.71s/it]

Validation Loss: 0.3295
Epoch [33/50], Batch [1/184], Gradient of input_bn.weight: 2.8639078664127737e-07
Epoch [33/50], Batch [1/184], Gradient of input_bn.bias: 8.248505764640868e-06
Epoch [33/50], Batch [1/184], Gradient of hidden_bns.0.weight: 2.7102569220005535e-06
Epoch [33/50], Batch [1/184], Gradient of hidden_bns.0.bias: 3.9268736145459116e-05
Epoch [33/50], Batch [1/184], Gradient of input_layer.weight: -6.326759205421695e-08
Epoch [33/50], Batch [1/184], Gradient of hidden_layers.0.weight: 8.06960160844028e-05
Epoch [33/50], Batch [1/184], Gradient of output_layer.weight: 0.00013937440235167742
Epoch [33/50], Batch [2/184], Gradient of input_bn.weight: 5.188121576793492e-08
Epoch [33/50], Batch [2/184], Gradient of input_bn.bias: 3.981836016464513e-06
Epoch [33/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.068274286808446e-07
Epoch [33/50], Batch [2/184], Gradient of hidden_bns.0.bias: 5.704357135982718e-06
Epoch [33/50], Batch [2/184], Gradient of input_layer.weigh

 66%|█████████████████████████             | 33/50 [5:09:55<2:39:40, 563.57s/it]

Validation Loss: 0.3295
Epoch [34/50], Batch [1/184], Gradient of input_bn.weight: 9.52236405282747e-08
Epoch [34/50], Batch [1/184], Gradient of input_bn.bias: 5.804243301099632e-06
Epoch [34/50], Batch [1/184], Gradient of hidden_bns.0.weight: -7.986602668097476e-07
Epoch [34/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.3398554074228741e-05
Epoch [34/50], Batch [1/184], Gradient of input_layer.weight: 1.4460737496335696e-08
Epoch [34/50], Batch [1/184], Gradient of hidden_layers.0.weight: -1.9589224393712357e-05
Epoch [34/50], Batch [1/184], Gradient of output_layer.weight: 5.6672444770811126e-05
Epoch [34/50], Batch [2/184], Gradient of input_bn.weight: 2.4600694814580493e-09
Epoch [34/50], Batch [2/184], Gradient of input_bn.bias: -2.8513784400274744e-07
Epoch [34/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.9931007955165114e-09
Epoch [34/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.7412525039617321e-06
Epoch [34/50], Batch [2/184], Gradient of input_layer

 68%|█████████████████████████▊            | 34/50 [5:19:18<2:30:17, 563.62s/it]

Validation Loss: 0.3294
Epoch [35/50], Batch [1/184], Gradient of input_bn.weight: 2.5352528609801084e-07
Epoch [35/50], Batch [1/184], Gradient of input_bn.bias: 6.205384579516249e-06
Epoch [35/50], Batch [1/184], Gradient of hidden_bns.0.weight: -1.5663831618439872e-06
Epoch [35/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.2108435839763843e-05
Epoch [35/50], Batch [1/184], Gradient of input_layer.weight: -3.227951950179886e-08
Epoch [35/50], Batch [1/184], Gradient of hidden_layers.0.weight: 6.545776705024764e-05
Epoch [35/50], Batch [1/184], Gradient of output_layer.weight: 8.427750435657799e-05
Epoch [35/50], Batch [2/184], Gradient of input_bn.weight: 1.434868863725569e-07
Epoch [35/50], Batch [2/184], Gradient of input_bn.bias: 3.6351598282635678e-06
Epoch [35/50], Batch [2/184], Gradient of hidden_bns.0.weight: -1.953229457285488e-06
Epoch [35/50], Batch [2/184], Gradient of hidden_bns.0.bias: -1.1945786354772281e-05
Epoch [35/50], Batch [2/184], Gradient of input_layer.

 70%|██████████████████████████▌           | 35/50 [5:28:42<2:20:55, 563.72s/it]

Validation Loss: 0.3294
Epoch [36/50], Batch [1/184], Gradient of input_bn.weight: 3.6830442695645615e-08
Epoch [36/50], Batch [1/184], Gradient of input_bn.bias: 3.5068779880020884e-07
Epoch [36/50], Batch [1/184], Gradient of hidden_bns.0.weight: 7.646497124369489e-07
Epoch [36/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.2527515536930878e-05
Epoch [36/50], Batch [1/184], Gradient of input_layer.weight: -1.5959287225086882e-08
Epoch [36/50], Batch [1/184], Gradient of hidden_layers.0.weight: 4.014471869595582e-06
Epoch [36/50], Batch [1/184], Gradient of output_layer.weight: 3.8191748899407685e-05
Epoch [36/50], Batch [2/184], Gradient of input_bn.weight: 1.2250438885530457e-08
Epoch [36/50], Batch [2/184], Gradient of input_bn.bias: 1.9523272385413293e-06
Epoch [36/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.0193267598879174e-06
Epoch [36/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.020695981424069e-05
Epoch [36/50], Batch [2/184], Gradient of input_layer.

 72%|███████████████████████████▎          | 36/50 [5:38:06<2:11:31, 563.71s/it]

Validation Loss: 0.3294
Epoch [37/50], Batch [1/184], Gradient of input_bn.weight: -1.3451995073410217e-09
Epoch [37/50], Batch [1/184], Gradient of input_bn.bias: -1.0191357091571263e-08
Epoch [37/50], Batch [1/184], Gradient of hidden_bns.0.weight: -6.699501398088614e-09
Epoch [37/50], Batch [1/184], Gradient of hidden_bns.0.bias: 3.758782440854702e-07
Epoch [37/50], Batch [1/184], Gradient of input_layer.weight: -1.6305334860078347e-09
Epoch [37/50], Batch [1/184], Gradient of hidden_layers.0.weight: -2.036362502622069e-06
Epoch [37/50], Batch [1/184], Gradient of output_layer.weight: 1.7412024817531346e-06
Epoch [37/50], Batch [2/184], Gradient of input_bn.weight: 4.2721467252704315e-08
Epoch [37/50], Batch [2/184], Gradient of input_bn.bias: 2.7926575967285316e-06
Epoch [37/50], Batch [2/184], Gradient of hidden_bns.0.weight: 7.723039061602321e-07
Epoch [37/50], Batch [2/184], Gradient of hidden_bns.0.bias: 7.668481885048095e-06
Epoch [37/50], Batch [2/184], Gradient of input_laye

 74%|████████████████████████████          | 37/50 [5:47:29<2:02:06, 563.61s/it]

Validation Loss: 0.3294
Epoch [38/50], Batch [1/184], Gradient of input_bn.weight: 1.0863641364267096e-07
Epoch [38/50], Batch [1/184], Gradient of input_bn.bias: 6.235647560970392e-06
Epoch [38/50], Batch [1/184], Gradient of hidden_bns.0.weight: 2.0530451365630142e-06
Epoch [38/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.4447985449805856e-05
Epoch [38/50], Batch [1/184], Gradient of input_layer.weight: -8.126834316612985e-09
Epoch [38/50], Batch [1/184], Gradient of hidden_layers.0.weight: 5.14540406584274e-05
Epoch [38/50], Batch [1/184], Gradient of output_layer.weight: 4.1109298763331026e-05
Epoch [38/50], Batch [2/184], Gradient of input_bn.weight: 9.907307685352862e-08
Epoch [38/50], Batch [2/184], Gradient of input_bn.bias: 6.291613317443989e-06
Epoch [38/50], Batch [2/184], Gradient of hidden_bns.0.weight: 2.4587857296864968e-06
Epoch [38/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.9023178538191132e-05
Epoch [38/50], Batch [2/184], Gradient of input_layer.wei

 76%|████████████████████████████▉         | 38/50 [5:56:53<1:52:44, 563.71s/it]

Validation Loss: 0.3294
Epoch [39/50], Batch [1/184], Gradient of input_bn.weight: -3.659442882053554e-08
Epoch [39/50], Batch [1/184], Gradient of input_bn.bias: 2.3196016627480276e-06
Epoch [39/50], Batch [1/184], Gradient of hidden_bns.0.weight: 3.6501523936749436e-07
Epoch [39/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.812267510103993e-05
Epoch [39/50], Batch [1/184], Gradient of input_layer.weight: -9.149530910690373e-08
Epoch [39/50], Batch [1/184], Gradient of hidden_layers.0.weight: -5.5078660807339475e-05
Epoch [39/50], Batch [1/184], Gradient of output_layer.weight: 8.182731835404411e-05
Epoch [39/50], Batch [2/184], Gradient of input_bn.weight: 3.4366894396953285e-07
Epoch [39/50], Batch [2/184], Gradient of input_bn.bias: 8.437080396106467e-06
Epoch [39/50], Batch [2/184], Gradient of hidden_bns.0.weight: -7.153039405238815e-08
Epoch [39/50], Batch [2/184], Gradient of hidden_bns.0.bias: 2.425233651592862e-05
Epoch [39/50], Batch [2/184], Gradient of input_layer.w

 78%|█████████████████████████████▋        | 39/50 [6:06:18<1:43:22, 563.87s/it]

Validation Loss: 0.3294
Epoch [40/50], Batch [1/184], Gradient of input_bn.weight: 1.9113599591946695e-12
Epoch [40/50], Batch [1/184], Gradient of input_bn.bias: -6.101805638536462e-08
Epoch [40/50], Batch [1/184], Gradient of hidden_bns.0.weight: 6.441602806717128e-09
Epoch [40/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.5423901800204476e-07
Epoch [40/50], Batch [1/184], Gradient of input_layer.weight: -3.4204800103410093e-10
Epoch [40/50], Batch [1/184], Gradient of hidden_layers.0.weight: -1.1879978956130799e-07
Epoch [40/50], Batch [1/184], Gradient of output_layer.weight: 4.620363540652761e-07
Epoch [40/50], Batch [2/184], Gradient of input_bn.weight: 4.184298063591996e-08
Epoch [40/50], Batch [2/184], Gradient of input_bn.bias: 5.897740038562915e-07
Epoch [40/50], Batch [2/184], Gradient of hidden_bns.0.weight: -8.559768502891529e-07
Epoch [40/50], Batch [2/184], Gradient of hidden_bns.0.bias: 2.237931767012924e-06
Epoch [40/50], Batch [2/184], Gradient of input_layer.w

 80%|██████████████████████████████▍       | 40/50 [6:15:41<1:33:56, 563.68s/it]

Validation Loss: 0.3294
Epoch [41/50], Batch [1/184], Gradient of input_bn.weight: 6.416075848392211e-08
Epoch [41/50], Batch [1/184], Gradient of input_bn.bias: 5.77413993596565e-06
Epoch [41/50], Batch [1/184], Gradient of hidden_bns.0.weight: 5.743534075008938e-07
Epoch [41/50], Batch [1/184], Gradient of hidden_bns.0.bias: -6.834672603872605e-06
Epoch [41/50], Batch [1/184], Gradient of input_layer.weight: -2.5207175013974847e-08
Epoch [41/50], Batch [1/184], Gradient of hidden_layers.0.weight: 1.4659474800282624e-05
Epoch [41/50], Batch [1/184], Gradient of output_layer.weight: 3.42439379892312e-05
Epoch [41/50], Batch [2/184], Gradient of input_bn.weight: 3.63270373782143e-07
Epoch [41/50], Batch [2/184], Gradient of input_bn.bias: 3.51237667928217e-06
Epoch [41/50], Batch [2/184], Gradient of hidden_bns.0.weight: -4.992702997697052e-06
Epoch [41/50], Batch [2/184], Gradient of hidden_bns.0.bias: 9.95658410829492e-07
Epoch [41/50], Batch [2/184], Gradient of input_layer.weight: -

 82%|███████████████████████████████▏      | 41/50 [6:25:05<1:24:34, 563.86s/it]

Validation Loss: 0.3294
Epoch [42/50], Batch [1/184], Gradient of input_bn.weight: 9.007408152683638e-09
Epoch [42/50], Batch [1/184], Gradient of input_bn.bias: 4.7305985617640545e-07
Epoch [42/50], Batch [1/184], Gradient of hidden_bns.0.weight: -7.651314604117943e-08
Epoch [42/50], Batch [1/184], Gradient of hidden_bns.0.bias: -7.623548299307004e-07
Epoch [42/50], Batch [1/184], Gradient of input_layer.weight: -5.141727044133404e-10
Epoch [42/50], Batch [1/184], Gradient of hidden_layers.0.weight: 3.3715195968397893e-06
Epoch [42/50], Batch [1/184], Gradient of output_layer.weight: 4.37453354606987e-06
Epoch [42/50], Batch [2/184], Gradient of input_bn.weight: 1.9537765183486044e-07
Epoch [42/50], Batch [2/184], Gradient of input_bn.bias: 7.647297024959698e-06
Epoch [42/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.2383211469568778e-06
Epoch [42/50], Batch [2/184], Gradient of hidden_bns.0.bias: -2.8605547413462773e-05
Epoch [42/50], Batch [2/184], Gradient of input_layer.w

 84%|███████████████████████████████▉      | 42/50 [6:34:28<1:15:07, 563.45s/it]

Validation Loss: 0.3294
Epoch [43/50], Batch [1/184], Gradient of input_bn.weight: 3.0616865842603147e-07
Epoch [43/50], Batch [1/184], Gradient of input_bn.bias: 7.091552106430754e-06
Epoch [43/50], Batch [1/184], Gradient of hidden_bns.0.weight: -2.4179767024179455e-06
Epoch [43/50], Batch [1/184], Gradient of hidden_bns.0.bias: 2.2944979718886316e-05
Epoch [43/50], Batch [1/184], Gradient of input_layer.weight: -5.0104689286456505e-09
Epoch [43/50], Batch [1/184], Gradient of hidden_layers.0.weight: 7.354082299571019e-06
Epoch [43/50], Batch [1/184], Gradient of output_layer.weight: 9.90370026556775e-05
Epoch [43/50], Batch [2/184], Gradient of input_bn.weight: 2.021920408878941e-10
Epoch [43/50], Batch [2/184], Gradient of input_bn.bias: 2.5391997837687086e-08
Epoch [43/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.0611179490638278e-08
Epoch [43/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.1956407774960098e-07
Epoch [43/50], Batch [2/184], Gradient of input_layer.w

 86%|████████████████████████████████▋     | 43/50 [6:43:52<1:05:45, 563.69s/it]

Validation Loss: 0.3294
Epoch [44/50], Batch [1/184], Gradient of input_bn.weight: 5.3940084399073385e-08
Epoch [44/50], Batch [1/184], Gradient of input_bn.bias: -7.967358897076338e-07
Epoch [44/50], Batch [1/184], Gradient of hidden_bns.0.weight: -9.165386245513218e-07
Epoch [44/50], Batch [1/184], Gradient of hidden_bns.0.bias: 8.970197995950002e-06
Epoch [44/50], Batch [1/184], Gradient of input_layer.weight: -2.503422358302032e-08
Epoch [44/50], Batch [1/184], Gradient of hidden_layers.0.weight: -3.3910696402017493e-06
Epoch [44/50], Batch [1/184], Gradient of output_layer.weight: 3.356536035425961e-05
Epoch [44/50], Batch [2/184], Gradient of input_bn.weight: 6.1272658058442175e-09
Epoch [44/50], Batch [2/184], Gradient of input_bn.bias: 4.2153510548814666e-07
Epoch [44/50], Batch [2/184], Gradient of hidden_bns.0.weight: 1.0723488230723888e-07
Epoch [44/50], Batch [2/184], Gradient of hidden_bns.0.bias: 6.782398031646153e-06
Epoch [44/50], Batch [2/184], Gradient of input_layer.

 88%|███████████████████████████████████▏    | 44/50 [6:53:15<56:21, 563.55s/it]

Validation Loss: 0.3294
Epoch [45/50], Batch [1/184], Gradient of input_bn.weight: 3.6359779187478125e-08
Epoch [45/50], Batch [1/184], Gradient of input_bn.bias: 3.2888099212868838e-06
Epoch [45/50], Batch [1/184], Gradient of hidden_bns.0.weight: 8.003613629625761e-07
Epoch [45/50], Batch [1/184], Gradient of hidden_bns.0.bias: 8.551850442017894e-06
Epoch [45/50], Batch [1/184], Gradient of input_layer.weight: -9.658632116327226e-09
Epoch [45/50], Batch [1/184], Gradient of hidden_layers.0.weight: -2.0779498299816623e-05
Epoch [45/50], Batch [1/184], Gradient of output_layer.weight: 2.4973434847197495e-05
Epoch [45/50], Batch [2/184], Gradient of input_bn.weight: 3.05862158711534e-07
Epoch [45/50], Batch [2/184], Gradient of input_bn.bias: 6.8192130129318684e-06
Epoch [45/50], Batch [2/184], Gradient of hidden_bns.0.weight: -5.35136950929882e-06
Epoch [45/50], Batch [2/184], Gradient of hidden_bns.0.bias: 8.443577826255932e-06
Epoch [45/50], Batch [2/184], Gradient of input_layer.wei

 90%|████████████████████████████████████    | 45/50 [7:02:40<46:59, 563.80s/it]

Validation Loss: 0.3294
Epoch [46/50], Batch [1/184], Gradient of input_bn.weight: 2.387951099080965e-07
Epoch [46/50], Batch [1/184], Gradient of input_bn.bias: 5.104132469568867e-06
Epoch [46/50], Batch [1/184], Gradient of hidden_bns.0.weight: -2.575607140897773e-07
Epoch [46/50], Batch [1/184], Gradient of hidden_bns.0.bias: 2.1988351363688707e-05
Epoch [46/50], Batch [1/184], Gradient of input_layer.weight: -4.2313903492186e-08
Epoch [46/50], Batch [1/184], Gradient of hidden_layers.0.weight: 6.951342948013917e-05
Epoch [46/50], Batch [1/184], Gradient of output_layer.weight: 8.274237188743427e-05
Epoch [46/50], Batch [2/184], Gradient of input_bn.weight: 3.211415133819173e-09
Epoch [46/50], Batch [2/184], Gradient of input_bn.bias: 7.297309139175923e-07
Epoch [46/50], Batch [2/184], Gradient of hidden_bns.0.weight: 5.818883437314071e-07
Epoch [46/50], Batch [2/184], Gradient of hidden_bns.0.bias: 8.402066669077612e-06
Epoch [46/50], Batch [2/184], Gradient of input_layer.weight: 

 92%|████████████████████████████████████▊   | 46/50 [7:12:03<37:34, 563.66s/it]

Validation Loss: 0.3294
Epoch [47/50], Batch [1/184], Gradient of input_bn.weight: 6.350524017761927e-08
Epoch [47/50], Batch [1/184], Gradient of input_bn.bias: 1.9116409930575173e-06
Epoch [47/50], Batch [1/184], Gradient of hidden_bns.0.weight: -2.123742206094903e-06
Epoch [47/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.7049619600584265e-06
Epoch [47/50], Batch [1/184], Gradient of input_layer.weight: -2.0024801372642287e-08
Epoch [47/50], Batch [1/184], Gradient of hidden_layers.0.weight: 8.431684364040848e-06
Epoch [47/50], Batch [1/184], Gradient of output_layer.weight: 2.0283478079363704e-05
Epoch [47/50], Batch [2/184], Gradient of input_bn.weight: -1.2426198736648075e-08
Epoch [47/50], Batch [2/184], Gradient of input_bn.bias: 6.642871539952466e-06
Epoch [47/50], Batch [2/184], Gradient of hidden_bns.0.weight: -4.1140560824715067e-07
Epoch [47/50], Batch [2/184], Gradient of hidden_bns.0.bias: 2.0068217054358684e-05
Epoch [47/50], Batch [2/184], Gradient of input_laye

 94%|█████████████████████████████████████▌  | 47/50 [7:21:27<28:11, 563.76s/it]

Validation Loss: 0.3294
Epoch [48/50], Batch [1/184], Gradient of input_bn.weight: 4.233925210428424e-07
Epoch [48/50], Batch [1/184], Gradient of input_bn.bias: 9.946202226274181e-06
Epoch [48/50], Batch [1/184], Gradient of hidden_bns.0.weight: -1.124014715969679e-06
Epoch [48/50], Batch [1/184], Gradient of hidden_bns.0.bias: 5.012187102693133e-06
Epoch [48/50], Batch [1/184], Gradient of input_layer.weight: -9.119349186903491e-08
Epoch [48/50], Batch [1/184], Gradient of hidden_layers.0.weight: 2.2422698748414405e-05
Epoch [48/50], Batch [1/184], Gradient of output_layer.weight: 9.359008254250512e-05
Epoch [48/50], Batch [2/184], Gradient of input_bn.weight: 2.2534300114784855e-07
Epoch [48/50], Batch [2/184], Gradient of input_bn.bias: 1.0951589501928538e-05
Epoch [48/50], Batch [2/184], Gradient of hidden_bns.0.weight: -3.5353846215002704e-06
Epoch [48/50], Batch [2/184], Gradient of hidden_bns.0.bias: -1.2898375643999316e-05
Epoch [48/50], Batch [2/184], Gradient of input_layer.

 96%|██████████████████████████████████████▍ | 48/50 [7:30:50<18:47, 563.63s/it]

Validation Loss: 0.3294
Epoch [49/50], Batch [1/184], Gradient of input_bn.weight: 2.3490883904742077e-07
Epoch [49/50], Batch [1/184], Gradient of input_bn.bias: 7.15633041181718e-06
Epoch [49/50], Batch [1/184], Gradient of hidden_bns.0.weight: -6.812336323491763e-06
Epoch [49/50], Batch [1/184], Gradient of hidden_bns.0.bias: -4.684927625930868e-05
Epoch [49/50], Batch [1/184], Gradient of input_layer.weight: 5.6615679255855866e-08
Epoch [49/50], Batch [1/184], Gradient of hidden_layers.0.weight: -1.0200595170317683e-05
Epoch [49/50], Batch [1/184], Gradient of output_layer.weight: 7.912825094535947e-05
Epoch [49/50], Batch [2/184], Gradient of input_bn.weight: 1.5356317817349918e-07
Epoch [49/50], Batch [2/184], Gradient of input_bn.bias: 4.034543508169008e-06
Epoch [49/50], Batch [2/184], Gradient of hidden_bns.0.weight: -5.289507498673629e-09
Epoch [49/50], Batch [2/184], Gradient of hidden_bns.0.bias: -7.917155016912147e-06
Epoch [49/50], Batch [2/184], Gradient of input_layer.w

 98%|███████████████████████████████████████▏| 49/50 [7:40:13<09:23, 563.50s/it]

Validation Loss: 0.3293
Epoch [50/50], Batch [1/184], Gradient of input_bn.weight: 3.0268893169704825e-08
Epoch [50/50], Batch [1/184], Gradient of input_bn.bias: 3.837144504359458e-06
Epoch [50/50], Batch [1/184], Gradient of hidden_bns.0.weight: -1.927667199197458e-07
Epoch [50/50], Batch [1/184], Gradient of hidden_bns.0.bias: 1.3225001566752326e-05
Epoch [50/50], Batch [1/184], Gradient of input_layer.weight: -1.3810705468131346e-08
Epoch [50/50], Batch [1/184], Gradient of hidden_layers.0.weight: -2.9994362193974666e-05
Epoch [50/50], Batch [1/184], Gradient of output_layer.weight: 4.059905404574238e-05
Epoch [50/50], Batch [2/184], Gradient of input_bn.weight: 9.573386705596931e-08
Epoch [50/50], Batch [2/184], Gradient of input_bn.bias: 3.623121529017226e-06
Epoch [50/50], Batch [2/184], Gradient of hidden_bns.0.weight: 2.104438408423448e-06
Epoch [50/50], Batch [2/184], Gradient of hidden_bns.0.bias: 1.4189294233801775e-05
Epoch [50/50], Batch [2/184], Gradient of input_layer.w

100%|████████████████████████████████████████| 50/50 [7:49:37<00:00, 563.55s/it]

Validation Loss: 0.3293





In [7]:
torch.save(model.state_dict(), 'trained_kgat_model_v100.pth')


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math

class KGATLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(KGATLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        #self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.weight = nn.Parameter(torch.randn(in_features, out_features) * 0.1)  # multiply by 0.1 to increase values
        self.reset_parameters()

#    def reset_parameters(self):
#        gain = nn.init.calculate_gain('relu') * math.sqrt(3)
#        nn.init.uniform_(self.weight, -10, 10)
#        self.weight.data.mul_(gain)

    def reset_parameters(self):
        #print(self.weight)
        #nn.Parameter(torch.randn(11, out_features) * 0.1)  # multiply by 0.1 to increase values
        nn.init.kaiming_uniform_(self.weight, mode='fan_in', nonlinearity='relu')
        #print(self.weight)

    def forward(self, adjacency_matrix, input_features):
        output_features = torch.mm(input_features, self.weight)
        return torch.mm(adjacency_matrix, output_features)

class KGAT(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_layers):
        super(KGAT, self).__init__()
        self.input_bn = nn.BatchNorm1d(hidden_features)
        self.hidden_bns = nn.ModuleList([nn.BatchNorm1d(hidden_features) for _ in range(num_layers - 1)])
        self.in_features = in_features
        self.hidden_features = hidden_features
        self.out_features = out_features
        self.num_layers = num_layers

        # Define the input layer
        self.input_layer = KGATLayer(in_features, hidden_features)

        # Define the hidden layers
        self.hidden_layers = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.hidden_layers.append(KGATLayer(hidden_features, hidden_features))

        # Define the output layer
        self.output_layer = KGATLayer(hidden_features, out_features)

    def forward(self, adjacency_matrix, input_features):
        x = self.input_layer(adjacency_matrix, input_features)
        x = x.unsqueeze(2)  # Add an extra dimension for Batch Normalization
        x = F.relu(self.input_bn(x))
        x = x.squeeze(2)  # Remove the extra dimension
        for i, layer in enumerate(self.hidden_layers):
            x = layer(adjacency_matrix, x)
            x = x.unsqueeze(2)  # Add an extra dimension for Batch Normalization
            x = F.relu(self.hidden_bns[i](x))
            x = x.squeeze(2)  # Remove the extra dimension

        # Pass through the output layer
        x = self.output_layer(adjacency_matrix, x)

        return x        

hidden_features = 64  # Number of hidden features in the KGAT model
out_features = 1  # Number of output features in the KGAT model
num_layers = 2  # Number of layers in the KGAT model

numeric_keys = [
    'acousticness', 'danceability', 'duration_ms', 'energy', 'explicit',
    'liveness', 'loudness', 'popularity', 'speechiness', 'tempo', 'valence'
]

model = KGAT(num_layers=num_layers, in_features=len(numeric_keys), hidden_features=hidden_features, out_features=out_features)

In [4]:
loaded_model = KGAT(num_layers=num_layers, in_features=len(numeric_keys), hidden_features=hidden_features, out_features=out_features)
loaded_model.load_state_dict(torch.load('trained_kgat_model_v100.pth'))


<All keys matched successfully>

In [5]:
from pymongo import MongoClient

# connect to MongoDB
username = "..."
password = "..."
cluster_name = "..."
dbname = "..."
client = MongoClient(f"mongodb+srv://{username}:{password}@{cluster_name}.mongodb.net/{dbname}?retryWrites=true&w=majority")
db = client[dbname]
collection = db["tune-users"]

# define the query
query = { "user_id": "..." }

result = collection.find(query, { "top_tracks": 1, "_id": 0 })

track_ids = []

for document in result:
    for track in document['top_tracks']:
        track_ids.append(track['id'])

In [18]:
required_features = [
    't.acousticness', 't.danceability', 't.duration_ms', 't.energy', 't.explicit',
    't.liveness', 't.loudness', 't.popularity', 't.speechiness', 't.tempo', 't.valence'
]
track_features = []
for track in track_result:
    x_features = [track[feature] for feature in required_features]
    track_features.append(x_features)
track_features_tensor = torch.FloatTensor(track_features)

In [19]:
cosine_similarity_adj_tensor = torch.FloatTensor(cosine_similarity_adj.toarray())
def get_ranked_recommendations(user_songs, model, unique_track_ids, track_features):
    # Move the model to the evaluation mode
    model.eval()

    # Create an input tensor containing user's input songs' features
    user_song_indices = [unique_track_ids.index(track_id) for track_id in user_songs if track_id in unique_track_ids]
    user_song_features = track_features[user_song_indices]

    # Calculate the embeddings of the user's input songs using the trained model
    user_song_embeddings = model(cosine_similarity_adj_tensor[:len(user_song_indices), :len(user_song_indices)], user_song_features)
    user_song_embeddings += model(shared_album_adj_tensor[:len(user_song_indices), :len(user_song_indices)], user_song_features)
    user_song_embeddings += model(shared_artist_adj_tensor[:len(user_song_indices), :len(user_song_indices)], user_song_features)
    user_song_embeddings += model(shared_genre_adj_tensor[:len(user_song_indices), :len(user_song_indices)], user_song_features)
    
    # Calculate the average of the input songs' embeddings
    avg_embedding = torch.mean(user_song_embeddings, dim=0, keepdim=True)

    # Calculate the similarity score between the average embedding and all tracks in the dataset
    similarity_scores = torch.mm(track_embeddings, avg_embedding.t()).squeeze()

    # Sort similarity scores in descending order and get the corresponding indices
    sorted_indices = torch.argsort(similarity_scores, descending=True)

    # Convert the sorted indices to track IDs
    ranked_track_ids = [unique_track_ids[idx] for idx in sorted_indices.tolist()]

    # Remove the user's input songs from the ranked recommendations
    recommendations = [track_id for track_id in ranked_track_ids if track_id not in user_songs]

    return recommendations


In [21]:
# Convert adjacency matrices to PyTorch tensors
shared_album_adj_tensor = torch.FloatTensor(shared_album_adj.toarray())
shared_artist_adj_tensor = torch.FloatTensor(shared_artist_adj.toarray())
shared_genre_adj_tensor = torch.FloatTensor(shared_genre_adj.toarray())
cosine_similarity_adj_tensor = torch.FloatTensor(cosine_similarity_adj.toarray())

# Obtain embeddings from the KGAT model
track_embeddings_shared_album = model(shared_album_adj_tensor, track_features_tensor)
track_embeddings_shared_artist = model(shared_artist_adj_tensor, track_features_tensor)
track_embeddings_shared_genre = model(shared_genre_adj_tensor, track_features_tensor)
track_embeddings_cosine_similarity = model(cosine_similarity_adj_tensor, track_features_tensor)

# Combine the embeddings from different relationships
track_embeddings = track_embeddings_shared_album + track_embeddings_shared_artist + track_embeddings_shared_genre + track_embeddings_cosine_similarity

In [22]:
recommendations = get_ranked_recommendations(track_ids, model, unique_track_ids, track_features_tensor)

In [23]:
recommendations[:50]

['1yRl361oS9xXUIDN3GF7pd',
 '19dSGEa2qpyGM60nxoLv3q',
 '2qywCWdMFK2ew9ALHA4UfD',
 '4lIIOHPXeh74fbPUV74zRn',
 '7pN1FnWdKOx3GAciqfVVf7',
 '1Eq8p4OHjhVrSW0Wa815B1',
 '1Q5cjSmt4WTJyJ0EgAFfgg',
 '4o4nIPbrEAoCAq3zqWVA4E',
 '3msjyQTZr7gHspgJCzl4v5',
 '1HW3I6KUzjYPEMqZrlFN66',
 '20uX9Nz9dJBKD7MstFDFs0',
 '1ah3n5wpOQhFQzl9t3qGIu',
 '0gIsM9tn5pP1jnTSrkdHCL',
 '7f32ap55N5PhZBMEr8TnKV',
 '6rOVIb2PyUMBw8b6hUdxHj',
 '2eh3gUtf5GwwlEHPQOgTmV',
 '0OdX5KLktgTp3FAmUqWvLt',
 '2AYtjqogao6Fj3N7cx39Of',
 '0xuvCPMUrbr6xk3myTCqwF',
 '59YiIzi4C1KJwClBBRFucR',
 '4nwwj1bE8KHaBixtEtMu4d',
 '4uew5SER845c4iNj4sl8jG',
 '5deBKQ6qZaTHBRtRNZtL9X',
 '22sdeHdy5DcCVV7dJw4uRn',
 '5gAPTm517kIXqq9YUBLehU',
 '5O71be8zUY85FRBQJsQRWd',
 '2k4O9OqM4RmuGwOgrU272J',
 '4VRcCHaxMHugywsyc33xZa',
 '6JvK5IGgZ1rmenDKHufzIx',
 '0hESC8513XQKkf82EMXpiI',
 '2gc6W4zP2EYcimOCg0tg2v',
 '5eYO9lIshL5wbYDO0WWbbX',
 '2eIAizioh2hRmYRagndFb8',
 '4VfdRKABHsrA0ii7wGDjJm',
 '7uiJYByjnDGY8Z1sAsCSFC',
 '17uRlIf3SeM3VMptNbsZtP',
 '37oH03cQRwoyBUTyGDyixV',
 

In [17]:
from spotipy.oauth2 import SpotifyClientCredentials
import spotipy

client_credentials_manager = SpotifyClientCredentials(client_id='...', client_secret="...")
sp = spotipy.Spotify(client_credentials_manager=client_credentials_manager)

def get_track_info(track_id):
    track_info = sp.track(track_id)
    track_name = track_info['name']
    artist_info = track_info['artists']
    artist_names = [artist['name'] for artist in artist_info]
    return track_name, artist_names

for track_id in recommendations[:50]:
    track_name, artist_names = get_track_info(track_id)
    print(f'Track name: {track_name}')
    print(f'Artist(s): {", ".join(artist_names)}')
    print()

Track name: Euryanthe, J. 291: Overture
Artist(s): Carl Maria von Weber, Tapiola Sinfonietta, Jean-Jacques Kantorow

Track name: Der Freischütz, J. 277 / Act III: "Und ob die Wolke sie verhülle"
Artist(s): Carl Maria von Weber, Gundula Janowitz, Staatskapelle Dresden, Carlos Kleiber

Track name: St Matthew Passion: Iii. Psalm
Artist(s): Bent Sørensen, The Norwegian Soloists' Choir, Grete Pedersen, Ensemble Allegria

Track name: Drei Romanzen, Op. 94: III. Nicht schnell
Artist(s): Robert Schumann, Kim Dami

Track name: Les berceaux, Op. 23, No. 1
Artist(s): Gabriel Fauré, Mischa Maisky, Daria Hovora

Track name: Les Pêcheurs de perles: Je crois entendre encore (Nadir)
Artist(s): Georges Bizet, Nicolai Gedda, Pierre Dervaux, Orchestre du Theatre National de IOpera-Comique, Orchestre Du Theatre National De L'opera-comique

Track name: Il Signor Bruschino: No. 1: Introduzione - Deh tu m’assisti amore
Artist(s): Gioachino Rossini, Massimiliano Barbolini, Alessandro Codeluppi, Clara Giangasp