In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt

import time
import random
import os
import argparse
import enum
import logging
from datetime import datetime
import json
from models import block_regularizer, compute_layer_blocks_in, compute_layer_blocks_out
from spectral_utils import normalize_w

from three_d_shapes_ds import ThreeDShapes
from col_mnist import ColMNIST
from models import DisentangledLinear, BlockDropout

In [2]:
dateTimeObj = datetime.now()
timestampStr = dateTimeObj.strftime("%d.%m.%Y(%H:%M:%S)")

In [3]:
class SupportedDatasets(enum.Enum):
    THREEDSHAPES = 0,
    COL_MNIST = 1

In [4]:
#!export CUDA_VISIBLE_DEVICES=2,5
#!echo $CUDA_VISIBLE_DEVICES
torch.cuda.device_count()

2

In [5]:
parser = argparse.ArgumentParser()

# Common params
parser.add_argument("--dataset", type=str, choices=[ds.name for ds in SupportedDatasets], help="", 
                    default=SupportedDatasets.COL_MNIST.name)
parser.add_argument("--input_layer", type=str, help="Layer to disentangle", default='cl0')
parser.add_argument("--output_layer", type=str, help="Layer to disentangle", default='cl3')
parser.add_argument("--blocks", type=int, help="Number of blocks", default=2)
parser.add_argument("--layer_size", type=int, help="Size of the disentangled layer", default=4096)
parser.add_argument("--prune_by", type=int, help="How many neurons we want to remove", default=2048)
parser.add_argument("--data_dir", type=str, help="Directory to load data from", default='data')
parser.add_argument("--load_model", type=str, help="")
parser.add_argument("--save_dir", type=str, help="Directory to save models, logs and plots to", 
                    default=os.path.join("outputs", timestampStr))
parser.add_argument("--deterministic", type=bool, help="", default=False)
parser.add_argument("--gpus", type=str, help="", default=None)
parser.add_argument("--batch_size", type=int, help="", default=8)
parser.add_argument("--n_epochs", type=int, help="", default=30)
parser.add_argument("--dropout_p", type=float, help="Probability of block dropout", default=0.8)
parser.add_argument("--optimizer", type=str, help="Optimizer", choices=["SGD", "Adam"], default="SGD")

args = parser.parse_args(['--gpus', '2'])  # important to put '' in Jupyter otherwise it will complain

config = dict()
# Wrapping configuration into a dictionary
for arg in vars(args):
    config[arg] = getattr(args, arg)
    
if not os.path.exists(config["save_dir"]):
    os.makedirs(config["save_dir"])
    
#logging.basicConfig(filename=os.path.join(config["save_dir"], "run.log"), level=logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)

config['data_dir'] = os.path.basename(config['data_dir'])  # handle absolute and relative paths
config['save_dir'] = os.path.basename(config['save_dir'])
print("Saving and logging to {}".format(config['save_dir']))

if config['deterministic']:
    torch.manual_seed(123)
    torch.cuda.manual_seed(123)
    np.random.seed(123)
    random.seed(123)
    torch.backends.cudnn.enabled=False
    torch.backends.cudnn.deterministic=True

device = "cpu"
if len(config["gpus"]) > 0:
    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"]=",".join(config["gpus"])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #change to actual args
    
logging.info("Config {}".format(json.dumps(config, indent=4)))
logging.info("Using {}".format(device))

INFO:root:Config {
    "dataset": "COL_MNIST",
    "input_layer": "cl0",
    "output_layer": "cl3",
    "blocks": 2,
    "layer_size": 4096,
    "prune_by": 2048,
    "data_dir": "data",
    "load_model": null,
    "save_dir": "28.09.2021(10:04:20)",
    "deterministic": false,
    "gpus": "2",
    "batch_size": 8,
    "n_epochs": 30,
    "dropout_p": 0.8,
    "optimizer": "SGD"
}
INFO:root:Using cuda


Saving and logging to 28.09.2021(10:04:20)


