In [1]:
#coding=utf8

from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals

import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os, time, argparse

from utils import *
from models import *

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
class Backprop:

    
    def __init__(self, args):
        if args.model == "MLP":
            self.model  = MLP(args)
        if args.model == "signMLP":
            self.model  = signMLP(args)
        if args.model == "CNN":
            self.model  = CNN(args)
        if args.model == "VGG":
            self.model  = VGG(args)
        if args.model == 'KAN':
            self.model = MNISTChebyKAN(degree = 20)
        
        self.model.to(device)
        self.batch_size = args.batchsize
        self.lambda_0   = args.lambda_
        self.sigma      = args.sigma_
        self.last_linear = "output_layer"

        
        self.opt = optim.AdamW(self.model.parameters(), lr=0.001)
#         self.opt = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.01)
        self.iter_loss1, self.iter_loss2, self.iter_loss3 = [], [], []
        self.track_loss1, self.track_loss2, self.track_loss3 = [], [], []
        
        self.loss = args.loss
        if self.loss == "mse": self.output_criterion = nn.MSELoss()#y_pred, labels_float)
        elif self.loss == "CE": self.output_criterion = nn.CrossEntropyLoss()#y_pred, label)
        
    def step(self, input_data, labels):
        self.opt.zero_grad()

        labels_float = F.one_hot(labels, num_classes=10).float()
        
        y_pred, hidden_zs = self.model(input_data)

        if self.loss == "mse": 
            l = self.output_criterion(y_pred, labels_float)
        elif self.loss == "CE": 
            l = self.output_criterion(y_pred, labels)

        l.backward()
        self.opt.step()
        return(l)


In [10]:
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="mnist")
    parser.add_argument('--model', type=str, default="KAN")
    parser.add_argument('--loss', type=str, default="CE")
    parser.add_argument('--BP', type=int, default=0)
    parser.add_argument('--HSIC', type=str, default="nHSIC")
    parser.add_argument('--kernel_x', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_h', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_y', type=str, default="student", choices=["rbf", "student"])
    parser.add_argument('--sigma_', type=int, default=10)
    parser.add_argument('--lambda_', type=int, default=100)
    parser.add_argument('--batchsize', type=int, default=128)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--bn_affine', type=int, default=1)
    parser.add_argument('--forward', type=str, default="n", choices=["x", "h", "n"])
    
    # Testing.
    parser.add_argument('--Latinb', type=int, default=0, choices=[0, 1])
    parser.add_argument('--Latinb_lambda', type=float, default=1.)
    parser.add_argument('--Latinb_type', type=str, default="f", choices=["f", "n"])
        
    args, _ = parser.parse_known_args()
    filename = 'kan_results_bp_degree20.csv'#get_filename(args)
    print(filename)
    
    torch.manual_seed(1)
    device = "cuda:{}".format(args.device)
    batch_size = args.batchsize
    train_loader, test_loader = load_data(args)
    
    logs = list()
    backprop = Backprop(args)
    
    get_loss = list()
    print("Model trainable parameters: ", sum(p.numel() for p in backprop.model.parameters() if p.requires_grad))

    for epoch in range(50):
        backprop.model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(args.batchsize, -1)
            start = time.time()
            loss = backprop.step(data.view(args.batchsize, -1).to(device), target.to(device))
        if epoch % 2 == 0:
            print("Input shape: ", data.shape)
            print("Target shape: ", target.shape)
            show_result(backprop, train_loader, test_loader, epoch, logs, device)
            logs[epoch//2].append(time.time()-start)
            print("{:.2f}".format(time.time()-start))
            start = time.time()

    txt_path = os.path.join(".\\", filename+".csv")
    df = pd.DataFrame(logs)
    df.to_csv(txt_path,index=False)

kan_results_bp_degree20.csv
Model trainable parameters:  541056
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 0. 	 Training  ACC: 0.9117. 	 Testing ACC: 0.9071
16.13
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 2. 	 Training  ACC: 0.9584. 	 Testing ACC: 0.9440
17.35
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 4. 	 Training  ACC: 0.9753. 	 Testing ACC: 0.9534
14.75
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 6. 	 Training  ACC: 0.9749. 	 Testing ACC: 0.9505
14.68
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 8. 	 Training  ACC: 0.9772. 	 Testing ACC: 0.9542
14.69
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 10. 	 Training  ACC: 0.9762. 	 Testing ACC: 0.9477
14.26
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 12. 	 Training  ACC: 0.9882. 	 Testing ACC: 0.9566
15.82
Input shap

In [16]:
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="mnist")
    parser.add_argument('--model', type=str, default="MLP")
    parser.add_argument('--loss', type=str, default="CE")
    parser.add_argument('--BP', type=int, default=0)
    parser.add_argument('--HSIC', type=str, default="nHSIC")
    parser.add_argument('--kernel_x', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_h', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_y', type=str, default="student", choices=["rbf", "student"])
    parser.add_argument('--sigma_', type=int, default=10)
    parser.add_argument('--lambda_', type=int, default=100)
    parser.add_argument('--batchsize', type=int, default=128)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--bn_affine', type=int, default=1)
    parser.add_argument('--forward', type=str, default="n", choices=["x", "h", "n"])
    
    # Testing.
    parser.add_argument('--Latinb', type=int, default=0, choices=[0, 1])
    parser.add_argument('--Latinb_lambda', type=float, default=1.)
    parser.add_argument('--Latinb_type', type=str, default="f", choices=["f", "n"])
        
    args, _ = parser.parse_known_args()
    filename = 'mlp_results_bp.csv'#get_filename(args)
    print(filename)
    
    torch.manual_seed(1)
    device = "cuda:{}".format(args.device)
    batch_size = args.batchsize
    train_loader, test_loader = load_data(args)
    
    logs = list()
    backprop = Backprop(args)
    
    get_loss = list()
    print("Model trainable parameters: ", sum(p.numel() for p in backprop.model.parameters() if p.requires_grad))

    for epoch in range(50):
        backprop.model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(args.batchsize, -1)
            start = time.time()
            loss = backprop.step(data.view(args.batchsize, -1).to(device), target.to(device))
        if epoch % 2 == 0:
            print("Input shape: ", data.shape)
            print("Target shape: ", target.shape)
            show_result(backprop, train_loader, test_loader, epoch, logs, device)
            logs[epoch//2].append(time.time()-start)
            print("{:.2f}".format(time.time()-start))
            start = time.time()

    txt_path = os.path.join(".\\", filename+".csv")
    df = pd.DataFrame(logs)
    df.to_csv(txt_path,index=False)

mlp_results_bp.csv
Model trainable parameters:  252682
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 0. 	 Training  ACC: 0.9756. 	 Testing ACC: 0.9692
14.27
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 2. 	 Training  ACC: 0.9916. 	 Testing ACC: 0.9796
13.85
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 4. 	 Training  ACC: 0.9939. 	 Testing ACC: 0.9805
14.16
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 6. 	 Training  ACC: 0.9966. 	 Testing ACC: 0.9816
19.02
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 8. 	 Training  ACC: 0.9965. 	 Testing ACC: 0.9817
14.73
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 10. 	 Training  ACC: 0.9971. 	 Testing ACC: 0.9827
14.07
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 12. 	 Training  ACC: 0.9971. 	 Testing ACC: 0.9804
13.81
Input shape:  torch

In [17]:
sum(p.numel() for p in hsic.model.parameters() if p.requires_grad)

252682