In [None]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from tensorflow.contrib.layers.python.layers import instance_norm
import random
import os
import time

In [None]:
'''Load Dataset'''
os.environ['CUDA_VISIBLE_DEVICES']='0'

'''Normalization Function'''
def bn(x):
    x=instance_norm(x, epsilon=1e-05)
    return x

'''Linear Function'''
def line(x):
    return x

'''conv'''
def conv(x,filters,kernel_size=3,stride=1,activation=tf.nn.leaky_relu, padding='valid',norm=True,bias=False):
    with tf.name_scope('conv'):
        x=tf.layers.conv2d(x,filters,kernel_size=kernel_size, strides=stride, padding=padding,use_bias=bias)
        if norm:
            x=bn(x)
            x=activation(x)
    return x

'''dconv'''
def dconv(x,filters,kernel_size=3,stride=1,activation=tf.nn.leaky_relu,padding='valid',norm=True,bias=False):
    with tf.name_scope('dconv'):
        x=tf.layers.conv2d_transpose(x,filters,kernel_size=kernel_size, strides=stride, padding=padding
                                     ,use_bias=bias)
        if norm:
            x=bn(x)
            x=activation(x)
    return x

'''SDconv'''
def sdconv(x,filters,activation=tf.nn.leaky_relu):
    x1=conv(x,filters*2,kernel_size=[1,3],stride=(1,3),activation=activation)
    x2=conv(x,filters*2,kernel_size=[3,1],stride=(3,1),activation=activation)
    x3=x
    x13=dconv(x1,filters*1,kernel_size=[1,1],activation=activation,stride=1)
    x23=dconv(x2,filters*1,kernel_size=[1,1],activation=activation,stride=1)
    
    x11=conv(x1,filters*2,kernel_size=[1,3],stride=(1,3),activation=activation)
    x12=conv(x1,filters*1,kernel_size=[3,1],stride=(3,1),activation=activation)
    x21=conv(x2,filters*1,kernel_size=[1,3],stride=(1,3),activation=activation)
    x22=conv(x2,filters*2,kernel_size=[3,1],stride=(3,1),activation=activation)
    x12=tf.concat([x12,x21],-1)
    x113=dconv(x11,filters*1,kernel_size=[1,1],activation=activation,stride=1)
    x123=dconv(x12,filters*1,kernel_size=[1,1],activation=activation,stride=1)
    x223=dconv(x22,filters*1,kernel_size=[1,1],activation=activation,stride=1)
    
    x112=conv(x11,filters*2,kernel_size=[3,1],stride=(3,1),activation=activation)
    x121=conv(x12,filters*2,kernel_size=[1,3],stride=(1,3),activation=activation)
    x212=conv(x12,filters*2,kernel_size=[3,1],stride=(3,1),activation=activation)
    x221=conv(x22,filters*2,kernel_size=[1,3],stride=(1,3),activation=activation)
    x121=tf.concat([x112,x121],-1)
    x212=tf.concat([x212,x221],-1)
    x1213=dconv(x121,filters*2,kernel_size=[1,1],activation=activation,stride=1)
    x2123=dconv(x212,filters*2,kernel_size=[1,1],activation=activation,stride=1)
    
    x1212=conv(x121,filters*4,kernel_size=[3,1],stride=(3,1),activation=activation)
    x2121=conv(x212,filters*4,kernel_size=[1,3],stride=(1,3),activation=activation)
    x1212=tf.concat([x1212,x2121],-1)
    
    x121=dconv(x1212,filters*2,kernel_size=[3,1],stride=(3,1),activation=activation)
    x212=dconv(x1212,filters*2,kernel_size=[1,3],stride=(1,3),activation=activation)
    x121=tf.concat([x121,x1213],-1)
    x212=tf.concat([x212,x2123],-1)
    
    x11=dconv(x121,filters*1,kernel_size=[3,1],stride=(3,1),activation=activation)
    x12=dconv(x121,filters*1,kernel_size=[1,3],stride=(1,3),activation=activation)
    x21=dconv(x212,filters*1,kernel_size=[3,1],stride=(3,1),activation=activation)
    x22=dconv(x212,filters*1,kernel_size=[1,3],stride=(1,3),activation=activation)
    x11=tf.concat([x11,x113],-1)
    x12=tf.concat([x12,x21,x123],-1)
    x22=tf.concat([x22,x223],-1)
    
    x1=dconv(x11,filters*1,kernel_size=[1,3],stride=(1,3),activation=activation)
    x1_=dconv(x12,filters*1,kernel_size=[3,1],stride=(3,1),activation=activation)
    x2=dconv(x12,filters*1,kernel_size=[1,3],stride=(1,3),activation=activation)
    x2_=dconv(x22,filters*1,kernel_size=[3,1],stride=(3,1),activation=activation)
    x1=tf.concat([x1,x1_,x13],-1)
    x2=tf.concat([x2,x2_,x23],-1)
    
    x=dconv(x1,int(filters*1),kernel_size=[1,3],stride=(1,3),activation=line,norm=False)
    x_=dconv(x2,int(filters*1),kernel_size=[3,1],stride=(3,1),activation=line,norm=False)

    x=tf.concat([x3,x,x_],-1)
    x=activation(bn(x))
    x=dconv(x,int(filters*1),kernel_size=[1,1],activation=line,stride=1,norm=False)
    return x

