In [None]:
import numpy as np
import pickle
import torch
from torch.utils.data import Dataset
import sys
import torch.nn as nn
import torch.nn.functional as F
import time
import glob
import scipy.io
import os
import math

from blitz.modules import BayesianLinear, BayesianConv2d, BayesianLSTM
from blitz.utils import variational_estimator

num_joint = 20
max_frame = 125
input_feature = 6
num_feature = 16
hidden_size = 256
batch_size = 32
learning_rate = 0.001
momentum = 0.9
decay_rate = 0.9
decay_step = 100
epochs = 1500
device = 'cuda:0'
path = "UTD_AP/"

class UTDDataset(Dataset):
    def __init__(self, data_path):
        super(UTDDataset, self).__init__()
        self.data_path = data_path
        self.load_data()
        
    def load_data(self):
        path_pattern = self.data_path + '*.mat'
        files_list = glob.glob(path_pattern, recursive=True)
        self.data = torch.zeros((len(files_list),input_feature,max_frame,num_joint),dtype=torch.float32)
        self.labels = []
        self.num_frame = []
        for i,file_name in enumerate(files_list):
            a = os.path.basename(file_name).split('_')[0]
            mat = scipy.io.loadmat(file_name)['d_skel'].astype("float32")
            mat = mat.transpose((1,2,0)) # transpose to (C, #frame, #joint)
            mat -= np.expand_dims(mat[:,:,2],axis=2) # set center at spine of body
            
            aug_mat = np.concatenate((mat,mat),axis=0) # concat speed of xyz-axis
            aug_mat[3:,:,:] -= np.roll(mat,1,axis=1) # speed = x(t) - x(t-1)
            aug_mat[3:,0,:] = 0 # 1st frame, speed = 0
            
            #self.data.append(aug_mat)
            frame = aug_mat.shape[1]
            self.data[i,:,:frame] = torch.from_numpy(aug_mat)
            self.num_frame.append(frame)
            self.labels.append(int(a[1:])-1)
    
    def __getitem__(self, index):
        data = self.data[index]
        label = self.labels[index]
        f = self.num_frame[index]
        return data, label, f

    def __len__(self):
        return len(self.labels)
    
    def __iter__(self):
        return self
    
train_dataset = UTDDataset(data_path=path+'train/') # subject 1, 3, 5
valid_dataset = UTDDataset(data_path=path+'valid/') # subject 7
test_dataset = UTDDataset(data_path=path+'test/')   # subject 2, 4, 6, 8

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=len(test_dataset),shuffle=False)

In [None]:
validloader_iter = iter(valid_loader)
for i in range(100):
    try:
        data, label, f = next(validloader_iter)
    except StopIteration:
        validloader_iter = iter(valid_loader)
        data, label, f = next(validloader_iter)
    print(f.size())

In [None]:
"""
1. head; 
2. shoulder_center;
3. spine;
4. hip_center;
5. left_shoulder;
6. left_elbow;
7. left_wrist;
8. left_hand;
9. right_shoulder;
10. right_elbow;
11. right_wrist;
12. right_hand;
13. left_hip;
14. left_knee;
15. left_ankle;
16. left_foot;
17. right_hip;
18. right_knee;
19. right_ankle;
20. right_foot;
"""
inward_ori_index = [(1,2),(3,2),(4,3),(5,2),(6,5),(7,6),(8,7),(9,2),(10,9),(11,10),
                    (12,11),(13,4),(14,13),(15,14),(16,15),(17,4),(18,17),(19,18),(20,19)]
inward = [(i - 1, j - 1) for (i, j) in inward_ori_index]
outward = [(j, i) for (i, j) in inward]
five_key_point = [1,7,11,15,19]

def normalize(A):
    rowsum = torch.sum(A, 0)
    r_inv = torch.pow(rowsum, -0.5)
    r_mat_inv = torch.diag(r_inv).float()

    A_norm = torch.mm(r_mat_inv, A)
    A_norm = torch.mm(A_norm, r_mat_inv)

    return A_norm
    
