In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import time
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import models

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
# check GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


In [5]:
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument("--input_layer", type=str, help="Layer to disentangle", default='cl0')
parser.add_argument("--output_layer", type=str, help="Layer to disentangle", default='cl3')
parser.add_argument("--blocks", type=int, help="Number of blocks", default=4)
parser.add_argument("--layer_size", type=int, help="Size of the disentangled layer", default=400)
parser.add_argument("--prune_by", type=int, help="How many neurons we want to remove", default=200)
parser.add_argument("--data_dir", type=str, help="Directory to load data from", default='data')
parser.add_argument("--load_model", type=str, help="")
parser.add_argument("--deterministic", dest="deterministic", action="store_true")
parser.add_argument("--no_dt_labels", dest="dt_labels", action="store_false")
parser.add_argument("--test_dt_tree", dest="test_dt_tree", action="store_true")
parser.add_argument("--homebrew_model", dest="homebrew_model", action="store_true")
parser.add_argument("--not_pretrained", dest="pretrained", action="store_false")
parser.add_argument("--filtered", help="Filter 3dshapes dataset (otherwise the decision tree labels are used)",
                     dest="filtered", action="store_true")
parser.add_argument("--gpus", type=str, help="", default=None)
parser.add_argument("--batch_size", type=int, help="", default=32)
parser.add_argument("--n_epochs", type=int, help="", default=30)
parser.add_argument("--img_size", type=int, help="", default=32)
parser.add_argument("--dropout_p", type=float, help="Probability of block dropout", default=0.5)
parser.add_argument("--optimizer", type=str, help="Optimizer", choices=["SGD", "Adam"], default="SGD")
parser.add_argument("--br_coef", type=float, help="Block regularizer coefficient", default=0)

args = parser.parse_args('')  # important to put '' in Jupyter otherwise it will complain
parser.set_defaults(filtered=False, deterministic=False, dt_labels=True, homebrew_model=False, pretrained=True, test_dt_tree=False)

config = dict()
# Wrapping configuration into a dictionary
for arg in vars(args):
    config[arg] = getattr(args, arg)
    

In [6]:
#%pdb on

In [7]:

from three_d_shapes_ds import ThreeDShapes

trainloader = torch.utils.data.DataLoader(
                                          ThreeDShapes(filename=os.path.join(config["data_dir"], "3dshapes.h5"),
                                                       transform=torchvision.transforms.Compose([
                                                           torchvision.transforms.ToPILImage(), 
                                                           torchvision.transforms.Resize((config["img_size"], config["img_size"])),
                                                           torchvision.transforms.ToTensor()]), 
                                                           train = True,
                                                           filtered = config["filtered"],
                                                           dt_labels=config["dt_labels"],
                                                           test_dt_labels=config["test_dt_tree"]),
                                          batch_size=config["batch_size"], shuffle=True)

    
testloader = torch.utils.data.DataLoader(
                                          ThreeDShapes(filename=os.path.join(config["data_dir"], "3dshapes.h5"),
                                                       transform=torchvision.transforms.Compose([
                                                           torchvision.transforms.ToPILImage(), 
                                                           torchvision.transforms.Resize((config["img_size"], config["img_size"])),
                                                           torchvision.transforms.ToTensor()]), 
                                                           train = False,
                                                           filtered = config["filtered"],
                                                           dt_labels=config["dt_labels"],
                                                           test_dt_labels=config["test_dt_tree"]),
                                          batch_size=config["batch_size"], shuffle=True)

if config["filtered"]:
    assert((not config["dt_labels"]))
    n_classes = 16
    def target_vec_to_class(vec):
        labels = (vec[:, 0] == 0).int()*(2**3) + (vec[:, 1] == 0).int()*(2**2) + (vec[:, 2] == 0)*2 + (vec[:, 4] == 0)
        return labels.long()
else: #decision tree labels
    assert(config["dt_labels"])
    n_classes = 8
    def target_vec_to_class(tpl):
        latents, labels = tpl      
        return labels.long()

[0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 132, 136, 140, 144, 148, 152, 156, 160, 164, 168, 172, 176, 180, 184, 188, 192, 196, 200, 204, 208, 212, 216, 220, 224, 228, 232, 236, 240, 244, 248, 252, 256, 260, 264, 268, 272, 276, 280, 284, 288, 292, 296, 300, 304, 308, 312, 316, 320, 324, 328, 332, 336, 340, 344, 348, 352, 356, 360, 364, 368, 372, 376, 380, 384, 388, 392, 396]


KeyboardInterrupt: 

In [None]:
vgg16 = models.vgg16(pretrained=True)
vgg16.to(device)

In [None]:
# change the number of classes 
vgg16.classifier[6] = nn.Linear(vgg16.classifier[6].in_features, n_classes)
# freeze convolution weights
for param in vgg16.features.parameters():
    param.requires_grad = False
    
vgg16 = vgg16.to(device)
print(vgg16)

In [None]:
# optimizer
optimizer = optim.SGD(vgg16.classifier.parameters(), lr=0.001, momentum=0.9)
# loss function
criterion = nn.CrossEntropyLoss()

In [None]:
for i, data in enumerate(trainloader):
    print(i, end='\r')
    

In [None]:

for i, data in enumerate(testloader):
    print(i, end='\r')

In [None]:
# validation function
def validate(model, test_dataloader):
    model.eval()
    val_running_loss = 0.0
    val_running_correct = 0
    for int, data in enumerate(test_dataloader):
        data, target = data[0], data[1]
        data = data.to(device)
        target = target_vec_to_class(target).to(device)
        output = model(data.cuda())
        loss = criterion(output, target)
        
        val_running_loss += loss.item()
        _, preds = torch.max(output.data, 1)
        val_running_correct += (preds == target).sum().item()
    
    val_loss = val_running_loss/len(test_dataloader.dataset)
    val_accuracy = 100. * val_running_correct/len(test_dataloader.dataset)
    print(f'Val Acc: {val_accuracy:.2f}')
    
    return val_loss, val_accuracy

In [None]:
# training function
def fit(model, train_dataloader):
    model.train()
    train_running_loss = 0.0
    train_running_correct = 0
    for i, data in enumerate(train_dataloader):
        data, target = data[0], data[1]
        data = data.to(device)
        target = target_vec_to_class(target).to(device)
        optimizer.zero_grad()
        output = model(data.cuda())
        loss = criterion(output, target)
        train_running_loss += loss.item()
        _, preds = torch.max(output.data, 1)
        train_running_correct += (preds == target).sum().item()
        loss.backward()
        optimizer.step()
    train_loss = train_running_loss/len(train_dataloader.dataset)
    train_accuracy = 100. * train_running_correct/len(train_dataloader.dataset)
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}')
    
    return train_loss, train_accuracy

In [None]:
train_loss , train_accuracy = [], []
val_loss , val_accuracy = [], []
start = time.time()
for epoch in range(10):
    print(epoch)
    train_epoch_loss, train_epoch_accuracy = fit(vgg16, trainloader)
    val_epoch_loss, val_epoch_accuracy = validate(vgg16, testloader)
    train_loss.append(train_epoch_loss)
    train_accuracy.append(train_epoch_accuracy)
    val_loss.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)
end = time.time()
print((end-start)/60, 'minutes')

In [None]:
plt.figure(figsize=(10, 7))
plt.plot(train_accuracy, color='green', label='train accuracy')
plt.plot(val_accuracy, color='blue', label='validataion accuracy')
plt.legend()
plt.savefig('accuracy.png')
plt.show()

In [None]:
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(val_loss, color='red', label='validataion loss')
plt.legend()
plt.savefig('loss.png')
plt.show()