In [1]:
#coding=utf8
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import torch
from torch import nn, optim
import numpy as np
import sys
import matplotlib.pyplot as plt
from torchsummary import summary
import matplotlib.pyplot as plt

import torch.nn.functional as F
import time
from utils import *
import argparse

In [4]:
class HSICBottleneck:
    def __init__(self, args):
        self.model      = MLP(args)
        self.model.to(device)
        self.batch_size = args.batchsize
        self.lambda_0   = args.lambda_
        self.sigma      = args.sigma_
        self.extractor  = 'hsic'
        self.last_linear = "output_layer"
        self.HSIC = compute_HSIC(args.HSIC)
        self.kernel = compute_kernel()
        self.kernel_x = args.kernel_x
        self.kernel_h = args.kernel_h
        self.kernel_y = args.kernel_y
        self.forward = args.forward
        
        self.opt = optim.AdamW(self.model.parameters(), lr=0.001)
        self.iter_loss1, self.iter_loss2, self.iter_loss3 = [], [], []
        self.track_loss1, self.track_loss2, self.track_loss3 = [], [], []
        
        self.loss = args.loss
        if self.loss == "mse": self.output_criterion = nn.MSELoss()
        elif self.loss == "CE": self.output_criterion = nn.CrossEntropyLoss()
        
    def step(self, input_data, labels):
        
        labels_float = F.one_hot(labels, num_classes=10).float()
        if self.forward == "x": Kx  = self.kernel(input_data, self.sigma, self.kernel_x)
        Ky = self.kernel(labels_float, self.sigma, self.kernel_y) 
        
        kernel_list = list()
        y_pred, hidden_zs = self.model(input_data)
        for num, feature in enumerate(hidden_zs): kernel_list.append(self.kernel(feature, self.sigma, self.kernel_h))
        
        total_loss1, total_loss2, total_loss3 = 0., 0., 0.
        for num, feature in enumerate(kernel_list):
            if num == (len(hidden_zs)-1): 
                if self.forward == "f": total_loss1 += self.HSIC(feature, kernel_list[num-1], self.batch_size, device)
                elif self.forward == "x": total_loss1 += self.HSIC(feature, Kx, self.batch_size, device)
                if self.loss == "mse": total_loss3 += self.output_criterion(hidden_zs[-1], labels_float)
                elif self.loss == "CE": total_loss3 += self.output_criterion(hidden_zs[-1], labels)
            elif num == 0:
                if self.forward == "x": total_loss1 += self.HSIC(feature, Kx, self.batch_size, device)
                total_loss2 += - self.lambda_0*self.HSIC(feature, Ky, self.batch_size, device)
            else:
                if self.forward == "f": total_loss1 += self.HSIC(feature, kernel_list[num-1], self.batch_size, device)
                elif self.forward == "x": total_loss1 += self.HSIC(feature, Kx, self.batch_size, device)
                total_loss2 += - self.lambda_0*self.HSIC(feature, Ky, self.batch_size, device)
        
        if self.forward == "f" or self.forward == "x": 
            total_loss = total_loss1 + total_loss2 + total_loss3
            self.iter_loss1.append(total_loss1.item())
        if self.forward == "n": 
            total_loss = total_loss2 + total_loss3
            self.iter_loss1.append(-1)
        self.opt.zero_grad()
        total_loss.backward()
        self.opt.step()
                
        self.iter_loss2.append(total_loss2.item())
        self.iter_loss3.append(total_loss3.item())
        
    def update_loss(self):
        self.track_loss1.append(np.mean(self.iter_loss1))
        self.track_loss2.append(np.mean(self.iter_loss2))
        self.track_loss3.append(np.mean(self.iter_loss3))
        self.iter_loss1, self.iter_loss2, self.iter_loss3 = [], [], []
    
    def tune_output(self, input_data, labels):
        self.model.train()
        if self.loss == "mse":
            one_hot_labels = F.one_hot(labels, num_classes=10)
            labels = F.one_hot(labels, num_classes=10).float()
        
        y_pred, hidden_zs = self.model(input_data)
        total_loss = self.output_criterion(hidden_zs[-1], labels)
        self.opt.zero_grad()
        total_loss.backward()
        self.opt.step()
    