def check(out,q_):
    mask=np.sum(out*(q_!=0)==q_,(1,2))==81
    mask=mask*(np.sum(np.sort(out,axis=-1)==np.arange(1,10),(1,2))==81)
    mask=mask*(np.sum(np.transpose(np.sort(out,axis=-2),[0,2,1])==np.arange(1,10),(1,2))==81)
    mask=mask*(np.sum(np.sort(np.reshape(np.transpose(np.reshape(out,[-1,3,3,3,3]),[0,1,3,2,4]),[-1,9,9]
                                        ),axis=-1)==np.arange(1,10),(1,2))==81)
    return (1-mask).astype(np.bool)

In [None]:
def sdconv0(net):
    reuse=len([t for t in tf.global_variables() if t.name.startswith('step0')])>0
    with tf.variable_scope('step0',reuse=reuse):
        net=sdconv(net,256)
    return net

def sdconv1(net):
    reuse=len([t for t in tf.global_variables() if t.name.startswith('step1')])>0
    with tf.variable_scope('step1',reuse=reuse):
        net=sdconv(net,256)
    return net

def sdconv2(net):
    reuse=len([t for t in tf.global_variables() if t.name.startswith('step2')])>0
    with tf.variable_scope('step2',reuse=reuse):
        net=sdconv(net,256)
    return net

def sdconv3(net):
    reuse=len([t for t in tf.global_variables() if t.name.startswith('step3')])>0
    with tf.variable_scope('step3',reuse=reuse):
        net=sdconv(net,256)
    return net

def sdconv4(net):
    reuse=len([t for t in tf.global_variables() if t.name.startswith('step4')])>0
    with tf.variable_scope('step4',reuse=reuse):
        net=sdconv(net,256)
    return net

def sdconv5(net):
    reuse=len([t for t in tf.global_variables() if t.name.startswith('step5')])>0
    with tf.variable_scope('step5',reuse=reuse):
        net=sdconv(net,256)
    return net


def sdconv6(net):
    reuse=len([t for t in tf.global_variables() if t.name.startswith('step6')])>0
    with tf.variable_scope('step6',reuse=reuse):
        net=sdconv(net,256)
    return net

def sdconv7(net):
    reuse=len([t for t in tf.global_variables() if t.name.startswith('step7')])>0
    with tf.variable_scope('step7',reuse=reuse):
        net=sdconv(net,256)
    return net

In [None]:
x=tf.placeholder(tf.float32,[None,9,9,10])
net0=dconv(x,256,kernel_size=[1,1],activation=line,stride=1,norm=False)

for i in range(8):
    exec('net{} = sdconv{}(net{})'.format(i+1,i,i))
    for _ in range(3):
        exec('net{} = sdconv{}(net{})'.format(i+1,i,i+1))

net = net8[:,:,:,-9:]

In [None]:
sess=tf.Session()
sess.run(tf.global_variables_initializer())
saver=tf.train.Saver(tf.trainable_variables())
saver.restore(sess,r'save\log')