def gen_adj():
    A = torch.zeros(3,num_joint,num_joint,dtype=torch.float)
    for (i,j) in inward:
        A[0,j,i] = 1
    for (i,j) in outward:
        A[1,j,i] = 1
    for i in five_key_point:
        for j in five_key_point:
            A[2,i,j] = 1
    for i in range(num_joint):
        A[:,i,i] = 1
    A[0] = normalize(A[0])
    A[1] = normalize(A[1])
    A[2] = normalize(A[2])
    return A

In [None]:
from torch.nn.parameter import Parameter
class GraphConvolution(nn.Module):
    def __init__(self, num_graph, in_feature, out_feature):
        super(GraphConvolution, self).__init__()
        self.num_graph = num_graph
        self.in_feature = in_feature
        self.out_feature = out_feature
        
        self.mask = nn.Parameter(torch.ones(num_graph, num_joint, num_joint))
        
        self.gcn_list = nn.ModuleList([
            BayesianConv2d(
                self.in_feature,
                self.out_feature,
                kernel_size=(1, 1)) for i in range(self.num_graph)
        ])
        
        self.bn = nn.BatchNorm2d(out_feature)
        self.act = nn.ReLU()

    def forward(self, adj, x):
        # x : B*f*T*20
        N, C, T, V = x.size()
        
        adj = adj * self.mask
        
        for i,a in enumerate(adj):
            xa = x.view(-1,V).mm(a).view(N,C,T,V)
            if i == 0:
                y = self.gcn_list[i](xa)
            else:
                y += self.gcn_list[i](xa)
                
        y = self.bn(y)
        
        return self.act(y)
    
class GCLayers(nn.Module):
    def __init__(self,num_feature,num_graph):
        super(GCLayers, self).__init__()
        self.num_feature = num_feature
        self.gc1 = GraphConvolution(num_graph,input_feature,num_feature)
        self.gc2 = GraphConvolution(num_graph,num_feature,num_feature)
        self.gc3 = GraphConvolution(num_graph,num_feature,num_feature)
        self.gc4 = GraphConvolution(num_graph,num_feature,num_feature)
        
    def forward(self, adj, x):
        # x : B*6*T*20
        output1 = self.gc1(adj,x)
        output2 = self.gc2(adj,output1)
        output3 = self.gc3(adj,output2) + output1
        output4 = self.gc4(adj,output3) + output2
        return output4

@variational_estimator
class GC_LSTM(nn.Module):
    def __init__(self, num_feature, hidden_size):
        super(GC_LSTM, self).__init__()
        self.adj = gen_adj().to(device)
        self.num_graph = self.adj.shape[0]
        self.num_feature = num_feature
        self.hidden_size = hidden_size
        self.output_feature = num_feature
        
        self.datat_bn = nn.BatchNorm1d(input_feature * num_joint)
        self.gclayers = GCLayers(num_feature,self.num_graph)
        self.dropout = nn.Dropout(0.25)
        #self.lstm = BayesianLSTM(self.output_feature*num_joint,hidden_size,prior_sigma_1=1,prior_pi=1,posterior_rho_init=-3.0)
        self.lstm = nn.LSTM(self.output_feature*num_joint,hidden_size)
        
        
    def forward(self, x, num_frame):
        # x : B*6*T*20
        x = self.gclayers(self.adj,x)
        x = self.dropout(x)

        N,C,T,V = x.size()
        x = x.permute(0,2,1,3).contiguous().view(N,T,1,C*V)
        for i in range(N):
            if i == 0:
                output = self.lstm(x[i,:num_frame[i]])[0][-1] # 取lstm最後一個output
            else:
                output = torch.cat((output,self.lstm(x[i,:num_frame[i]])[0][-1]))
        
        return self.dropout(output)

@variational_estimator
class Classifier(nn.Module):
    def __init__(self, hidden_size):
        super(Classifier, self).__init__()
        self.fc = BayesianLinear(hidden_size,27)
        self.act = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.fc(x)
        return self.act(x)
    
@variational_estimator
class Discriminator(nn.Module):
    def __init__(self, hidden_size):
        super(Discriminator, self).__init__()
        self.fc = BayesianLinear(hidden_size,1)
        self.act = nn.Sigmoid()
        
    def forward(self, x):
        x = self.fc(x)
        return self.act(x)
            
net = GC_LSTM(num_feature,hidden_size)
net = net.to(device)
classifier = Classifier(hidden_size)
classifier = classifier.to(device)
discriminator = Discriminator(hidden_size)
discriminator = discriminator.to(device)

