In [1]:
import pandas as pd
import numpy as np
labels = pd.read_csv("/Users/mingliangang/Desktop/cifar10/trainLabels.csv")
labels.head()

Unnamed: 0,id,label
0,1,frog
1,2,truck
2,3,truck
3,4,deer
4,5,automobile


In [2]:
label2id = {x:i for i,x in enumerate(list(set(labels["label"])))}
id2label = {i:x for i,x in enumerate(list(set(labels["label"])))}
labelcon = lambda x: label2id[x]
labels['label'] = labels['label'].map(labelcon)
labels_ = labels['label'].as_matrix()

In [3]:
labels['label'].head()

0    1
1    5
2    5
3    7
4    0
Name: label, dtype: int64

In [4]:
labels_

array([1, 5, 5, ..., 5, 0, 0])

# Sanity check

In [5]:
labels['label'].map(lambda x: id2label[x]).head()

0          frog
1         truck
2         truck
3          deer
4    automobile
Name: label, dtype: object

# Tensorflow input pipeline

In [6]:
import os
from sklearn.model_selection import train_test_split
import cv2
train = np.array([cv2.imread("/Users/mingliangang/Desktop/cifar10/train/"+i) for i in os.listdir("/Users/mingliangang/Desktop/cifar10/train")]).astype(np.float32)
X_train, X_val, y_train, y_val = train_test_split(train, labels_, test_size=0.33, random_state=42)

In [7]:
X_train.shape

(33500, 32, 32, 3)

# Resnet Model (CIFAR 10)

In [8]:
from models import resnet
import tensorflow as tf

batch_size = 128
GPU = True
summaries_dir = "/Users/mingliangang/Desktop/resnet2/"
decay = 2e-4
if GPU:
    with tf.device('/device:GPU:0'):
        X = tf.placeholder("float", [batch_size, 32, 32, 3])
        Y = tf.placeholder("float", [batch_size, 10])
        label = tf.placeholder("int64",[batch_size])
        learning_rate = tf.placeholder("float", [])
        
        global_step_tensor = tf.Variable(0, trainable=False, name='global_step')
        
        net = resnet(X,20)
        
        cross_entropy = -tf.reduce_sum(Y*tf.log(net))
        loss = cross_entropy + decay*tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()])
        tf.summary.scalar('cross_entropy', cross_entropy)
        correct_prediction = tf.equal(label,tf.argmax(net,axis=1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
        
        opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
        train_op = opt.minimize(loss,global_step=global_step_tensor)
        
    merged = tf.summary.merge_all()
    config = tf.ConfigProto(allow_soft_placement = True)
    sess = tf.Session(config = config)
    writer = tf.summary.FileWriter(summaries_dir + '/gpu', sess.graph)
    sess.run(tf.global_variables_initializer())
    
else:
    with tf.device('/device:CPU:0'):
        X = tf.placeholder("float", [batch_size, 32, 32, 3])
        Y = tf.placeholder("float", [batch_size, 10])
        label = tf.placeholder("int64",[batch_size])
        learning_rate = tf.placeholder("float", [])
        
        global_step = tf.Variable(0, trainable=False, name='global_step')
        
        net = resnet(X,20)
        
        cross_entropy = -tf.reduce_sum(Y*tf.log(net))
        tf.summary.scalar('cross_entropy', cross_entropy)
        correct_prediction = tf.equal(label,tf.argmax(net,axis=1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
        
        opt = tf.train.MomentumOptimizer(learning_rate, 0.9)
        train_op = opt.minimize(cross_entropy)
        
    merged = tf.summary.merge_all()
    sess = tf.Session()
    writer = tf.summary.FileWriter(summaries_dir + '/cpu', sess.graph)
    sess.run(tf.global_variables_initializer())

Tensor("Shape:0", shape=(4,), dtype=int32, device=/device:GPU:0)


In [9]:
#Multiple epochs
from tqdm import tqdm
n_epochs = 1
for _ in tqdm(range(n_epochs)):
  while True:
    try:
      randix = np.random.randint(len(X_train),size=batch_size)
      x_batch = X_train[randix]
      y_batch = y_train[randix]
      summary,loss,acc,_ = sess.run([merged,cross_entropy,accuracy,train_op], feed_dict={label: y_batch,X: x_batch, Y: np.eye(10)[y_batch],learning_rate: 0.005})
      steps = tf.train.global_step(sess, global_step_tensor)
      if steps % 500 == 0:
          print("loss : {} accuracy : {}".format(loss,acc))
          writer.add_summary(summary, steps)
    except ValueError:
      print(loss)
      break

  0%|          | 0/1 [00:00<?, ?it/s]

loss : 294.34454345703125 accuracy : 0.1171875
loss : 296.0260314941406 accuracy : 0.078125
loss : 299.44464111328125 accuracy : 0.109375
loss : 299.57916259765625 accuracy : 0.0859375
loss : 296.91839599609375 accuracy : 0.0625
loss : 296.551025390625 accuracy : 0.125
loss : 296.4378967285156 accuracy : 0.109375
loss : 300.7614440917969 accuracy : 0.0703125
loss : 293.95721435546875 accuracy : 0.1328125
loss : 292.79217529296875 accuracy : 0.125
loss : 296.4313659667969 accuracy : 0.140625
loss : 297.61474609375 accuracy : 0.1015625
loss : 298.6258239746094 accuracy : 0.109375
loss : 293.48748779296875 accuracy : 0.1171875
loss : 296.3218994140625 accuracy : 0.0625
loss : 299.706787109375 accuracy : 0.0859375
loss : 296.920166015625 accuracy : 0.0625
loss : 295.7039794921875 accuracy : 0.09375
loss : 296.4648742675781 accuracy : 0.1171875
loss : 298.9210510253906 accuracy : 0.109375
loss : 296.5628356933594 accuracy : 0.0703125
loss : 296.032958984375 accuracy : 0.0703125
loss : 297.0

loss : 293.7008056640625 accuracy : 0.1171875
loss : 295.8150939941406 accuracy : 0.0703125
loss : 292.630126953125 accuracy : 0.140625
loss : 295.2437744140625 accuracy : 0.1015625
loss : 294.1252746582031 accuracy : 0.1171875
loss : 300.8419189453125 accuracy : 0.09375
loss : 295.815185546875 accuracy : 0.109375
loss : 294.3324279785156 accuracy : 0.109375
loss : 295.1236877441406 accuracy : 0.078125
loss : 294.95025634765625 accuracy : 0.1171875
loss : 295.55859375 accuracy : 0.0859375
loss : 297.3970031738281 accuracy : 0.0703125
loss : 296.3607177734375 accuracy : 0.0859375
loss : 295.67767333984375 accuracy : 0.109375
loss : 293.70947265625 accuracy : 0.1328125
loss : 296.41241455078125 accuracy : 0.1015625
loss : 295.9024658203125 accuracy : 0.0703125
loss : 294.0887451171875 accuracy : 0.140625
loss : 295.7059631347656 accuracy : 0.09375
loss : 298.8115234375 accuracy : 0.0703125
loss : 298.3303527832031 accuracy : 0.0703125
loss : 293.5090637207031 accuracy : 0.1171875
loss : 

KeyboardInterrupt: 

From what i read online, it seems that this implementation of resnet is not very good. It is not that the loss will collapse it is that it is not really learning anything at all.  