In [41]:
# import packages
import h5py
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms, utils
from torch.utils.data import TensorDataset
from tqdm import tqdm
from torch_geometric.nn import TransformerConv, global_max_pool, GATv2Conv, ClusterGCNConv, PointTransformerConv, global_mean_pool, dense_mincut_pool
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import time
import torch 
from torch_geometric.data import Data, InMemoryDataset, download_url
import torch_geometric.transforms as T
import networkx as nx
from torch_cluster import knn_graph
from torch_geometric.utils import from_networkx, to_dense_adj, dense_to_sparse, to_dense_batch

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [42]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [43]:
# Dataset class
import torch
from torch_geometric.data import Data, InMemoryDataset, download_url
import torch_geometric.transforms as T

class Quark_Gluon_Dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(Quark_Gluon_Dataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        pass

    def create_graph(self, image):
        G = nx.Graph()
        for i in range(image.shape[0]):
            for j in range(image.shape[1]):
                if image[i][j].any() > 0:
                    #add node (i,j) to graph with attribute 'x' = image[i][j] and pos = (i,j)
                    G.add_node((i,j), x=image[i][j], pos=(i,j,0))
                    if i > 0 and image[i-1][j].any() > 0:
                        G.add_edge((i,j), (i-1,j))
                    if j > 0 and image[i][j-1].any() > 0:
                        G.add_edge((i,j), (i,j-1))
                    if i < image.shape[0]-1 and image[i+1][j].any() > 0:
                        G.add_edge((i,j), (i+1,j))
                    if j < image.shape[1]-1 and image[i][j+1].any() > 0:
                        G.add_edge((i,j), (i,j+1))
                    if i > 0 and j > 0 and image[i-1][j-1].any() > 0:
                        G.add_edge((i,j), (i-1,j-1))
                    if i < image.shape[0]-1 and j < image.shape[1]-1 and image[i+1][j+1].any() > 0:
                        G.add_edge((i,j), (i+1,j+1))
                    if i > 0 and j < image.shape[1]-1 and image[i-1][j+1].any() > 0:
                        G.add_edge((i,j), (i-1,j+1))
                    if i < image.shape[0]-1 and j > 0 and image[i+1][j-1].any() > 0:
                        G.add_edge((i,j), (i+1,j-1))

        return G

    def process(self):
        f = h5py.File('/hdfs1/Data/Shrutimoy/quark-gluon_data-set_n139306.hdf5', 'r')
        X_jets = np.asarray(f['X_jets'])
        m0 = np.asarray(f['m0'])
        pt = np.asarray(f['pt'])
        y = np.asarray(f['y'])
        data_list = []
        for i in tqdm(range(len(X_jets))):
            G = self.create_graph(X_jets[i])
            data = from_networkx(G)
            data.y = torch.tensor(y[i], dtype=torch.long)
            data.m = torch.tensor(m0[i], dtype=torch.float)
            data.p = torch.tensor(pt[i], dtype=torch.float)
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [44]:
def train_test_split(X, y, test_size):
    dataset_size = len(X)
    train_data_X = X[:int(dataset_size*(1-test_size))]
    test_data_X = X[int(dataset_size*(1-test_size)):]
    train_data_y = y[:int(dataset_size*(1-test_size))]
    test_data_y = y[int(dataset_size*(1-test_size)):]
    return train_data_X, test_data_X, train_data_y, test_data_y


In [109]:
#loss function for Graph Autoencoder
def loss_function(recon_x, x, mu, logvar, mc1_loss, o1_loss, mc2_loss, o2_loss, mc3_loss, o3_loss):
    #print(recon_x.shape, x.shape)
    BCE = F.mse_loss(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD + mc1_loss + o1_loss + mc2_loss + o2_loss + mc3_loss + o3_loss

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.Sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.Sum += val * n
        self.count += n
        self.avg = self.Sum / self.count

In [110]:
#Graph Autoencoder Model
class GAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, cluster_dim, dropout, batch_size):
        super(GAE, self).__init__()
        self.batch_size = batch_size
        self.fc0 = nn.Linear(input_dim+3, hidden_dim[0])
        self.batch_norm0 = nn.BatchNorm1d(hidden_dim[0])
        self.conv1 = nn.ModuleList()
        self.fc11 = nn.Linear(hidden_dim[0]*2, hidden_dim[1])
        self.batch_norm11 = nn.BatchNorm1d(hidden_dim[1])
        self.mlp1 = nn.Sequential()
        self.conv2 = nn.ModuleList()
        self.fc12 = nn.Linear(hidden_dim[1]*2, hidden_dim[2])
        self.batch_norm12 = nn.BatchNorm1d(hidden_dim[2])
        self.mlp2 = nn.Sequential()
        self.conv3 = nn.ModuleList()
        self.fc13 = nn.Linear(hidden_dim[2]*2, hidden_dim[3])
        self.batch_norm13 = nn.BatchNorm1d(hidden_dim[3])
        self.mlp3 = nn.Sequential()
        self.conv4 = nn.ModuleList()
        self.fc14 = nn.Linear(hidden_dim[3]*2, hidden_dim[4])
        self.batch_norm14 = nn.BatchNorm1d(hidden_dim[4])

        self.conv1.append(GATv2Conv(hidden_dim[0], hidden_dim[0], add_self_loops=True, dropout=dropout))
        self.conv1.append(ClusterGCNConv(hidden_dim[0], hidden_dim[0], add_self_loops=True, dropout=dropout))
        self.mlp1.append(nn.Linear(hidden_dim[1], hidden_dim[1]))
        self.mlp1.append(nn.ReLU())
        self.mlp1.append(nn.Linear(hidden_dim[1], cluster_dim[0]))

        self.conv2.append(GATv2Conv(hidden_dim[1], hidden_dim[1], add_self_loops=True, dropout=dropout, edge_dim=1))
        self.conv2.append(ClusterGCNConv(hidden_dim[1], hidden_dim[1], add_self_loops=True, dropout=dropout))
        self.mlp2.append(nn.Linear(hidden_dim[2], hidden_dim[2]))
        self.mlp2.append(nn.ReLU())
        self.mlp2.append(nn.Linear(hidden_dim[2], cluster_dim[1]))
    
        self.conv3.append(GATv2Conv(hidden_dim[2], hidden_dim[2], add_self_loops=True, dropout=dropout, edge_dim=1))
        self.conv3.append(ClusterGCNConv(hidden_dim[2], hidden_dim[2], add_self_loops=True, dropout=dropout))
        self.mlp3.append(nn.Linear(hidden_dim[3], hidden_dim[3]))
        self.mlp3.append(nn.ReLU())
        self.mlp3.append(nn.Linear(hidden_dim[3], cluster_dim[2]))

        self.conv4.append(GATv2Conv(hidden_dim[3], hidden_dim[3], add_self_loops=True, dropout=dropout, edge_dim=1))
        self.conv4.append(ClusterGCNConv(hidden_dim[3], hidden_dim[3], add_self_loops=True, dropout=dropout))

        self.fc1 = nn.Linear(hidden_dim[4], hidden_dim[5])
        self.fc2 = nn.Linear(hidden_dim[4], hidden_dim[5])
        self.fc3 = nn.Linear(hidden_dim[5], hidden_dim[4])

        self.conv5 = nn.ModuleList()
        self.fc21 = nn.Linear(hidden_dim[4]*2, hidden_dim[3])
        self.batch_norm21 = nn.BatchNorm1d(hidden_dim[3])
        self.conv6 = nn.ModuleList()
        self.fc22 = nn.Linear(hidden_dim[3]*2, hidden_dim[2])
        self.batch_norm22 = nn.BatchNorm1d(hidden_dim[2])
        self.conv7 = nn.ModuleList()
        self.fc23 = nn.Linear(hidden_dim[2]*2, hidden_dim[1])
        self.batch_norm23 = nn.BatchNorm1d(hidden_dim[1])
        self.conv8 = nn.ModuleList()
        self.fc24 = nn.Linear(hidden_dim[1]*2, hidden_dim[0])
        self.batch_norm24 = nn.BatchNorm1d(hidden_dim[0])

        self.conv5.append(GATv2Conv(hidden_dim[4], hidden_dim[4], add_self_loops=True, dropout=dropout, edge_dim=1))    
        self.conv5.append(ClusterGCNConv(hidden_dim[4], hidden_dim[4], add_self_loops=True, dropout=dropout))
        

        self.conv6.append(GATv2Conv(hidden_dim[3], hidden_dim[3], add_self_loops=True, dropout=dropout, edge_dim=1))
        self.conv6.append(ClusterGCNConv(hidden_dim[3], hidden_dim[3], add_self_loops=True, dropout=dropout))
        

        self.conv7.append(GATv2Conv(hidden_dim[2], hidden_dim[2], add_self_loops=True, dropout=dropout, edge_dim=1))
        self.conv7.append(ClusterGCNConv(hidden_dim[2], hidden_dim[2], add_self_loops=True, dropout=dropout))
        

        self.conv8.append(GATv2Conv(hidden_dim[1], hidden_dim[1], add_self_loops=True, dropout=dropout, edge_dim=1))
        self.conv8.append(ClusterGCNConv(hidden_dim[1], hidden_dim[1], add_self_loops=True, dropout=dropout))
        

        self.fc4 = nn.Linear(hidden_dim[0], input_dim)

    def encode(self, x, edge_index, pos, batch):
        #print("encoder start")
        pos = pos.float()
        x = torch.cat((x, pos), dim=1)
        
        #print(x.shape)
        x = self.fc0(x)
        x = F.relu(x)
        x = self.batch_norm0(x)
        #print("1st layer")
        x1 = self.conv1[0](x, edge_index)
        x3 = self.conv1[1](x, edge_index)
        x = torch.cat((x1, x3), dim=1)
        x = self.fc11(x)
        x = F.relu(x)
        x = self.batch_norm11(x)
        adj = to_dense_adj(edge_index, max_num_nodes=x.size(0))
        x_new, mask = to_dense_batch(x)
        s1 = self.mlp1(x_new)
        x, adj, mc1_loss, o1_loss = dense_mincut_pool(x_new, adj, s1, mask)
        x = x.reshape(x.size(0)*x.size(1), -1)
        edge_index, edge_attr = dense_to_sparse(adj)

        #print("2nd layer")  
        x1 = self.conv2[0](x, edge_index, edge_attr=edge_attr)
        x3 = self.conv2[1](x, edge_index)
        x = torch.cat((x1, x3), dim=1)
        x = self.fc12(x)
        x = F.relu(x)
        x = self.batch_norm12(x)
        adj = to_dense_adj(edge_index, max_num_nodes=x.size(0))
        x_new, mask = to_dense_batch(x)
        s2 = self.mlp2(x_new)
        x, adj, mc2_loss, o2_loss = dense_mincut_pool(x_new, adj, s2, mask)
        edge_index, edge_attr = dense_to_sparse(adj)

        #print("3rd layer")
        x = x.reshape(x.size(0)*x.size(1), -1)
        x1 = self.conv3[0](x, edge_index, edge_attr=edge_attr)
        x3 = self.conv3[1](x, edge_index)
        x = torch.cat((x1, x3), dim=1)
        x = self.fc13(x)  
        x = F.relu(x)
        x = self.batch_norm13(x)
        adj = to_dense_adj(edge_index, max_num_nodes=x.size(0))
        x_new, mask = to_dense_batch(x)
        s3 = self.mlp3(x_new)
        x, adj, mc3_loss, o3_loss = dense_mincut_pool(x_new, adj, s3, mask)
        edge_index, edge_attr = dense_to_sparse(adj)
        last_edge = edge_index
        last_attr = edge_attr

        #print("4th layer")
        x = x.reshape(x.size(0)*x.size(1), -1)
        x1 = self.conv4[0](x, edge_index, edge_attr=edge_attr)
        x3 = self.conv4[1](x, edge_index)
        x = torch.cat((x1, x3), dim=1)
        x = self.fc14(x)
        x = F.relu(x)
        x = self.batch_norm14(x)
        #print("encoder end")
        return self.fc1(x), self.fc2(x), mc1_loss, o1_loss, mc2_loss, o2_loss, mc3_loss, o3_loss, s1, s2, s3, last_edge, last_attr
    
    def reparameterize(self, mu, sigma):
        #print("in reparameterize")
        std = sigma.mul(0.5).exp_()
        eps = torch.FloatTensor(std.size()).normal_().to(device)
        eps = Variable(eps)
        #print("reparameterize end")
        return eps.mul(std).add_(mu)
    
    def decode(self, z, last_edge, last_attr, S1, S2, S3):
        #print("in decode")
        z = self.fc3(z)
        z1 = self.conv5[0](z, last_edge, last_attr)
        z2 = self.conv5[1](z, last_edge)
        z = torch.cat((z1, z2), dim=1)
        z = self.fc21(z)
        z = F.relu(z)
        z = self.batch_norm21(z)
        s = torch.squeeze(S3)
        z = torch.matmul(s, z)
        adj = to_dense_adj(last_edge, edge_attr=last_attr, max_num_nodes=z.size(0))
        #print(adj.shape, S[2].reshape(S[2].shape[1], S[2].shape[2]).shape)
        adj_pool = torch.matmul(s.T, adj)
        #print(adj_pool.shape)
        adj_pool = torch.matmul(adj_pool, s)
        edge_index, edge_attr = dense_to_sparse(adj_pool)
 
        z1 = self.conv6[0](z, edge_index, edge_attr)
        z2 = self.conv6[1](z, edge_index)
        z = torch.cat((z1, z2), dim=1)
        z = self.fc22(z)
        z = F.relu(z)
        z = self.batch_norm22(z)
        s = torch.squeeze(S2)
        z = torch.matmul(s, z)
        adj = to_dense_adj(edge_index, max_num_nodes=z.size(0))
        adj_pool = torch.matmul(s.T, adj)
        adj_pool = torch.matmul(adj_pool, s)
        edge_index, edge_attr = dense_to_sparse(adj_pool)

        z1 = self.conv7[0](z,edge_index, edge_attr)
        z2 = self.conv7[1](z, edge_index)
        z = torch.cat((z1, z2), dim=1)
        z = self.fc23(z)
        z = F.relu(z)
        z = self.batch_norm23(z)
        s = torch.squeeze(S1)
        z = torch.matmul(s, z)
        adj = to_dense_adj(edge_index, max_num_nodes=z.size(0))
        adj_pool = torch.matmul(s.T, adj)
        adj_pool = torch.matmul(adj_pool, s)
        edge_index, edge_attr = dense_to_sparse(adj_pool)

        z1 = self.conv8[0](z, edge_index, edge_attr)
        z2 = self.conv8[1](z, edge_index)
        z = torch.cat((z1, z2), dim=1)
        z = self.fc24(z)
        z = F.relu(z)
        z = self.batch_norm24(z)
        
        z = self.fc4(z)
        #print(z.shape)
        #print("out decode")
        z = F.softmax(z)
        return z
    
    def forward(self, data):
        x, edge_index, pos, batch = data.x, data.edge_index, data.pos, data.batch
        mu, sigma, mc1_loss, o1_loss, mc2_loss, o2_loss, mc3_loss, o3_loss, s1, s2, s3, last_edge, last_attr = self.encode(x, edge_index, pos, batch)
        z = self.reparameterize(mu, sigma)
        #print(S[0].shape, S[1].shape, S[2].shape)
        x_new = self.decode(z, last_edge, last_attr, s1, s2, s3)
        return x_new, mu, sigma, mc1_loss, o1_loss, mc2_loss, o2_loss, mc3_loss, o3_loss

In [47]:
dataset = Quark_Gluon_Dataset(root='/hdfs1/Data/Shubhajit/Quark_Gluon_Data_1/')

In [48]:
# split dataset into train, validation and test
train_dataset = dataset[:int(0.8*len(dataset))]
val_dataset = dataset[int(0.8*len(dataset)):int(0.9*len(dataset))]
test_dataset = dataset[int(0.9*len(dataset)):]

# load the dataset
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [49]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [113]:
#initialize the model
model = GAE(3, [16, 32, 64, 128, 256, 512], [256, 128, 64], 0, 16).to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.50)

