In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import time
import shutil
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from models.ST_Former import GenerateModel
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import datetime
from dataloader.dataset_NIA import train_data_loader, test_data_loader

In [3]:
from runner_helper import *

In [4]:
class Pseudoarg():
    def __init__(self):
        self.workers = 1
        self.epochs = 100
        self.start_epoch = 0
        self.batch_size = 32
        self.lr = 0.01
        self.momentum = 0.9
        self.weight_decay = 1e-4
        self.print_freq = 10
        self.resume = None
        self.data_set = 0
        
args = Pseudoarg()

In [12]:
now = datetime.datetime.now()
time_str = now.strftime("[%m-%d]-[%H:%M]-")
project_path = './nia/'
log_txt_path = project_path + 'log/' + time_str + 'set' + str(args.data_set) + '-log.txt'
log_curve_path = project_path + 'log/' + time_str + 'set' + str(args.data_set) + '-log.png'
checkpoint_path = project_path + 'checkpoint/' + time_str + 'set' + str(args.data_set) + '-model.pth'
best_checkpoint_path = project_path + 'checkpoint/' + time_str + 'set' + str(args.data_set) + '-model_best.pth'


In [13]:
fn_model = "./nia/checkpoint/[08-25]-[08:08]-set0-model_best.pth"
args.resume = fn_model

In [14]:
#def main():
best_acc = 0
#recorder = RecorderMeter(args.epochs)

# create model and load pre_trained parameters
model = GenerateModel()
model = torch.nn.DataParallel(model).cuda()

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)


In [15]:
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        recorder = checkpoint['recorder']
        best_acc = best_acc.cuda()
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True

=> loading checkpoint './nia/checkpoint/[08-25]-[08:08]-set0-model_best.pth'
=> loaded checkpoint './nia/checkpoint/[08-25]-[08:08]-set0-model_best.pth' (epoch 48)


In [16]:
# Data loading code
train_data = train_data_loader(project_dir=project_path, 
                               data_set=args.data_set)
test_data = test_data_loader(project_dir=project_path,
                             data_set=args.data_set)

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.workers,
                                           pin_memory=True,
                                           drop_last=True)
val_loader = torch.utils.data.DataLoader(test_data,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=True)

video number:2800
video number:799


In [24]:
tt =[]
for i, (images, target) in enumerate(val_loader):
    tt.extend(target)

In [29]:
np.histogram(np.squeeze(tt))

(array([213,   0, 130,   0,   0, 159,   0, 177,   0, 120]),
 array([0. , 0.4, 0.8, 1.2, 1.6, 2. , 2.4, 2.8, 3.2, 3.6, 4. ]))

In [25]:
len(tt)

799

In [66]:
for epoch in range(args.start_epoch, args.epochs):
    inf = '********************' + str(epoch) + '********************'

    # evaluate on validation set
    val_acc, val_los = validate(val_loader, model, criterion, args)

    # remember best acc and save checkpoint
    is_best = val_acc > best_acc
    best_acc = max(val_acc, best_acc)
    save_checkpoint({'epoch': epoch + 1,
                     'state_dict': model.state_dict(),
                     'best_acc': best_acc,
                     'optimizer': optimizer.state_dict(),
                     'recorder': recorder}, is_best)

    # print and save log
    epoch_time = time.time() - start_time
    recorder.update(epoch, train_los, train_acc, val_los, val_acc)
    recorder.plot_curve(log_curve_path)

    print('The best accuracy: {:.3f}'.format(best_acc.item()))
    print('An epoch time: {:.1f}s'.format(epoch_time))
    with open(log_txt_path, 'a') as f:
        f.write('The best accuracy: ' + str(best_acc.item()) + '\n')
        f.write('An epoch time: {:.1f}s' + str(epoch_time) + '\n')