def show_result():
    hsic.model.eval()
    with torch.no_grad():
        counts, correct, counts2, correct2 = 0, 0, 0, 0        
        for batch_idx, (data, target) in enumerate(train_loader): 
            output = hsic.model.forward(data.view(batch_size, -1).to(device))[0].cpu()
            pred = output.argmax(dim=1, keepdim=True)
            correct += (pred[:,0] == target).float().sum()
            counts += len(pred)
        for batch_idx, (data, target) in enumerate(test_loader): 
            output = hsic.model.forward(data.view(batch_size, -1).to(device))[0].cpu()
            pred = output.argmax(dim=1, keepdim=True)
            correct2 += (pred[:,0] == target).float().sum()
            counts2 += len(pred)
        print("EPOCH {}. \t Training  ACC: {:.4f}. \t Testing ACC: {:.4f}".format(epoch, correct/counts, correct2/counts2))

In [6]:
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--loss', type=str, default="CE")
    parser.add_argument('--HSIC', type=str, default="nHSIC")
    parser.add_argument('--kernel_x', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_h', type=str, default="student", choices=["rbf", "student"])
    parser.add_argument('--kernel_y', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--sigma_', type=int, default=1)
    parser.add_argument('--lambda_', type=int, default=100)
    parser.add_argument('--batchsize', type=int, default=256)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--bn_affine', type=int, default=0)
    parser.add_argument('--forward', type=str, default="n", choices=["x", "f", "n"])
    args, _ = parser.parse_known_args()    
    
    torch.manual_seed(1)
    device = "cuda:{}".format(args.device)
    batch_size = args.batchsize
    train_loader, test_loader = load_data(batch_size=args.batchsize)
    
    hsic = HSICBottleneck(args)
    start = time.time()
    for epoch in range(100):
        hsic.model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(batch_size, -1)
            hsic.step(data.view(batch_size, -1).to(device), target.to(device))
            hsic.tune_output(data.view(batch_size, -1).to(device), target.to(device))
        if epoch in range(0, 100, 10):
            show_result()
            print("{:.2f}".format(time.time()-start))
            start = time.time()

EPOCH 0
Training  ACC: 0.9531 	 Testing ACC: 0.9517
9.29
EPOCH 10
Training  ACC: 0.9797 	 Testing ACC: 0.9728
75.58
EPOCH 20
Training  ACC: 0.9817 	 Testing ACC: 0.9740
78.17
EPOCH 30
Training  ACC: 0.9834 	 Testing ACC: 0.9765
77.30
EPOCH 40
Training  ACC: 0.9842 	 Testing ACC: 0.9754
77.20
EPOCH 50
Training  ACC: 0.9846 	 Testing ACC: 0.9741
78.65
EPOCH 60
Training  ACC: 0.9849 	 Testing ACC: 0.9735
77.26
EPOCH 70
Training  ACC: 0.9847 	 Testing ACC: 0.9752
78.16
EPOCH 80
Training  ACC: 0.9851 	 Testing ACC: 0.9732
70.80
EPOCH 90
Training  ACC: 0.9858 	 Testing ACC: 0.9766
70.49


In [8]:
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--loss', type=str, default="CE")
    parser.add_argument('--HSIC', type=str, default="nHSIC")
    parser.add_argument('--kernel_x', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_h', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_y', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--sigma_', type=int, default=10)
    parser.add_argument('--lambda_', type=int, default=100)
    parser.add_argument('--batchsize', type=int, default=256)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--bn_affine', type=int, default=0)
    parser.add_argument('--forward', type=str, default="n", choices=["x", "f", "n"])
    args, _ = parser.parse_known_args()    
    
    torch.manual_seed(1)
    device = "cuda:{}".format(args.device)
    batch_size = args.batchsize
    train_loader, test_loader = load_data(batch_size=args.batchsize)
    
    hsic = HSICBottleneck(args)
    start = time.time()
    for epoch in range(100):
        hsic.model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(batch_size, -1)
            hsic.step(data.view(batch_size, -1).to(device), target.to(device))
            hsic.tune_output(data.view(batch_size, -1).to(device), target.to(device))
        if epoch in range(0, 100, 10):
            print("EPOCH %d" % epoch)
            show_result()
            print("{:.2f}".format(time.time()-start))
            start = time.time()

EPOCH 0
Training  ACC: 0.9644 	 Testing ACC: 0.9602
9.45
EPOCH 10
Training  ACC: 0.9879 	 Testing ACC: 0.9787
79.39
EPOCH 20
Training  ACC: 0.9884 	 Testing ACC: 0.9774
71.22
EPOCH 30
Training  ACC: 0.9903 	 Testing ACC: 0.9792
71.77
EPOCH 40
Training  ACC: 0.9903 	 Testing ACC: 0.9784
74.82
EPOCH 50
Training  ACC: 0.9906 	 Testing ACC: 0.9788
74.83
EPOCH 60
Training  ACC: 0.9915 	 Testing ACC: 0.9787
75.61
EPOCH 70
Training  ACC: 0.9921 	 Testing ACC: 0.9773
74.30
EPOCH 80
Training  ACC: 0.9920 	 Testing ACC: 0.9787
74.67
EPOCH 90
Training  ACC: 0.9918 	 Testing ACC: 0.9776
73.75


In [9]:
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--loss', type=str, default="CE")
    parser.add_argument('--HSIC', type=str, default="nHSIC")
    parser.add_argument('--kernel_x', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_h', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--kernel_y', type=str, default="rbf", choices=["rbf", "student"])
    parser.add_argument('--sigma_', type=int, default=10)
    parser.add_argument('--lambda_', type=int, default=1000)
    parser.add_argument('--batchsize', type=int, default=256)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--bn_affine', type=int, default=0)
    parser.add_argument('--forward', type=str, default="n", choices=["x", "f", "n"])
    args, _ = parser.parse_known_args()    
    
    torch.manual_seed(1)
    device = "cuda:{}".format(args.device)
    batch_size = args.batchsize
    train_loader, test_loader = load_data(batch_size=args.batchsize)
    
    hsic = HSICBottleneck(args)
    start = time.time()
    for epoch in range(100):
        hsic.model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(batch_size, -1)
            hsic.step(data.view(batch_size, -1).to(device), target.to(device))
            hsic.tune_output(data.view(batch_size, -1).to(device), target.to(device))
        if epoch in range(0, 100, 10):
            print("EPOCH %d" % epoch)
            show_result()
            print("{:.2f}".format(time.time()-start))
            start = time.time()

EPOCH 0
Training  ACC: 0.9632 	 Testing ACC: 0.9613
8.17
EPOCH 10
Training  ACC: 0.9888 	 Testing ACC: 0.9791
73.34
EPOCH 20
Training  ACC: 0.9892 	 Testing ACC: 0.9798
75.73
EPOCH 30
Training  ACC: 0.9910 	 Testing ACC: 0.9811
75.44
EPOCH 40
Training  ACC: 0.9910 	 Testing ACC: 0.9786
76.82
EPOCH 50
Training  ACC: 0.9908 	 Testing ACC: 0.9789
75.38
EPOCH 60
Training  ACC: 0.9916 	 Testing ACC: 0.9787
47.91
EPOCH 70
Training  ACC: 0.9918 	 Testing ACC: 0.9782
45.17
EPOCH 80
Training  ACC: 0.9917 	 Testing ACC: 0.9777
44.95
EPOCH 90
Training  ACC: 0.9922 	 Testing ACC: 0.9780
44.96