In [114]:
loss_train, loss_test, loss_val = [], [], []
#traing loop
def train():
    model.train()
    train_loss = 0
    step = 0
    for data in tqdm(train_loader):
        optimizer.zero_grad()
        data = data.to(device)
        #copy data.x tensor to another tensor
        orig_x = data.x.clone().detach()
        #print(data.x.shape, orig_x.shape)
        z, mu, sigma, mc1_loss, o1_loss, mc2_loss, o2_loss, mc3_loss, o3_loss = model(data)
        #print(data.x.shape, orig_x.shape, z.shape)
        loss = loss_function(z, orig_x, mu, sigma, mc1_loss, o1_loss, mc2_loss, o2_loss, mc3_loss, o3_loss)
        loss.backward()
        train_loss += loss.item()
        #print(loss.item())
        optimizer.step()
        step += 1
        if step%100 == 0:
            print("Epoch: {} Train loss: {:.4f}".format(
          epoch+1, train_loss/(step*16)))
    return train_loss/len(train_loader.dataset)
    
    
#validation loop
def validation():
    model.eval()
    val_loss = 0
    for data in tqdm(val_loader):
        data = data.to(device)
        orig_x = data.x.clone().detach()
        z, mu, sigma, mc1_loss, o1_loss, mc2_loss, o2_loss, mc3_loss, o3_loss = model(data)
        loss = loss_function(z, orig_x, mu, sigma, mc1_loss, o1_loss, mc2_loss, o2_loss, mc3_loss, o3_loss)
        val_loss += loss.item()
    return val_loss/len(val_loader.dataset)

    
