In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from ILIFModel import *
from SCNN_Layer import *
import os
import time

In [8]:
class nmnistDataset(Dataset):
    def __init__(self, dataPath, mode):#dataPath = '../../../DATA/N-MNIST/processed/NMNISTsmall'
        self.dataPath = dataPath + mode
        self.samples = os.listdir(self.dataPath)
        self.mode = mode

    def __getitem__(self, index):
        sample  = self.samples[index]
        label  = torch.tensor(int(sample[-4]), dtype=torch.int64)

        data = torch.load(self.dataPath + '/' + sample)
        
        return data, label
    
    def __len__(self):
        return len(self.samples)

In [9]:
dataPath = '../../../DATA/N-MNIST/processed/NMNISTsmall/'

In [10]:
trainingSet = nmnistDataset(dataPath = dataPath, mode = 'train')
trainLoader = DataLoader(dataset=trainingSet, batch_size=8, shuffle=False, num_workers = 0)

In [11]:
testSet = nmnistDataset(dataPath = dataPath, mode = 'test')
testLoader = DataLoader(dataset=testSet, batch_size=8, shuffle=False, num_workers = 0)

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
device

device(type='cuda')

In [14]:
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
acc_record = []
num_epochs = 100 # max epoch
num_classes = 10
batch_size = 8
names = 'STBPmodelN-MNIST'

In [15]:
learning_rate = 1e-3

In [16]:
snn = SCNN()
snn.to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(snn.parameters(), lr=learning_rate)

In [17]:
# Dacay learning_rate
def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50):
    """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs."""
    if epoch % lr_decay_epoch == 0 and epoch > 1:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1
    return optimizer

In [18]:
for epoch in range(num_epochs):
    train_loss = 0
    start_time = time.time()
    for i, (train_images, train_labels) in enumerate(trainLoader):#共计125个batch
        snn.zero_grad()
        optimizer.zero_grad()
        
        train_images = train_images.float().to(device)
        train_predicts = snn(train_images)
        train_labels = torch.zeros(train_labels.shape[0], num_classes).scatter_(1, train_labels.view(-1,1), 1.)
                                            #将index视作列（dim = 1），按照train_labels.view(-1,1)值作为index，把src =“1”插入到zeros中
        loss = criterion(train_predicts.cpu(), train_labels)
        train_loss = loss.item() + train_loss
        loss.backward()
        optimizer.step()
        
        if (i + 1) % 25 == 0:#每100个batch报告一次，即100*100张图片报告一次
            print('Epoch [%d/%d], Train Step [%d/%d], Train Loss = %.5f' 
                    %(epoch + 1, num_epochs, i + 1, len(trainingSet)//batch_size, train_loss))
            train_loss = 0
            print('Time elasped:', time.time() - start_time)
    correct = 0
    total = 0
    optimizer = lr_scheduler(optimizer, epoch, learning_rate, 40)
    
    with torch.no_grad():
        for j, (test_images, test_targets) in enumerate(testLoader):
            optimizer.zero_grad()
            snn.zero_grad()
            
            test_images = test_images.float().to(device)
            test_predicts = snn(test_images)
            test_labels = torch.zeros(test_targets.shape[0], num_classes).scatter_(1, test_targets.view(-1,1), 1.)
            loss = criterion(test_predicts.cpu(), test_labels)
            
            _, predicted = test_predicts.cpu().max(1)
            total = total + float(test_targets.size(0))
            correct = correct + float(predicted.eq(test_targets).sum().item())
            if (j + 1) % 25 == 0:
                acc = 100. * float(correct)/float(total)
                print('Test Step [%d/%d], Acc: %.5f' %(j + 1, len(testLoader), acc))
    acc = float(100 * correct / total)
    print('Epoch [%d/%d] \t Test Accuracy Over Test Dataset: %.3f' %(epoch + 1, num_epochs, acc))
    acc_record.append(acc)
    if epoch % 5 == 0:
        print('Saving......')
        state = {
            'net': snn.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'acc_record': acc_record,
        }
        if not os.path.isdir('./history'):
            os.mkdir('./history')
        torch.save(state, './history/' + names + '.pk')

Epoch [1/100], Train Step [25/125], Train Loss = 2.50000
Time elasped: 22.147284984588623


KeyboardInterrupt: 