In [1]:
import torch

In [2]:
torch.__version__

'1.13.0'

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
import time
import shutil
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from models.ST_Former import GenerateModel
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import datetime
from dataloader.dataset_NIA import train_data_loader, test_data_loader

In [5]:
from runner_helper import *

In [11]:
class Pseudoarg():
    def __init__(self):
        self.workers = 1
        self.epochs = 100
        self.start_epoch = 0
        self.batch_size = 32
        self.lr = 0.01
        self.momentum = 0.9
        self.weight_decay = 1e-4
        self.print_freq = 10
        self.resume = None
        self.data_set = 0
        
args = Pseudoarg()

In [12]:
now = datetime.datetime.now()
time_str = now.strftime("%m%d_%H%M_")
project_path = "/media/di/data/lee/nia/"
log_txt_path = project_path + 'log/' + time_str + 'set' + str(args.data_set) + '-log.txt'
log_curve_path = project_path + 'log/' + time_str + 'set' + str(args.data_set) + '-log.png'
checkpoint_path = project_path + 'checkpoint/' + time_str + 'set' + str(args.data_set) + '-model.pth'
best_checkpoint_path = project_path + 'checkpoint/' + time_str + 'set' + str(args.data_set) + '-model_best.pth'


In [13]:
#def main():
best_acc = 0
recorder = RecorderMeter(args.epochs)
print('The training time: ' + now.strftime("%m-%d %H:%M"))
print('The training set: set ' + str(args.data_set))
os.makedirs(project_path+"log/",exist_ok = True)
with open(log_txt_path, 'a') as f:
    f.write('The training set: set ' + str(args.data_set) + '\n')

# create model and load pre_trained parameters
model = GenerateModel()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = torch.nn.DataParallel(model).cuda()
print(model)

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)


The training time: 12-27 05:05
The training set: set 0
DataParallel(
  (module): GenerateModel(
    (s_former): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), str

In [14]:
if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        args.start_epoch = checkpoint['epoch']
        best_acc = checkpoint['best_acc']
        recorder = checkpoint['recorder']
        best_acc = best_acc.cuda()
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
cudnn.benchmark = True

In [15]:
# Data loading code
train_data = train_data_loader(project_dir=project_path, 
                               data_set=args.data_set)
test_data = test_data_loader(project_dir=project_path,
                             data_set=args.data_set)

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.workers,
                                           pin_memory=True,
                                           drop_last=True)
val_loader = torch.utils.data.DataLoader(test_data,
                                         batch_size=args.batch_size,
                                         shuffle=False,
                                         num_workers=args.workers,
                                         pin_memory=True)

for epoch in range(args.start_epoch, args.epochs):
    inf = '********************' + str(epoch) + '********************'
    start_time = time.time()
    current_learning_rate = optimizer.state_dict()['param_groups'][0]['lr']

    with open(log_txt_path, 'a') as f:
        f.write(inf + '\n')
        f.write('Current learning rate: ' + str(current_learning_rate) + '\n')

    print(inf)
    print('Current learning rate: ', current_learning_rate)

    # train for one epoch
    train_acc, train_los = train(train_loader, model, criterion, optimizer, epoch, args)

    # evaluate on validation set
    val_acc, val_los = validate(val_loader, model, criterion, args)

    scheduler.step()

    # remember best acc and save checkpoint
    is_best = val_acc > best_acc
    best_acc = max(val_acc, best_acc)
    save_checkpoint({'epoch': epoch + 1,
                     'state_dict': model.state_dict(),
                     'best_acc': best_acc,
                     'optimizer': optimizer.state_dict(),
                     'recorder': recorder}, is_best)

    # print and save log
    epoch_time = time.time() - start_time
    recorder.update(epoch, train_los, train_acc, val_los, val_acc)
    recorder.plot_curve(log_curve_path)

    print('The best accuracy: {:.3f}'.format(best_acc.item()))
    print('An epoch time: {:.1f}s'.format(epoch_time))
    with open(log_txt_path, 'a') as f:
        f.write('The best accuracy: ' + str(best_acc.item()) + '\n')
        f.write('An epoch time: {:.1f}s' + str(epoch_time) + '\n')