#validation loop
def test():
    model.eval()
    test_loss = 0
    for data in tqdm(test_loader):
        data = data.to(device)
        orig_x = data.x.clone().detach()
        z, mu, sigma, mc1_loss, o1_loss, mc2_loss, o2_loss, mc3_loss, o3_loss = model(data)
        loss = loss_function(z, orig_x, mu, sigma, mc1_loss, o1_loss, mc2_loss, o2_loss, mc3_loss, o3_loss)
        test_loss += loss.item()
    return test_loss/len(test_loader.dataset)

#training loop
for epoch in range(10):
    train_loss = train()
    val_loss = validation()
    test_loss = test()
    #scheduler.step()
    print('====> Epoch: {} Train loss: {:.4f}'.format(
          epoch+1, train_loss))
    print('====> Epoch: {} Validation loss: {:.4f}'.format(
          epoch+1, val_loss))
    print('====> Epoch: {} Test loss: {:.4f}'.format(
            epoch+1, test_loss))
    loss_train.append(train_loss)
    loss_test.append(test_loss)
    loss_val.append(val_loss)

  1%|▏         | 100/6966 [00:32<39:32,  2.89it/s]

Epoch: 1 Train loss: 19.2400


  3%|▎         | 200/6966 [01:06<40:18,  2.80it/s]