********************0********************
Current learning rate:  0.01
Epoch: [0][ 0/87]	Loss 2.3868 (2.3868)	Accuracy 12.500 (12.500)
Epoch: [0][10/87]	Loss 4.5241 (3.5132)	Accuracy 15.625 (19.318)
Epoch: [0][20/87]	Loss 2.0793 (2.9263)	Accuracy 21.875 (18.304)
Epoch: [0][30/87]	Loss 1.6212 (2.5622)	Accuracy 31.250 (19.556)
Epoch: [0][40/87]	Loss 2.0750 (2.3503)	Accuracy 21.875 (19.970)
Epoch: [0][50/87]	Loss 1.6269 (2.2201)	Accuracy 25.000 (19.547)
Epoch: [0][60/87]	Loss 1.6939 (2.1245)	Accuracy 18.750 (20.031)
Epoch: [0][70/87]	Loss 1.5758 (2.0535)	Accuracy 28.125 (20.202)
Epoch: [0][80/87]	Loss 1.6554 (2.0040)	Accuracy 12.500 (20.216)
Test: [ 0/25]	Loss 1.5443 (1.5443)	Accuracy 31.250 (31.250)
Test: [10/25]	Loss 1.5183 (1.5857)	Accuracy 12.500 (22.159)
Test: [20/25]	Loss 1.6525 (1.5996)	Accuracy 15.625 (22.321)
Current Accuracy: 23.029
The best accuracy: 23.029
An epoch time: 68.0s
********************1********************
Current learning rate:  0.01
Epoch: [1][ 0/87]	Loss 1.6111 