In [None]:
time_start=time.time()
# delay=0
q=np.load(r'dataset\question.npy')
# q=q[np.sum(q>0,(1,2))==17]
lenq=len(q)
unsolve=np.sum(q>0,(1,2))
show=[]
solve=[]
while(len(q)>0):
    wronglist=np.array(unsolve)
    unsolve=[]
    qlist=q
    q=[]
    for i in range(0,len(qlist),128):
        q_=qlist[i:i+128]
        feed=(np.reshape(q_,[-1,9,9,1])==np.arange(0,10)).astype(np.float32)
        wrong=wronglist[i:i+128]
        out=sess.run(net,feed_dict={x:feed})
        mask=check(np.argmax(out,-1)+1,q_)
        solve.extend(wrong[(1-mask).astype(np.bool)])
        wrong=wrong[mask]
        n=np.max(out,-1)*(np.array(q_)==0)
        n=((n==np.max(n,(1,2),keepdims=True))*(np.argmax(out,-1)+1)+q_)[mask]
        mask=np.sum(n>1,(1,2))<81
        q.extend(n[mask])
        unsolve.extend(wrong[mask])
#         time.sleep(0.5)
#         delay+=1
    show.append([np.sum(np.array(solve)==_) for _ in range(17,35)])
#     print(len(solve),len(unsolve))
time_end=time.time()
print('totally cost',time_end-time_start)
print('%fms each question'%((time_end-time_start)/lenq))
print('accuracy={}'.format(len(solve)/lenq))

In [None]:
time_start=time.time()
# delay=0
q=np.load(r'dataset\testq.npy')
# q=q[np.sum(q>0,(1,2))==17]
lenq=len(q)
unsolve=np.sum(q>0,(1,2))
show=[]
solve=[]
while(len(q)>0):
    wronglist=np.array(unsolve)
    unsolve=[]
    qlist=q
    q=[]
    for i in range(0,len(qlist),128):
        q_=qlist[i:i+128]
        feed=(np.reshape(q_,[-1,9,9,1])==np.arange(0,10)).astype(np.float32)
        wrong=wronglist[i:i+128]
        out=sess.run(net,feed_dict={x:feed})
        mask=check(np.argmax(out,-1)+1,q_)
        solve.extend(wrong[(1-mask).astype(np.bool)])
        wrong=wrong[mask]
        n=np.max(out,-1)*(np.array(q_)==0)
        n=((n==np.max(n,(1,2),keepdims=True))*(np.argmax(out,-1)+1)+q_)[mask]
        mask=np.sum(n>1,(1,2))<81
        q.extend(n[mask])
        unsolve.extend(wrong[mask])
#         time.sleep(0.5)
#         delay+=1
    show.append([np.sum(np.array(solve)==_) for _ in range(17,35)])
#     print(len(solve),len(unsolve))
time_end=time.time()
print('totally cost',time_end-time_start)
print('%fms each question'%((time_end-time_start)/lenq))
print('accuracy={}'.format(len(solve)/lenq))

In [None]:
time_start=time.time()
# delay=0
q=np.load(r'dataset\validq.npy')
# q=q[np.sum(q>0,(1,2))==17]
lenq=len(q)
unsolve=np.sum(q>0,(1,2))
show=[]
solve=[]
while(len(q)>0):
    wronglist=np.array(unsolve)
    unsolve=[]
    qlist=q
    q=[]
    for i in range(0,len(qlist),128):
        q_=qlist[i:i+128]
        feed=(np.reshape(q_,[-1,9,9,1])==np.arange(0,10)).astype(np.float32)
        wrong=wronglist[i:i+128]
        out=sess.run(net,feed_dict={x:feed})
        mask=check(np.argmax(out,-1)+1,q_)
        solve.extend(wrong[(1-mask).astype(np.bool)])
        wrong=wrong[mask]
        n=np.max(out,-1)*(np.array(q_)==0)
        n=((n==np.max(n,(1,2),keepdims=True))*(np.argmax(out,-1)+1)+q_)[mask]
        mask=np.sum(n>1,(1,2))<81
        q.extend(n[mask])
        unsolve.extend(wrong[mask])
#         time.sleep(0.5)
#         delay+=1
    show.append([np.sum(np.array(solve)==_) for _ in range(17,35)])
#     print(len(solve),len(unsolve))
time_end=time.time()
print('totally cost',time_end-time_start)
print('%fms each question'%((time_end-time_start)/lenq))
print('accuracy={}'.format(len(solve)/lenq))