In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import os
import struct
import numpy as np
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

In [61]:
n_bins = 2. **8

In [2]:

def load_mnist(path, kind='train'):
    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte'
                               % kind)
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',
                                 lbpath.read(8))
        labels = np.fromfile(lbpath,
                             dtype=np.uint8)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII',
                                               imgpath.read(16))
        images = np.fromfile(imgpath,
                             dtype=np.uint8).reshape((len(labels),28,28))

    return images, labels

In [3]:
X_train, y_train = load_mnist('./data')
                              
X_test, y_test = load_mnist('./data',kind='t10k')
print(X_test.shape,y_test.shape)

(10000, 28, 28) (10000,)


In [4]:
def img_show(img):
    plt.figure(figsize=(1,1))
    plt.imshow(img,cmap='Greys',interpolation='nearest')
    

In [5]:
def imgs_show(imgs,row,col):
    fig,ax = plt.subplots(nrows=row,ncols=col,sharex=True,sharey=True)
    ax = ax.flatten()
    for i in range(row*col):
        img = imgs[i]
        ax[i].imshow(img,cmap='Greys', interpolation='nearest')
    
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    plt.tight_layout()
    plt.show()

In [6]:
def normal_minist(img):
    return (img / 255.0) -0.5

In [7]:
def expend_HWC(img):
    return np.expand_dims(img,axis=3)


In [8]:
def to_CHW(img):
    return np.transpose(img,(0,3,1,2))
def to_HWC(img):
    return np.transpose(img,(0,2,3,1))

In [9]:
#只能运行一次
X_train_nor = expend_HWC(X_train)
X_train_nor = normal_minist(X_train_nor)
X_train_nor = to_CHW(X_train_nor)
print(X_train_nor.shape)

X_test_nor = expend_HWC(X_test)
X_test_nor = normal_minist(X_test_nor)
X_test_nor = to_CHW(X_test_nor)
print(X_test_nor.shape)




(60000, 1, 28, 28)
(10000, 1, 28, 28)


In [10]:
eval_index = 0

In [11]:
test_postitive = X_test_nor[y_test == eval_index]
test_nagivate = X_test_nor[y_test != eval_index]
print(test_postitive.shape)
print(test_nagivate.shape)

train_positive = X_train_nor[y_train == eval_index][[0,55,1111,1009,327,128,5000,469,2000,3001]]
train_nagivate = X_train_nor[y_train != eval_index]

print(train_nagivate.shape)
print(train_positive.shape)

(980, 1, 28, 28)
(9020, 1, 28, 28)
(54077, 1, 28, 28)
(10, 1, 28, 28)


In [13]:
use_cuda = torch.cuda.is_available()
torch.manual_seed(666)
# device = torch.device("cuda" if use_cuda else "cpu")

<torch._C.Generator at 0x7f5360f2d070>

In [62]:
inputs_posi = Variable(torch.from_numpy(train_positive)).cuda()
print(inputs_posi.size())
# shape = inputs_posi.mean(0,keepdim = True).mean(2,keepdim = True).mean(3,keepdim = True)
# a = shape.clone()
# b = shape.data.new()
# a = a+1
# print(shape)
# print(a)
# print(b)
objective = torch.zeros_like(inputs_posi)[:,0,0,0]
objective += -np.log(n_bins) * np.prod(inputs_posi.size()[1:])
print(objective)

torch.Size([10, 1, 28, 28])
Variable containing:
-4347.4191
-4347.4191
-4347.4191
-4347.4191
-4347.4191
-4347.4191
-4347.4191
-4347.4191
-4347.4191
-4347.4191
[torch.cuda.DoubleTensor of size 10 (GPU 0)]



In [36]:
class Actnorm(nn.Module):
    def __init__(self):
        super(Actnorm,self).__init__()
        self.register_parameter("center",None)
    def reset_parameters(self,x):
        shape = x.size()
        assert len(shape) == 2 or len(shape) == 4
        if len(shape) == 2:
            x_mean = x.mean(0,keepdim =True)
            self.center = nn.Parameter(x_mean.data.clone()) 
        if len(shape) == 4:
            x_mean = x.mean(0,keepdim = True).mean(2,keepdim = True).mean(3,keepdim = True)
            self.center =  nn.Parameter(x_mean.data.clone())  #Variable(x_mean.data.clone(),requires_grad=True) 
        print('self.center = ',self.center)
        print(self.center.grad)
        
    def forward(self, x,reverse = False):
        if self.center is None:
            self.reset_parameters(x)
        if reverse:
            x -= self.center
        else:
            x += self.center
        return x