video number:2801
video number:800
********************0********************
Current learning rate:  0.01
Epoch: [0][ 0/87]	Loss 1.5449 (1.5449)	Accuracy 31.250 (31.250)
Epoch: [0][10/87]	Loss 5.0881 (4.0296)	Accuracy 28.125 (23.864)
Epoch: [0][20/87]	Loss 1.6990 (3.2211)	Accuracy 21.875 (22.470)
Epoch: [0][30/87]	Loss 1.6571 (2.8030)	Accuracy  9.375 (20.968)
Epoch: [0][40/87]	Loss 1.6023 (2.5302)	Accuracy 18.750 (21.037)
Epoch: [0][50/87]	Loss 1.6378 (2.3635)	Accuracy 21.875 (20.772)
Epoch: [0][60/87]	Loss 1.5337 (2.2386)	Accuracy 31.250 (21.107)
Epoch: [0][70/87]	Loss 1.6574 (2.1570)	Accuracy 12.500 (21.083)
Epoch: [0][80/87]	Loss 1.6386 (2.0904)	Accuracy 31.250 (21.721)
Test: [ 0/25]	Loss 1.6151 (1.6151)	Accuracy 25.000 (25.000)
Test: [10/25]	Loss 1.6140 (1.6367)	Accuracy 25.000 (19.318)
Test: [20/25]	Loss 1.6885 (1.6500)	Accuracy  6.250 (19.196)
Current Accuracy: 18.875
The best accuracy: 18.875
An epoch time: 66.3s
********************1********************
Current learning rate:  

