In [38]:
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from resnet_cifar import resnet50
import torch
import argparse
import pickle
import os
import random
from PIL import Image
from main import build_dt
from torch.autograd import Variable

# Load Data

In [2]:
DATA_PATH = "../../data/cifar"
BATCH_SIZE = 64 

In [3]:
transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2471, 0.2435, 0.2616]),
        ]
    )

In [89]:
dataset = CIFAR10(root=DATA_PATH, train=True, transform=transform)

train_dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        num_workers=1,
        drop_last=True,
        pin_memory=True,
)

In [90]:
dataset = CIFAR10(root=DATA_PATH, train=False, transform=transform)

test_dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        num_workers=1,
        drop_last=True,
        pin_memory=True,
)

# Load Model

In [8]:
model = resnet50(pretrained=True)
model.eval() # for evaluation

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

## Test accuracy

In [12]:
import utils
import matplotlib.pyplot as plt

from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

def test_loop(dataloader, model):
    size = len(dataloader.dataset)
    print("Taille ", size)
    num_batches = len(dataloader)
    print("Batchs ", num_batches)
    correct = 0
   
    
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            
            correct+= (pred.argmax(1)==y).type(torch.float).sum().item()
            print(100*correct/size)
    correct/= size
    return 100*correct
test_loop(test_dataloader, model)

Taille  50000
Batchs  781
0.128
0.256
0.384
0.512
0.64
0.766
0.894
1.022
1.15
1.278
1.406
1.532
1.66
1.786
1.914
2.042
2.17
2.298
2.424
2.55
2.678
2.806
2.934
3.062
3.19
3.318
3.446
3.574
3.702
3.83
3.958
4.086
4.214
4.342
4.47
4.598
4.726
4.854
4.982
5.11
5.238
5.366
5.494
5.62
5.748
5.876
6.004
6.132
6.26
6.386
6.514
6.642
6.77
6.898
7.026
7.154
7.282
7.41
7.538
7.666
7.792
7.92
8.046
8.174
8.302
8.43
8.558
8.686
8.814
8.942
9.068
9.194
9.322
9.45
9.578
9.706
9.834
9.96
10.088
10.212
10.34
10.468
10.596
10.724
10.852
10.98
11.108
11.236
11.364
11.492
11.62
11.746
11.874
12.002
12.128
12.256
12.384
12.512
12.64
12.766
12.894
13.022
13.148
13.276
13.404
13.532
13.66
13.788
13.914
14.042
14.17
14.298
14.426
14.554
14.682
14.81
14.938
15.066
15.194
15.322
15.45
15.578
15.706
15.834
15.958
16.082
16.21
16.336
16.464
16.59
16.718
16.846
16.974
17.102
17.23
17.358
17.486
17.614
17.742
17.87
17.998
18.126
18.254
18.382
18.51
18.638
18.766
18.894
19.022
19.15
19.278
19.406
19.534
19.662
19.79

99.67399999999999

# Build and Train DT model

In [156]:
MODEL_PATH = "dtmodels/cifar/bestmodel8"
DATA_LOADER_PATH = "loaders/cifar/dataloader.p"
TRAIN_DATA_PATH = "loaders/"
CROP = False

In [91]:
model_labels = []

def modify_loader(ds, loader):
    with torch.no_grad():
        for X,_ in loader:
            pred = model(X)
            curr_labs = pred.argmax(1)
            model_labels.append(curr_labs)
            
            
modify_loader(modif_dataset, train_dataloader)

### Model parameters

In [None]:
# Training settings
parser = argparse.ArgumentParser(description='DT model')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--input-dim', type=int, default=32 * 32 * 3, metavar='N',
                    help='input dimension size(default: 224 * 224 * 3)')
parser.add_argument('--output-dim', type=int, default=10, metavar='N',
                    help='output dimension size(default: 2)')
parser.add_argument('--max-depth', type=int, default=8, metavar='N',
                    help='maximum depth of tree(default: 8)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 5)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--lmbda', type=float, default=0.01, metavar='LR',
                    help='temperature rate (default: 0.1)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')

args = parser.parse_args("")

args.cuda = False
# args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)


### Model training