In [6]:
if config["dataset"] == SupportedDatasets.THREEDSHAPES.name:
    trainloader = torch.utils.data.DataLoader(
                                          ThreeDShapes(filename=os.path.join(config["data_dir"], "3dshapes.h5"),
                                                       transform=torchvision.transforms.Compose([
                                                           torchvision.transforms.ToPILImage(), 
                                                           torchvision.transforms.Resize((32, 32)),
                                                           torchvision.transforms.ToTensor()]), filtered = True),
                                          batch_size=config["batch_size"], shuffle=True)

    testloader = torch.utils.data.DataLoader(
                                          ThreeDShapes(filename=os.path.join(config["data_dir"], "3dshapes.h5"),
                                                       transform=torchvision.transforms.Compose([
                                                           torchvision.transforms.ToPILImage(), 
                                                           torchvision.transforms.Resize((32, 32)),
                                                           torchvision.transforms.ToTensor()]), filtered = True),
                                          batch_size=config["batch_size"], shuffle=True)

    n_classes = 16
    def target_vec_to_class(vec):
        labels = (vec[:, 0] == 0).int()*(2**3) + (vec[:, 1] == 0).int()*(2**2) + (vec[:, 2] == 0)*2 + (vec[:, 4] == 0)
        return labels.long()
    
elif config["dataset"] == SupportedDatasets.COL_MNIST.name:
    trainloader = torch.utils.data.DataLoader(
      ColMNIST(os.path.join(config["data_dir"], "mnist"), train=True, download=True,
                                 transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                 ])),
      batch_size=config["batch_size"], shuffle=True)

    testloader = torch.utils.data.DataLoader(
      ColMNIST(os.path.join(config["data_dir"], "mnist"), train=False, download=True,
                                 transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor()
                                 ])),
      batch_size=config["batch_size"], shuffle=True)
    
    n_classes = 30
    def target_vec_to_class(tpl):
        (target, dclr_idx, bclr_idx) = tpl
        target += bclr_idx*10
        return target.long()

else:
    logging.error("Dataset not supported")
    
for data, target in testloader:
    break
img_shape = data.shape[1:]

In [7]:
#config["load_model"] = "models/vgg16disen_colmnist_e29.pt"
if config["load_model"] is not None:
    vgg16 = torch.load(config["load_model"])
else:
    vgg16 = models.vgg16(pretrained=True)
vgg16.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [8]:
ncc = config["blocks"] #number of connected components
#TODO: add option of not specifying layer size in case we're loading

if config["input_layer"]=="cl3" and config["output_layer"]=="cl6":
    vgg16.classifier[3] = DisentangledLinear(vgg16.classifier[3].in_features, config["layer_size"]).to(device)
    vgg16.classifier[6] = DisentangledLinear(config["layer_size"], n_classes).to(device)
    vgg16.classifier[5] = BlockDropout(vgg16.classifier[6], ncc=ncc, p=config["dropout_p"], apply_to="in")
elif config["input_layer"]=="cl0" and config["output_layer"]=="cl3":
    # disentangle layers right after convolutions
    vgg16.classifier[0] = DisentangledLinear(vgg16.classifier[0].in_features, config["layer_size"]).to(device)
    vgg16.classifier[3] = DisentangledLinear(config["layer_size"], vgg16.classifier[3].out_features).to(device)
    vgg16.classifier[2] = BlockDropout(vgg16.classifier[3], ncc=ncc, p=config["dropout_p"], apply_to="in")
else:
    logging.error("Layer combination not supported")

for param in vgg16.features.parameters():
    param.requires_grad = True
    
logging.info(vgg16)

INFO:root:VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, 

In [9]:
if config["optimizer"] == "Adam":
    optimizer = optim.Adam(vgg16.classifier.parameters(), lr=0.001)
else:
    optimizer = optim.SGD(vgg16.classifier.parameters(), lr=0.001, momentum=0.9)
# loss function
criterion = nn.CrossEntropyLoss()

In [10]:
# validation function
def validate(model, test_dataloader):
    model.eval()
    val_running_loss = 0.0
    val_running_correct = 0
    for int, data in enumerate(test_dataloader):
        data, target = data[0], data[1]
        target = target_vec_to_class(target)
        
        data = data.to(device)
        target = target.to(device)
        output = model(data)
        loss = criterion(output, target)
        
        val_running_loss += loss.item()
        _, preds = torch.max(output.data, 1)
        val_running_correct += (preds == target).sum().item()
    
    val_loss = val_running_loss/len(test_dataloader.dataset)
    val_accuracy = 100. * val_running_correct/len(test_dataloader.dataset)
    logging.info(f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}')
    
    return val_loss, val_accuracy