CE_criterion = nn.CrossEntropyLoss()
BCE_criterion = nn.BCELoss()
optimizer = torch.optim.Adam(net.parameters(),lr=learning_rate)
optimizer_C = torch.optim.Adam(classifier.parameters(),lr=learning_rate)
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=decay_step,gamma=decay_rate)
scheduler_C = torch.optim.lr_scheduler.StepLR(optimizer_C,step_size=decay_step,gamma=decay_rate)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D,step_size=decay_step,gamma=decay_rate)

In [None]:
kl_weight = 1. / len(train_dataset)

M = 10 # for Monte Carlo estimation

test_interval = 10
early_stop = 0.92
training_Dloss = []
training_Gloss = []
start = time.time()
net.train()
classifier.train()
discriminator.train()

validloader_iter = iter(valid_loader)

for epoch in range(epochs):
    G_LOSS = 0
    D_LOSS = 0
    print("{:3d} epoch".format(epoch+1),end=", ")
    correct = 0
    for i,(data, label, num_frame) in enumerate(train_loader):
        try:
            valid_data, valid_label, valid_f = next(validloader_iter)
        except StopIteration:
            validloader_iter = iter(valid_loader)
            valid_data, valid_label, valid_f = next(validloader_iter)
        
        data, label, num_frame = data.to(device), label.to(device), num_frame.to(device)
        valid_data, valid_label, valid_f = valid_data.to(device), valid_label.to(device), valid_f.to(device)
        
        positive = torch.ones(label.size()).to(device)
        valid_positive = torch.ones(valid_label.size()).to(device)
        valid_negative = torch.zeros(valid_label.size()).to(device)
        
        # train discriminator
        optimizer_D.zero_grad()
        for m in range(M):
            feature = net(data,num_frame)
            output = discriminator(feature).squeeze()
            D_positive_loss = BCE_criterion(output,positive)
            
            valid_feature = net(valid_data,valid_f)
            valid_output = discriminator(valid_feature).squeeze()
            D_negative_loss = BCE_criterion(valid_output,valid_negative)
            
            D_kl_loss = discriminator.nn_kl_divergence() * kl_weight
            
            D_loss = (D_positive_loss + D_negative_loss + D_kl_loss) / M
            
            D_loss.backward()
            D_LOSS += D_loss.item()
        optimizer_D.step()
        
        # train GC-LSTM and Classifier
        optimizer.zero_grad()
        optimizer_C.zero_grad()
        for m in range(M):
            feature = net(data,num_frame)
            output = classifier(feature)
            class_loss = CE_criterion(output, label)
            
            valid_feature = net(valid_data,valid_f)
            valid_output = discriminator(valid_feature).squeeze()
            adversarial_loss = BCE_criterion(valid_output,valid_positive)
            
            G_kl_loss = net.nn_kl_divergence() * kl_weight
            C_kl_loss = classifier.nn_kl_divergence() * kl_weight
            
            G_loss = (class_loss + adversarial_loss + G_kl_loss + C_kl_loss) / M
            G_loss.backward()
            G_LOSS += G_loss.item()
            
            _, pred = output.max(1)
            correct += pred.eq(label).sum().item()
        optimizer.step()
        optimizer_C.step()
    correct /= 10
    training_Dloss.append(D_LOSS/len(train_dataset))
    training_Gloss.append(G_LOSS/len(train_dataset))
    print("D loss:{:6.4f}, G loss:{:6.4f}, training acc:{:6.2f}%, time:{:.2f}s"
          .format(training_Dloss[-1],training_Gloss[-1],correct/len(train_dataset)*100.,time.time()-start))

    scheduler.step()
    scheduler_C.step()
    scheduler_D.step()
    if (epoch+1) % test_interval == 0:
        correct = 0
        with torch.no_grad():
            for (data, label, num_frame) in test_loader:
                data, label, num_frame = data.to(device), label.to(device), num_frame.to(device)
                for m in range(M):
                    feature = net(data,num_frame)
                    output = classifier(feature)
                    _, pred = output.max(1)
                    correct += pred.eq(label).sum().item()
        correct /= 10
        print("test acc: {:5.2f}%, time:{:7.2f}s"
              .format(correct/len(test_dataset)*100.,time.time()-start))
        if correct/len(test_dataset) > early_stop:
            torch.save(net,"Bayesian_GC_LSTM.pkl")
            break

  1 epoch, D loss:0.2258, G loss:0.1248, training acc:  3.37%, time:26.56s
  2 epoch, D loss:0.2692, G loss:0.1214, training acc:  6.16%, time:52.48s
  3 epoch, D loss:0.2304, G loss:0.1213, training acc:  5.26%, time:78.39s
  4 epoch, D loss:0.1931, G loss:0.1211, training acc:  6.81%, time:104.35s
  5 epoch, D loss:0.1644, G loss:0.1212, training acc:  7.21%, time:130.85s
  6 epoch, D loss:0.1523, G loss:0.1211, training acc:  7.62%, time:156.74s
  7 epoch, D loss:0.1563, G loss:0.1212, training acc:  6.56%, time:182.62s
  8 epoch, D loss:0.1584, G loss:0.1209, training acc:  7.00%, time:208.63s
  9 epoch, D loss:0.1551, G loss:0.1210, training acc:  7.52%, time:235.24s
 10 epoch, D loss:0.1299, G loss:0.1214, training acc:  7.40%, time:261.43s
