In [None]:
# import packages
import h5py
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms, utils
from torch.utils.data import TensorDataset
from tqdm import tqdm
from torch_geometric.nn import TransformerConv, global_max_pool, GATv2Conv, PointNetConv, ClusterGCNConv, PointTransformerConv, global_mean_pool
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import time
import torch
from torch_geometric.data import Data, InMemoryDataset, download_url
import torch_geometric.transforms as T
import networkx as nx
from torch_cluster import knn_graph
from torch_geometric.utils import from_networkx

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


**Helper and Metric**


In [9]:
# Dataset class
import torch
from torch_geometric.data import Data, InMemoryDataset, download_url
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

class Quark_Gluon_Dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(Quark_Gluon_Dataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        pass

    def create_graph(self, image):
        G = nx.Graph()
        for i in range(image.shape[0]):
            for j in range(image.shape[1]):
                if image[i][j].any() > 0:
                    #add node (i,j) to graph with attribute 'x' = image[i][j] and pos = (i,j)
                    G.add_node((i,j), x=image[i][j], pos=(i,j,0))
                    if i > 0 and image[i-1][j].any() > 0:
                        G.add_edge((i,j), (i-1,j))
                    if j > 0 and image[i][j-1].any() > 0:
                        G.add_edge((i,j), (i,j-1))
                    if i < image.shape[0]-1 and image[i+1][j].any() > 0:
                        G.add_edge((i,j), (i+1,j))
                    if j < image.shape[1]-1 and image[i][j+1].any() > 0:
                        G.add_edge((i,j), (i,j+1))
                    if i > 0 and j > 0 and image[i-1][j-1].any() > 0:
                        G.add_edge((i,j), (i-1,j-1))
                    if i < image.shape[0]-1 and j < image.shape[1]-1 and image[i+1][j+1].any() > 0:
                        G.add_edge((i,j), (i+1,j+1))
                    if i > 0 and j < image.shape[1]-1 and image[i-1][j+1].any() > 0:
                        G.add_edge((i,j), (i-1,j+1))
                    if i < image.shape[0]-1 and j > 0 and image[i+1][j-1].any() > 0:
                        G.add_edge((i,j), (i+1,j-1))

        return G

    def process(self):
        f = h5py.File('/hdfs1/Data/Shrutimoy/quark-gluon_data-set_n139306.hdf5', 'r')
        X_jets = np.asarray(f['X_jets'])
        m0 = np.asarray(f['m0'])
        pt = np.asarray(f['pt'])
        y = np.asarray(f['y'])
        data_list = []
        for i in tqdm(range(len(X_jets))):
            G = self.create_graph(X_jets[i])
            data = from_networkx(G)
            data.y = torch.tensor(y[i], dtype=torch.long)
            data.m = torch.tensor(m0[i], dtype=torch.float)
            data.p = torch.tensor(pt[i], dtype=torch.float)
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


In [4]:
def train_test_split(X, y, test_size):
    dataset_size = len(X)
    train_data_X = X[:int(dataset_size*(1-test_size))]
    test_data_X = X[int(dataset_size*(1-test_size)):]
    train_data_y = y[:int(dataset_size*(1-test_size))]
    test_data_y = y[int(dataset_size*(1-test_size)):]
    return train_data_X, test_data_X, train_data_y, test_data_y


In [5]:
def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


**Model**


In [10]:
# Define the model
class GraphClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout):
        super(GraphClassifier, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.dropout = dropout
        self.convs_trans = nn.ModuleList()
        self.convs_point = nn.ModuleList()
        self.convs_attn = nn.ModuleList()
        self.fc1 = nn.ModuleList()
        self.batch_norm = nn.ModuleList()
        for i in range(num_layers):
            self.convs_trans.append(GATv2Conv(hidden_dim*(3**i), hidden_dim*(3**i), add_self_loops=True))
            self.convs_point.append(PointTransformerConv(hidden_dim*(3**i), hidden_dim*(3**i),  add_self_loops=True))
            self.convs_attn.append(ClusterGCNConv(hidden_dim*(3**i), hidden_dim*(3**i), add_self_loops=True))
            self.batch_norm.append(nn.BatchNorm1d(hidden_dim*(3**(i+1))))
            self.fc1.append(nn.Linear(hidden_dim*(3**(i+1)), hidden_dim*(3**(i+1))))
        #self.fc1 = nn.Linear(hidden_dim*3, hidden_dim)
        self.fc0 = nn.Linear(input_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.batch_norm1 = nn.BatchNorm1d(hidden_dim*(3**num_layers)+2)
        self.fc2 = nn.Linear(hidden_dim*(3**num_layers)+2, hidden_dim*(3**num_layers)+2)
        self.fc3 = nn.Linear(hidden_dim*(3**num_layers)+2, output_dim)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x, edge_index, pos, batch, m0, pt):
        pos = pos.float()
        x = self.fc0(x)
        for i in range(self.num_layers):
            x1 = self.convs_point[i](x, pos, edge_index)
            x1 = F.elu(x1)
            x2 = self.convs_attn[i](x, edge_index)
            x2 = F.elu(x2)
            x3 = self.convs_trans[i](x, edge_index)
            x3 = F.elu(x3)
            x = torch.cat([x1, x2, x3], dim=1)
            x = self.batch_norm[i](x)
            x = self.fc1[i](x)
        x = global_mean_pool(x, batch)
        m0 = m0.reshape(-1, 1)
        pt = pt.reshape(-1, 1)
        x = torch.cat([x, m0, pt], dim=1)
        x = self.fc2(x)
        x = F.elu(x)
        x = self.batch_norm1(x)
        x = self.fc3(x)
        return self.softmax(x)


**Main**


In [11]:
dataset = Quark_Gluon_Dataset(root='/hdfs1/Data/Shubhajit/Quark_Gluon_Data_1/')


In [12]:
# split dataset into train, validation and test
train_dataset = dataset[:int(0.8*len(dataset))]
val_dataset = dataset[int(0.8*len(dataset)):int(0.9*len(dataset))]
test_dataset = dataset[int(0.9*len(dataset)):]

# load the dataset
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


In [13]:
#initialize the model
model = GraphClassifier(3, 16,2,3, 0).to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.50)

