In [1]:
import torch
import numpy as np
torch.cuda.is_available()

True

In [4]:
import sys
print(sys.version)

3.9.1 (default, Dec 11 2020, 09:29:25) [MSC v.1916 64 bit (AMD64)]


In [21]:
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import datetime
import os
import torchvision
import time
import copy
from torch.utils.tensorboard import SummaryWriter

from torchsummary import summary

from tqdm import tqdm

In [5]:
torch.__version__

'1.7.1+cu110'

In [23]:
# misc functions (https://github.com/choasma/HSIC-bottleneck/blob/master/source/hsicbt/utils/misc.py)

def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    return torch.squeeze(torch.eye(num_classes)[y])

def get_layer_parameters(model, idx_range):

    param_out = []
    param_out_name = []
    for it, (name, param) in enumerate(model.named_parameters()):
        if it in idx_range:
            param_out.append(param)
            param_out_name.append(name)

    return param_out, param_out_name

# https://github.com/choasma/HSIC-bottleneck/blob/master/source/hsicbt/utils/meter.py
class AverageMeter(object):
    """Basic meter"""
    def __init__(self):
        self.reset()

    def reset(self):
        """ reset meter
        """
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        """ incremental meter
        """
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [2]:
# Implement Gausian kernel function to calculate K_X and K_y
# gausian kernel, k(x, y) ~ exp(-(1/2)*||x - y||^2/sigma**2 )

def distmat(X):
    """ distance matrix
        Euclidean Distance Matrix (EDM)
        D = abs(a^2 + b^2 - 2ab_T)
    """
    r = torch.sum(X*X, 1)
    r = r.view([-1, 1])
    a = torch.mm(X, torch.transpose(X,0,1))
    D = r.expand_as(a) - 2*a +  torch.transpose(r,0,1).expand_as(a)
    D = torch.abs(D)

    return D

def kernelmat(X, sigma):
    """
    Kernel function

    m: training batch size
    H: centering matrix:: I_m - (1/m)*1_m.1_m
    gausian kernel: k(x, y) ~ exp(-(1/2)*||x - y||^2/sigma**2)
    """
    m = int(X.size()[0]) # batch size
    H = torch.eye(m) - (1./m) * torch.ones([m,m])

    Dxx = distmat(X)

    variance = 2.*sigma*sigma*X.size()[1]            
    Kx = torch.exp( -Dxx / variance).type(torch.FloatTensor)   # kernel
    Kxc = torch.mm(Kx, H) # kernel function centered with H

    return Kxc


def hsic_base(x, y, sigma=None, use_cuda=True):
    """
    Implement equation 3 in the paper
    HSIC: (m - 1)^-2 . trace(Kx H Ky H)
    """
    m = int(x.size()[0]) # batch size

    KxH = kernelmat(x, sigma=sigma)
    KyH = kernelmat(y, sigma=sigma)

    return torch.trace(KxH @ KyH)/(m - 1)**2

# taken from HSIC implementation 
# https://github.com/choasma/HSIC-bottleneck/blob/9f1fe2447592d61c0ba524aad0ff0820ae2ba9cb/source/hsicbt/core/train_misc.py#L26
# def hsic_objective(hidden, h_target, h_data, sigma):

#     hsic_hy_val = hsic_base( hidden, h_target, sigma=sigma)
#     hsic_hx_val = hsic_base( hidden, h_data,   sigma=sigma)

#     return hsic_hx_val, hsic_hy_val

def hsic_loss_obj(hidden, h_target, h_data, sigma):
    """
    calculate hsic between input (X) and hidden layer weights
    calculate hsic between hidden layer weights and target (Y)

    return: hx, hy for calculating loss in training pipeline
    """
    hsic_hx = hsic_base(hidden, h_data, sigma=sigma)
    hsic_hy = hsic_base(hidden, h_target, sigma=sigma)

    return hsic_hx, hsic_hy

In [6]:
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader

In [7]:
# prepare data loader for CIFAR10 and MNIST

train_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize(size=(227, 227))])
valid_transform = train_transform