test acc:  8.33%, time: 263.99s
 11 epoch, D loss:0.1072, G loss:0.1222, training acc:  8.30%, time:289.94s
 12 epoch, D loss:0.0943, G loss:0.1235, training acc:  7.21%, time:316.22s
 13 epoch, D loss:0.0847, G loss:0.1245, training acc:  6.8

104 epoch, D loss:0.0485, G loss:0.1216, training acc: 64.92%, time:2751.35s
105 epoch, D loss:0.0484, G loss:0.1211, training acc: 66.50%, time:2777.94s
106 epoch, D loss:0.0485, G loss:0.1236, training acc: 61.95%, time:2803.95s
107 epoch, D loss:0.0481, G loss:0.1247, training acc: 61.02%, time:2829.94s
108 epoch, D loss:0.0482, G loss:0.1224, training acc: 65.54%, time:2856.00s
109 epoch, D loss:0.0477, G loss:0.1244, training acc: 65.39%, time:2882.44s
110 epoch, D loss:0.0477, G loss:0.1229, training acc: 67.46%, time:2908.45s
test acc: 50.67%, time:2911.00s
111 epoch, D loss:0.0478, G loss:0.1215, training acc: 66.19%, time:2936.99s
112 epoch, D loss:0.0484, G loss:0.1225, training acc: 66.04%, time:2962.95s
113 epoch, D loss:0.0485, G loss:0.1216, training acc: 66.28%, time:2989.53s
114 epoch, D loss:0.0482, G loss:0.1215, training acc: 68.98%, time:3015.44s
115 epoch, D loss:0.0484, G loss:0.1209, training acc: 68.92%, time:3041.44s
116 epoch, D loss:0.0485, G loss:0.1197, tra

207 epoch, D loss:0.0488, G loss:0.1129, training acc: 85.08%, time:5468.69s
208 epoch, D loss:0.0482, G loss:0.1140, training acc: 83.81%, time:5494.61s
209 epoch, D loss:0.0486, G loss:0.1142, training acc: 78.30%, time:5521.10s
210 epoch, D loss:0.0491, G loss:0.1129, training acc: 82.79%, time:5547.10s
test acc: 66.74%, time:5549.66s
211 epoch, D loss:0.0488, G loss:0.1137, training acc: 82.97%, time:5575.83s
212 epoch, D loss:0.0485, G loss:0.1129, training acc: 84.24%, time:5601.79s
213 epoch, D loss:0.0488, G loss:0.1124, training acc: 83.99%, time:5628.33s
214 epoch, D loss:0.0492, G loss:0.1115, training acc: 86.04%, time:5654.35s
215 epoch, D loss:0.0495, G loss:0.1128, training acc: 84.92%, time:5680.41s
216 epoch, D loss:0.0495, G loss:0.1121, training acc: 86.35%, time:5706.34s
217 epoch, D loss:0.0493, G loss:0.1124, training acc: 86.38%, time:5732.98s
218 epoch, D loss:0.0487, G loss:0.1143, training acc: 85.63%, time:5758.81s
219 epoch, D loss:0.0491, G loss:0.1124, tra

