In [1]:
from __future__ import print_function
import argparse
import inspect
import os
import pickle
import random
import shutil
import sys
import time
from collections import OrderedDict
import traceback
from sklearn.metrics import confusion_matrix
import csv
import numpy as np
import glob
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
from tensorboardX import SummaryWriter
from tqdm import tqdm
from feeders.feeder_ntu import Feeder
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_add_pool
import inspect
from typing import Any, Dict, Optional

import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.typing import Adj
from torch_geometric.utils import to_dense_batch

from mamba_ssm import Mamba
from torch_geometric.utils import degree, sort_edge_index

import wandb

In [2]:
def init_seed(seed):
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    # torch.backends.cudnn.enabled = True
    # training speed is too slow if set to True
    torch.backends.cudnn.deterministic = True

    # on cuda 11 cudnn8, the default algorithm is very slow
    # unlike on cuda 10, the default works well
    torch.backends.cudnn.benchmark = False

In [3]:
def load_data(self, phase='train'):
    if phase=='train':
        data_loader = torch.utils.data.DataLoader(
            dataset=Feeder('data/ntu/NTU60_CS.npz', split='train', p_interval =[0.5, 1], window_size=64),
            batch_size=32,
            num_workers=0,
            worker_init_fn=init_seed)
    else:
        data_loader = torch.utils.data.DataLoader(
            dataset=Feeder('data/ntu/NTU60_CS.npz', split='test', p_interval =[0.95], window_size=64),
            batch_size=32,
            num_workers=0,
            worker_init_fn=init_seed)
    return data_loader

In [4]:
class GPSConv(torch.nn.Module):
    def __init__(
        self,
        channels: int,
        conv: Optional[MessagePassing],
        heads: int = 1,
        dropout: float = 0.0,
        attn_dropout: float = 0.0,
        act: str = 'relu',
        att_type: str = 'transformer',
        order_by_degree: bool = False,
        shuffle_ind: int = 0,
        d_state: int = 16,
        d_conv: int = 4,
        act_kwargs: Optional[Dict[str, Any]] = None,
        norm: Optional[str] = 'batch_norm',
        norm_kwargs: Optional[Dict[str, Any]] = None,
    ):
        super().__init__()

        self.channels = channels
        self.conv = conv
        self.heads = heads
        self.dropout = dropout
        self.att_type = att_type
        self.shuffle_ind = shuffle_ind
        self.order_by_degree = order_by_degree
        
        assert (self.order_by_degree==True and self.shuffle_ind==0) or (self.order_by_degree==False), f'order_by_degree={self.order_by_degree} and shuffle_ind={self.shuffle_ind}'
        
        if self.att_type == 'transformer':
            self.attn = torch.nn.MultiheadAttention(
                channels,
                heads,
                dropout=attn_dropout,
                batch_first=True,
            )
        if self.att_type == 'mamba':
            self.self_attn = Mamba(
                d_model=channels,
                d_state=d_state,
                d_conv=d_conv,
                expand=1
            )
            
        self.mlp = Sequential(
            Linear(channels, channels * 2),
            activation_resolver(act, **(act_kwargs or {})),
            Dropout(dropout),
            Linear(channels * 2, channels),
            Dropout(dropout),
        )

        norm_kwargs = norm_kwargs or {}
        self.norm1 = normalization_resolver(norm, channels, **norm_kwargs)
        self.norm2 = normalization_resolver(norm, channels, **norm_kwargs)
        self.norm3 = normalization_resolver(norm, channels, **norm_kwargs)

        self.norm_with_batch = False
        if self.norm1 is not None:
            signature = inspect.signature(self.norm1.forward)
            self.norm_with_batch = 'batch' in signature.parameters

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        if self.conv is not None:
            self.conv.reset_parameters()
        self.attn._reset_parameters()
        reset(self.mlp)
        if self.norm1 is not None:
            self.norm1.reset_parameters()
        if self.norm2 is not None:
            self.norm2.reset_parameters()
        if self.norm3 is not None:
            self.norm3.reset_parameters()

    def forward(
        self,
        x: Tensor,
        edge_index: Adj,
        batch: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Tensor:
        r"""Runs the forward pass of the module."""
        hs = []
        if self.conv is not None:  # Local MPNN.
            h = self.conv(x, edge_index, **kwargs)
            h = F.dropout(h, p=self.dropout, training=self.training)
            h = h + x
            if self.norm1 is not None:
                if self.norm_with_batch:
                    h = self.norm1(h, batch=batch)
                else:
                    h = self.norm1(h)
            hs.append(h)

        ### Global attention transformer-style model.
        if self.att_type == 'transformer':
            h, mask = to_dense_batch(x, batch)
            h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False)
            h = h[mask]
            
        if self.att_type == 'mamba':
            
            if self.order_by_degree:
                deg = degree(edge_index[0], x.shape[0]).to(torch.long)
                order_tensor = torch.stack([batch, deg], 1).T
                _, x = sort_edge_index(order_tensor, edge_attr=x)
                
            if self.shuffle_ind == 0:
                h, mask = to_dense_batch(x, batch)
                h = self.self_attn(h)[mask]
            else:
                mamba_arr = []
                for _ in range(self.shuffle_ind):
                    h_ind_perm = permute_within_batch(x, batch)
                    h_i, mask = to_dense_batch(x[h_ind_perm], batch)
                    h_i = self.self_attn(h_i)[mask][h_ind_perm]
                    mamba_arr.append(h_i)
                h = sum(mamba_arr) / self.shuffle_ind
        ###
        
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = h + x  # Residual connection.
        if self.norm2 is not None:
            if self.norm_with_batch:
                h = self.norm2(h, batch=batch)
            else:
                h = self.norm2(h)
        hs.append(h)

        out = sum(hs)  # Combine local and global outputs.

        out = out + self.mlp(out)
        if self.norm3 is not None:
            if self.norm_with_batch:
                out = self.norm3(out, batch=batch)
            else:
                out = self.norm3(out)

        return out

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.channels}, '
                f'conv={self.conv}, heads={self.heads})')