Epoch: 1 Train loss: 9.6792


  4%|▍         | 300/6966 [01:39<38:15,  2.90it/s]

Epoch: 1 Train loss: 6.5009


  6%|▌         | 400/6966 [02:13<40:34,  2.70it/s]

Epoch: 1 Train loss: 4.8985


  7%|▋         | 500/6966 [02:46<33:17,  3.24it/s]

Epoch: 1 Train loss: 3.9386


  9%|▊         | 600/6966 [03:19<31:56,  3.32it/s]

Epoch: 1 Train loss: 3.2969


 10%|█         | 700/6966 [03:53<31:31,  3.31it/s]

Epoch: 1 Train loss: 2.8380


 11%|█▏        | 800/6966 [04:26<34:36,  2.97it/s]

Epoch: 1 Train loss: 2.4936


 13%|█▎        | 900/6966 [05:00<32:43,  3.09it/s]

Epoch: 1 Train loss: 2.2256


 14%|█▍        | 1000/6966 [05:34<34:48,  2.86it/s]

Epoch: 1 Train loss: 2.0112


 16%|█▌        | 1100/6966 [06:07<35:10,  2.78it/s]

Epoch: 1 Train loss: 1.8357


 17%|█▋        | 1200/6966 [06:41<32:24,  2.97it/s]