310 epoch, D loss:0.0614, G loss:0.1056, training acc: 88.24%, time:8205.35s
test acc: 69.09%, time:8207.95s
311 epoch, D loss:0.0624, G loss:0.1034, training acc: 88.61%, time:8234.13s
312 epoch, D loss:0.0624, G loss:0.1036, training acc: 88.95%, time:8260.29s
313 epoch, D loss:0.0602, G loss:0.1063, training acc: 89.01%, time:8286.91s
314 epoch, D loss:0.0641, G loss:0.1031, training acc: 88.76%, time:8313.12s
315 epoch, D loss:0.0618, G loss:0.1052, training acc: 88.61%, time:8339.32s
316 epoch, D loss:0.0605, G loss:0.1044, training acc: 88.89%, time:8365.49s
317 epoch, D loss:0.0586, G loss:0.1037, training acc: 89.04%, time:8392.13s
318 epoch, D loss:0.0595, G loss:0.1055, training acc: 88.20%, time:8418.34s
319 epoch, D loss:0.0574, G loss:0.1071, training acc: 88.64%, time:8444.53s
320 epoch, D loss:0.0597, G loss:0.1043, training acc: 88.54%, time:8470.90s
test acc: 68.14%, time:8473.49s
321 epoch, D loss:0.0613, G loss:0.1051, training acc: 87.86%, time:8500.15s
322 epoch, D

412 epoch, D loss:0.0566, G loss:0.1079, training acc: 96.25%, time:10912.29s
413 epoch, D loss:0.0574, G loss:0.1053, training acc: 96.28%, time:10939.03s
414 epoch, D loss:0.0594, G loss:0.1029, training acc: 96.22%, time:10965.26s
415 epoch, D loss:0.0584, G loss:0.1029, training acc: 96.25%, time:10991.58s
416 epoch, D loss:0.0591, G loss:0.1035, training acc: 96.16%, time:11017.87s
417 epoch, D loss:0.0604, G loss:0.1021, training acc: 96.25%, time:11044.61s
418 epoch, D loss:0.0636, G loss:0.1000, training acc: 96.25%, time:11070.90s
419 epoch, D loss:0.0625, G loss:0.1004, training acc: 96.16%, time:11097.34s
420 epoch, D loss:0.0580, G loss:0.1039, training acc: 96.22%, time:11123.67s
test acc: 71.49%, time:11126.26s
421 epoch, D loss:0.0559, G loss:0.1073, training acc: 96.22%, time:11153.09s
422 epoch, D loss:0.0543, G loss:0.1073, training acc: 96.25%, time:11179.34s
423 epoch, D loss:0.0548, G loss:0.1055, training acc: 96.01%, time:11205.74s
424 epoch, D loss:0.0568, G los

513 epoch, D loss:0.0549, G loss:0.1030, training acc: 96.25%, time:13609.77s
514 epoch, D loss:0.0525, G loss:0.1060, training acc: 96.19%, time:13636.02s
515 epoch, D loss:0.0532, G loss:0.1065, training acc: 96.04%, time:13662.39s
516 epoch, D loss:0.0538, G loss:0.1077, training acc: 96.28%, time:13688.58s
517 epoch, D loss:0.0555, G loss:0.1056, training acc: 96.19%, time:13715.39s
518 epoch, D loss:0.0529, G loss:0.1072, training acc: 96.04%, time:13741.60s
519 epoch, D loss:0.0505, G loss:0.1092, training acc: 96.25%, time:13767.84s
520 epoch, D loss:0.0493, G loss:0.1081, training acc: 96.13%, time:13794.24s
test acc: 69.09%, time:13796.85s
521 epoch, D loss:0.0506, G loss:0.1066, training acc: 95.98%, time:13823.59s
522 epoch, D loss:0.0512, G loss:0.1048, training acc: 95.98%, time:13849.75s
523 epoch, D loss:0.0546, G loss:0.1025, training acc: 96.01%, time:13875.94s
524 epoch, D loss:0.0564, G loss:0.1017, training acc: 96.25%, time:13902.12s
525 epoch, D loss:0.0576, G los