In [5]:
class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1):
        super(TemporalConv, self).__init__()
        # adjust padding for kernel size so that it will be equal to out_channe;s
        pad = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            # kernel_size, 1 so that we look only for spatial
            # 3 time steps windows of only 1 node
            kernel_size=(kernel_size, 1),
            padding=(pad, 0),
            stride=(stride, 1),
            dilation=(dilation, 1),
        )

        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class MultiScale_TemporalConv(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size=3,
                 stride=1,
                 dilations=[1, 2, 3, 4],
                 residual=False,
                 residual_kernel_size=1):

        super().__init__()
        assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'

        # Multiple branches of temporal convolution
        # + 2 because we have additional 2 branches for max and 1x1 branch
        self.num_branches = len(dilations) + 2
        branch_channels = out_channels // self.num_branches
        if type(kernel_size) == list:
            assert len(kernel_size) == len(dilations)
        else:
            kernel_size = [kernel_size] * len(dilations)
        # Temporal Convolution branches
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    branch_channels,
                    kernel_size=1,
                    padding=0
                ),
                nn.BatchNorm2d(branch_channels),
                nn.ReLU(inplace=True),
                TemporalConv(
                    branch_channels,
                    branch_channels,
                    kernel_size=ks,
                    stride=stride,
                    dilation=dilation
                ),
            )
            # checking for each dilation so that we will look for global context
            for ks, dilation in zip(kernel_size, dilations)
        ])

        # Additional Max & 1x1 branch
        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(branch_channels),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)),
            nn.BatchNorm2d(branch_channels)  # 为什么还要加bn
        ))

        self.branches.append(nn.Sequential(
            nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride, 1)),
            nn.BatchNorm2d(branch_channels)
        ))

        # Residual connection
        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride)
        # print(len(self.branches))

    def forward(self, x):
        x = x.permute(0,3,1,2)
        # Input dim: (N,C,T,V)
        res = self.residual(x)
        branch_outs = []
        for tempconv in self.branches:
            out = tempconv(x)
            branch_outs.append(out)

        out = torch.cat(branch_outs, dim=1)
        out += res
        out = out.permute(0, 2, 3, 1)
        return out