Epoch: 1 Train loss: 1.6894


 19%|█▊        | 1300/6966 [07:15<33:47,  2.80it/s]

Epoch: 1 Train loss: 1.5656


 20%|██        | 1400/6966 [07:49<32:05,  2.89it/s]

Epoch: 1 Train loss: 1.4595


 22%|██▏       | 1500/6966 [08:22<28:36,  3.18it/s]

Epoch: 1 Train loss: 1.3674


 23%|██▎       | 1600/6966 [08:56<29:21,  3.05it/s]

Epoch: 1 Train loss: 1.2870


 24%|██▍       | 1700/6966 [09:29<27:56,  3.14it/s]

Epoch: 1 Train loss: 1.2160


 26%|██▌       | 1800/6966 [10:02<27:03,  3.18it/s]

Epoch: 1 Train loss: 1.1528


 27%|██▋       | 1900/6966 [10:36<29:17,  2.88it/s]

Epoch: 1 Train loss: 1.0962


 29%|██▊       | 2000/6966 [11:09<29:20,  2.82it/s]

Epoch: 1 Train loss: 1.0454


 30%|███       | 2100/6966 [11:44<28:28,  2.85it/s]

Epoch: 1 Train loss: 0.9993


 32%|███▏      | 2200/6966 [12:17<26:23,  3.01it/s]