In [None]:
actnet = Actnorm()
x = actnet(inputs_posi)
print(x)

In [40]:
for p in actnet.parameters():
    print(p)

Parameter containing:
(0 ,0 ,.,.) = 
 -0.3241
[torch.cuda.DoubleTensor of size 1x1x1x1 (GPU 0)]



In [42]:
def actnorm_center(x,reverse = False):
    shape = x.size()
    assert len(shape) == 2 or len(shape) == 4
    if len(shape) == 2:
        x_mean = x.mean(0,keepdim =True)
        b = -x_mean.clone()
    if len(shape) == 4:
        x_mean = x.mean(0,keepdim = True).mean(2,keepdim = True).mean(3,keepdim = True)
        b = -x_mean.clone()
    print('b = ',b)
    print(b.grad)
    if reverse:
        x -= b
    else:
        x += b
    print('x = ',x)
    return x

In [54]:
def squeeze2d(x, factor=2):
    assert factor >=1
    if factor ==1:
        return x
    shape = x.size()
    height = int(shape[2])
    width = int(shape[3])
    n_channels = int(shape[1])
    assert height % factor ==0 and width % factor == 0
    x = x.view(-1,n_channels,height // factor,factor,width//factor,factor)
    print(x.size())
    #     x = 
    x.permute(0,1,3,5,2,4)
#     x = x.transpose
    print(x.size())
    x = x.view(-1,n_channels * factor * factor ,height // factor ,width//factor)
    print(x.size())
    return x
    
    
    

In [83]:
print(inputs_posi.size())
inputs =  squeeze2d(inputs_posi).float()

print(inputs_posi.size()[2] * inputs_posi.size()[3])

torch.Size([10, 1, 28, 28])
torch.Size([10, 1, 14, 2, 14, 2])
torch.Size([10, 1, 14, 2, 14, 2])
torch.Size([10, 4, 14, 14])
784


In [101]:
class Invertible(nn.Module):
    def __init__(self):
        super(Invertible,self).__init__()
        
        shape = [-1,4,14,14]
        w_shape = [shape[1],shape[1]]
        w_init = np.linalg.qr(np.random.randn(*w_shape))[0]
        _w = torch.from_numpy(w_init).float()
        _w = _w.view(4,4,1,1) # or variable
        
        
        self.conv1 = nn.Conv2d(4,4,(1,1),stride=(1,1),bias=False) # bias is bug
        self.conv1.data = _w  # = or clone or deep copy
        
    def forward(self,x,reverse = False):
        shape = x.size()
        w_sq =  torch.squeeze(self.conv1)
        dlogdet = torch.log(torch.potrf(w_sq).diag().prod()) * 
        
        if not reverse:
            
        
        return self.conv1(x)
        

In [108]:
invernet = Invertible().cuda()
outputs = invernet(inputs)
print(outputs.size())
w_sq11 =  torch.squeeze(inputs_posi)
print(w_sq11.size())

torch.Size([10, 4, 14, 14])
torch.Size([10, 28, 28])


In [103]:
for p in invernet.parameters():
    
    print(p.data)


(0 ,0 ,.,.) = 
  0.3999

(0 ,1 ,.,.) = 
  0.1624

(0 ,2 ,.,.) = 
  0.0436

(0 ,3 ,.,.) = 
  0.3629

(1 ,0 ,.,.) = 
 -0.4253

(1 ,1 ,.,.) = 
  0.0332

(1 ,2 ,.,.) = 
 -0.3253

(1 ,3 ,.,.) = 
 -0.1894

(2 ,0 ,.,.) = 
 -0.3643

(2 ,1 ,.,.) = 
 -0.4892

(2 ,2 ,.,.) = 
 -0.3443

(2 ,3 ,.,.) = 
 -0.4984

(3 ,0 ,.,.) = 
 -0.0707

(3 ,1 ,.,.) = 
  0.1803

(3 ,2 ,.,.) = 
 -0.3030

(3 ,3 ,.,.) = 
 -0.2147
[torch.cuda.FloatTensor of size 4x4x1x1 (GPU 0)]



In [None]:
def invertible_1x1_conv(name,input,logdet,reverse = False):
    # [batch,w,h,c]
    shape = input.size()
    
    # [chanel,chanel],[12,12]
    w_shape = [shape[1],shape[1]]
    w_init = np.linalg.qr(np.random.randn(*w_shape))[0]
    w = Variable(torch.from_numpy(w_init))