In [6]:
class GraphEnt(nn.Module):
    def __init__(self, dim_in, dim, num_points=25):
        super().__init__()
        self.dim = dim
        self.dim_in = dim_in
        nn1 = Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(),
                nn.Linear(dim_in, dim_in),
            )
        self.conv = GPSConv(dim_in, GINEConv(nn1), heads=4, attn_dropout=0.5,
                               att_type='mamba',
                               shuffle_ind=0,
                               order_by_degree=True,
                               d_state=16, d_conv=4)
        self.lin = nn.Linear(dim_in, dim)
        self.edge_emb = nn.Embedding(num_points*2, dim_in)
        self_link = [(i, i) for i in range(25)]
        inward_ori_index = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6),
                    (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), (13, 1),
                    (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18),
                    (20, 19), (22, 8), (23, 8), (24, 12), (25, 12)]
        inward = [(i - 1, j - 1) for (i, j) in inward_ori_index]
        outward = [(j, i) for (i, j) in inward]
        self.neighbor = inward + outward
        global device
        self.edge_index = self.convert_neighbor_to_edge_index(self.neighbor, device)

    def forward(self, x, dims):
        # N*M, T, V, C
        N, C, T, V, M = dims
        edge_attr = torch.ones(self.edge_index.size(1), dtype=torch.int, device=x.device)
        edge_attr = self.edge_emb(edge_attr)
        batch = self.create_batch_array(N, M, T, V, x.device)
        '''N * M - number of video sequences with person number
        C - number of channels (3d position of points)
        T - number of frames
        V - number of skeleton points (25)
        order is: N*M, T, V, C
        '''
        _, T, V, C = x.size()
        x = x.view(-1, C)
        x = self.conv(x, self.edge_index, batch, edge_attr=edge_attr)
        x = x.view(N*M, T, V, C)
        if self.dim != self.dim_in:
            x = self.lin(x)
        return x


    def convert_neighbor_to_edge_index(self, neighbor, device):
        indices = torch.tensor(neighbor, dtype=torch.int64, device=device).t()
        return indices

    def create_batch_array(self, N, M, T, V, device):
        # Total number of unique indices
        num_indices = N * M 
        # Create a tensor of shape (num_indices, V) where each row contains the same index
        batch = torch.arange(num_indices, dtype=torch.int64, device=device).repeat_interleave(V*T)
        return batch

In [7]:
class GraphTCN(nn.Module):
    def __init__(self, dim_in, dim):
        super().__init__()
        self.dim_in = dim_in
        self.dim = dim
        self.conv = GraphEnt(dim_in, dim)
        if dim_in != dim:
            self.lin = nn.Linear(dim_in, dim)
        self.tcn = MultiScale_TemporalConv(dim, dim, kernel_size=5, stride=1,
                                            dilations=[1,2],
                                            # residual=True has worse performance in the end
                                            residual=False)
        self.act = nn.ReLU()
        self.norm = nn.LayerNorm(dim)


    def forward(self, x, dims):
        if self.dim_in != self.dim:
            x = self.act(self.tcn(self.conv(x, dims))+self.lin(x))
        else:
            x = self.act(self.tcn(self.conv(self.norm(x), dims)) + x)
        return x

In [8]:
class GraphModel(nn.Module):
    def __init__(self, dim_in, dim):
        super().__init__()
        self.convs = nn.ModuleList()
        for _ in range(9):
            conv = GraphTCN(dim, dim)
            self.convs.append(conv)
        self.graph_tcn = GraphTCN(dim_in, dim)
        self.fc1 = nn.Linear(dim, 60)
        self.mlp = nn.Sequential(
            self.fc1,
        )
    
    def forward(self, x):
        N, C, T, V, M = x.size()
        dims = x.size()
        x = x.permute(0, 4, 2, 3, 1).contiguous().view(N, M, V, C, T).contiguous().view(N * M, T, V, C)
        # N*M, T, V, C
        x = self.graph_tcn(x, dims)
        for conv in self.convs:
            x = conv(x, dims)
        '''
        order is: N*M, T, V, C
        '''
        x = x.permute(0,3,1,2)
        _, C, T, V = x.size()
        x = x.view(N, M, C, -1)
        # order is: N, M, C, T*V
        x = x.mean(3).mean(1)
        x = self.mlp(x)
        #print("RESULT", x)
        return x