Epoch: [9][10/87]	Loss 1.5804 (1.5999)	Accuracy 21.875 (22.159)
Epoch: [9][20/87]	Loss 1.5926 (1.5983)	Accuracy 25.000 (21.429)
Epoch: [9][30/87]	Loss 1.5335 (1.5984)	Accuracy 28.125 (21.774)
Epoch: [9][40/87]	Loss 1.6392 (1.6020)	Accuracy 15.625 (21.418)
Epoch: [9][50/87]	Loss 1.5473 (1.5974)	Accuracy 25.000 (22.304)
Epoch: [9][60/87]	Loss 1.5688 (1.5956)	Accuracy 28.125 (23.156)
Epoch: [9][70/87]	Loss 1.5821 (1.5971)	Accuracy 34.375 (22.931)
Epoch: [9][80/87]	Loss 1.5919 (1.5967)	Accuracy 21.875 (22.762)
Test: [ 0/25]	Loss 1.6214 (1.6214)	Accuracy 18.750 (18.750)
Test: [10/25]	Loss 1.4297 (1.5780)	Accuracy 53.125 (30.966)
Test: [20/25]	Loss 1.7708 (1.6144)	Accuracy 15.625 (25.595)
Current Accuracy: 24.781
The best accuracy: 26.658
An epoch time: 64.5s
********************10********************
Current learning rate:  0.01
Epoch: [10][ 0/87]	Loss 1.6374 (1.6374)	Accuracy 25.000 (25.000)
Epoch: [10][10/87]	Loss 1.5728 (1.5910)	Accuracy 18.750 (23.011)
Epoch: [10][20/87]	Loss 1.6636 (1.

Epoch: [18][20/87]	Loss 1.3855 (1.0954)	Accuracy 28.125 (45.685)
Epoch: [18][30/87]	Loss 1.0655 (1.1565)	Accuracy 43.750 (43.750)
Epoch: [18][40/87]	Loss 1.3343 (1.1628)	Accuracy 43.750 (43.750)
Epoch: [18][50/87]	Loss 1.2647 (1.1649)	Accuracy 43.750 (43.995)
Epoch: [18][60/87]	Loss 1.1902 (1.1525)	Accuracy 37.500 (45.031)
Epoch: [18][70/87]	Loss 1.0865 (1.1526)	Accuracy 50.000 (45.114)
Epoch: [18][80/87]	Loss 1.1941 (1.1456)	Accuracy 43.750 (45.216)
Test: [ 0/25]	Loss 0.8570 (0.8570)	Accuracy 68.750 (68.750)
Test: [10/25]	Loss 1.0433 (1.0248)	Accuracy 50.000 (48.295)
Test: [20/25]	Loss 1.1731 (1.0068)	Accuracy 53.125 (50.000)
Current Accuracy: 49.437
The best accuracy: 49.437
An epoch time: 64.0s
********************19********************
Current learning rate:  0.01
Epoch: [19][ 0/87]	Loss 1.2433 (1.2433)	Accuracy 40.625 (40.625)
Epoch: [19][10/87]	Loss 1.0753 (1.0820)	Accuracy 53.125 (50.000)
Epoch: [19][20/87]	Loss 1.1841 (1.0769)	Accuracy 43.750 (50.446)
Epoch: [19][30/87]	Loss 1.

Epoch: [27][30/87]	Loss 1.1034 (1.0482)	Accuracy 37.500 (49.899)
Epoch: [27][40/87]	Loss 1.0579 (1.0293)	Accuracy 50.000 (50.991)
Epoch: [27][50/87]	Loss 1.0741 (1.0279)	Accuracy 34.375 (50.919)
Epoch: [27][60/87]	Loss 1.2411 (1.0258)	Accuracy 40.625 (50.820)
Epoch: [27][70/87]	Loss 0.9543 (1.0250)	Accuracy 50.000 (50.836)
Epoch: [27][80/87]	Loss 1.0814 (1.0243)	Accuracy 50.000 (51.003)
Test: [ 0/25]	Loss 1.0299 (1.0299)	Accuracy 43.750 (43.750)
Test: [10/25]	Loss 1.0889 (1.0313)	Accuracy 37.500 (47.159)
Test: [20/25]	Loss 1.0959 (1.0035)	Accuracy 46.875 (49.851)
Current Accuracy: 49.937
The best accuracy: 57.071
An epoch time: 63.9s
********************28********************
Current learning rate:  0.01
Epoch: [28][ 0/87]	Loss 1.1982 (1.1982)	Accuracy 40.625 (40.625)
Epoch: [28][10/87]	Loss 1.0837 (1.0010)	Accuracy 43.750 (51.136)
Epoch: [28][20/87]	Loss 1.1249 (0.9852)	Accuracy 46.875 (53.571)
Epoch: [28][30/87]	Loss 1.0089 (1.0080)	Accuracy 56.250 (53.327)
Epoch: [28][40/87]	Loss 0.

Epoch: [36][40/87]	Loss 1.1470 (0.9427)	Accuracy 43.750 (57.470)
Epoch: [36][50/87]	Loss 0.8186 (0.9296)	Accuracy 68.750 (58.027)
Epoch: [36][60/87]	Loss 0.9978 (0.9184)	Accuracy 53.125 (58.402)
Epoch: [36][70/87]	Loss 1.1411 (0.9337)	Accuracy 56.250 (57.614)
Epoch: [36][80/87]	Loss 0.9803 (0.9409)	Accuracy 46.875 (57.253)
Test: [ 0/25]	Loss 0.8911 (0.8911)	Accuracy 46.875 (46.875)
Test: [10/25]	Loss 1.4177 (1.0333)	Accuracy 31.250 (48.580)
Test: [20/25]	Loss 0.8566 (0.9822)	Accuracy 56.250 (51.042)
Current Accuracy: 51.189
The best accuracy: 61.202
An epoch time: 64.2s
********************37********************
Current learning rate:  0.01
Epoch: [37][ 0/87]	Loss 0.9037 (0.9037)	Accuracy 62.500 (62.500)
Epoch: [37][10/87]	Loss 0.7495 (0.9359)	Accuracy 65.625 (55.114)
Epoch: [37][20/87]	Loss 0.8312 (0.9331)	Accuracy 68.750 (57.589)
Epoch: [37][30/87]	Loss 1.3173 (0.9372)	Accuracy 43.750 (57.359)
Epoch: [37][40/87]	Loss 1.0828 (0.9516)	Accuracy 43.750 (56.326)
Epoch: [37][50/87]	Loss 0.

Epoch: [45][50/87]	Loss 0.5904 (0.7898)	Accuracy 78.125 (64.828)
Epoch: [45][60/87]	Loss 0.7971 (0.7927)	Accuracy 59.375 (64.600)
Epoch: [45][70/87]	Loss 0.5212 (0.7895)	Accuracy 78.125 (65.229)
Epoch: [45][80/87]	Loss 0.7021 (0.7843)	Accuracy 75.000 (65.702)
Test: [ 0/25]	Loss 0.7087 (0.7087)	Accuracy 75.000 (75.000)
Test: [10/25]	Loss 1.0618 (0.9003)	Accuracy 53.125 (59.943)
Test: [20/25]	Loss 0.7917 (0.8516)	Accuracy 59.375 (61.310)
Current Accuracy: 61.202
The best accuracy: 62.954
An epoch time: 63.9s
********************46********************
Current learning rate:  0.001
Epoch: [46][ 0/87]	Loss 0.8343 (0.8343)	Accuracy 56.250 (56.250)
Epoch: [46][10/87]	Loss 0.9526 (0.8005)	Accuracy 46.875 (63.068)
Epoch: [46][20/87]	Loss 0.7853 (0.8436)	Accuracy 59.375 (61.607)
Epoch: [46][30/87]	Loss 0.6743 (0.8024)	Accuracy 65.625 (64.113)
Epoch: [46][40/87]	Loss 0.6900 (0.8074)	Accuracy 71.875 (64.253)
Epoch: [46][50/87]	Loss 0.7532 (0.8023)	Accuracy 71.875 (64.400)
Epoch: [46][60/87]	Loss 0

Epoch: [54][60/87]	Loss 0.7635 (0.7490)	Accuracy 62.500 (66.342)
Epoch: [54][70/87]	Loss 0.7577 (0.7374)	Accuracy 59.375 (66.681)
Epoch: [54][80/87]	Loss 0.7498 (0.7374)	Accuracy 65.625 (66.782)
Test: [ 0/25]	Loss 0.7276 (0.7276)	Accuracy 75.000 (75.000)
Test: [10/25]	Loss 1.0743 (0.9421)	Accuracy 59.375 (60.511)
Test: [20/25]	Loss 0.8390 (0.8946)	Accuracy 59.375 (62.202)
Current Accuracy: 62.078
The best accuracy: 63.579
An epoch time: 63.8s
********************55********************
Current learning rate:  0.001
Epoch: [55][ 0/87]	Loss 0.8346 (0.8346)	Accuracy 68.750 (68.750)
Epoch: [55][10/87]	Loss 0.6612 (0.7392)	Accuracy 78.125 (69.318)
Epoch: [55][20/87]	Loss 0.8001 (0.7669)	Accuracy 56.250 (68.155)
Epoch: [55][30/87]	Loss 0.5605 (0.7605)	Accuracy 75.000 (67.742)
Epoch: [55][40/87]	Loss 0.8604 (0.7672)	Accuracy 65.625 (66.921)
Epoch: [55][50/87]	Loss 0.7041 (0.7616)	Accuracy 59.375 (66.605)
Epoch: [55][60/87]	Loss 0.6615 (0.7412)	Accuracy 75.000 (67.469)
Epoch: [55][70/87]	Loss 0

Epoch: [63][70/87]	Loss 0.7946 (0.6974)	Accuracy 75.000 (68.882)
Epoch: [63][80/87]	Loss 0.5405 (0.6909)	Accuracy 75.000 (69.213)
Test: [ 0/25]	Loss 0.7804 (0.7804)	Accuracy 68.750 (68.750)
Test: [10/25]	Loss 1.0437 (0.9451)	Accuracy 53.125 (58.807)
Test: [20/25]	Loss 0.8062 (0.8854)	Accuracy 56.250 (61.012)
Current Accuracy: 61.327
The best accuracy: 63.579
An epoch time: 63.7s
********************64********************
Current learning rate:  0.001
Epoch: [64][ 0/87]	Loss 0.9003 (0.9003)	Accuracy 53.125 (53.125)
Epoch: [64][10/87]	Loss 0.6145 (0.7373)	Accuracy 71.875 (67.045)
Epoch: [64][20/87]	Loss 0.7329 (0.7094)	Accuracy 68.750 (68.601)
Epoch: [64][30/87]	Loss 0.8472 (0.7172)	Accuracy 62.500 (68.044)
Epoch: [64][40/87]	Loss 0.7192 (0.7055)	Accuracy 56.250 (69.207)
Epoch: [64][50/87]	Loss 0.5911 (0.6998)	Accuracy 71.875 (69.485)
Epoch: [64][60/87]	Loss 0.9848 (0.7147)	Accuracy 43.750 (69.057)
Epoch: [64][70/87]	Loss 0.8522 (0.7035)	Accuracy 62.500 (69.938)
Epoch: [64][80/87]	Loss 0

Epoch: [72][80/87]	Loss 0.9402 (0.6869)	Accuracy 68.750 (69.753)
Test: [ 0/25]	Loss 0.7460 (0.7460)	Accuracy 65.625 (65.625)
Test: [10/25]	Loss 1.1418 (0.9845)	Accuracy 50.000 (56.534)
Test: [20/25]	Loss 0.9281 (0.9307)	Accuracy 56.250 (58.185)
Current Accuracy: 58.448
The best accuracy: 63.579
An epoch time: 63.8s
********************73********************
Current learning rate:  0.001
Epoch: [73][ 0/87]	Loss 0.9373 (0.9373)	Accuracy 68.750 (68.750)
Epoch: [73][10/87]	Loss 0.5622 (0.7450)	Accuracy 81.250 (67.898)
Epoch: [73][20/87]	Loss 0.5977 (0.6910)	Accuracy 78.125 (71.429)
Epoch: [73][30/87]	Loss 0.6183 (0.6853)	Accuracy 71.875 (70.766)
Epoch: [73][40/87]	Loss 0.8068 (0.6819)	Accuracy 62.500 (70.274)
Epoch: [73][50/87]	Loss 0.5414 (0.6850)	Accuracy 78.125 (70.282)
Epoch: [73][60/87]	Loss 0.6007 (0.6807)	Accuracy 71.875 (70.441)
Epoch: [73][70/87]	Loss 0.5420 (0.6726)	Accuracy 75.000 (70.863)
Epoch: [73][80/87]	Loss 0.6534 (0.6802)	Accuracy 71.875 (70.563)
Test: [ 0/25]	Loss 0.7506

Test: [ 0/25]	Loss 0.7518 (0.7518)	Accuracy 68.750 (68.750)
Test: [10/25]	Loss 1.1313 (0.9627)	Accuracy 50.000 (57.386)
Test: [20/25]	Loss 0.9032 (0.9018)	Accuracy 62.500 (59.226)
Current Accuracy: 59.700
The best accuracy: 63.579
An epoch time: 63.9s
********************82********************
Current learning rate:  0.0001
Epoch: [82][ 0/87]	Loss 0.5751 (0.5751)	Accuracy 71.875 (71.875)
Epoch: [82][10/87]	Loss 0.8466 (0.6778)	Accuracy 65.625 (70.455)
Epoch: [82][20/87]	Loss 0.7997 (0.6708)	Accuracy 62.500 (69.196)
Epoch: [82][30/87]	Loss 0.8147 (0.6564)	Accuracy 65.625 (70.665)
Epoch: [82][40/87]	Loss 0.6808 (0.6508)	Accuracy 71.875 (71.113)
Epoch: [82][50/87]	Loss 0.7730 (0.6495)	Accuracy 71.875 (71.752)
Epoch: [82][60/87]	Loss 0.6555 (0.6559)	Accuracy 75.000 (71.516)
Epoch: [82][70/87]	Loss 0.7114 (0.6470)	Accuracy 65.625 (71.523)
Epoch: [82][80/87]	Loss 0.7771 (0.6417)	Accuracy 71.875 (71.335)
Test: [ 0/25]	Loss 0.7660 (0.7660)	Accuracy 65.625 (65.625)
Test: [10/25]	Loss 1.1165 (0.

Test: [ 0/25]	Loss 0.7561 (0.7561)	Accuracy 62.500 (62.500)
Test: [10/25]	Loss 1.1514 (0.9447)	Accuracy 56.250 (58.239)
Test: [20/25]	Loss 0.8588 (0.8811)	Accuracy 62.500 (60.565)
Current Accuracy: 60.451
The best accuracy: 63.579
An epoch time: 64.2s
********************91********************
Current learning rate:  0.0001
Epoch: [91][ 0/87]	Loss 0.4445 (0.4445)	Accuracy 84.375 (84.375)
Epoch: [91][10/87]	Loss 0.5779 (0.6080)	Accuracy 68.750 (73.864)
Epoch: [91][20/87]	Loss 0.7065 (0.6032)	Accuracy 71.875 (73.363)
Epoch: [91][30/87]	Loss 0.7234 (0.6382)	Accuracy 78.125 (72.379)
Epoch: [91][40/87]	Loss 0.9303 (0.6461)	Accuracy 59.375 (72.104)
Epoch: [91][50/87]	Loss 0.6845 (0.6372)	Accuracy 68.750 (72.488)
Epoch: [91][60/87]	Loss 0.6733 (0.6380)	Accuracy 65.625 (72.336)
Epoch: [91][70/87]	Loss 1.0818 (0.6404)	Accuracy 62.500 (72.491)
Epoch: [91][80/87]	Loss 0.4746 (0.6290)	Accuracy 78.125 (72.917)
Test: [ 0/25]	Loss 0.7584 (0.7584)	Accuracy 65.625 (65.625)
Test: [10/25]	Loss 1.1563 (0.

Test: [ 0/25]	Loss 0.7422 (0.7422)	Accuracy 65.625 (65.625)
Test: [10/25]	Loss 1.1398 (0.9567)	Accuracy 53.125 (57.102)
Test: [20/25]	Loss 0.8787 (0.8915)	Accuracy 65.625 (59.524)
Current Accuracy: 59.825
The best accuracy: 63.579
An epoch time: 64.2s


In [19]:
# switch to evaluate mode
model.eval()

with torch.no_grad():
    for i, (images, target) in enumerate(val_loader):
        output = model(images)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
        
    

KeyboardInterrupt: 

In [21]:
target

tensor([3, 0, 0, 0, 4, 0, 0, 1, 4, 3, 2, 2, 1, 3, 0, 1, 0, 1, 1, 2, 0, 2, 2, 0,
        2, 2, 0, 4, 3, 3, 0, 2], device='cuda:0')

In [20]:
output

tensor([[ 3.5110,  3.3695,  3.8335,  4.5357, -2.8376, -5.9850, -6.0877],
        [ 3.3959,  2.8180,  3.7116,  4.6528, -2.7243, -5.8329, -5.9907],
        [ 4.4607,  0.2919,  4.1597,  1.8861, -0.1025, -5.3760, -5.1569],
        [ 3.4314,  0.6431,  2.0497,  1.2253,  1.7846, -4.5564, -4.7125],
        [ 0.1662,  0.5564, -3.8358, -5.7132, 13.0709, -1.6205, -2.9633],
        [ 2.4302,  1.7143,  1.7459,  1.5953,  1.1204, -4.3854, -4.6642],
        [ 4.3323,  1.2517,  3.7745,  2.0542, -0.3597, -5.6882, -5.6275],
        [ 0.8303,  9.5929,  1.9384,  4.8597, -2.9944, -6.2197, -8.0811],
        [ 1.5242, -1.0064, -2.3783, -4.5763, 12.0179, -2.1769, -3.2284],
        [ 2.8859,  2.2543,  3.0043,  3.6587, -1.1717, -5.0004, -5.4509],
        [ 2.6237,  3.5168,  2.6947,  4.1646, -1.7337, -5.1878, -5.7990],
        [ 3.1057,  1.6394,  2.9939,  3.8415, -1.0835, -5.0740, -5.2911],
        [ 2.0512,  9.3580,  3.2799,  4.7642, -3.4347, -7.1573, -8.2486],
        [ 2.7611,  1.8076,  2.9718,  4.2924, -1.510

In [17]:
def validate(val_loader, model, criterion, args):
    losses = AverageMeter('Loss', ':.4f')
    top1 = AverageMeter('Accuracy', ':6.3f')
    progress = ProgressMeter(len(val_loader),
                             [losses, top1],
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (images, target) in enumerate(val_loader):
            images = images.cuda()
            target = target.cuda()

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, _ = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))

            if i % args.print_freq == 0:
                progress.display(i, log_txt_path)

        # TODO: this should also be done with the ProgressMeter
        print('Current Accuracy: {top1.avg:.3f}'.format(top1=top1))
        with open(log_txt_path, 'a') as f:
            f.write('Current Accuracy: {top1.avg:.3f}'.format(top1=top1) + '\n')
            

    return top1.avg, losses.avg


def save_checkpoint(state, is_best):
    torch.save(state, checkpoint_path)
    if is_best:
        shutil.copyfile(checkpoint_path, best_checkpoint_path)
