In [1]:
#coding=utf8

from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals

import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os, time, argparse

from utils import *
from models import *

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class HSICBottleneck:
    def __init__(self, args):
        if args.model == "MLP":
            self.model  = MLP(args)
        if args.model == "signMLP":
            self.model  = signMLP(args)
        if args.model == "CNN":
            self.model  = CNN(args)
        if args.model == "VGG":
            self.model  = VGG(args)
        if args.model == 'KAN':
            self.model = MNISTChebyKAN2(degree=8)
        
        self.model.to(device)
        self.batch_size = args.batchsize
        self.lambda_0   = args.lambda_
        self.sigma      = args.sigma_
        self.extractor  = 'hsic'
        self.last_linear = "output_layer"
        self.HSIC = compute_HSIC(args.HSIC)
        self.kernel = compute_kernel()
        self.kernel_x = args.kernel_x
        self.kernel_h = args.kernel_h
        self.kernel_y = args.kernel_y
        self.forward = args.forward
        
        self.opt = optim.AdamW(self.model.parameters(), lr=0.001)
#         self.opt = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.01)
        self.iter_loss1, self.iter_loss2, self.iter_loss3 = [], [], []
        self.track_loss1, self.track_loss2, self.track_loss3 = [], [], []
        
        self.loss = args.loss
        if self.loss == "mse": self.output_criterion = nn.MSELoss()
        elif self.loss == "CE": self.output_criterion = nn.CrossEntropyLoss()
        
    def step(self, input_data, labels):
        
        labels_float = F.one_hot(labels, num_classes=10).float()
        if self.forward == "x": Kx  = self.kernel(input_data, self.sigma, self.kernel_x)
        Ky = self.kernel(labels_float, self.sigma, self.kernel_y)
        
        kernel_list = list()
        y_pred, hidden_zs = self.model(input_data)
        #print(y_pred.shape, [h.shape for h in hidden_zs])
        
        loss_LI = 0.
        for num, feature in enumerate(hidden_zs): 
            kernel_list.append(self.kernel(feature, self.sigma, self.kernel_h))
            ## Testing new features.
            if args.Latinb == 1:
                if num == (len(hidden_zs)-1) or feature.size(2) >= 4: continue
                loss_LI += spatial_contrast(feature, args)*args.Latinb_lambda
        
        total_loss1, total_loss2, total_loss3 = 0., 0., 0.
        for num, feature in enumerate(kernel_list):
            if num == (len(hidden_zs)-1): 
                if self.forward == "h": total_loss1 += self.HSIC(feature, kernel_list[num-1], self.batch_size, device)
                elif self.forward == "x": total_loss1 += self.HSIC(feature, Kx, self.batch_size, device)
                if self.loss == "mse": total_loss3 += self.output_criterion(hidden_zs[-1], labels_float)
                elif self.loss == "CE": total_loss3 += self.output_criterion(hidden_zs[-1], labels)
            elif num == 0:
                if self.forward == "x": total_loss1 += self.HSIC(feature, Kx, self.batch_size, device)
                total_loss2 += - self.lambda_0*self.HSIC(feature, Ky, self.batch_size, device)
            else:
                if self.forward == "h": total_loss1 += self.HSIC(feature, kernel_list[num-1], self.batch_size, device)
                elif self.forward == "x": total_loss1 += self.HSIC(feature, Kx, self.batch_size, device)
                total_loss2 += - self.lambda_0*self.HSIC(feature, Ky, self.batch_size, device)
        
        if self.forward == "h" or self.forward == "x": 
            total_loss = total_loss1 + total_loss2 + total_loss3 + loss_LI
            self.iter_loss1.append(total_loss1.item())
        if self.forward == "n": 
            total_loss = total_loss2 + total_loss3 + loss_LI
            self.iter_loss1.append(-1)
        self.opt.zero_grad()
        total_loss.backward()
        self.opt.step()
                
        self.iter_loss2.append(total_loss2.item())
        self.iter_loss3.append(total_loss3.item())
        
    def update_loss(self):
        self.track_loss1.append(np.mean(self.iter_loss1))
        self.track_loss2.append(np.mean(self.iter_loss2))
        self.track_loss3.append(np.mean(self.iter_loss3))
        self.iter_loss1, self.iter_loss2, self.iter_loss3 = [], [], []
    
    def tune_output(self, input_data, labels):
        self.model.train()
        if self.loss == "mse":
            one_hot_labels = F.one_hot(labels, num_classes=10)
            labels = F.one_hot(labels, num_classes=10).float()
        
        y_pred, hidden_zs = self.model(input_data)
        total_loss = self.output_criterion(hidden_zs[-1], labels)
        self.opt.zero_grad()
        total_loss.backward()
        self.opt.step()

