## Lab 8: Write a program for Logistic Regression with 
# Tensorflow

In [1]:
from tensorflow.keras.datasets import fashion_mnist
from sklearn.model_selection import train_test_split
import tensorflow as tf

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train, x_test = x_train/255., x_test/255.

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.15)
x_train = tf.reshape(x_train, shape=(-1, 784))
x_test  = tf.reshape(x_test, shape=(-1, 784))

weights = tf.Variable(tf.random.normal(shape=(784, 10), dtype=tf.float64))
biases  = tf.Variable(tf.random.normal(shape=(10,), dtype=tf.float64))

def logistic_regression(x):
    lr = tf.add(tf.matmul(x, weights), biases)
    
    return lr


def cross_entropy(y_true, y_pred):
    y_true = tf.one_hot(y_true, 10)
    loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    return tf.reduce_mean(loss)

def accuracy(y_true, y_pred):
    y_true = tf.cast(y_true, dtype=tf.int32)
    preds = tf.cast(tf.argmax(y_pred, axis=1), dtype=tf.int32)
    preds = tf.equal(y_true, preds)
    return tf.reduce_mean(tf.cast(preds, dtype=tf.float32))

def grad(x, y):
    with tf.GradientTape() as tape:
        y_pred = logistic_regression(x)
        loss_val = cross_entropy(y, y_pred)
    return tape.gradient(loss_val, [weights, biases])

n_batches = 10000
learning_rate = 0.01
batch_size = 128

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.repeat().shuffle(x_train.shape[0]).batch(batch_size)

optimizer = tf.optimizers.SGD(learning_rate)

for batch_numb, (batch_xs, batch_ys) in enumerate(dataset.take(n_batches), 1):
    gradients = grad(batch_xs, batch_ys)
    optimizer.apply_gradients(zip(gradients, [weights, biases]))

    y_pred = logistic_regression(batch_xs)
    loss = cross_entropy(batch_ys, y_pred)
    acc = accuracy(batch_ys, y_pred)
    print("Batch number: %i, loss: %f, accuracy: %f" % (batch_numb, loss, acc))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Batch number: 5001, loss: 1.649218, accuracy: 0.664062
Batch number: 5002, loss: 1.331714, accuracy: 0.726562
Batch number: 5003, loss: 1.477354, accuracy: 0.679688
Batch number: 5004, loss: 2.027618, accuracy: 0.593750
Batch number: 5005, loss: 0.955636, accuracy: 0.765625
Batch number: 5006, loss: 1.894405, accuracy: 0.671875
Batch number: 5007, loss: 1.739773, accuracy: 0.617188
Batch number: 5008, loss: 1.408901, accuracy: 0.695312
Batch number: 5009, loss: 1.842604, accuracy: 0.656250
Batch number: 5010, loss: 1.668541, accuracy: 0.664062
Batch number: 5011, loss: 2.317048, accuracy: 0.601562
Batch number: 5012, loss: 2.238409, accuracy: 0.632812
Batch number: 5013, loss: 1.269212, accuracy: 0.718750
Batch number: 5014, loss: 1.411858, accuracy: 0.718750
Batch number: 5015, loss: 1.192763, accuracy: 0.703125
Batch number: 5016, loss: 1.410926, accuracy: 0.703125
Batch number: 5017, loss: 1.863972, accuracy: 0.671875
