In [1]:
#coding=utf8
"""
# Author : Jianbai(Gus) Ye
# created at Feb 2 2019
# pytorch implementation of HSIC bottleneck method
# reference : https://github.com/forin-xyz/Keras-HSIC-Bottleneck
"""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals
import torch
from torch import nn, optim
import numpy as np
import sys
import matplotlib.pyplot as plt
from torchsummary import summary
# from collections import Iterable

import torch.nn.functional as F
import time
from utils import *

In [2]:
torch.manual_seed(1)
batch_size = 128
train_loader, test_loader = load_data(batch_size=batch_size)

In [3]:
class Extractor(nn.Module):
    def __init__(self, model : nn.Module):
        super(Extractor, self).__init__()
        self.extractor_pre = 'hsic'
        self.output_pre = 'output_layer'
        self.all_layers = []
        for name, layer in model.named_children():
            setattr(self, name, layer)
            self.all_layers.append(name)
        self.all_layers = tuple(self.all_layers)
    
    def forward(self, data):
        x = data
        hidden = {}
        for name in self.all_layers:
            name : str
            layer = getattr(self, name)
            x_ = layer(x.detach())
            x = layer(x)
            if name.startswith(self.extractor_pre):
                hidden[name] = x_
            if name.startswith(self.output_pre):
                hidden[name] = torch.sigmoid(x_)
        return x, hidden

class Test_model(nn.Module):
    def __init__(self):
        super(Test_model, self).__init__()
        self.hsic1 = nn.Linear(784, 512)
        self.hsic2 = nn.Linear(512, 512)
        
        self.f3 = nn.Dropout(p=0.2)
        self.output_layer  = nn.Linear(512, 10)
        
        self.act1 = nn.GELU()
        self.act2 = nn.Sigmoid()
        
    def forward(self, data):
        x = self.act1(self.hsic1(data))
        x = self.f3(x)
        x = self.act1(self.hsic2(x))
        x = self.f3(x)
        x = self.act2(self.output_layer(x))
        return x
    
class HSICBottleneck:
    def __init__(self, model, batch_size, lambda_0, sigma, multi_sigma=None,lr=0.01):
        self.model      = Extractor(model)
        self.batch_size = batch_size
        self.lambda_0   = lambda_0
        self.sigma      = sigma
        self.extractor  = 'hsic'
        self.last_linear = "output_layer"
        self.lr         = lr
        self.multi_sigma = multi_sigma
        assert isinstance(self.multi_sigma, Iterable) if  multi_sigma is not None else True
        
        self.opt = optim.Adam(self.model.parameters(), lr=0.01)
        self.track_loss1 = []
        self.track_loss2 = []
        self.track_loss3 = []
        
        self.output_criterion = nn.CrossEntropyLoss()
        
    def step(self, input_data, labels):
        
        one_hot_labels = F.one_hot(labels, num_classes=10)
#         labels = F.one_hot(labels, num_classes=10).float()
        
        Kx  = kernel_matrix(input_data, self.sigma)
        Ky = kernel_matrix(one_hot_labels, self.sigma)
        
        total_loss1 = 0.
        total_loss2 = 0.
        total_loss3 = 0.
        y_pred, hidden_zs = self.model(input_data)
        
#         print(hidden_zs.keys())
        for name, layer in self.model.named_children():
            if self.extractor in name:
                hidden_z = hidden_zs[name]
                Kz = kernel_matrix(hidden_z, self.sigma)
                loss1 = HSIC(Kz, Kx, self.batch_size) 
                loss2 = - self.lambda_0*HSIC(Kz,Ky, self.batch_size)
                total_loss1 += loss1
                total_loss2 += loss2
            if self.last_linear in name:
                hidden_z = hidden_zs[name]
                total_loss3 += self.output_criterion(hidden_z, labels)
                
        total_loss = total_loss1 + total_loss2 + total_loss3
        self.opt.zero_grad()
        total_loss.backward()
        self.opt.step()
        
        self.track_loss1.append(total_loss1.item())
        self.track_loss2.append(total_loss2.item())
        self.track_loss3.append(total_loss3.item())
                
        return total_loss1.item(), total_loss2.item(), total_loss3.item()
    
    
    def tune_output(self, input_data, labels):
        
        total_loss3 = 0.
        y_pred, hidden_zs = self.model(input_data)
        
        for name, layer in self.model.named_children():
            if self.last_linear in name:
                hidden_z = hidden_zs[name]
                total_loss3 += self.output_criterion(hidden_z, labels)
                
        total_loss = total_loss3
        self.opt.zero_grad()
        total_loss.backward()
        self.opt.step()
    
        self.track_loss3.append(total_loss3.item())
        return total_loss3.item()
    
def show_result():
    model.eval()
    with torch.no_grad():
        counts, correct, counts_test, correct_test = 0, 0, 0, 0        
        for batch_idx, (data, target) in enumerate(train_loader): 
            output = hsic.model.forward(data.view(batch_size, -1).to(device))[0].cpu()
            pred = output.argmax(dim=1, keepdim=True)
            correct += (pred[:,0] == target).float().sum()
            counts += len(pred)
        for batch_idx, (data, target) in enumerate(test_loader): 
            output = hsic.model.forward(data.view(batch_size, -1).to(device))[0].cpu()
            pred = output.argmax(dim=1, keepdim=True)
            correct_test += (pred[:,0] == target).float().sum()
            counts_test += len(pred)
        print("Testing  ACC: {:.2f} \t Training ACC: {:.2f}".format(correct/counts, 
                                                                    correct_test/counts_test))

In [4]:
device = "cuda"
model = Test_model()
# model = nn.DataParallel(model,device_ids=[0,1])
model.to(device)
model.train()
HSIC_epochs = 100
lambda_0 = 500