In [5]:
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="mnist")
    parser.add_argument('--model', type=str, default="KAN")
    parser.add_argument('--loss', type=str, default="CE")
    parser.add_argument('--BP', type=int, default=0)
    parser.add_argument('--HSIC', type=str, default="nHSIC")
    parser.add_argument('--kernel_x', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_h', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_y', type=str, default="student", choices=["rbf", "student"])
    parser.add_argument('--sigma_', type=int, default=10)
    parser.add_argument('--lambda_', type=int, default=100)
    parser.add_argument('--batchsize', type=int, default=128)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--bn_affine', type=int, default=1)
    parser.add_argument('--forward', type=str, default="n", choices=["x", "h", "n"])
    
    # Testing.
    parser.add_argument('--Latinb', type=int, default=0, choices=[0, 1])
    parser.add_argument('--Latinb_lambda', type=float, default=1.)
    parser.add_argument('--Latinb_type', type=str, default="f", choices=["f", "n"])
        
    args, _ = parser.parse_known_args()
    filename = 'kan2_results_degree_eight.csv'#get_filename(args)
    print(filename)
    
    torch.manual_seed(1)
    device = "cuda:{}".format(args.device)
    batch_size = args.batchsize
    train_loader, test_loader = load_data(args)
    
    logs = list()
    hsic = HSICBottleneck(args)
    start = time.time()
    get_loss = list()
    print("Model trainable parameters: ", sum(p.numel() for p in hsic.model.parameters() if p.requires_grad))

    for epoch in range(50):
        hsic.model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(args.batchsize, -1)
            hsic.step(data.view(args.batchsize, -1).to(device), target.to(device))
            hsic.tune_output(data.view(args.batchsize, -1).to(device), target.to(device))
        if epoch % 2 == 0:
            print("Input shape: ", data.shape)
            print("Target shape: ", target.shape)
            show_result(hsic, train_loader, test_loader, epoch, logs, device)
            print("{:.2f}".format(time.time()-start))
            start = time.time()

    txt_path = os.path.join(".\\", filename+".csv")
    df = pd.DataFrame(logs)
    #df.to_csv(txt_path,index=False)

kan2_results_degree_eight.csv
Model trainable parameters:  2249994
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 0. 	 Training  ACC: 0.8089. 	 Testing ACC: 0.8084
27.61
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 2. 	 Training  ACC: 0.8485. 	 Testing ACC: 0.8516
41.08
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 4. 	 Training  ACC: 0.8623. 	 Testing ACC: 0.8664
41.91
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 6. 	 Training  ACC: 0.8687. 	 Testing ACC: 0.8702
40.60
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 8. 	 Training  ACC: 0.8794. 	 Testing ACC: 0.8802
40.17
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 10. 	 Training  ACC: 0.8890. 	 Testing ACC: 0.8920
40.77
Input shape:  torch.Size([128, 784])
Target shape:  torch.Size([128])
EPOCH 12. 	 Training  ACC: 0.8982. 	 Testing ACC: 0.8989
41.39
Input s

In [10]:
txt_path

'.\\kan_results.csv.csv'