Epoch: 1 Train loss: 0.9575


 33%|███▎      | 2300/6966 [12:51<26:31,  2.93it/s]

Epoch: 1 Train loss: 0.9193


 34%|███▍      | 2400/6966 [13:25<25:18,  3.01it/s]

Epoch: 1 Train loss: 0.8843


 36%|███▌      | 2500/6966 [13:59<26:16,  2.83it/s]

Epoch: 1 Train loss: 0.8521


 37%|███▋      | 2600/6966 [14:33<24:24,  2.98it/s]

Epoch: 1 Train loss: 0.8224


 39%|███▉      | 2700/6966 [15:07<24:20,  2.92it/s]

Epoch: 1 Train loss: 0.7949


 40%|████      | 2800/6966 [15:41<23:32,  2.95it/s]

Epoch: 1 Train loss: 0.7695


 42%|████▏     | 2900/6966 [16:13<21:46,  3.11it/s]

Epoch: 1 Train loss: 0.7457


 43%|████▎     | 3000/6966 [16:47<23:32,  2.81it/s]

Epoch: 1 Train loss: 0.7236


 45%|████▍     | 3100/6966 [17:21<22:51,  2.82it/s]

Epoch: 1 Train loss: 0.7030


 46%|████▌     | 3200/6966 [17:55<20:42,  3.03it/s]

Epoch: 1 Train loss: 0.6836


 47%|████▋     | 3300/6966 [18:29<20:25,  2.99it/s]

Epoch: 1 Train loss: 0.6655


 49%|████▉     | 3400/6966 [19:03<19:52,  2.99it/s]

Epoch: 1 Train loss: 0.6484


 50%|█████     | 3500/6966 [19:37<19:41,  2.93it/s]

Epoch: 1 Train loss: 0.6324


 52%|█████▏    | 3600/6966 [20:11<18:35,  3.02it/s]

Epoch: 1 Train loss: 0.6172


 53%|█████▎    | 3700/6966 [20:47<19:22,  2.81it/s]

Epoch: 1 Train loss: 0.6029


 55%|█████▍    | 3800/6966 [21:20<16:48,  3.14it/s]

Epoch: 1 Train loss: 0.5894


 56%|█████▌    | 3900/6966 [21:53<16:15,  3.14it/s]

Epoch: 1 Train loss: 0.5766


 57%|█████▋    | 4000/6966 [22:27<15:54,  3.11it/s]

Epoch: 1 Train loss: 0.5644


 59%|█████▉    | 4100/6966 [23:01<16:22,  2.92it/s]

Epoch: 1 Train loss: 0.5528


 60%|██████    | 4200/6966 [23:35<15:35,  2.96it/s]

Epoch: 1 Train loss: 0.5424


 62%|██████▏   | 4300/6966 [24:08<14:34,  3.05it/s]

Epoch: 1 Train loss: 0.5321


 63%|██████▎   | 4400/6966 [24:42<14:52,  2.88it/s]

Epoch: 1 Train loss: 0.5220


 65%|██████▍   | 4500/6966 [25:16<13:18,  3.09it/s]

Epoch: 1 Train loss: 0.5124


 66%|██████▌   | 4600/6966 [25:50<12:36,  3.13it/s]