In [9]:
def train(epoch):
    model.train()
    loss_value = []
    acc_value = []
    train_loader = load_data('train')
    process = tqdm(train_loader, ncols=80)
    for batch_idx, (data, label, index) in enumerate(process):
        with torch.no_grad():
            data = data.float().to(device)
            label = label.long().to(device)
        with torch.amp.autocast('cuda', enabled=use_amp):
            out = model(data)
            loss = lossC(out, target=label)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loss_value.append(loss.data.item())
        value, predict_label = torch.max(out.data, 1)
        acc = torch.mean((predict_label == label.data).float())
        acc_value.append(acc.data.item())
    if epoch % 10 == 0:
        state_dict = model.state_dict()
        weights = OrderedDict([[k.split('module.')[-1], v.cpu()] for k, v in state_dict.items()])
        torch.save(weights, 'mainruns/Model-' + str(epoch) + '.pt')
    return np.nanmean(loss_value), np.nanmean(acc_value)*100

In [10]:
def test():
    model.eval()
    test_loader = load_data('test')
    loss_value = []
    score_frag = []
    process = tqdm(test_loader, ncols=80)
    for batch_idx, (data, label, index) in enumerate(process):
        with torch.no_grad():
            data = data.float().to(device)
            label = label.long().to(device)
            out = model(data)
            loss = lossC(out, target=label)
            #print(out.data.cpu().numpy())
            value, predict_label = torch.max(out.data, 1)
            score_frag.append(out.data.cpu().numpy())
            loss_value.append(loss.data.item())
    score = np.concatenate(score_frag)
    loss = np.nanmean(loss_value)
    best_acc = 0
    for k in [1, 5]:
        acc = test_loader.dataset.top_k(score, k) * 100
        if acc > best_acc:
            best_acc = acc
        print('\tTop{}: {:.2f}%'.format(
            k, acc))
    return best_acc, loss

In [11]:
init_seed(2)
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')

lr = 0.001
epochs = 50
out_channels = 20

model = GraphModel(3, out_channels).to(device)
'''devices = [3,2]
device = devices[0]
model = nn.DataParallel(GraphModel(3, 1024).to(device),
                        device_ids=devices,
                        output_device=device)'''
optimizer = optim.SGD(
                model.parameters(),
                lr=lr)
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Parameters:", count_parameters(model))
lossC = nn.CrossEntropyLoss().to(device)
use_amp = True
scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3,
                              min_lr=0.00001)


wandb.login()
wandb.init(
    # set the wandb project where this run will be logged
    project="spanchsan-hong-kong-polytechnic-university",

    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "architecture": "Mamba",
    "dataset": "NTU RGB+D 60",
    "epochs": epochs,
    "out_channels": out_channels,
    "batch_size": 32
    }
)

for epoch in range(1, epochs+1):
    trn = train(epoch)
    tst = test()
    scheduler.step(tst[1])
    wandb.log({"train_loss": trn[0], "train_acc": trn[1], "test_acc": tst[0], "test_loss": tst[1]})
    print(f'Epoch: {epoch:02d}, Loss: {trn[0]:.4f}, Train Acc: {trn[1]:.4f}, Accuracy: {tst[0]:.4f}, Test_Loss: {tst[1]:.4f}')

Parameters: 67535