hsic = HSICBottleneck(model, batch_size=batch_size, lambda_0=lambda_0, sigma=5.)

for epoch in range(HSIC_epochs):
    model.train()
    start = time.time()
    total_loss1, total_loss2, total_loss3, total_loss_tune = 0, 0, 0, 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(batch_size, -1)
        loss1, loss2, loss3 = hsic.step(data.view(batch_size, -1).to(device), target.to(device))
        total_loss1 += loss1
        total_loss2 += loss2
        total_loss3 += loss3
    if epoch in range(0, 100, 10):
        print("===============================")
        print("EPOCH %d" % epoch)
        show_result()
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(batch_size, -1)
            loss_tune = hsic.tune_output(data.view(batch_size, -1).to(device), target.to(device))
            total_loss_tune += loss_tune
        show_result()
        for batch_idx, (data, target) in enumerate(train_loader):
            data = data.view(batch_size, -1)
            loss_tune = hsic.tune_output(data.view(batch_size, -1).to(device), target.to(device))
            total_loss_tune += loss_tune
        model.eval()
        show_result()
        sys.stdout.write("{:.3f}, {:.3f}, {:.3f}, {:.3f}".format(total_loss1/(batch_idx+1), 
                                                                 total_loss2/lambda_0*100/(batch_idx+1), 
                                                                 total_loss3/(batch_idx+1),
                                                                 total_loss_tune/(batch_idx+1)))
        sys.stdout.flush()
        sys.stdout.write('\n')
        print("{:.2f}".format(time.time()-start))

EPOCH 0
Testing  ACC: 0.33 	 Training ACC: 0.32
Testing  ACC: 0.59 	 Training ACC: 0.58
Testing  ACC: 0.58 	 Training ACC: 0.58
0.026, -0.300, 2.107, 3.548
16.16
EPOCH 10
Testing  ACC: 0.45 	 Training ACC: 0.45
Testing  ACC: 0.51 	 Training ACC: 0.52
Testing  ACC: 0.53 	 Training ACC: 0.54
0.027, -0.338, 1.831, 3.595
15.21
EPOCH 20
Testing  ACC: 0.53 	 Training ACC: 0.53
Testing  ACC: 0.66 	 Training ACC: 0.66
Testing  ACC: 0.68 	 Training ACC: 0.69
0.027, -0.342, 1.827, 3.532
15.07
EPOCH 30
Testing  ACC: 0.57 	 Training ACC: 0.56
Testing  ACC: 0.74 	 Training ACC: 0.73
Testing  ACC: 0.75 	 Training ACC: 0.75
0.027, -0.344, 1.831, 3.416
14.94
EPOCH 40
Testing  ACC: 0.64 	 Training ACC: 0.64
Testing  ACC: 0.72 	 Training ACC: 0.72
Testing  ACC: 0.75 	 Training ACC: 0.74
0.027, -0.344, 1.823, 3.432
15.27
EPOCH 50
Testing  ACC: 0.71 	 Training ACC: 0.71
Testing  ACC: 0.75 	 Training ACC: 0.76
Testing  ACC: 0.78 	 Training ACC: 0.77
0.027, -0.344, 1.822, 3.370
15.15
EPOCH 60
Testing  ACC: 

In [5]:
class PostTrained:
    def __init__(self, model : nn.Module, criterion,lr=0.1):
        parameters = []
        model.train()
        for name, layer in model.named_children():
            if name == "output_layer":
                for params in layer.parameters():
                    parameters.append(params)
            else:
                for params in layer.parameters():
                    params.requires_grad = False
        self.opt   = optim.Adam(model.parameters(), lr=0.01)
        self.model = model
        self.lr    = lr
        self.criterion = criterion
#         summary(self.model, (batch_size, 784))

    def step(self, input_data, labels):
        output_data = self.model(input_data)
        loss = self.criterion(output_data, labels)
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()
        return loss.item()
    
def show_result_post():
    post.model.eval()
    with torch.no_grad():
        counts = 0
        correct = 0
        for batch_idx, (data, target) in enumerate(train_loader): 
            output = post.model.forward(data.view(batch_size, -1).to(device)).cpu()
            pred = output.argmax(dim=1, keepdim=True)
            correct += (pred[:,0] == target).float().sum()
            counts += len(pred)
        print("Training ACC: {:.3f}".format(correct/counts))
        
        counts = 0
        correct = 0        
        for batch_idx, (data, target) in enumerate(test_loader): 
            output = post.model.forward(data.view(batch_size, -1).to(device)).cpu()
            pred = output.argmax(dim=1, keepdim=True)
            correct += (pred[:,0] == target).float().sum()
            counts += len(pred)
        print("Testing  ACC: {:.3f}".format(correct/counts))

In [6]:
POST_epochs = 30
criterion = nn.CrossEntropyLoss()
post = PostTrained(model, criterion=criterion)
post.model.to(device)
for epoch in range(POST_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        loss = post.step(data.view(batch_size, -1).to(device), target.to(device))
    if (epoch+1) % 5 == 0 or epoch == 0: 
        print("===============================")
        print("POST EPOCH %d" % epoch)
        show_result_post()

POST EPOCH 0
Training ACC: 0.485
Testing  ACC: 0.480
POST EPOCH 4
Training ACC: 0.720
Testing  ACC: 0.720
POST EPOCH 9
Training ACC: 0.726
Testing  ACC: 0.725
POST EPOCH 14
Training ACC: 0.734
Testing  ACC: 0.734
POST EPOCH 19
Training ACC: 0.737
Testing  ACC: 0.737
POST EPOCH 24
Training ACC: 0.741
Testing  ACC: 0.739
POST EPOCH 29
Training ACC: 0.743
Testing  ACC: 0.741
