In [1]:
#%env CUDA_VISIBLE_DEVICES=2
import setGPU
import numpy as np
import matplotlib.pyplot as plt

import torch.cuda
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import models #,transforms, utils,

#import math
import time
import shutil

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

setGPU: Setting GPU to: 2


In [2]:
from JupyterLoader import NotebookFinder
import sys
sys.meta_path.append(NotebookFinder())
from dataloaders.GreyColor_dataloader import ImageNetGenerator

data_gen = ImageNetGenerator(data_folder = '/home/frati/Grasping/ImageNet/')

importing Jupyter notebook from /home/frati/new_Grasping/code/dataloaders/GreyColor_dataloader.ipynb


In [3]:
from utilities import AverageMeter, accuracy, get_trainable_parameters, varify

importing Jupyter notebook from utilities.ipynb


In [4]:
#from models.colorNet import ColorNet

In [5]:
class ColorNet(nn.Module):
    def __init__(self, hooks=False, pretrained=None):
        super(ColorNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 1)
        self.conv2 = nn.Conv2d(10, 3, 1)
        self.conv3 = nn.Conv2d(6, 3, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(3,affine=True)
        self.bn1 = nn.BatchNorm2d(10,affine=True)

        if pretrained is None:
            self.init_weights()
        else:
            state = torch.load(pretrained, map_location=lambda storage, loc: storage)
            self.load_state_dict(state['state_dict'])
        
        def printgradnorm(self, grad_input, grad_output):
            print('Inside ' + self.__class__.__name__ + ' backward')
            print('{} -> {}'.format(grad_input[0].size(),grad_output[0].size()))
            print('grad_in norm: {}'.format(grad_input[0].data.norm()))
            print('grad_out norm: {}'.format(grad_output[0].data.norm()))
                  
        if hooks:
            self.conv1.register_backward_hook(printgradnorm)
            self.conv2.register_backward_hook(printgradnorm)
     
    def forward(self, x):
        
        residual = torch.cat([x,x,x],dim=1)
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x) 
        x = torch.cat([x,residual],dim=1)
        x = self.conv3(x)
        
        return x
    
    def init_weights(self):
        """Initialize the weights."""
        self.conv1.weight.data.normal_(0, 0.02)
        self.conv1.bias.data.fill_(0)
        self.conv2.weight.data.normal_(0, 0.02)
        self.conv2.bias.data.fill_(0)

In [6]:
t = ColorNet()
fake_batch = Variable(torch.rand(4,1,224,224))
t(fake_batch).shape

torch.Size([4, 3, 224, 224])