In [11]:
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="mnist")
    parser.add_argument('--model', type=str, default="MLP")
    parser.add_argument('--loss', type=str, default="CE")
    parser.add_argument('--BP', type=int, default=0)
    parser.add_argument('--HSIC', type=str, default="nHSIC")
    parser.add_argument('--kernel_x', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_h', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_y', type=str, default="student", choices=["rbf", "student"])
    parser.add_argument('--sigma_', type=int, default=10)
    parser.add_argument('--lambda_', type=int, default=100)
    parser.add_argument('--batchsize', type=int, default=128)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--bn_affine', type=int, default=1)
    parser.add_argument('--forward', type=str, default="n", choices=["x", "h", "n"])
    
    # Testing.
    parser.add_argument('--Latinb', type=int, default=0, choices=[0, 1])
    parser.add_argument('--Latinb_lambda', type=float, default=1.)
    parser.add_argument('--Latinb_type', type=str, default="f", choices=["f", "n"])
        
    args, _ = parser.parse_known_args()
    filename = 'mlp_results'#get_filename(args)
    print(filename)
    
    torch.manual_seed(1)
    device = "cuda:{}".format(args.device)
    batch_size = args.batchsize
    train_loader, test_loader = load_data(args)
    
    logs = list()
    hsic = HSICBottleneck(args)
    start = time.time()
    get_loss = list()
    print("Model trainable parameters: ", sum(p.numel() for p in hsic.model.parameters() if p.requires_grad))
    for epoch in range(50):
        hsic.model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(args.batchsize, -1)
            hsic.step(data.view(args.batchsize, -1).to(device), target.to(device))
            hsic.tune_output(data.view(args.batchsize, -1).to(device), target.to(device))
        if epoch % 2 == 0:
            show_result(hsic, train_loader, test_loader, epoch, logs, device)
            print("{:.2f}".format(time.time()-start))
            start = time.time()

    txt_path = os.path.join(".\\", filename+".csv")
    df = pd.DataFrame(logs)
    
    #df.to_csv(txt_path,index=False)

mlp_results
EPOCH 0. 	 Training  ACC: 0.9042. 	 Testing ACC: 0.9103
25.75
EPOCH 2. 	 Training  ACC: 0.9334. 	 Testing ACC: 0.9342
38.61
EPOCH 4. 	 Training  ACC: 0.9401. 	 Testing ACC: 0.9400
38.58
EPOCH 6. 	 Training  ACC: 0.9440. 	 Testing ACC: 0.9486
38.14
EPOCH 8. 	 Training  ACC: 0.9462. 	 Testing ACC: 0.9470
37.17
EPOCH 10. 	 Training  ACC: 0.9496. 	 Testing ACC: 0.9490
37.64
EPOCH 12. 	 Training  ACC: 0.9502. 	 Testing ACC: 0.9494
38.00
EPOCH 14. 	 Training  ACC: 0.9521. 	 Testing ACC: 0.9523
37.77
EPOCH 16. 	 Training  ACC: 0.9523. 	 Testing ACC: 0.9500
37.45
EPOCH 18. 	 Training  ACC: 0.9535. 	 Testing ACC: 0.9526
37.83
EPOCH 20. 	 Training  ACC: 0.9531. 	 Testing ACC: 0.9516
37.59
EPOCH 22. 	 Training  ACC: 0.9538. 	 Testing ACC: 0.9497
37.41
EPOCH 24. 	 Training  ACC: 0.9566. 	 Testing ACC: 0.9544
37.59
EPOCH 26. 	 Training  ACC: 0.9567. 	 Testing ACC: 0.9541
38.13
EPOCH 28. 	 Training  ACC: 0.9569. 	 Testing ACC: 0.9552
37.52
EPOCH 30. 	 Training  ACC: 0.9582. 	 Testing ACC

In [17]:
sum(p.numel() for p in hsic.model.parameters() if p.requires_grad)

252682