def neuron_wise_br(model, layer, blocks, examples, ncc):
    model.eval()
    br_wo_neuron = [np.inf]*layer.out_features
    for n in range(layer.out_features):
        if layer.out_mask is not None:
            mask = torch.clone(layer.out_mask)
        else:
            mask = torch.ones(layer.out_features, dtype=torch.bool)
        if mask[n] > 0:
            mask[n] = 0
            a_n = normalize_w(layer.weight[mask])
            _, s, _ = torch.svd(a_n)
            br_wo_neuron[n] = ncc - torch.sum(s[:ncc]).detach().cpu()
    return br_wo_neuron

def plot_blocked_weights(layer):
    plt.figure(figsize=(20, 7))
    blocks_in = compute_layer_blocks_in(layer, ncc)
    blocks_out = compute_layer_blocks_out(layer, ncc)
    plt.imshow(layer.weight[np.argsort(blocks_out)][:, np.argsort(blocks_in)].cpu().detach().numpy())
    plt.show()
    
def prune(model, layer_out, layer_in, ncc):
    blocks = compute_layer_blocks_in(layer_out, ncc)
    for batch_features in testloader:
        batch_features = batch_features[0]
        test_examples = batch_features.to(device)
        break
    re = neuron_wise_br(model, layer_in, blocks, test_examples, ncc)
    removal_mask = layer_in.out_mask
    if removal_mask is None:
        removal_mask = torch.ones(layer_in.out_features, dtype=torch.bool)
    removal_mask[np.argmin(re)] = 0
    layer_out.turn_input_neurons_off(removal_mask)
    layer_in.turn_output_neurons_off(removal_mask)
    logging.info("Pruned to {} neurons".format(layer_out.in_mask.sum().item()))
    
def fit(model, train_dataloader, prune_every_n_steps):
    model.train()
    train_running_loss = 0.0
    train_running_correct = 0
    for i, data in enumerate(train_dataloader):
        data, target = data[0], data[1]
        target = target_vec_to_class(target)
        
        #data = data.to(device)
        #target = target.to(device)
        optimizer.zero_grad()
        output = model(data)
        block_reg = block_regularizer(model.module.classifier[3], ncc)
        loss = criterion(output.cpu(), target)# + block_reg
        #loss = block_reg
        train_running_loss += loss.item()
        _, preds = torch.max(output.data, 1)
        train_running_correct += (preds.cpu() == target).sum().item()
        loss.backward()
        optimizer.step()
        if (i)%prune_every_n_steps == 0:
            logging.info("Block regularizer "+str(block_reg.item()))
            #plot_blocked_weights(vgg16.classifier[6])
            #plot_blocked_weights(vgg16.classifier[3])
            if config["input_layer"]=="cl3" and config["output_layer"]=="cl6":
                prune(model.module, model.module.classifier[6], model.module.classifier[3], ncc)
            if config["input_layer"]=="cl0" and config["output_layer"]=="cl3":
                prune(model.module, model.module.classifier[3], model.module.classifier[0], ncc)
            
    train_loss = train_running_loss/len(train_dataloader.dataset)
    train_accuracy = 100. * train_running_correct/len(train_dataloader.dataset)
    logging.info(f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}')
    
    return train_loss, train_accuracy, block_reg.item()

In [11]:
n_epochs = config["n_epochs"]
total_batches = len(trainloader)*n_epochs
layer_size_reduction = config["prune_by"]
prune_every_n_steps = int(round(total_batches/(layer_size_reduction)))

In [12]:
vgg16_parallel = nn.DataParallel(vgg16, device_ids = [0,1])

In [13]:
train_loss , train_accuracy = [], []
val_loss , val_accuracy, br = [], [], []

start = time.time()
for epoch in range(n_epochs):
    start_e = time.time()
    logging.info("Epoch {}".format(epoch))
    train_epoch_loss, train_epoch_accuracy, block_reg = fit(vgg16_parallel, trainloader, prune_every_n_steps)
    val_epoch_loss, val_epoch_accuracy = validate(vgg16_parallel, testloader)
    train_loss.append(train_epoch_loss)
    train_accuracy.append(train_epoch_accuracy)
    br.append(block_reg)
    val_loss.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)
    torch.save(vgg16_parallel.module, os.path.join(config["save_dir"], "model.pt"))
    end_e = time.time()
    logging.info('Epoch {} took {} minutes '.format(epoch+1, (end-start)/60))
    
end = time.time()
logging.info('{} minutes in total'.format((end-start)/60))

INFO:root:Epoch 0
INFO:root:Block regularizer 0.9821493625640869


KeyboardInterrupt: 