[34m[1mwandb[0m: Currently logged in as: [33mspanchsan[0m ([33mspanchsan-hong-kong-polytechnic-university[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


100%|███████████████████████████████████████| 1253/1253 [02:25<00:00,  8.63it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.75it/s]


	Top1: 1.80%
	Top5: 11.14%
Epoch: 01, Loss: 4.2597, Train Acc: 2.2471, Accuracy: 11.1372, Test_Loss: 4.1666


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.72it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.72it/s]


	Top1: 2.44%
	Top5: 11.53%
Epoch: 02, Loss: 4.0217, Train Acc: 4.7745, Accuracy: 11.5338, Test_Loss: 4.1224


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.64it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.74it/s]


	Top1: 4.05%
	Top5: 15.79%
Epoch: 03, Loss: 3.9390, Train Acc: 5.9766, Accuracy: 15.7916, Test_Loss: 4.0191


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.73it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.64it/s]


	Top1: 5.11%
	Top5: 20.44%
Epoch: 04, Loss: 3.8600, Train Acc: 6.6579, Accuracy: 20.4435, Test_Loss: 3.8980


100%|███████████████████████████████████████| 1253/1253 [02:25<00:00,  8.61it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.67it/s]


	Top1: 5.54%
	Top5: 22.02%
Epoch: 05, Loss: 3.7769, Train Acc: 7.6730, Accuracy: 22.0224, Test_Loss: 3.8368


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.66it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.85it/s]


	Top1: 6.98%
	Top5: 25.00%
Epoch: 06, Loss: 3.6948, Train Acc: 8.7060, Accuracy: 25.0031, Test_Loss: 3.7275


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.72it/s]
100%|███████████████████████████████████████| 1253/1253 [01:02<00:00, 20.02it/s]


	Top1: 7.49%
	Top5: 26.94%
Epoch: 07, Loss: 3.6170, Train Acc: 10.0951, Accuracy: 26.9412, Test_Loss: 3.6589


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.74it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.79it/s]


	Top1: 7.89%
	Top5: 28.32%
Epoch: 08, Loss: 3.5397, Train Acc: 11.4349, Accuracy: 28.3231, Test_Loss: 3.6043


100%|███████████████████████████████████████| 1253/1253 [02:21<00:00,  8.83it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.64it/s]


	Top1: 8.38%
	Top5: 30.39%
Epoch: 09, Loss: 3.4597, Train Acc: 12.3531, Accuracy: 30.3859, Test_Loss: 3.5298


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.70it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.70it/s]


	Top1: 9.24%
	Top5: 32.36%
Epoch: 10, Loss: 3.3818, Train Acc: 13.2072, Accuracy: 32.3589, Test_Loss: 3.4733


100%|███████████████████████████████████████| 1253/1253 [02:22<00:00,  8.80it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.85it/s]


	Top1: 10.36%
	Top5: 35.06%
Epoch: 11, Loss: 3.3070, Train Acc: 14.4068, Accuracy: 35.0577, Test_Loss: 3.4144


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.76it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.75it/s]


	Top1: 11.17%
	Top5: 37.36%
Epoch: 12, Loss: 3.2385, Train Acc: 15.5176, Accuracy: 37.3650, Test_Loss: 3.3637


100%|███████████████████████████████████████| 1253/1253 [02:21<00:00,  8.86it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.76it/s]


	Top1: 11.95%
	Top5: 39.15%
Epoch: 13, Loss: 3.1752, Train Acc: 16.5417, Accuracy: 39.1484, Test_Loss: 3.3192


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.72it/s]
100%|███████████████████████████████████████| 1253/1253 [01:01<00:00, 20.45it/s]


	Top1: 12.39%
	Top5: 40.56%
Epoch: 14, Loss: 3.1181, Train Acc: 17.8037, Accuracy: 40.5627, Test_Loss: 3.2862


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.76it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.61it/s]


	Top1: 13.58%
	Top5: 42.84%
Epoch: 15, Loss: 3.0654, Train Acc: 18.8167, Accuracy: 42.8400, Test_Loss: 3.2399


100%|███████████████████████████████████████| 1253/1253 [02:25<00:00,  8.60it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.89it/s]


	Top1: 14.63%
	Top5: 44.36%
Epoch: 16, Loss: 3.0167, Train Acc: 19.5998, Accuracy: 44.3591, Test_Loss: 3.1917


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.73it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.81it/s]


	Top1: 15.40%
	Top5: 45.75%
Epoch: 17, Loss: 2.9703, Train Acc: 20.5455, Accuracy: 45.7509, Test_Loss: 3.1556


100%|███████████████████████████████████████| 1253/1253 [02:22<00:00,  8.78it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.94it/s]


	Top1: 15.81%
	Top5: 46.55%
Epoch: 18, Loss: 2.9272, Train Acc: 21.4318, Accuracy: 46.5466, Test_Loss: 3.1343


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.75it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.86it/s]


	Top1: 16.71%
	Top5: 47.93%
Epoch: 19, Loss: 2.8881, Train Acc: 22.2589, Accuracy: 47.9285, Test_Loss: 3.0886


100%|███████████████████████████████████████| 1253/1253 [02:25<00:00,  8.62it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.89it/s]


	Top1: 16.96%
	Top5: 49.03%