Epoch: 1 Train loss: 0.5032


 67%|██████▋   | 4700/6966 [26:24<12:22,  3.05it/s]

Epoch: 1 Train loss: 0.4945


 69%|██████▉   | 4800/6966 [26:58<12:26,  2.90it/s]

Epoch: 1 Train loss: 0.4861


 70%|███████   | 4900/6966 [27:32<11:07,  3.10it/s]

Epoch: 1 Train loss: 0.4780


 72%|███████▏  | 5000/6966 [28:06<10:59,  2.98it/s]

Epoch: 1 Train loss: 0.4703


 73%|███████▎  | 5100/6966 [28:40<10:46,  2.89it/s]

Epoch: 1 Train loss: 0.4629


 75%|███████▍  | 5200/6966 [29:14<09:56,  2.96it/s]

Epoch: 1 Train loss: 0.4557


 76%|███████▌  | 5300/6966 [29:49<09:33,  2.90it/s]

Epoch: 1 Train loss: 0.4488


 78%|███████▊  | 5400/6966 [30:23<09:07,  2.86it/s]

Epoch: 1 Train loss: 0.4422


 79%|███████▉  | 5500/6966 [30:57<08:19,  2.94it/s]

Epoch: 1 Train loss: 0.4358


 80%|████████  | 5600/6966 [31:32<08:11,  2.78it/s]

Epoch: 1 Train loss: 0.4296


 82%|████████▏ | 5700/6966 [32:05<07:24,  2.85it/s]

Epoch: 1 Train loss: 0.4237


 83%|████████▎ | 5800/6966 [32:40<06:35,  2.95it/s]

Epoch: 1 Train loss: 0.4179


 85%|████████▍ | 5900/6966 [33:13<05:56,  2.99it/s]

Epoch: 1 Train loss: 0.4124


 86%|████████▌ | 6000/6966 [33:47<05:12,  3.09it/s]

Epoch: 1 Train loss: 0.4070


 88%|████████▊ | 6100/6966 [34:21<04:59,  2.89it/s]

Epoch: 1 Train loss: 0.4018


 89%|████████▉ | 6200/6966 [34:55<04:20,  2.94it/s]

Epoch: 1 Train loss: 0.3968


 90%|█████████ | 6300/6966 [35:29<03:46,  2.94it/s]

Epoch: 1 Train loss: 0.3919


 92%|█████████▏| 6400/6966 [36:03<03:22,  2.79it/s]

Epoch: 1 Train loss: 0.3872


 93%|█████████▎| 6500/6966 [36:37<02:33,  3.04it/s]

Epoch: 1 Train loss: 0.3826


 95%|█████████▍| 6600/6966 [37:11<02:01,  3.02it/s]

Epoch: 1 Train loss: 0.3781


 96%|█████████▌| 6700/6966 [37:44<01:29,  2.97it/s]

Epoch: 1 Train loss: 0.3738


 98%|█████████▊| 6800/6966 [38:18<00:55,  2.97it/s]

Epoch: 1 Train loss: 0.3696


 99%|█████████▉| 6900/6966 [38:52<00:22,  2.98it/s]

Epoch: 1 Train loss: 0.3656


100%|██████████| 6966/6966 [39:13<00:00,  2.96it/s]
100%|██████████| 871/871 [04:31<00:00,  3.21it/s]
100%|██████████| 871/871 [04:30<00:00,  3.22it/s]


====> Epoch: 1 Train loss: 0.3630
====> Epoch: 1 Validation loss: nan
====> Epoch: 1 Test loss: nan


  1%|▏         | 100/6966 [00:33<37:10,  3.08it/s]

Epoch: 2 Train loss: 0.0874


  3%|▎         | 200/6966 [01:07<36:29,  3.09it/s]

Epoch: 2 Train loss: 0.0886


  4%|▍         | 300/6966 [01:40<37:21,  2.97it/s]

Epoch: 2 Train loss: 0.0883


  6%|▌         | 400/6966 [02:14<37:15,  2.94it/s]

Epoch: 2 Train loss: 0.0881


  7%|▋         | 500/6966 [02:47<33:55,  3.18it/s]

Epoch: 2 Train loss: 0.0883


  9%|▊         | 598/6966 [03:20<35:53,  2.96it/s]