In [14]:
model.parameters

<bound method Module.parameters of GraphClassifier(
  (convs_trans): ModuleList(
    (0): GATv2Conv(16, 16, heads=1)
    (1): GATv2Conv(48, 48, heads=1)
    (2): GATv2Conv(144, 144, heads=1)
  )
  (convs_point): ModuleList(
    (0): PointTransformerConv(16, 16)
    (1): PointTransformerConv(48, 48)
    (2): PointTransformerConv(144, 144)
  )
  (convs_attn): ModuleList(
    (0): ClusterGCNConv(16, 16, diag_lambda=0.0)
    (1): ClusterGCNConv(48, 48, diag_lambda=0.0)
    (2): ClusterGCNConv(144, 144, diag_lambda=0.0)
  )
  (fc1): ModuleList(
    (0): Linear(in_features=48, out_features=48, bias=True)
    (1): Linear(in_features=144, out_features=144, bias=True)
    (2): Linear(in_features=432, out_features=432, bias=True)
  )
  (batch_norm): ModuleList(
    (0): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): BatchNorm1d(432, eps=1e-05, momentum=0.1, affine=T

In [15]:
# training loop
def train(data):
    optimizer.zero_grad()
    data = data.to(device)
    output = model(data.x, data.edge_index, data.pos, data.batch, data.m, data.p)
    loss_train = F.cross_entropy(output, data.y)
    loss_train.backward()
    optimizer.step()
    return output, loss_train


def test(data):
    data = data.to(device)
    output = model(data.x, data.edge_index, data.pos, data.batch, data.m, data.p)
    loss_test = F.cross_entropy(output, data.y)
    return output, loss_test


In [16]:


for epoch in range(100):
    start = time.time()
    # Train for one epoch
    model.train()
    train_loss = AverageMeter()
    train_acc = AverageMeter()
    step = 0
    for data in tqdm(train_loader):
        output, loss = train(data)
        train_loss.update(loss.item(), output.size(0))
        train_acc.update(accuracy(output, data.y), output.size(0))

    # Evaluate on validation set
    model.eval()
    val_loss = AverageMeter()
    val_acc = AverageMeter()
    for data in tqdm(val_loader):
        output, loss = test(data)
        val_loss.update(loss.item(), output.size(0))
        val_acc.update(accuracy(output, data.y), output.size(0))

    # Evaluate on test set
    test_loss = AverageMeter()
    test_acc = AverageMeter()
    for data in tqdm(test_loader):
        output, loss = test(data)
        test_loss.update(loss.item(), output.size(0))
        test_acc.update(accuracy(output, data.y), output.size(0))

    print("epoch:" + '%03d ' % (epoch + 1) + "train_loss=" + "{:.5f} ".format(train_loss.avg) + "train_acc=" + "{:.5f} ".format(train_acc.avg) +
          "val_acc=" + "{:.5f} ".format(val_acc.avg) + "test_acc=" + "{:.5f} ".format(test_acc.avg) + "time=" + "{:.5f} ".format(time.time() - start))

    scheduler.step()

    # log the results tom tensorboard
    writer.add_scalar('train_loss', train_loss.avg, epoch)
    writer.add_scalar('train_acc', train_acc.avg, epoch)
    writer.add_scalar('val_loss', val_loss.avg, epoch)
    writer.add_scalar('val_acc', val_acc.avg, epoch)
    writer.add_scalar('test_loss', test_loss.avg, epoch)
    writer.add_scalar('test_acc', test_acc.avg, epoch)


100%|██████████| 6966/6966 [05:17<00:00, 21.93it/s]
100%|██████████| 871/871 [00:18<00:00, 46.76it/s]
100%|██████████| 871/871 [00:17<00:00, 48.43it/s]


epoch:001 train_loss=0.58725 train_acc=0.70798 val_acc=0.57333 test_acc=0.57419 time=354.22800 


100%|██████████| 6966/6966 [04:53<00:00, 23.76it/s]
100%|██████████| 871/871 [00:15<00:00, 56.19it/s]
100%|██████████| 871/871 [00:12<00:00, 67.93it/s]


epoch:002 train_loss=0.58253 train_acc=0.71266 val_acc=0.68430 test_acc=0.67913 time=321.55234 


100%|██████████| 6966/6966 [04:59<00:00, 23.24it/s]
100%|██████████| 871/871 [00:15<00:00, 56.63it/s]
100%|██████████| 871/871 [00:13<00:00, 65.65it/s]


epoch:003 train_loss=0.57949 train_acc=0.71622 val_acc=0.49185 test_acc=0.48941 time=328.39590 


100%|██████████| 6966/6966 [05:06<00:00, 22.74it/s]
100%|██████████| 871/871 [00:15<00:00, 57.16it/s]
100%|██████████| 871/871 [00:15<00:00, 55.13it/s]


epoch:004 train_loss=0.57976 train_acc=0.71579 val_acc=0.50039 test_acc=0.50197 time=337.42938 


100%|██████████| 6966/6966 [04:55<00:00, 23.55it/s]
100%|██████████| 871/871 [00:14<00:00, 59.35it/s]
100%|██████████| 871/871 [00:13<00:00, 63.90it/s]


epoch:005 train_loss=0.57855 train_acc=0.71638 val_acc=0.73182 test_acc=0.72730 time=324.08411 


100%|██████████| 6966/6966 [05:09<00:00, 22.48it/s]
100%|██████████| 871/871 [00:14<00:00, 60.00it/s]
100%|██████████| 871/871 [00:14<00:00, 58.65it/s]


epoch:006 train_loss=0.57730 train_acc=0.71793 val_acc=0.52459 test_acc=0.51533 time=339.19987 


100%|██████████| 6966/6966 [05:09<00:00, 22.48it/s]
100%|██████████| 871/871 [00:16<00:00, 54.26it/s]
100%|██████████| 871/871 [00:14<00:00, 60.34it/s]


epoch:007 train_loss=0.57649 train_acc=0.72001 val_acc=0.49537 test_acc=0.49846 time=340.43804 


100%|██████████| 6966/6966 [04:56<00:00, 23.51it/s]
100%|██████████| 871/871 [00:14<00:00, 59.51it/s]
100%|██████████| 871/871 [00:12<00:00, 68.10it/s]


epoch:008 train_loss=0.57706 train_acc=0.71923 val_acc=0.67045 test_acc=0.66743 time=323.77615 


100%|██████████| 6966/6966 [04:59<00:00, 23.30it/s]
100%|██████████| 871/871 [00:14<00:00, 61.16it/s]
100%|██████████| 871/871 [00:12<00:00, 67.76it/s]


epoch:009 train_loss=0.57698 train_acc=0.72004 val_acc=0.65279 test_acc=0.65437 time=326.12413 


100%|██████████| 6966/6966 [05:00<00:00, 23.15it/s]
100%|██████████| 871/871 [00:14<00:00, 61.56it/s]
100%|██████████| 871/871 [00:13<00:00, 64.19it/s]


epoch:010 train_loss=0.57712 train_acc=0.71885 val_acc=0.49882 test_acc=0.50004 time=328.65580 


100%|██████████| 6966/6966 [05:26<00:00, 21.32it/s]
100%|██████████| 871/871 [00:15<00:00, 57.64it/s]
100%|██████████| 871/871 [00:12<00:00, 67.70it/s]


epoch:011 train_loss=0.57471 train_acc=0.72098 val_acc=0.50111 test_acc=0.50312 time=354.75698 


100%|██████████| 6966/6966 [04:52<00:00, 23.78it/s]
100%|██████████| 871/871 [00:16<00:00, 52.63it/s]
100%|██████████| 871/871 [00:13<00:00, 65.21it/s]


epoch:012 train_loss=0.57398 train_acc=0.72260 val_acc=0.72012 test_acc=0.71366 time=322.90349 


100%|██████████| 6966/6966 [05:01<00:00, 23.07it/s]
100%|██████████| 871/871 [00:16<00:00, 51.82it/s]
100%|██████████| 871/871 [00:15<00:00, 56.11it/s]


epoch:013 train_loss=0.57331 train_acc=0.72415 val_acc=0.52990 test_acc=0.52918 time=334.31725 


100%|██████████| 6966/6966 [05:04<00:00, 22.89it/s]
100%|██████████| 871/871 [00:13<00:00, 62.92it/s]
100%|██████████| 871/871 [00:12<00:00, 67.35it/s]


epoch:014 train_loss=0.57279 train_acc=0.72409 val_acc=0.65860 test_acc=0.65250 time=331.11558 


100%|██████████| 6966/6966 [05:09<00:00, 22.53it/s]
100%|██████████| 871/871 [00:13<00:00, 64.08it/s]
100%|██████████| 871/871 [00:12<00:00, 67.10it/s]


epoch:015 train_loss=0.57351 train_acc=0.72379 val_acc=0.65846 test_acc=0.65566 time=335.80770 


100%|██████████| 6966/6966 [05:03<00:00, 22.94it/s]
100%|██████████| 871/871 [00:15<00:00, 56.59it/s]
100%|██████████| 871/871 [00:13<00:00, 62.65it/s]


epoch:016 train_loss=0.57289 train_acc=0.72361 val_acc=0.57849 test_acc=0.58029 time=332.97676 


100%|██████████| 6966/6966 [04:56<00:00, 23.49it/s]
100%|██████████| 871/871 [00:14<00:00, 59.58it/s]
100%|██████████| 871/871 [00:14<00:00, 61.69it/s]


epoch:017 train_loss=0.57238 train_acc=0.72470 val_acc=0.59637 test_acc=0.59364 time=325.33168 


100%|██████████| 6966/6966 [05:06<00:00, 22.73it/s]
100%|██████████| 871/871 [00:14<00:00, 61.61it/s]
100%|██████████| 871/871 [00:14<00:00, 60.73it/s]


epoch:018 train_loss=0.57258 train_acc=0.72361 val_acc=0.65903 test_acc=0.65135 time=335.00768 


100%|██████████| 6966/6966 [05:13<00:00, 22.21it/s]
100%|██████████| 871/871 [00:15<00:00, 57.01it/s]
100%|██████████| 871/871 [00:12<00:00, 67.58it/s]


epoch:019 train_loss=0.57212 train_acc=0.72478 val_acc=0.49896 test_acc=0.50047 time=341.79981 


100%|██████████| 6966/6966 [04:53<00:00, 23.70it/s]
100%|██████████| 871/871 [00:15<00:00, 57.52it/s]
100%|██████████| 871/871 [00:15<00:00, 54.90it/s]


epoch:020 train_loss=0.57170 train_acc=0.72542 val_acc=0.63635 test_acc=0.63262 time=324.95634 


100%|██████████| 6966/6966 [04:55<00:00, 23.57it/s]
100%|██████████| 871/871 [00:13<00:00, 64.42it/s]
100%|██████████| 871/871 [00:14<00:00, 59.68it/s]


epoch:021 train_loss=0.56978 train_acc=0.72811 val_acc=0.66298 test_acc=0.65911 time=323.71458 


100%|██████████| 6966/6966 [05:09<00:00, 22.52it/s]
100%|██████████| 871/871 [00:14<00:00, 59.15it/s]
100%|██████████| 871/871 [00:15<00:00, 56.37it/s]


epoch:022 train_loss=0.56917 train_acc=0.72908 val_acc=0.72192 test_acc=0.71689 time=339.45887 


100%|██████████| 6966/6966 [05:04<00:00, 22.88it/s]
100%|██████████| 871/871 [00:15<00:00, 54.77it/s]
100%|██████████| 871/871 [00:16<00:00, 53.73it/s]


epoch:023 train_loss=0.56871 train_acc=0.72909 val_acc=0.60599 test_acc=0.59622 time=336.60309 


100%|██████████| 6966/6966 [05:06<00:00, 22.75it/s]
100%|██████████| 871/871 [00:14<00:00, 60.84it/s]
100%|██████████| 871/871 [00:14<00:00, 60.43it/s]


epoch:024 train_loss=0.56845 train_acc=0.72868 val_acc=0.70038 test_acc=0.69098 time=334.89687 


100%|██████████| 6966/6966 [05:06<00:00, 22.73it/s]
100%|██████████| 871/871 [00:14<00:00, 59.84it/s]
100%|██████████| 871/871 [00:15<00:00, 55.48it/s]


epoch:025 train_loss=0.56870 train_acc=0.72880 val_acc=0.73254 test_acc=0.72615 time=336.67604 


100%|██████████| 6966/6966 [04:55<00:00, 23.56it/s]
100%|██████████| 871/871 [00:14<00:00, 60.91it/s]
100%|██████████| 871/871 [00:12<00:00, 67.31it/s]


epoch:026 train_loss=0.56832 train_acc=0.72901 val_acc=0.72701 test_acc=0.71825 time=322.87162 


100%|██████████| 6966/6966 [04:57<00:00, 23.42it/s]
100%|██████████| 871/871 [00:16<00:00, 52.50it/s]
100%|██████████| 871/871 [00:14<00:00, 60.59it/s]


epoch:027 train_loss=0.56785 train_acc=0.73005 val_acc=0.72586 test_acc=0.71998 time=328.41092 


100%|██████████| 6966/6966 [04:52<00:00, 23.83it/s]
100%|██████████| 871/871 [00:13<00:00, 64.03it/s]
100%|██████████| 871/871 [00:13<00:00, 64.95it/s]


epoch:028 train_loss=0.56820 train_acc=0.72956 val_acc=0.60261 test_acc=0.59874 time=319.35844 


100%|██████████| 6966/6966 [05:03<00:00, 22.96it/s]
100%|██████████| 871/871 [00:15<00:00, 54.70it/s]
100%|██████████| 871/871 [00:12<00:00, 68.27it/s]


epoch:029 train_loss=0.56819 train_acc=0.72906 val_acc=0.73168 test_acc=0.72615 time=332.05788 


100%|██████████| 6966/6966 [05:05<00:00, 22.81it/s]
100%|██████████| 871/871 [00:16<00:00, 53.39it/s]
100%|██████████| 871/871 [00:16<00:00, 52.66it/s]


epoch:030 train_loss=0.56742 train_acc=0.73035 val_acc=0.66083 test_acc=0.65753 time=338.20588 


100%|██████████| 6966/6966 [04:52<00:00, 23.78it/s]
100%|██████████| 871/871 [00:14<00:00, 61.46it/s]
100%|██████████| 871/871 [00:12<00:00, 67.96it/s]


epoch:031 train_loss=0.56603 train_acc=0.73189 val_acc=0.70289 test_acc=0.69421 time=319.97563 


100%|██████████| 6966/6966 [05:06<00:00, 22.70it/s]
100%|██████████| 871/871 [00:16<00:00, 54.12it/s]
100%|██████████| 871/871 [00:12<00:00, 67.49it/s]


epoch:032 train_loss=0.56577 train_acc=0.73187 val_acc=0.70749 test_acc=0.69772 time=335.90450 


100%|██████████| 6966/6966 [04:56<00:00, 23.53it/s]
100%|██████████| 871/871 [00:14<00:00, 59.43it/s]
100%|██████████| 871/871 [00:16<00:00, 53.01it/s]


epoch:033 train_loss=0.56553 train_acc=0.73229 val_acc=0.71825 test_acc=0.71100 time=327.19108 


100%|██████████| 6966/6966 [04:54<00:00, 23.63it/s]
100%|██████████| 871/871 [00:13<00:00, 64.66it/s]
100%|██████████| 871/871 [00:13<00:00, 63.81it/s]


epoch:034 train_loss=0.56573 train_acc=0.73192 val_acc=0.67676 test_acc=0.67691 time=321.87838 


100%|██████████| 6966/6966 [05:00<00:00, 23.17it/s]
100%|██████████| 871/871 [00:16<00:00, 53.12it/s]
100%|██████████| 871/871 [00:13<00:00, 65.52it/s]


epoch:035 train_loss=0.56591 train_acc=0.73231 val_acc=0.73426 test_acc=0.72823 time=330.28742 


100%|██████████| 6966/6966 [04:54<00:00, 23.62it/s]
100%|██████████| 871/871 [00:15<00:00, 57.77it/s]
100%|██████████| 871/871 [00:16<00:00, 52.98it/s]


epoch:036 train_loss=0.56564 train_acc=0.73239 val_acc=0.72443 test_acc=0.71524 time=326.39231 


100%|██████████| 6966/6966 [04:53<00:00, 23.72it/s]
100%|██████████| 871/871 [00:16<00:00, 51.83it/s]
100%|██████████| 871/871 [00:16<00:00, 52.53it/s]


epoch:037 train_loss=0.56522 train_acc=0.73376 val_acc=0.71567 test_acc=0.71036 time=327.02683 


100%|██████████| 6966/6966 [04:57<00:00, 23.45it/s]
100%|██████████| 871/871 [00:15<00:00, 54.71it/s]
100%|██████████| 871/871 [00:14<00:00, 60.18it/s]


epoch:038 train_loss=0.56499 train_acc=0.73434 val_acc=0.71567 test_acc=0.70907 time=327.43487 


100%|██████████| 6966/6966 [05:02<00:00, 23.05it/s]
100%|██████████| 871/871 [00:16<00:00, 52.10it/s]
100%|██████████| 871/871 [00:13<00:00, 65.41it/s]


epoch:039 train_loss=0.56520 train_acc=0.73297 val_acc=0.72292 test_acc=0.71581 time=332.31679 


100%|██████████| 6966/6966 [04:50<00:00, 23.96it/s]
100%|██████████| 871/871 [00:13<00:00, 62.86it/s]
100%|██████████| 871/871 [00:12<00:00, 67.76it/s]


epoch:040 train_loss=0.56512 train_acc=0.73315 val_acc=0.72773 test_acc=0.72177 time=317.40601 


100%|██████████| 6966/6966 [04:59<00:00, 23.27it/s]
100%|██████████| 871/871 [00:14<00:00, 61.55it/s]
100%|██████████| 871/871 [00:12<00:00, 68.26it/s]


epoch:041 train_loss=0.56401 train_acc=0.73393 val_acc=0.73814 test_acc=0.73139 time=326.23573 


100%|██████████| 6966/6966 [05:06<00:00, 22.69it/s]
100%|██████████| 871/871 [00:16<00:00, 52.81it/s]
100%|██████████| 871/871 [00:15<00:00, 56.60it/s]


epoch:042 train_loss=0.56432 train_acc=0.73359 val_acc=0.73699 test_acc=0.73053 time=338.83885 


100%|██████████| 6966/6966 [05:05<00:00, 22.84it/s]
100%|██████████| 871/871 [00:13<00:00, 63.35it/s]
100%|██████████| 871/871 [00:14<00:00, 61.47it/s]


epoch:043 train_loss=0.56393 train_acc=0.73395 val_acc=0.73232 test_acc=0.72816 time=332.96728 


100%|██████████| 6966/6966 [05:00<00:00, 23.14it/s]
100%|██████████| 871/871 [00:13<00:00, 63.90it/s]
100%|██████████| 871/871 [00:13<00:00, 66.07it/s]


epoch:044 train_loss=0.56419 train_acc=0.73425 val_acc=0.68753 test_acc=0.68301 time=327.79741 


100%|██████████| 6966/6966 [05:12<00:00, 22.26it/s]
100%|██████████| 871/871 [00:13<00:00, 63.50it/s]
100%|██████████| 871/871 [00:15<00:00, 56.29it/s]


epoch:045 train_loss=0.56379 train_acc=0.73412 val_acc=0.71825 test_acc=0.71388 time=342.19829 


100%|██████████| 6966/6966 [05:00<00:00, 23.17it/s]
100%|██████████| 871/871 [00:13<00:00, 62.80it/s]
100%|██████████| 871/871 [00:13<00:00, 66.46it/s]


epoch:046 train_loss=0.56408 train_acc=0.73429 val_acc=0.69327 test_acc=0.68811 time=327.63903 


100%|██████████| 6966/6966 [05:02<00:00, 22.99it/s]
100%|██████████| 871/871 [00:16<00:00, 54.43it/s]
100%|██████████| 871/871 [00:12<00:00, 67.90it/s]


epoch:047 train_loss=0.56346 train_acc=0.73536 val_acc=0.71847 test_acc=0.71388 time=331.82744 


100%|██████████| 6966/6966 [05:00<00:00, 23.16it/s]
100%|██████████| 871/871 [00:14<00:00, 59.78it/s]
100%|██████████| 871/871 [00:14<00:00, 59.21it/s]


epoch:048 train_loss=0.56366 train_acc=0.73486 val_acc=0.73455 test_acc=0.72931 time=330.02696 


100%|██████████| 6966/6966 [05:02<00:00, 23.04it/s]
100%|██████████| 871/871 [00:13<00:00, 65.07it/s]
100%|██████████| 871/871 [00:13<00:00, 65.45it/s]


epoch:049 train_loss=0.56388 train_acc=0.73388 val_acc=0.73555 test_acc=0.73017 time=328.98708 


100%|██████████| 6966/6966 [05:01<00:00, 23.08it/s]
100%|██████████| 871/871 [00:18<00:00, 45.99it/s]
100%|██████████| 871/871 [00:19<00:00, 45.61it/s]


epoch:050 train_loss=0.56348 train_acc=0.73471 val_acc=0.72026 test_acc=0.71481 time=339.90290 


100%|██████████| 6966/6966 [05:04<00:00, 22.87it/s]
100%|██████████| 871/871 [00:14<00:00, 59.37it/s]
100%|██████████| 871/871 [00:14<00:00, 59.01it/s]


epoch:051 train_loss=0.56313 train_acc=0.73545 val_acc=0.73713 test_acc=0.73225 time=334.01418 


  7%|▋         | 489/6966 [00:20<04:41, 23.01it/s]