Epoch: 20, Loss: 2.8477, Train Acc: 23.1148, Accuracy: 49.0260, Test_Loss: 3.0625


100%|███████████████████████████████████████| 1253/1253 [02:28<00:00,  8.44it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.82it/s]


	Top1: 17.56%
	Top5: 49.95%
Epoch: 21, Loss: 2.8099, Train Acc: 23.8401, Accuracy: 49.9464, Test_Loss: 3.0403


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.67it/s]
100%|███████████████████████████████████████| 1253/1253 [01:01<00:00, 20.54it/s]


	Top1: 18.15%
	Top5: 51.49%
Epoch: 22, Loss: 2.7726, Train Acc: 24.8008, Accuracy: 51.4854, Test_Loss: 2.9967


100%|███████████████████████████████████████| 1253/1253 [02:22<00:00,  8.80it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.65it/s]


	Top1: 18.39%
	Top5: 52.40%
Epoch: 23, Loss: 2.7370, Train Acc: 25.5340, Accuracy: 52.4033, Test_Loss: 2.9673


100%|███████████████████████████████████████| 1253/1253 [02:22<00:00,  8.78it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.84it/s]


	Top1: 18.89%
	Top5: 53.42%
Epoch: 24, Loss: 2.7032, Train Acc: 26.2901, Accuracy: 53.4210, Test_Loss: 2.9398


100%|███████████████████████████████████████| 1253/1253 [02:22<00:00,  8.80it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.71it/s]


	Top1: 19.43%
	Top5: 54.87%
Epoch: 25, Loss: 2.6677, Train Acc: 27.1655, Accuracy: 54.8652, Test_Loss: 2.8973


100%|███████████████████████████████████████| 1253/1253 [02:26<00:00,  8.57it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.81it/s]


	Top1: 20.39%
	Top5: 56.39%
Epoch: 26, Loss: 2.6344, Train Acc: 27.7317, Accuracy: 56.3867, Test_Loss: 2.8582


100%|███████████████████████████████████████| 1253/1253 [02:22<00:00,  8.79it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.67it/s]


	Top1: 20.94%
	Top5: 57.41%
Epoch: 27, Loss: 2.6036, Train Acc: 28.5103, Accuracy: 57.4119, Test_Loss: 2.8255


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.68it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.69it/s]


	Top1: 20.77%
	Top5: 57.58%
Epoch: 28, Loss: 2.5736, Train Acc: 29.1687, Accuracy: 57.5765, Test_Loss: 2.8160


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.68it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.66it/s]


	Top1: 21.53%
	Top5: 58.72%
Epoch: 29, Loss: 2.5459, Train Acc: 29.8546, Accuracy: 58.7239, Test_Loss: 2.7866


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.65it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.73it/s]


	Top1: 21.74%
	Top5: 58.76%
Epoch: 30, Loss: 2.5165, Train Acc: 30.5134, Accuracy: 58.7563, Test_Loss: 2.7752


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.68it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.74it/s]


	Top1: 22.51%
	Top5: 59.96%
Epoch: 31, Loss: 2.4895, Train Acc: 31.2387, Accuracy: 59.9611, Test_Loss: 2.7420


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.67it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.88it/s]


	Top1: 22.96%
	Top5: 60.87%
Epoch: 32, Loss: 2.4637, Train Acc: 31.7824, Accuracy: 60.8690, Test_Loss: 2.7113


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.70it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.81it/s]


	Top1: 23.68%
	Top5: 61.81%
Epoch: 33, Loss: 2.4406, Train Acc: 32.2701, Accuracy: 61.8119, Test_Loss: 2.6915


100%|███████████████████████████████████████| 1253/1253 [02:20<00:00,  8.91it/s]
100%|███████████████████████████████████████| 1253/1253 [01:02<00:00, 19.96it/s]


	Top1: 24.03%
	Top5: 62.12%
Epoch: 34, Loss: 2.4140, Train Acc: 32.8438, Accuracy: 62.1187, Test_Loss: 2.6857


100%|███████████████████████████████████████| 1253/1253 [02:25<00:00,  8.62it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.90it/s]


	Top1: 24.41%
	Top5: 62.86%