614 epoch, D loss:0.0561, G loss:0.1034, training acc: 96.16%, time:16291.92s
615 epoch, D loss:0.0554, G loss:0.1042, training acc: 96.28%, time:16317.97s
616 epoch, D loss:0.0539, G loss:0.1057, training acc: 96.25%, time:16343.88s
617 epoch, D loss:0.0554, G loss:0.1038, training acc: 96.19%, time:16370.36s
618 epoch, D loss:0.0583, G loss:0.1023, training acc: 96.25%, time:16396.14s
619 epoch, D loss:0.0563, G loss:0.1047, training acc: 96.22%, time:16422.09s
620 epoch, D loss:0.0534, G loss:0.1086, training acc: 96.25%, time:16447.95s
test acc: 73.58%, time:16450.50s
621 epoch, D loss:0.0517, G loss:0.1101, training acc: 96.25%, time:16476.84s
622 epoch, D loss:0.0512, G loss:0.1086, training acc: 96.28%, time:16502.90s
623 epoch, D loss:0.0501, G loss:0.1088, training acc: 96.25%, time:16528.84s
624 epoch, D loss:0.0530, G loss:0.1057, training acc: 96.19%, time:16554.81s
625 epoch, D loss:0.0571, G loss:0.1023, training acc: 96.25%, time:16581.03s
626 epoch, D loss:0.0577, G los

715 epoch, D loss:0.0503, G loss:0.1065, training acc: 95.98%, time:18953.26s
716 epoch, D loss:0.0525, G loss:0.1024, training acc: 96.04%, time:18979.29s
717 epoch, D loss:0.0527, G loss:0.1015, training acc: 95.82%, time:19005.92s
718 epoch, D loss:0.0524, G loss:0.1022, training acc: 96.10%, time:19031.84s
719 epoch, D loss:0.0517, G loss:0.1033, training acc: 95.98%, time:19057.76s
720 epoch, D loss:0.0519, G loss:0.1032, training acc: 95.94%, time:19083.82s
test acc: 70.58%, time:19086.36s
721 epoch, D loss:0.0536, G loss:0.1029, training acc: 95.94%, time:19112.93s
722 epoch, D loss:0.0520, G loss:0.1045, training acc: 95.70%, time:19138.86s
723 epoch, D loss:0.0515, G loss:0.1052, training acc: 95.82%, time:19164.83s
724 epoch, D loss:0.0510, G loss:0.1052, training acc: 95.91%, time:19190.70s
725 epoch, D loss:0.0512, G loss:0.1052, training acc: 95.57%, time:19217.36s
726 epoch, D loss:0.0489, G loss:0.1062, training acc: 95.88%, time:19243.23s
727 epoch, D loss:0.0491, G los

816 epoch, D loss:0.0494, G loss:0.1063, training acc: 95.73%, time:21639.03s
817 epoch, D loss:0.0493, G loss:0.1076, training acc: 94.37%, time:21665.68s
818 epoch, D loss:0.0495, G loss:0.1075, training acc: 95.79%, time:21691.77s
819 epoch, D loss:0.0488, G loss:0.1067, training acc: 95.85%, time:21717.87s
820 epoch, D loss:0.0488, G loss:0.1076, training acc: 96.16%, time:21743.89s
test acc: 70.93%, time:21746.49s
821 epoch, D loss:0.0484, G loss:0.1070, training acc: 96.01%, time:21773.13s
822 epoch, D loss:0.0482, G loss:0.1074, training acc: 96.04%, time:21799.20s
823 epoch, D loss:0.0491, G loss:0.1075, training acc: 95.73%, time:21825.35s
824 epoch, D loss:0.0486, G loss:0.1065, training acc: 96.10%, time:21851.42s
825 epoch, D loss:0.0491, G loss:0.1064, training acc: 96.13%, time:21878.00s
826 epoch, D loss:0.0493, G loss:0.1073, training acc: 95.94%, time:21904.40s
827 epoch, D loss:0.0490, G loss:0.1066, training acc: 95.94%, time:21930.58s
828 epoch, D loss:0.0491, G los