train_set = CIFAR100('./data/cifar100', train=True,
                  download=True, transform=train_transform)
valid_set = CIFAR100('./data/cifar100', train=False,
                  download=True, transform=valid_transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(valid_set, batch_size=128, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar100\cifar-100-python.tar.gz
Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar100\cifar-100-python.tar.gz


100.0%

Extracting ./data/cifar100\cifar-100-python.tar.gz to ./data/cifar100
Files already downloaded and verified


In [13]:
# create primitive conv block with conv2d, bn, and activation.

def makeblock_conv(in_chs, out_chs, atype, stride=1):

    layer = nn.Conv2d(in_channels=in_chs, 
        out_channels=out_chs, kernel_size=5, stride=stride)
    bn = nn.BatchNorm2d(out_chs, affine=False)
    nonlinear = nn.ReLU()

    return nn.Sequential(*[layer, bn, nonlinear])

def makeblock_dense(in_dim, out_dim, atype):
    
    layer = nn.Linear(in_dim, out_dim)
    bn = nn.BatchNorm1d(out_dim, affine=False)
    nonlinear = nn.ReLU()
    out = nn.Sequential(*[layer, bn, nonlinear])
    
    return out

In [18]:
class ModelConv(nn.Module):

    def __init__(self, in_width=784, hidden_width=64, n_layers=5, atype='relu', 
        last_hidden_width=None, data_code='cifar10', **kwargs):
        super(ModelConv, self).__init__()
    
        block_list = []
        is_conv = False

        if data_code == 'cifar10':
            in_ch = 3
            last_hidden_width = 10
        elif data_code == 'mnist':
            in_ch == 1

        last_hw = hidden_width
        if last_hidden_width:
            last_hw = last_hidden_width
        
        for i in range(n_layers):
            block = makeblock_conv(hidden_width, hidden_width, atype)
            block_list.append(block)

        self.input_layer    = makeblock_conv(in_ch, hidden_width, atype)
        self.sequence_layer = nn.Sequential(*block_list)
        if data_code == 'mnist':
            dim = 128
        elif data_code == 'cifar10':
            dim = 960

        self.output_layer   = makeblock_dense(dim, last_hw, atype)

        self.is_conv = is_conv
        self.in_width = in_width

    def forward(self, x):

        output_list = []
        
        x = self.input_layer(x)
        output_list.append(x)
        
        for block in self.sequence_layer:
            x = block(x)
            output_list.append(x)
            
        x = x.view(-1, np.prod(x.size()[1:]))

        x = self.output_layer(x)
        output_list.append(x)

        return x, output_list

In [19]:
device = 'cuda'
model = ModelConv()
model.to(device)

ModelConv(
  (input_layer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (2): ReLU()
  )
  (sequence_layer): Sequential(
    (0): Sequential(
      (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stat

In [None]:
def hsic_train(cepoch, model, data_loader, config_dict):

    # cross_entropy_loss = torch.nn.CrossEntropyLoss()
    prec1 = total_loss = hx_l = hy_l = -1

    batch_acc    = AverageMeter()
    batch_loss   = AverageMeter()
    batch_hischx = AverageMeter()
    batch_hischy = AverageMeter()

    batch_log = {}
    batch_log['batch_acc'] = []
    batch_log['batch_loss'] = []
    batch_log['batch_hsic_hx'] = []
    batch_log['batch_hsic_hy'] = []

    model = model.to(config_dict['device'])

    n_data = config_dict['batch_size'] * len(data_loader)

    # sigma_optimizer = optim.SGD([sigma_tensor], lr=1E-5)

    # for batch_idx, (data, target) in enumerate(data_loader):
    pbar = tqdm(enumerate(data_loader), total=n_data/config_dict['batch_size'], ncols=120)
    for batch_idx, (data, target) in pbar:

        if os.environ.get('HSICBT_DEBUG')=='4':
            if batch_idx > 5:
                break
                
        data   = data.to(config_dict['device'])
        target = target.to(config_dict['device'])
        output, hiddens = model(data)

        h_target = target.view(-1,1)
        h_target = to_categorical(h_target, num_classes=10).float()
        h_data = data.view(-1, np.prod(data.size()[1:]))

        idx_range = []
        it = 0

        # So the batchnorm is not learnable, making only @,b at layer
        for i in range(len(hiddens)):
            idx_range.append(np.arange(it, it+2).tolist())
            it += 2
    
        for i in range(len(hiddens)):
            
            output, hiddens = model(data)
            params, param_names = get_layer_parameters(model=model, idx_range=idx_range[i]) # so we only optimize one layer at a time
            optimizer = optim.SGD(params, lr = config_dict['learning_rate'], momentum=.9, weight_decay=0.001)
            optimizer.zero_grad()
            if len(hiddens[i].size()) > 2:
                hiddens[i] = hiddens[i].view(-1, np.prod(hiddens[i].size()[1:]))

            hx_l, hy_l = hsic_loss_obj(
                    hiddens[i],
                    h_target=h_target.float(),
                    h_data=h_data,
                    sigma=config_dict['sigma'],
            )
            #print(torch.max(hiddens[i]).cpu().detach().numpy(), torch.min(hiddens[i]).cpu().detach().numpy(), torch.std(hiddens[i]).cpu().detach().numpy())
            loss = hx_l - config_dict['lambda_y']*hy_l
            loss.backward()
            optimizer.step()
            # sigma_optimizer.step()
        # if config_dict['hsic_solve']:
        #     prec1, reorder_list = misc.get_accuracy_hsic(model, data_loader)
        batch_acc.update(prec1)
        batch_loss.update(total_loss)
        batch_hischx.update(hx_l.cpu().detach().numpy())
        batch_hischy.update(hy_l.cpu().detach().numpy())

        # # # preparation log information and print progress # # #

        msg = 'Train Epoch: {cepoch} [ {cidx:5d}/{tolidx:5d} ({perc:2d}%)] H_hx:{H_hx:.4f} H_hy:{H_hy:.4f}'.format(
                        cepoch = cepoch,  
                        cidx = (batch_idx+1)*config_dict['batch_size'], 
                        tolidx = n_data,
                        perc = int(100. * (batch_idx+1)*config_dict['batch_size']/n_data), 
                        H_hx = batch_hischx.avg, 
                        H_hy = batch_hischy.avg,
                )

        if ((batch_idx+1) % config_dict['log_batch_interval'] == 0):

            batch_log['batch_acc'].append(batch_loss.avg)
            batch_log['batch_loss'].append(batch_acc.avg)
            batch_log['batch_hsic_hx'].append(batch_hischx.avg)
            batch_log['batch_hsic_hy'].append(batch_hischy.avg)

        pbar.set_description(msg)

        # if cepoch==1:
        #     data = activations_extraction(model, data_loader)
        #     _code_name = [config_dict['task'], TTYPE_HSICTRAIN, config_dict['data_code']+"_batch", batch_idx]
        #     filepath = get_act_path(*_code_name)
        #     save_logs(data, filepath)

    return batch_log

In [None]:
# create neural network model

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        
        self.layer1 = nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 5,
                                stride = 2, padding = 2)
        self.layer2 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 5, 
                               stride = 2, padding = 2)
        self.layer3 = nn.Conv2d(in_channels = 128, out_channels = 256, kernel_size = 5, 
                               stride = 2, padding = 2)
        self.mxpool = nn.MaxPool2d(2, 2)
        self.mxpool1 = nn.MaxPool2d(2, 2, padding = 1)

        self.linear1 = nn.Linear(4096, 1024)
        self.linear2 = nn.Linear(1024, 1024)
        self.linear3 = nn.Linear(1024, 100)
        
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = self.mxpool(x)
        x = F.relu(self.layer2(x))
        x = self.mxpool1(x)
        x = F.relu(self.layer3(x))
        x = self.mxpool(x)
        
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x) # using nn.CrossEntropy() that combines nn.LogSoftmax() and nn.NLLLoss() in one single class
        return x


In [None]:
# create training pipeline