Epoch: 35, Loss: 2.3915, Train Acc: 33.4203, Accuracy: 62.8620, Test_Loss: 2.6666


100%|███████████████████████████████████████| 1253/1253 [02:22<00:00,  8.80it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.94it/s]


	Top1: 25.20%
	Top5: 63.84%
Epoch: 36, Loss: 2.3694, Train Acc: 33.9062, Accuracy: 63.8448, Test_Loss: 2.6409


100%|███████████████████████████████████████| 1253/1253 [02:25<00:00,  8.60it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.89it/s]


	Top1: 25.98%
	Top5: 64.65%
Epoch: 37, Loss: 2.3484, Train Acc: 34.2633, Accuracy: 64.6479, Test_Loss: 2.6197


100%|███████████████████████████████████████| 1253/1253 [02:27<00:00,  8.51it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.64it/s]


	Top1: 26.61%
	Top5: 65.59%
Epoch: 38, Loss: 2.3274, Train Acc: 34.8938, Accuracy: 65.5883, Test_Loss: 2.5855


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.75it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.95it/s]


	Top1: 26.94%
	Top5: 65.75%
Epoch: 39, Loss: 2.3064, Train Acc: 35.2036, Accuracy: 65.7479, Test_Loss: 2.5844


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.65it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.69it/s]


	Top1: 27.37%
	Top5: 66.65%
Epoch: 40, Loss: 2.2881, Train Acc: 35.7377, Accuracy: 66.6459, Test_Loss: 2.5591


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.68it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.68it/s]


	Top1: 28.05%
	Top5: 67.13%
Epoch: 41, Loss: 2.2684, Train Acc: 36.0570, Accuracy: 67.1348, Test_Loss: 2.5415


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.69it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.93it/s]


	Top1: 28.37%
	Top5: 67.56%
Epoch: 42, Loss: 2.2504, Train Acc: 36.5204, Accuracy: 67.5563, Test_Loss: 2.5284


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.75it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.89it/s]


	Top1: 28.53%
	Top5: 67.96%
Epoch: 43, Loss: 2.2336, Train Acc: 36.9269, Accuracy: 67.9579, Test_Loss: 2.5194


100%|███████████████████████████████████████| 1253/1253 [02:22<00:00,  8.77it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.67it/s]


	Top1: 28.98%
	Top5: 68.61%
Epoch: 44, Loss: 2.2169, Train Acc: 37.2772, Accuracy: 68.6064, Test_Loss: 2.4946


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.69it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.83it/s]


	Top1: 29.07%
	Top5: 68.45%
Epoch: 45, Loss: 2.2003, Train Acc: 37.6638, Accuracy: 68.4468, Test_Loss: 2.5075


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.72it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.81it/s]


	Top1: 29.27%
	Top5: 68.76%
Epoch: 46, Loss: 2.1885, Train Acc: 37.6817, Accuracy: 68.7611, Test_Loss: 2.4885


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.73it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.83it/s]


	Top1: 29.75%
	Top5: 69.56%
Epoch: 47, Loss: 2.1702, Train Acc: 38.1605, Accuracy: 69.5642, Test_Loss: 2.4637


100%|███████████████████████████████████████| 1253/1253 [02:24<00:00,  8.66it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.65it/s]


	Top1: 29.47%
	Top5: 69.24%
Epoch: 48, Loss: 2.1548, Train Acc: 38.3780, Accuracy: 69.2375, Test_Loss: 2.4839


100%|███████████████████████████████████████| 1253/1253 [02:23<00:00,  8.71it/s]
100%|███████████████████████████████████████| 1253/1253 [01:00<00:00, 20.86it/s]


	Top1: 30.05%
	Top5: 69.72%
Epoch: 49, Loss: 2.1409, Train Acc: 38.6240, Accuracy: 69.7214, Test_Loss: 2.4619


100%|███████████████████████████████████████| 1253/1253 [02:25<00:00,  8.63it/s]
100%|███████████████████████████████████████| 1253/1253 [00:59<00:00, 20.93it/s]


	Top1: 30.19%
	Top5: 69.95%
Epoch: 50, Loss: 2.1275, Train Acc: 39.0076, Accuracy: 69.9509, Test_Loss: 2.4593