917 epoch, D loss:0.0511, G loss:0.1029, training acc: 99.60%, time:24327.91s
918 epoch, D loss:0.0527, G loss:0.1013, training acc: 99.85%, time:24354.06s
919 epoch, D loss:0.0526, G loss:0.1028, training acc: 99.78%, time:24380.17s
920 epoch, D loss:0.0520, G loss:0.1037, training acc: 99.54%, time:24406.40s
test acc: 67.51%, time:24409.00s
921 epoch, D loss:0.0536, G loss:0.1027, training acc: 99.94%, time:24435.92s
922 epoch, D loss:0.0537, G loss:0.1026, training acc: 99.94%, time:24462.25s
923 epoch, D loss:0.0531, G loss:0.1031, training acc: 99.88%, time:24488.58s
924 epoch, D loss:0.0536, G loss:0.1023, training acc: 99.88%, time:24514.86s
925 epoch, D loss:0.0545, G loss:0.1016, training acc: 99.88%, time:24541.57s
926 epoch, D loss:0.0556, G loss:0.1018, training acc: 99.81%, time:24567.91s
927 epoch, D loss:0.0547, G loss:0.1033, training acc: 99.81%, time:24594.15s
928 epoch, D loss:0.0555, G loss:0.1025, training acc: 99.91%, time:24620.42s
929 epoch, D loss:0.0567, G los

1018 epoch, D loss:0.0536, G loss:0.1023, training acc:100.00%, time:27017.83s
1019 epoch, D loss:0.0558, G loss:0.1009, training acc: 99.97%, time:27044.04s
1020 epoch, D loss:0.0566, G loss:0.1011, training acc: 99.91%, time:27070.25s
test acc: 76.51%, time:27072.85s
1021 epoch, D loss:0.0536, G loss:0.1043, training acc:100.00%, time:27099.52s
1022 epoch, D loss:0.0525, G loss:0.1055, training acc: 99.94%, time:27125.72s
1023 epoch, D loss:0.0522, G loss:0.1053, training acc:100.00%, time:27151.86s
1024 epoch, D loss:0.0522, G loss:0.1052, training acc:100.00%, time:27178.27s
1025 epoch, D loss:0.0521, G loss:0.1052, training acc: 99.94%, time:27205.17s
1026 epoch, D loss:0.0509, G loss:0.1065, training acc:100.00%, time:27231.46s
1027 epoch, D loss:0.0509, G loss:0.1070, training acc: 99.97%, time:27257.65s
1028 epoch, D loss:0.0521, G loss:0.1053, training acc:100.00%, time:27283.74s
1029 epoch, D loss:0.0538, G loss:0.1033, training acc: 99.78%, time:27310.48s
1030 epoch, D loss:

1118 epoch, D loss:0.0543, G loss:0.1004, training acc: 99.94%, time:29654.92s
1119 epoch, D loss:0.0538, G loss:0.1007, training acc: 99.94%, time:29681.17s
1120 epoch, D loss:0.0563, G loss:0.0993, training acc:100.00%, time:29707.35s
test acc: 77.93%, time:29709.95s
1121 epoch, D loss:0.0560, G loss:0.0991, training acc:100.00%, time:29736.80s
1122 epoch, D loss:0.0553, G loss:0.0996, training acc: 99.78%, time:29762.98s
1123 epoch, D loss:0.0565, G loss:0.0982, training acc: 99.94%, time:29789.04s
1124 epoch, D loss:0.0552, G loss:0.0988, training acc: 99.91%, time:29815.21s
1125 epoch, D loss:0.0537, G loss:0.1003, training acc: 99.97%, time:29842.01s
1126 epoch, D loss:0.0520, G loss:0.1021, training acc: 99.97%, time:29868.16s
1127 epoch, D loss:0.0515, G loss:0.1032, training acc: 99.41%, time:29894.35s
1128 epoch, D loss:0.0503, G loss:0.1041, training acc: 99.97%, time:29920.52s
1129 epoch, D loss:0.0523, G loss:0.1047, training acc: 99.35%, time:29947.23s
1130 epoch, D loss:

1218 epoch, D loss:0.0531, G loss:0.1033, training acc: 99.94%, time:32324.40s
1219 epoch, D loss:0.0500, G loss:0.1049, training acc: 99.85%, time:32350.57s
1220 epoch, D loss:0.0490, G loss:0.1052, training acc: 99.91%, time:32376.69s
test acc: 74.09%, time:32379.29s
1221 epoch, D loss:0.0500, G loss:0.1031, training acc: 99.94%, time:32406.07s
1222 epoch, D loss:0.0510, G loss:0.1022, training acc:100.00%, time:32432.40s
1223 epoch, D loss:0.0521, G loss:0.1014, training acc:100.00%, time:32458.62s
1224 epoch, D loss:0.0516, G loss:0.1018, training acc: 99.97%, time:32484.68s
1225 epoch, D loss:0.0511, G loss:0.1030, training acc: 99.75%, time:32511.45s
1226 epoch, D loss:0.0496, G loss:0.1038, training acc: 99.97%, time:32537.65s
1227 epoch, D loss:0.0503, G loss:0.1032, training acc: 99.94%, time:32563.87s
1228 epoch, D loss:0.0528, G loss:0.1018, training acc: 99.91%, time:32590.14s
1229 epoch, D loss:0.0533, G loss:0.1014, training acc: 99.91%, time:32616.96s
1230 epoch, D loss:

1318 epoch, D loss:0.0485, G loss:0.1060, training acc: 99.54%, time:34970.98s
1319 epoch, D loss:0.0484, G loss:0.1063, training acc: 99.35%, time:34997.10s
1320 epoch, D loss:0.0481, G loss:0.1062, training acc: 98.95%, time:35023.17s
test acc: 75.23%, time:35025.77s
1321 epoch, D loss:0.0475, G loss:0.1059, training acc: 99.54%, time:35052.33s
1322 epoch, D loss:0.0478, G loss:0.1062, training acc: 99.60%, time:35078.71s
1323 epoch, D loss:0.0474, G loss:0.1057, training acc: 99.78%, time:35104.93s
1324 epoch, D loss:0.0479, G loss:0.1058, training acc: 99.72%, time:35131.08s
1325 epoch, D loss:0.0476, G loss:0.1057, training acc: 99.66%, time:35157.65s
1326 epoch, D loss:0.0481, G loss:0.1053, training acc: 99.66%, time:35183.54s
1327 epoch, D loss:0.0476, G loss:0.1053, training acc: 99.88%, time:35209.50s
1328 epoch, D loss:0.0476, G loss:0.1055, training acc: 99.47%, time:35235.54s
1329 epoch, D loss:0.0479, G loss:0.1058, training acc: 99.94%, time:35262.08s
1330 epoch, D loss:

1418 epoch, D loss:0.0480, G loss:0.1044, training acc: 99.75%, time:37610.96s
1419 epoch, D loss:0.0489, G loss:0.1046, training acc: 99.60%, time:37637.34s
1420 epoch, D loss:0.0491, G loss:0.1050, training acc: 99.47%, time:37663.53s
test acc: 75.30%, time:37666.12s
1421 epoch, D loss:0.0487, G loss:0.1055, training acc: 99.10%, time:37693.05s
1422 epoch, D loss:0.0486, G loss:0.1060, training acc: 99.69%, time:37719.27s
1423 epoch, D loss:0.0490, G loss:0.1074, training acc: 99.32%, time:37745.53s
1424 epoch, D loss:0.0487, G loss:0.1062, training acc: 99.85%, time:37771.73s
1425 epoch, D loss:0.0484, G loss:0.1063, training acc: 99.81%, time:37798.61s
1426 epoch, D loss:0.0483, G loss:0.1067, training acc: 99.81%, time:37824.66s
1427 epoch, D loss:0.0478, G loss:0.1071, training acc: 99.75%, time:37850.98s
1428 epoch, D loss:0.0473, G loss:0.1073, training acc: 99.75%, time:37877.32s
1429 epoch, D loss:0.0481, G loss:0.1061, training acc: 99.78%, time:37904.21s
1430 epoch, D loss:

In [None]:
net.eval()
M = 10
net = torch.load("Bayesian_GC_LSTM_86.pkl").to(device)
correct = 0
for (data, label, num_frame) in test_loader:
    data, label, num_frame = data.to(device), label.to(device), num_frame.to(device)
    for _ in range(M):
        output = net(data,num_frame)
        _, pred = output.max(1)
        correct += pred.eq(label).sum().item()
print("test acc: {:5.2f}%".format((correct/M)/len(test_dataset)*100.))