In [None]:
def train_(tree, train_loader, epoch, model_labels):
    tree.train()
    tree.define_extras(tree.args.batch_size)
    for batch_idx, (data, _) in enumerate(train_loader):
        correct = 0
        target = model_labels[batch_idx]
        #data = torch.cat(list(map(img_to_tensor, data)))
        if tree.args.cuda:
            data, target = data.cuda(), target.cuda()
        # data = data.view(self.args.batch_size,-1)
        target = Variable(target)
        target_ = target.view(-1, 1)
        batch_size = target_.size()[0]
        data = data.view(batch_size, -1)
        ##convert int target to one-hot vector
        data = Variable(data)
        
        if not batch_size == tree.args.batch_size:  # because we have to initialize parameters for batch_size, tensor not matches with batch size cannot be trained
            tree.define_extras(batch_size)
        tree.target_onehot.data.zero_()

        tree.target_onehot.scatter_(1, target_, 1.)
        tree.optimizer.zero_grad()

        loss, output = tree.cal_loss(data, tree.target_onehot)
        # loss.backward(retain_variables=True)
        loss.backward()
        tree.optimizer.step()
        pred = output.data.max(1)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data).cpu().sum()
        accuracy = 100. * correct / len(data)

        if batch_idx % tree.args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accuracy: {}/{} ({:.4f}%)'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.data.item(),
                correct, len(data),
                accuracy))

In [None]:
if os.path.isfile(MODEL_PATH):
    model_tree = pickle.load(open(MODEL_PATH, "rb"))
else:
    model_tree = build_dt(args)
    for epoch in range(1, args.epochs + 1):
        train_(model_tree,train_dataloader, epoch, model_labels)
    pickle.dump(model_tree, open(MODEL_PATH, "wb" ))

### Model fidelity

In [140]:
def test_(tree, test_loader):
    tree.eval()
    tree.define_extras(tree.args.batch_size)
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        #data = torch.cat(list(map(partial(utils.img_to_tensor, crop=crop), data)))

        if tree.args.cuda:
            data, target = data.cuda(), target.cuda()
        target = Variable(target)
        target_ = target.view(-1, 1)
        batch_size = target_.size()[0]
        data = data.view(batch_size, -1)
        ##convert int target to one-hot vector
        data = Variable(data)
        if not batch_size == tree.args.batch_size:  # because we have to initialize parameters for batch_size, tensor not matches with batch size cannot be trained
            tree.define_extras(batch_size)
        tree.target_onehot.data.zero_()
        tree.target_onehot.scatter_(1, target_, 1.)
        _, output = tree.cal_loss(data, tree.target_onehot)
        pred = output.data.max(1)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data).cpu().sum()
    accuracy = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Accuracy: {}/{} ({:.4f}%)\n'.format(
        correct, len(test_loader.dataset),
        accuracy))

In [174]:
test_(model_tree, train_dataloader)


Test set: Accuracy: 20657/50000 (41.3140%)



# Model visualization

In [185]:
import graphviz
import queue
import matplotlib.pyplot as plt
import matplotlib
def buildTree(tree):
        Tree = graphviz.Digraph(format='png', graph_attr={"randir": "LR"},
                                node_attr={'shape': "box"})
        all_nodes = gather_nodes(tree)
        for node in all_nodes:
            if not node.leaf:
                weights_viz(tree, node.fc.weight, node.index)
                Tree.node(str(node.index), image="weights_img/weights{}.png".format(node.index),
                          label="", style="rounded,filled")
            else:
                Tree.node(str(node.index),
                          label=str(node.forward().detach().numpy()[0]),
                          fillcolor="green", style="rounded,filled")

        for node in all_nodes:
            if not node.leaf:
                Tree.edge(str(node.index), str(node.right.index))
                Tree.edge(str(node.index), str(node.left.index))
        return Tree

def gather_nodes(tree):
    nodes_queue = queue.Queue()
    nodes_queue.put(tree.root)
    all_nodes = []
    node_index = 0
    while not nodes_queue.empty():
        curr_node = nodes_queue.get()
        curr_node.index = node_index
        all_nodes.append(curr_node)
        if not curr_node.leaf:
            nodes_queue.put(curr_node.left)
            nodes_queue.put(curr_node.right)
        node_index += 1
    return all_nodes

In [238]:
import cv2
import numpy as np
def weights_viz(tree, weights, index):
    weights = torch.sub(weights, torch.min(weights))
    weights = torch.mul(weights, torch.div(255, torch.max(weights)))
    weights = weights.detach().numpy()[0]
    image = []
    image1 = []
    image2 = []
    image3 = []
    for i in range(0, 1024, 32):
        image1.append(weights[i:i + 32])
    for i in range(1024, 2048, 32):
        image2.append(weights[i:i + 32])
    for i in range(2048, 3072, 32):
        image3.append(weights[i:i + 32])
    
    #image.append(np.array(image1))
    #image.append(np.array(image2))
    #image.append(np.array(image3))
    #image = np.array(image)
    #image = np.transpose(np.array(image),(2,1,0))
    
    #image = Image.fromarray(np.array(image), 'RGB')
    image = plt.imshow(image1, cmap='gray', vmin=0, vmax=255)
    figure = plt.gcf()
    figure.set_size_inches(32, 32)
    plt.savefig("weights_img/weights{}.png".format(index))

In [None]:
tree = buildTree(model_tree)

KeyboardInterrupt: 

In [None]:
tree.render('dt_viz', view=True)  