In [None]:
# import cv2 as cv
import torch
import sys
sys.path.append('/home/zlzhu/snn/spikingjelly')
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
from spikingjelly.clock_driven import neuron, encoding, functional,layer
import os
import time
import argparse
import matplotlib.pyplot as plt
from nmnist import nmnist
from LIS_model import LIS_model
%matplotlib inline
parser = argparse.ArgumentParser(description='train.py')

parser.add_argument('-gpu', type = int, default = 0)
parser.add_argument('-seed', type = int, default = 3154)
parser.add_argument('-epoch', type = int, default = 1)
parser.add_argument('-batch_size', type = int, default = 100)
parser.add_argument('-learning_rate', type = float, default = 1e-3)
parser.add_argument('-dts', type = str, default = 'MNIST')
parser.add_argument('-model', type = str, default = 'LISNN')
parser.add_argument('-if_lateral', type = bool, default = False)

opt = parser.parse_known_args()[0]

torch.cuda.set_device(opt.gpu)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
torch.backends.cudnn.deterministic = True

test_scores = [0.00]
train_scores = []
model_pth = '/home/zlzhu//snn/bsgcn/handcode/mnist_snn/best_mnist.model'
save_path = './' + opt.model + '_' + opt.dts + '_' + str(opt.seed)
if not os.path.exists(save_path):
    os.mkdir(save_path)

if opt.dts == 'MNIST':
    train_dataset = dsets.MNIST(root = './data/mnist/', train = True, transform = transforms.ToTensor(), download = True)
    test_dataset = dsets.MNIST(root = './data/mnist/', train = False, transform = transforms.ToTensor())
elif opt.dts == 'Fashion-MNIST':
    train_dataset = dsets.FashionMNIST(root = './data/fashion/', train = True, transform = transforms.ToTensor(), download = True)
    test_dataset = dsets.FashionMNIST(root = './data/fashion/', train = False, transform = transforms.ToTensor())
elif opt.dts == 'NMNIST':
    train_dataset = nmnist(datasetPath = 'nmnist/Train/', sampleFile = 'nmnist/Train.txt', samplingTime = 1.0, sampleLength = 20)
    test_dataset = nmnist(datasetPath = 'nmnist/Test/', sampleFile = 'nmnist/Test.txt', samplingTime = 1.0, sampleLength = 20)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = opt.batch_size, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = opt.batch_size, shuffle = False)

model = LIS_model(opt)
# model = torch.load(model_pth)
model.cuda()
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = opt.learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 50, gamma = 0.01)
time_window = 20

def train(epoch):
    model.train()
    start_time = time.time()
    total_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        images = Variable(images.cuda())
        one_hot = torch.zeros(opt.batch_size, model.fc[-1]).scatter(1, labels.unsqueeze(1), 1)
        labels = Variable(one_hot.cuda())      
        outputs, conv_image = model(images)
        
#         plt.imshow((images[4][0].detach().cpu()).numpy())
#         plt.show()
        if i==0: 
            print("conv_image.shape:",conv_image.shape)
#             for t in range(20):
#                 print("conv_image[t][0]:",conv_image)
#                 plt.imshow((conv_image[t][6].detach().cpu()).numpy())
#                 plt.show()
        loss = loss_function(outputs, labels)
        total_loss += float(loss)
        loss.backward()
        optimizer.step()
        # reset, because snn has memory
        functional.reset_net(model)
        
        if (i + 1) % (len(train_dataset) // (opt.batch_size * 6)) == 0:
            print('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f, Time: %.2f' % (epoch + 1, opt.epoch, i + 1, len(train_dataset) // opt.batch_size, total_loss, time.time() - start_time))
            start_time = time.time()
            total_loss = 0
    scheduler.step()
    
def eval(epoch, if_test):
    model.eval()
    correct = 0
    total = 0
    if if_test:
        for i, (images, labels) in enumerate(test_loader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            outputs, conv_image = model(images)
            pred = outputs.max(1)[1]
            total += labels.size(0)
            correct += (pred == labels).sum()
            # reset, because snn has memory
            functional.reset_net(model)
        acc = 100.0 * correct.item() / total
        print('Test correct: %d Accuracy: %.2f%%' % (correct, acc))
        
        if acc > max(test_scores):
            save_file = str(epoch) + '.pt'
            torch.save(model, model_pth)
            print("Model saved!")
        test_scores.append(acc)
    else:
        for i, (images, labels) in enumerate(train_loader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            outputs, conv_image = model(images)
            pred = outputs.max(1)[1]
            total += labels.size(0)
            correct += (pred == labels).sum()
            # reset, because snn has memory
            functional.reset_net(model)
            
        acc = 100.0 * correct.item() / total
        print('Train correct: %d Accuracy: %.2f%%' % (correct, acc))
        train_scores.append(acc)     

def main():
    for epoch in range(opt.epoch):
        start_time = time.time()
        train(epoch)
        print('Time of one epoch:',time.time()-start_time)
        if (epoch + 1) % 1 == 0:
            eval(epoch, if_test = True)
        if (epoch + 1) % 20 == 0:
            eval(epoch, if_test = False)
        if (epoch + 1) % 20 == 0:
            print('Best Test Accuracy in %d: %.2f%%' % (epoch + 1, max(test_scores)))
            print('Best Train Accuracy in %d: %.2f%%' % (epoch + 1, max(train_scores)))
if __name__ == '__main__':
    main()