Epoch: [9][ 0/87]	Loss 1.6959 (1.6959)	Accuracy 12.500 (12.500)
Epoch: [9][10/87]	Loss 1.6269 (1.6186)	Accuracy 21.875 (20.455)
Epoch: [9][20/87]	Loss 1.6051 (1.6056)	Accuracy 21.875 (22.321)
Epoch: [9][30/87]	Loss 1.6431 (1.6089)	Accuracy 15.625 (21.472)
Epoch: [9][40/87]	Loss 1.6009 (1.6041)	Accuracy 28.125 (22.256)
Epoch: [9][50/87]	Loss 1.5951 (1.6020)	Accuracy 28.125 (23.407)
Epoch: [9][60/87]	Loss 1.6183 (1.6059)	Accuracy 15.625 (22.439)
Epoch: [9][70/87]	Loss 1.6184 (1.6067)	Accuracy 12.500 (22.315)
Epoch: [9][80/87]	Loss 1.6788 (1.6063)	Accuracy  6.250 (22.299)
Test: [ 0/25]	Loss 1.5921 (1.5921)	Accuracy 31.250 (31.250)
Test: [10/25]	Loss 1.6025 (1.6035)	Accuracy 21.875 (21.875)
Test: [20/25]	Loss 1.6035 (1.6138)	Accuracy 21.875 (19.643)
Current Accuracy: 19.875
The best accuracy: 25.375
An epoch time: 63.1s
********************10********************
Current learning rate:  0.01
Epoch: [10][ 0/87]	Loss 1.5768 (1.5768)	Accuracy 12.500 (12.500)
Epoch: [10][10/87]	Loss 1.5749 (1.5

Epoch: [18][10/87]	Loss 1.6429 (1.5978)	Accuracy 21.875 (23.011)
Epoch: [18][20/87]	Loss 1.6211 (1.6013)	Accuracy 18.750 (23.363)
Epoch: [18][30/87]	Loss 1.5521 (1.5927)	Accuracy 40.625 (25.504)
Epoch: [18][40/87]	Loss 1.5665 (1.5953)	Accuracy 25.000 (25.229)
Epoch: [18][50/87]	Loss 1.5475 (1.5933)	Accuracy 18.750 (25.429)
Epoch: [18][60/87]	Loss 1.6163 (1.5960)	Accuracy 21.875 (24.795)
Epoch: [18][70/87]	Loss 1.6006 (1.5971)	Accuracy 15.625 (24.032)
Epoch: [18][80/87]	Loss 1.6061 (1.5958)	Accuracy 21.875 (24.306)
Test: [ 0/25]	Loss 1.5821 (1.5821)	Accuracy 21.875 (21.875)
Test: [10/25]	Loss 1.6164 (1.5916)	Accuracy 18.750 (25.284)
Test: [20/25]	Loss 1.6137 (1.6068)	Accuracy 34.375 (25.298)
Current Accuracy: 25.000
The best accuracy: 25.375
An epoch time: 61.9s
********************19********************
Current learning rate:  0.01
Epoch: [19][ 0/87]	Loss 1.5327 (1.5327)	Accuracy 34.375 (34.375)
Epoch: [19][10/87]	Loss 1.5851 (1.5685)	Accuracy 37.500 (28.409)
Epoch: [19][20/87]	Loss 1.

Epoch: [27][20/87]	Loss 1.6013 (1.5604)	Accuracy 21.875 (24.702)
Epoch: [27][30/87]	Loss 1.6578 (1.5700)	Accuracy 12.500 (24.798)
Epoch: [27][40/87]	Loss 1.6384 (1.5773)	Accuracy 15.625 (23.857)
Epoch: [27][50/87]	Loss 1.5121 (1.5776)	Accuracy 34.375 (24.265)
Epoch: [27][60/87]	Loss 1.6018 (1.5746)	Accuracy 15.625 (24.846)
Epoch: [27][70/87]	Loss 1.6055 (1.5724)	Accuracy 25.000 (25.132)
Epoch: [27][80/87]	Loss 1.4655 (1.5709)	Accuracy 37.500 (25.694)
Test: [ 0/25]	Loss 1.5970 (1.5970)	Accuracy 25.000 (25.000)
Test: [10/25]	Loss 1.6453 (1.5761)	Accuracy 21.875 (24.432)
Test: [20/25]	Loss 1.5435 (1.5802)	Accuracy 37.500 (25.893)
Current Accuracy: 25.875
The best accuracy: 27.500
An epoch time: 61.6s
********************28********************
Current learning rate:  0.01
Epoch: [28][ 0/87]	Loss 1.5843 (1.5843)	Accuracy 21.875 (21.875)
Epoch: [28][10/87]	Loss 1.5387 (1.5705)	Accuracy 37.500 (27.841)
Epoch: [28][20/87]	Loss 1.3822 (1.5570)	Accuracy 53.125 (29.018)
Epoch: [28][30/87]	Loss 1.

Epoch: [36][30/87]	Loss 1.2777 (1.2617)	Accuracy 37.500 (39.617)
Epoch: [36][40/87]	Loss 1.1858 (1.2484)	Accuracy 43.750 (39.787)
Epoch: [36][50/87]	Loss 1.1708 (1.2550)	Accuracy 50.000 (39.032)
Epoch: [36][60/87]	Loss 1.1010 (1.2604)	Accuracy 56.250 (38.986)
Epoch: [36][70/87]	Loss 1.3829 (1.2667)	Accuracy 37.500 (38.644)
Epoch: [36][80/87]	Loss 1.3741 (1.2729)	Accuracy 37.500 (38.735)
Test: [ 0/25]	Loss 1.3546 (1.3546)	Accuracy 34.375 (34.375)
Test: [10/25]	Loss 1.3156 (1.2482)	Accuracy 25.000 (40.057)
Test: [20/25]	Loss 1.2787 (1.2665)	Accuracy 37.500 (39.137)
Current Accuracy: 38.375
The best accuracy: 40.375
An epoch time: 61.9s
********************37********************
Current learning rate:  0.01
Epoch: [37][ 0/87]	Loss 1.2055 (1.2055)	Accuracy 40.625 (40.625)
Epoch: [37][10/87]	Loss 1.1769 (1.2451)	Accuracy 50.000 (37.500)
Epoch: [37][20/87]	Loss 1.1809 (1.2340)	Accuracy 34.375 (38.393)
Epoch: [37][30/87]	Loss 1.1037 (1.2170)	Accuracy 46.875 (40.020)
Epoch: [37][40/87]	Loss 1.

Epoch: [45][40/87]	Loss 1.0088 (1.1804)	Accuracy 40.625 (43.979)
Epoch: [45][50/87]	Loss 0.9643 (1.1924)	Accuracy 53.125 (42.892)
Epoch: [45][60/87]	Loss 1.1044 (1.1917)	Accuracy 59.375 (43.033)
Epoch: [45][70/87]	Loss 1.2936 (1.1910)	Accuracy 34.375 (43.046)
Epoch: [45][80/87]	Loss 1.0878 (1.1928)	Accuracy 62.500 (42.824)
Test: [ 0/25]	Loss 1.1621 (1.1621)	Accuracy 43.750 (43.750)
Test: [10/25]	Loss 1.2794 (1.1697)	Accuracy 37.500 (41.193)
Test: [20/25]	Loss 1.1847 (1.1831)	Accuracy 43.750 (40.923)
Current Accuracy: 41.500
The best accuracy: 43.000
An epoch time: 61.8s
********************46********************
Current learning rate:  0.001
Epoch: [46][ 0/87]	Loss 1.1522 (1.1522)	Accuracy 50.000 (50.000)
Epoch: [46][10/87]	Loss 1.5644 (1.2406)	Accuracy 37.500 (45.455)
Epoch: [46][20/87]	Loss 1.0762 (1.1986)	Accuracy 53.125 (46.577)
Epoch: [46][30/87]	Loss 1.1131 (1.2065)	Accuracy 53.125 (44.859)
Epoch: [46][40/87]	Loss 1.1574 (1.2042)	Accuracy 50.000 (44.665)
Epoch: [46][50/87]	Loss 1

Epoch: [54][50/87]	Loss 1.2996 (1.1905)	Accuracy 40.625 (42.279)
Epoch: [54][60/87]	Loss 1.1593 (1.1876)	Accuracy 40.625 (41.957)
Epoch: [54][70/87]	Loss 1.1873 (1.1792)	Accuracy 43.750 (42.474)
Epoch: [54][80/87]	Loss 1.2148 (1.1748)	Accuracy 43.750 (42.670)
Test: [ 0/25]	Loss 1.1497 (1.1497)	Accuracy 50.000 (50.000)
Test: [10/25]	Loss 1.2288 (1.1436)	Accuracy 34.375 (46.023)
Test: [20/25]	Loss 1.1655 (1.1653)	Accuracy 37.500 (42.560)
Current Accuracy: 42.750
The best accuracy: 44.000
An epoch time: 61.6s
********************55********************
Current learning rate:  0.001
Epoch: [55][ 0/87]	Loss 1.2770 (1.2770)	Accuracy 40.625 (40.625)
Epoch: [55][10/87]	Loss 1.1191 (1.1440)	Accuracy 46.875 (44.318)
Epoch: [55][20/87]	Loss 1.2055 (1.1781)	Accuracy 34.375 (40.923)
Epoch: [55][30/87]	Loss 1.0558 (1.1695)	Accuracy 40.625 (42.036)
Epoch: [55][40/87]	Loss 1.3002 (1.1736)	Accuracy 37.500 (41.692)
Epoch: [55][50/87]	Loss 1.1075 (1.1746)	Accuracy 43.750 (42.034)
Epoch: [55][60/87]	Loss 1

Epoch: [63][60/87]	Loss 1.0777 (1.1574)	Accuracy 53.125 (46.107)
Epoch: [63][70/87]	Loss 0.9169 (1.1656)	Accuracy 53.125 (45.379)
Epoch: [63][80/87]	Loss 1.1752 (1.1651)	Accuracy 40.625 (45.332)
Test: [ 0/25]	Loss 1.0749 (1.0749)	Accuracy 50.000 (50.000)
Test: [10/25]	Loss 1.2045 (1.1015)	Accuracy 37.500 (48.011)
Test: [20/25]	Loss 1.2139 (1.1169)	Accuracy 37.500 (45.685)
Current Accuracy: 45.750
The best accuracy: 45.750
An epoch time: 61.9s
********************64********************
Current learning rate:  0.001
Epoch: [64][ 0/87]	Loss 1.0784 (1.0784)	Accuracy 56.250 (56.250)
Epoch: [64][10/87]	Loss 1.1294 (1.1707)	Accuracy 43.750 (44.034)
Epoch: [64][20/87]	Loss 1.0481 (1.1734)	Accuracy 50.000 (43.899)
Epoch: [64][30/87]	Loss 1.1964 (1.1819)	Accuracy 34.375 (42.742)
Epoch: [64][40/87]	Loss 1.1066 (1.1734)	Accuracy 40.625 (43.369)
Epoch: [64][50/87]	Loss 0.7657 (1.1605)	Accuracy 71.875 (44.240)
Epoch: [64][60/87]	Loss 1.3196 (1.1620)	Accuracy 40.625 (44.826)
Epoch: [64][70/87]	Loss 1

Epoch: [72][70/87]	Loss 0.9870 (1.1405)	Accuracy 53.125 (47.095)
Epoch: [72][80/87]	Loss 1.0927 (1.1374)	Accuracy 40.625 (46.528)
Test: [ 0/25]	Loss 1.0171 (1.0171)	Accuracy 53.125 (53.125)
Test: [10/25]	Loss 1.2060 (1.0479)	Accuracy 34.375 (50.568)
Test: [20/25]	Loss 1.0549 (1.0642)	Accuracy 43.750 (48.065)
Current Accuracy: 47.625
The best accuracy: 48.000
An epoch time: 63.2s
********************73********************
Current learning rate:  0.001
Epoch: [73][ 0/87]	Loss 1.3373 (1.3373)	Accuracy 28.125 (28.125)
Epoch: [73][10/87]	Loss 1.1199 (1.1881)	Accuracy 40.625 (40.341)
Epoch: [73][20/87]	Loss 1.0650 (1.1683)	Accuracy 46.875 (41.369)
Epoch: [73][30/87]	Loss 1.1740 (1.1586)	Accuracy 50.000 (43.548)
Epoch: [73][40/87]	Loss 1.3272 (1.1676)	Accuracy 46.875 (43.750)
Epoch: [73][50/87]	Loss 1.2322 (1.1668)	Accuracy 37.500 (44.240)
Epoch: [73][60/87]	Loss 1.0885 (1.1475)	Accuracy 50.000 (45.645)
Epoch: [73][70/87]	Loss 1.0996 (1.1442)	Accuracy 53.125 (45.819)
Epoch: [73][80/87]	Loss 1

Epoch: [81][80/87]	Loss 1.1867 (1.0586)	Accuracy 40.625 (49.614)
Test: [ 0/25]	Loss 0.9908 (0.9908)	Accuracy 50.000 (50.000)
Test: [10/25]	Loss 1.1902 (1.0029)	Accuracy 50.000 (52.557)
Test: [20/25]	Loss 0.9309 (0.9859)	Accuracy 50.000 (51.339)
Current Accuracy: 51.375
The best accuracy: 51.375
An epoch time: 62.3s
********************82********************
Current learning rate:  0.0001
Epoch: [82][ 0/87]	Loss 0.8698 (0.8698)	Accuracy 65.625 (65.625)
Epoch: [82][10/87]	Loss 0.9542 (1.0997)	Accuracy 50.000 (46.591)
Epoch: [82][20/87]	Loss 0.8792 (1.0574)	Accuracy 56.250 (50.298)
Epoch: [82][30/87]	Loss 1.1767 (1.0667)	Accuracy 43.750 (50.302)
Epoch: [82][40/87]	Loss 1.2896 (1.0667)	Accuracy 40.625 (50.152)
Epoch: [82][50/87]	Loss 1.0260 (1.0667)	Accuracy 56.250 (50.368)
Epoch: [82][60/87]	Loss 1.1022 (1.0686)	Accuracy 46.875 (50.051)
Epoch: [82][70/87]	Loss 1.1176 (1.0768)	Accuracy 50.000 (49.736)
Epoch: [82][80/87]	Loss 1.0465 (1.0858)	Accuracy 53.125 (49.498)
Test: [ 0/25]	Loss 0.992

Epoch: [90][80/87]	Loss 1.0359 (1.0699)	Accuracy 53.125 (50.231)
Test: [ 0/25]	Loss 0.9747 (0.9747)	Accuracy 46.875 (46.875)
Test: [10/25]	Loss 1.1620 (0.9888)	Accuracy 43.750 (52.557)
Test: [20/25]	Loss 0.9017 (0.9747)	Accuracy 53.125 (51.637)
Current Accuracy: 51.875
The best accuracy: 52.875
An epoch time: 61.7s
********************91********************
Current learning rate:  0.0001
Epoch: [91][ 0/87]	Loss 0.8890 (0.8890)	Accuracy 56.250 (56.250)
Epoch: [91][10/87]	Loss 1.2047 (1.0527)	Accuracy 40.625 (48.580)
Epoch: [91][20/87]	Loss 1.1371 (1.0763)	Accuracy 37.500 (48.065)
Epoch: [91][30/87]	Loss 0.9760 (1.0647)	Accuracy 59.375 (49.899)
Epoch: [91][40/87]	Loss 1.1228 (1.0701)	Accuracy 59.375 (49.771)
Epoch: [91][50/87]	Loss 1.0134 (1.0759)	Accuracy 56.250 (49.632)
Epoch: [91][60/87]	Loss 1.1776 (1.0740)	Accuracy 43.750 (50.102)
Epoch: [91][70/87]	Loss 1.0818 (1.0703)	Accuracy 53.125 (50.176)
Epoch: [91][80/87]	Loss 1.0093 (1.0711)	Accuracy 53.125 (50.077)
Test: [ 0/25]	Loss 0.963

Epoch: [99][80/87]	Loss 1.0152 (1.0692)	Accuracy 56.250 (49.537)
Test: [ 0/25]	Loss 0.9708 (0.9708)	Accuracy 50.000 (50.000)
Test: [10/25]	Loss 1.1440 (0.9806)	Accuracy 50.000 (53.409)
Test: [20/25]	Loss 0.8870 (0.9639)	Accuracy 53.125 (52.530)
Current Accuracy: 52.500
The best accuracy: 53.250
An epoch time: 61.8s


In [16]:
project_path

'/media/di/data/lee/nia/'