In [7]:
def distill(data_loaders, model, reference_model, criterion, optimizer, epochs):
    since = time.time()
    def loss_fn_kd(outputs, labels, teacher_outputs, temperature=3., balance = 0.4):
        """
        Compute the knowledge-distillation (KD) loss given outputs, labels.
        "Hyperparameters": temperature and alpha
        """
        T = temperature
        alpha = balance
        KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                                 F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + F.cross_entropy(outputs, labels) * (1. - alpha)

        return KD_loss
    
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    num_epochs = epochs
    history = []
    
    best_model_wts = model.state_dict()
    best_acc = 0.0
    
    reference_model.eval()
    
    dataset_sizes = {key:len(val.dataset) for key,val in data_loaders.items()}
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)
        epoch_time = time.time()
        # Each epoch has a training and validation phase
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # Set model to training mode
                volatile_inputs = False
            else:
                model.eval()  # Set model to evaluate mode
                volatile_inputs = True
                
            losses.reset()
            top1.reset()
            top5.reset()
            
            batches = len(data_loaders[phase])
            # Iterate over data.
            for idx,(data) in enumerate(data_loaders[phase]):
                # prepare the inputs
                color_images, grey_images, labels = data
                grey_ims = Variable(grey_images,requires_grad = True).cuda()
                color_ims = Variable(color_images,volatile = True).cuda()
                targets = Variable(labels,volatile = True).cuda()
                
                # zero the parameter gradients
                optimizer.zero_grad() 

                # forward
                reference_outputs = reference_model(color_ims)
                _, reference_preds = torch.max(reference_outputs,1)
                
                outputs = model(grey_ims)
                _,preds = torch.max(outputs,1)
                
                if phase == 'train':
                    #loss = criterion(outputs, reference_preds)
                    loss = loss_fn_kd(outputs,targets,reference_outputs)

                    # measure accuracy and record loss
                    prec1, prec5 = accuracy(outputs, reference_preds, topk=(1, 5))
                    losses.update(loss.data[0], targets.size(0))
                    top1.update(float(prec1[0]), targets.size(0))
                    top5.update(float(prec5[0]), targets.size(0))
                else:
                    loss = criterion(outputs, targets)
                    #loss = loss_fn_kd(outputs,targets,reference_outputs)

                    # measure accuracy and record loss
                    prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
                    losses.update(loss.data[0], targets.size(0))
                    top1.update(float(prec1[0]), targets.size(0))
                    top5.update(float(prec5[0]), targets.size(0))
                
                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                else:
                    model.eval()
                    color = model[0](grey_ims)
                    original = data_gen.plot_color(color_ims[0].data.cpu())
                    colored = data_gen.plot_color(color[0].data.cpu())
                    label = data_gen.translate(labels[0])
                    history.append((original,colored,label))
                    model.train()
                
                # progress
                print("\r                                                                  ",end="")
                #print("\r{}: {}/{} - acc: {:.2f} - loss: {:.2f}".format(phase,idx,batches,correct/num_tags,loss.data[0]),end="")
                print("\r{}: {}/{}".format(phase,idx,batches),end="")
                

            print()
            print('{}: '
              'Loss {loss.avg:.4f}\t'
              'Prec@1 {top1.avg:.3f}\t'
              'Prec@5 {top5.avg:.3f}'.format(phase, loss=losses, top1=top1, top5=top5))
            print('Epoch time: {:.3f}'.format(time.time() - epoch_time))
            print()
            # deep copy the model
            epoch_acc = top1.avg
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

        print()
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60))
    print('Best validation acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model,history

In [6]:
def test(val_loader, model, criterion, color=False):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    for i, (color_images, grey_images, target) in enumerate(val_loader):
        target = target.cuda(async=True)
        if color == True:
            input_var = torch.autograd.Variable(color_images, volatile=True).cuda()
        else:
            input_var = torch.autograd.Variable(grey_images, volatile=True).cuda()
            input_var = torch.cat([input_var,input_var,input_var],dim=1)
            
        target_var = torch.autograd.Variable(target, volatile=True).cuda()

        # compute output
        output = model(input_var)
        loss = criterion(output, target_var)

        # measure accuracy and record loss
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
        losses.update(loss.data[0], target.size(0))
        top1.update(prec1[0], target.size(0))
        top5.update(prec5[0], target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        
        print('Test: [{0}/{1}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
              'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
              'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
               i, len(val_loader), batch_time=batch_time, loss=losses,
               top1=top1, top5=top5))

    print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(top1=top1, top5=top5))

    return top1.avg

In [8]:
# Hyper Parameters
num_epochs = 5
batch_size = 32
#learning_rate = 0.001
learning_rate = 0.01
best_prec1 = 0

In [9]:
color_net = ColorNet()
color_net.cuda()

