In [1]:
#coding=utf8
"""
# Author : Jianbai(Gus) Ye
# created at Feb 2 2019
# pytorch implementation of HSIC bottleneck method
# reference : https://github.com/forin-xyz/Keras-HSIC-Bottleneck
"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import torch
from torch import nn, optim
import numpy as np
import sys
import matplotlib.pyplot as plt
from torchsummary import summary
# from collections import Iterable

import torch.nn.functional as F
import time
from utils import *

In [2]:
torch.manual_seed(1)
batch_size = 128
train_loader, test_loader = load_data(batch_size=batch_size)

In [3]:
class Block(nn.Module):
    def __init__(self, inplane, outplane):
        super(Block, self).__init__()
        self.linear = nn.Linear(inplane, outplane)
        self.bn = nn.BatchNorm1d(outplane, affine=False)
        self.act = nn.GELU()
    def forward(self, x):
        x = self.linear(x)
        x = self.act(x)
        x = self.bn(x)
        return x
        
class Test_model(nn.Module):
    def __init__(self):
        super(Test_model, self).__init__()
        
        self.units = [784, 256, 128, 128]
#         self.module_list = [Block(self.units[i], self.units[i+1]).to(device) for i in range(len(self.units)-1)]
        self.module_list = nn.ModuleList( [Block(self.units[i], self.units[i+1]) for i in range(len(self.units)-1)])
        
        self.f3 = nn.Dropout(p=0.2)
        self.output_layer  = nn.Linear(self.units[-1], 10)        
        self.act2 = nn.ReLU()
        
    def forward(self, data):
        x = data
        output = []
        for module in self.module_list:
            x_ = module(x.detach())
            x = module(x)
            output.append(x_)
        x = self.f3(x)
        x_ = self.act2(self.output_layer(x.detach()))
        x = self.act2(self.output_layer(x))
        output.append(x_)
        return x, output
    
class HSICBottleneck:
    def __init__(self, model, batch_size, lambda_0, sigma, multi_sigma=None,lr=0.01):
        self.model      = model
        self.batch_size = batch_size
        self.lambda_0   = lambda_0
        self.sigma      = sigma
        self.extractor  = 'hsic'
        self.last_linear = "output_layer"
        self.lr         = lr
        self.multi_sigma = multi_sigma
        assert isinstance(self.multi_sigma, Iterable) if  multi_sigma is not None else True
        
        self.opt = optim.AdamW(self.model.parameters(), lr=0.001)
#         self.opt = optim.SGD(self.model.parameters(), lr=0.001)
        self.track_loss1 = []
        self.track_loss2 = []
        self.track_loss3 = []
        
        self.loss = "CE"
        if self.loss == "mse":
            self.output_criterion = nn.MSELoss()
        elif self.loss == "CE":
            self.output_criterion = nn.CrossEntropyLoss()
        
    def step(self, input_data, labels):
        
        one_hot_labels = F.one_hot(labels, num_classes=10)
        if self.loss == "mse":
            labels = F.one_hot(labels, num_classes=10).float()
        
        Kx  = kernel_matrix(input_data, self.sigma)
        Ky = kernel_matrix(one_hot_labels, self.sigma)
        
        total_loss1 = 0.
        total_loss2 = 0.
        total_loss3 = 0.
        y_pred, hidden_zs = self.model(input_data)
        
        for num, feature in enumerate(hidden_zs):
            if num == (len(hidden_zs)-1): 
                total_loss3 += self.output_criterion(feature, labels)
            else:
                Kz = kernel_matrix(feature, self.sigma)
                loss1 = HSIC(Kz, Kx, self.batch_size) 
                loss2 = - self.lambda_0*HSIC(Kz,Ky, self.batch_size)
                total_loss1 += loss1
                total_loss2 += loss2
                
        total_loss = total_loss1 + total_loss2 + total_loss3
        self.opt.zero_grad()
        total_loss.backward()
        self.opt.step()
        
        self.track_loss1.append(total_loss1.item())
        self.track_loss2.append(total_loss2.item())
        self.track_loss3.append(total_loss3.item())
                
        return total_loss1.item(), total_loss2.item(), total_loss3.item()
    
    def tune_output(self, input_data, labels):
        
        if self.loss == "mse":
            one_hot_labels = F.one_hot(labels, num_classes=10)
            labels = F.one_hot(labels, num_classes=10).float()
        
        y_pred, hidden_zs = self.model(input_data)
        total_loss = self.output_criterion(hidden_zs[-1], labels)
        self.opt.zero_grad()
        total_loss.backward()
        self.opt.step()
    
        return total_loss.item()
    
def show_result():
    model.eval()
    with torch.no_grad():
        counts, correct, counts2, correct2 = 0, 0, 0, 0        
        for batch_idx, (data, target) in enumerate(train_loader): 
            output = hsic.model.forward(data.view(batch_size, -1).to(device))[0].cpu()
            pred = output.argmax(dim=1, keepdim=True)
            correct += (pred[:,0] == target).float().sum()
            counts += len(pred)
        for batch_idx, (data, target) in enumerate(test_loader): 
            output = hsic.model.forward(data.view(batch_size, -1).to(device))[0].cpu()
            pred = output.argmax(dim=1, keepdim=True)
            correct2 += (pred[:,0] == target).float().sum()
            counts2 += len(pred)
        print("Training  ACC: {:.2f} \t Testing ACC: {:.2f}".format(correct/counts, correct2/counts2))

In [4]:
device = "cuda:0"
model = Test_model()
model.to(device)
model.train()
HSIC_epochs = 100
lambda_0 = 500

hsic = HSICBottleneck(model, batch_size=batch_size, lambda_0=lambda_0, sigma=10.)

for epoch in range(HSIC_epochs):
    model.train()
    start = time.time()
    total_loss1, total_loss2, total_loss3, total_loss_tune = 0, 0, 0, 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(batch_size, -1)
        loss1, loss2, loss3 = hsic.step(data.view(batch_size, -1).to(device), target.to(device))
        total_loss_tune += hsic.tune_output(data.view(batch_size, -1).to(device), target.to(device))
        total_loss1 += loss1
        total_loss2 += loss2
        total_loss3 += loss3
    if epoch in range(0, 100, 10):
        print("===============================")
        print("EPOCH %d" % epoch)
        model.eval()
        show_result()
        sys.stdout.write("{:.3f}, {:.3f}, {:.3f}, {:.3f}".format(total_loss1/(batch_idx+1), 
                                                                 total_loss2/lambda_0*100/(batch_idx+1), 
                                                                 total_loss3/(batch_idx+1),
                                                                 total_loss_tune/(batch_idx+1)))
        sys.stdout.flush()
        sys.stdout.write('\n')
        print("{:.2f}".format(time.time()-start))

EPOCH 0
Training  ACC: 0.89 	 Testing ACC: 0.89
0.027, -0.160, 0.563, 0.552
7.64
EPOCH 10
Training  ACC: 0.92 	 Testing ACC: 0.92
0.027, -0.181, 0.327, 0.322
7.53
EPOCH 20
Training  ACC: 0.92 	 Testing ACC: 0.92
0.026, -0.182, 0.316, 0.311
8.61
EPOCH 30
Training  ACC: 0.92 	 Testing ACC: 0.92
0.026, -0.182, 0.313, 0.309
8.75
EPOCH 40
Training  ACC: 0.92 	 Testing ACC: 0.92
0.026, -0.182, 0.315, 0.311
8.72
EPOCH 50
Training  ACC: 0.92 	 Testing ACC: 0.92
0.026, -0.182, 0.312, 0.309
8.53
EPOCH 60
Training  ACC: 0.92 	 Testing ACC: 0.92
0.026, -0.183, 0.311, 0.308
8.65
EPOCH 70
Training  ACC: 0.92 	 Testing ACC: 0.92
0.026, -0.183, 0.309, 0.305
8.49
EPOCH 80
Training  ACC: 0.92 	 Testing ACC: 0.92
0.026, -0.183, 0.310, 0.306
8.64
EPOCH 90
Training  ACC: 0.92 	 Testing ACC: 0.92
0.026, -0.183, 0.310, 0.304
8.79


In [5]:
class PostTrained:
    def __init__(self, model : nn.Module, criterion,lr=0.1):
        parameters = []
        model.train()
        for name, layer in model.named_children():
            if name == "output_layer":
                for params in layer.parameters():
                    parameters.append(params)
            else:
                for params in layer.parameters():
                    params.requires_grad = False
        self.opt   = optim.Adam(model.parameters(), lr=0.01)
        self.model = model
        self.lr    = lr
        self.criterion = criterion
#         summary(self.model, (batch_size, 784))

    def step(self, input_data, labels):
        output_data, _ = self.model(input_data)
        loss = self.criterion(output_data, labels)
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        return loss.item()
    
def show_result_post():
    post.model.eval()
    with torch.no_grad():
        counts = 0
        correct = 0
        for batch_idx, (data, target) in enumerate(train_loader): 
            output = post.model.forward(data.view(batch_size, -1).to(device))[0].cpu()
            pred = output.argmax(dim=1, keepdim=True)
            correct += (pred[:,0] == target).float().sum()
            counts += len(pred)
        print("Training ACC: {:.3f}".format(correct/counts))
        
        counts = 0
        correct = 0        
        for batch_idx, (data, target) in enumerate(test_loader): 
            output = post.model.forward(data.view(batch_size, -1).to(device))[0].cpu()
            pred = output.argmax(dim=1, keepdim=True)
            correct += (pred[:,0] == target).float().sum()
            counts += len(pred)
        print("Testing  ACC: {:.3f}".format(correct/counts))

In [6]:
POST_epochs = 10
criterion = nn.MSELoss()
post = PostTrained(model, criterion=criterion)
post.model.to(device)
for epoch in range(POST_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        target = F.one_hot(target, num_classes=10).float()
        loss = post.step(data.view(batch_size, -1).to(device), target.to(device))
    if (epoch+1) % 5 == 0 or epoch == 0: 
        print("===============================")
        print("POST EPOCH %d" % epoch)
        show_result_post()

POST EPOCH 0
Training ACC: 0.922
Testing  ACC: 0.920
POST EPOCH 4
Training ACC: 0.924
Testing  ACC: 0.922
POST EPOCH 9
Training ACC: 0.924
Testing  ACC: 0.921
