In [22]:
import sys
sys.path.append("../")
import os

In [23]:
from library.deep_dream import *
from library.dict_network.dict_net import *
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import cv2

In [24]:
def createInputImage():
    input_size = (32,128)
    zeroImage_np = np.random.random(input_size)*255
    zeroImage = Image.fromarray((zeroImage_np).astype('uint8'),'L')
    
    return zeroImage

In [25]:
def prepInputImage(inputImage):
    preprocess = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.47,),(0.14,))
    ])

    return preprocess(inputImage)

In [26]:
def setGaussianFilter(kernelSize=3,sigma=0.5):

    # Create a x, y coordinate grid of shape (kernelSize, kernelSize, 2)
    x_cord = torch.arange(kernelSize)
    x_grid = x_cord.repeat(kernelSize).view(kernelSize, kernelSize)
    y_grid = x_grid.t()
    xy_grid = torch.stack([x_grid, y_grid], dim=-1)
    xy_grid = xy_grid.float()

    mean = (kernelSize - 1)/2.
    variance = sigma**2.


    # Calculate the 2-dimensional gaussian kernel which is
    # the product of two gaussian distributions for two different
    # variables (in this case called x and y)
    gaussian_kernel = (1./(2.*math.pi*variance)) * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) /(2*variance))
    # Make sure sum of values in gaussian kernel equals 1.
    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

    # Reshape to 2d depthwise convolutional weight
    gaussian_kernel = gaussian_kernel.view(1, 1, kernelSize, kernelSize)
#             gaussian_kernel = gaussian_kernel.repeat(3, 1, 1, 1)

    pad = math.floor(kernelSize/2)

    gauss_filter = nn.Conv2d(in_channels=1, out_channels=1,padding=pad,
                        kernel_size=kernelSize, groups=1, bias=False)

    gauss_filter.weight.data = gaussian_kernel
    gauss_filter.weight.requires_grad = False

    return gauss_filter

In [27]:
def postProcess(image):
    image_tensor = torch.squeeze(image.data) # remove the batch dimension   
    image_tensor = image_tensor*0.14 + 0.47 # std and mean for mjsynth 
    image_tensor = image_tensor.cpu() # back to host

    img = Image.fromarray((image_tensor.data.numpy()*255).astype('uint8'), 'L') #torch tensor to PIL image_tensor

    return img

In [28]:
gaussian_filter = setGaussianFilter()

**Loss type 1 : out[label]**
**Loss type_2 : log(softmax(out)[label])**
**Loss type_3 : log(sigmoid(out)[label])**

In [29]:
def dream(net,label,nItr,loss_type=1,lr=0.1,g_filter=True):
    im = createInputImage()
    im = prepInputImage(im)
    im = Variable(im.unsqueeze(0),requires_grad=True)
    
    net.eval()
    prob = F.softmax(net(im),dim=1)[0,label]
    print("Probablity of correct label given random image: ",prob.item())
    for i in range(nItr):
        out = net(im)
        if loss_type == 1:
            loss = out[0,label]
        elif loss_type == 2:
            loss = torch.log(F.softmax(out,dim=1)[0,label])
        elif loss_type == 3:
            loss = torch.log(torch.sigmoid(out)[0,label])
        else:
            print("Loss type not recognized")
            return 0
                
        loss.backward()
#         print("F norm of gradient :", torch.norm(im.grad.data,p='fro'))
        im.data += lr * im.grad.data
        im.data = torch.clamp(im.data,-1,1)
        
        if g_filter==True:
            im.data = gaussian_filter(im.data)
        
        im.grad.data.zero_()
    
    prob = F.softmax(net(im),dim=1)[0,label]
    print("Probablity of correct label given random image: ",prob.item())
    return im

In [30]:
net = DictNet(1000)
net.load_state_dict(torch.load("../code/train_dict_network/out5/net_1000_0.001_200_0.0.pth",map_location=torch.device('cpu')))

<All keys matched successfully>

In [31]:
labels = {0:"unfix",1:"pluck",2:"toked",3:"brawl"}
loss_types = {1:"activation",2:"log_softmax",3:"log_sigmoid"}
for label_num, label_word in labels.items():
    for loss_num,loss_string in loss_types.items():
        img = dream(net,label=label_num,nItr=100,loss_type=loss_num,g_filter=True) # word 'unfix'
        img_pil = postProcess(img)
        file_name = label_word+"_"+loss_string+".png"
        img_pil.save(file_name)
        print("{} saved ".format(file_name))

Probablity of correct label given random image:  9.422208790965669e-08
Probablity of correct label given random image:  1.0
unfix_activation.png saved 
Probablity of correct label given random image:  1.0606801197354798e-06
Probablity of correct label given random image:  0.9549229741096497
unfix_log_softmax.png saved 
Probablity of correct label given random image:  1.6513338323420612e-06
Probablity of correct label given random image:  0.9934980869293213
unfix_log_sigmoid.png saved 
Probablity of correct label given random image:  7.928912236820906e-05
Probablity of correct label given random image:  1.0
pluck_activation.png saved 
Probablity of correct label given random image:  8.628621372963607e-09
Probablity of correct label given random image:  0.9583163857460022
pluck_log_softmax.png saved 
Probablity of correct label given random image:  2.383078708589892e-06
Probablity of correct label given random image:  0.9916452765464783
pluck_log_sigmoid.png saved 
Probablity of correct 

In [32]:
# img = dream(net,label=1,nItr=100,g_filter=True) # word 'pluck'
# img_pil = postProcess(img)
# show(img_pil)

In [33]:
# img = dream(net,label=2,nItr=100,g_filter=True) # word 'toked'
# img_pil = postProcess(img)
# show(img_pil)

In [34]:
# img = dream(net,label=3,nItr=100,g_filter=True) # word 'brawl'
# img_pil = postProcess(img)
# show(img_pil)

In [35]:
# img

In [36]:
# torch.max(img)