ColorNet(
  (conv1): Conv2d(1, 10, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(10, 3, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(6, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True)
  (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True)
)

In [10]:
vgg16 = models.vgg16(pretrained=True)
vgg16.cuda()
# Freeze model
for param in vgg16.parameters():
    param.requires_grad = False

In [11]:
get_trainable_parameters(vgg16)

Trainable parameters:	           0
Frozen parameters:	   138357544


[]

In [12]:
model = nn.Sequential(color_net, vgg16)

In [13]:
color_parameters = get_trainable_parameters(model)

Trainable parameters:	         244
Frozen parameters:	   138357544


In [14]:
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(color_parameters, lr=learning_rate)

In [15]:
loaders = data_gen.get_loaders(batch_size=batch_size,shuffle=True,num_workers=8)

In [16]:
for phase,loader in loaders.items():
    print(phase,len(loader))

train 1330
test 143
valid 63


In [17]:
loaders['train'] = loaders['test']

In [18]:
model, history = distill(data_loaders=loaders, model=model, reference_model=vgg16,criterion=criterion, optimizer=optimizer, epochs=70)

Epoch 1/70
----------
train: 142/143                                                    
train: Loss 2.7024	Prec@1 21.699	Prec@5 42.303
Epoch time: 37.615

valid: 62/63                                                      
valid: Loss 4.2391	Prec@1 22.645	Prec@5 45.240
Epoch time: 49.866


Epoch 2/70
----------
train: 142/143                                                    
train: Loss 2.5669	Prec@1 23.626	Prec@5 45.763
Epoch time: 34.537

valid: 62/63                                                      
valid: Loss 4.1666	Prec@1 23.347	Prec@5 44.088
Epoch time: 47.384


Epoch 3/70
----------
train: 142/143                                                    
train: Loss 2.5501	Prec@1 23.757	Prec@5 46.354
Epoch time: 34.528

valid: 62/63                                                      
valid: Loss 4.1207	Prec@1 23.848	Prec@5 45.691
Epoch time: 47.095


Epoch 4/70
----------
train: 142/143                                                    
train: Loss 2.5353	Prec@1 24.699	Prec@

Process Process-60:
Process Process-57:
Process Process-62:
Process Process-64:
Process Process-61:
Process Process-58:
Process Process-63:
Process Process-59:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/

                                                                  valid: 12/63

KeyboardInterrupt
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 50, in _worker_loop
    r = index_queue.get()
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/frati/miniconda3/envs/pytorch3.1/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)


KeyboardInterrupt: 

In [88]:
test(loaders['test'],vgg16,criterion,color=True)

Test: [0/143]	Time 1.018 (1.018)	Loss 1.4996 (1.4996)	Prec@1 59.375 (59.375)	Prec@5 87.500 (87.500)
Test: [1/143]	Time 0.082 (0.550)	Loss 0.6304 (1.0650)	Prec@1 84.375 (71.875)	Prec@5 96.875 (92.188)
Test: [2/143]	Time 0.091 (0.397)	Loss 0.8773 (1.0024)	Prec@1 78.125 (73.958)	Prec@5 90.625 (91.667)
Test: [3/143]	Time 0.085 (0.319)	Loss 1.1305 (1.0345)	Prec@1 87.500 (77.344)	Prec@5 87.500 (90.625)
Test: [4/143]	Time 0.094 (0.274)	Loss 1.6643 (1.1604)	Prec@1 62.500 (74.375)	Prec@5 81.250 (88.750)
Test: [5/143]	Time 0.084 (0.242)	Loss 1.1573 (1.1599)	Prec@1 68.750 (73.438)	Prec@5 90.625 (89.062)
Test: [6/143]	Time 0.083 (0.219)	Loss 1.4287 (1.1983)	Prec@1 71.875 (73.214)	Prec@5 90.625 (89.286)
Test: [7/143]	Time 0.085 (0.202)	Loss 0.8921 (1.1600)	Prec@1 78.125 (73.828)	Prec@5 96.875 (90.234)
Test: [8/143]	Time 0.082 (0.189)	Loss 1.9480 (1.2476)	Prec@1 56.250 (71.875)	Prec@5 87.500 (89.931)
Test: [9/143]	Time 0.081 (0.178)	Loss 1.0805 (1.2309)	Prec@1 62.500 (70.938)	Prec@5 93.750 (90.312)


Test: [84/143]	Time 0.077 (0.092)	Loss 2.0212 (1.2818)	Prec@1 53.125 (68.529)	Prec@5 81.250 (88.603)
Test: [85/143]	Time 0.078 (0.092)	Loss 1.2959 (1.2820)	Prec@1 71.875 (68.568)	Prec@5 87.500 (88.590)
Test: [86/143]	Time 0.080 (0.092)	Loss 1.1333 (1.2803)	Prec@1 78.125 (68.678)	Prec@5 84.375 (88.542)
Test: [87/143]	Time 0.080 (0.092)	Loss 1.7885 (1.2860)	Prec@1 56.250 (68.537)	Prec@5 87.500 (88.530)
Test: [88/143]	Time 0.083 (0.092)	Loss 2.0240 (1.2943)	Prec@1 62.500 (68.469)	Prec@5 84.375 (88.483)
Test: [89/143]	Time 0.089 (0.092)	Loss 1.0242 (1.2913)	Prec@1 62.500 (68.403)	Prec@5 96.875 (88.576)
Test: [90/143]	Time 0.085 (0.092)	Loss 2.0016 (1.2991)	Prec@1 43.750 (68.132)	Prec@5 81.250 (88.496)
Test: [91/143]	Time 0.102 (0.092)	Loss 1.4286 (1.3006)	Prec@1 53.125 (67.969)	Prec@5 87.500 (88.485)
Test: [92/143]	Time 0.077 (0.092)	Loss 1.7996 (1.3059)	Prec@1 71.875 (68.011)	Prec@5 81.250 (88.407)
Test: [93/143]	Time 0.080 (0.092)	Loss 1.2670 (1.3055)	Prec@1 65.625 (67.985)	Prec@5 87.500

67.98773812464611

In [31]:
test(loaders['test'],vgg16,criterion,color=False)

Test: [0/143]	Time 0.820 (0.820)	Loss 1.9504 (1.9504)	Prec@1 53.125 (53.125)	Prec@5 78.125 (78.125)
Test: [1/143]	Time 0.077 (0.448)	Loss 1.9913 (1.9709)	Prec@1 65.625 (59.375)	Prec@5 78.125 (78.125)
Test: [2/143]	Time 0.096 (0.331)	Loss 1.5757 (1.8392)	Prec@1 53.125 (57.292)	Prec@5 93.750 (83.333)
Test: [3/143]	Time 0.085 (0.269)	Loss 2.1791 (1.9242)	Prec@1 37.500 (52.344)	Prec@5 84.375 (83.594)
Test: [4/143]	Time 0.091 (0.234)	Loss 2.8618 (2.1117)	Prec@1 43.750 (50.625)	Prec@5 65.625 (80.000)
Test: [5/143]	Time 0.083 (0.209)	Loss 1.9720 (2.0884)	Prec@1 62.500 (52.604)	Prec@5 75.000 (79.167)
Test: [6/143]	Time 0.090 (0.192)	Loss 1.8885 (2.0598)	Prec@1 56.250 (53.125)	Prec@5 81.250 (79.464)
Test: [7/143]	Time 0.078 (0.177)	Loss 1.4413 (1.9825)	Prec@1 75.000 (55.859)	Prec@5 84.375 (80.078)
Test: [8/143]	Time 0.088 (0.168)	Loss 2.0481 (1.9898)	Prec@1 53.125 (55.556)	Prec@5 75.000 (79.514)
Test: [9/143]	Time 0.076 (0.158)	Loss 2.3105 (2.0219)	Prec@1 43.750 (54.375)	Prec@5 71.875 (78.750)


Test: [84/143]	Time 0.074 (0.089)	Loss 2.3153 (2.0487)	Prec@1 43.750 (52.794)	Prec@5 78.125 (77.610)
Test: [85/143]	Time 0.075 (0.089)	Loss 1.7025 (2.0447)	Prec@1 50.000 (52.762)	Prec@5 84.375 (77.689)
Test: [86/143]	Time 0.074 (0.089)	Loss 2.4964 (2.0499)	Prec@1 34.375 (52.550)	Prec@5 78.125 (77.694)
Test: [87/143]	Time 0.074 (0.089)	Loss 1.5056 (2.0437)	Prec@1 62.500 (52.663)	Prec@5 84.375 (77.770)
Test: [88/143]	Time 0.074 (0.088)	Loss 2.8244 (2.0524)	Prec@1 31.250 (52.423)	Prec@5 56.250 (77.528)
Test: [89/143]	Time 0.074 (0.088)	Loss 2.5306 (2.0578)	Prec@1 50.000 (52.396)	Prec@5 71.875 (77.465)
Test: [90/143]	Time 0.074 (0.088)	Loss 2.2441 (2.0598)	Prec@1 50.000 (52.370)	Prec@5 68.750 (77.370)
Test: [91/143]	Time 0.075 (0.088)	Loss 2.3531 (2.0630)	Prec@1 43.750 (52.276)	Prec@5 71.875 (77.310)
Test: [92/143]	Time 0.074 (0.088)	Loss 2.8555 (2.0715)	Prec@1 43.750 (52.184)	Prec@5 65.625 (77.184)
Test: [93/143]	Time 0.074 (0.088)	Loss 1.5988 (2.0665)	Prec@1 65.625 (52.327)	Prec@5 81.250

52.02539960169182