In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb

tf.random.set_seed(0)

In [5]:
train_X = np.array([[1, 2, 1, 1], [2, 1, 3, 2], [3, 1, 3, 4], [4, 1, 5, 5], [1, 7, 5, 5], [1, 2, 5, 6], [1, 6, 6, 6], [1, 7, 7, 7]], dtype=np.float32)
train_y = np.array([[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0]], dtype=np.float32)
dataset = tf.data.Dataset.from_tensor_slices((train_X, train_y)).batch(len(train_X))

n_fts = train_X.shape[1]
n_cls = train_y.shape[1]

W = tf.Variable(tf.random.normal([n_fts, n_cls]), name="weight")
b = tf.Variable(tf.random.normal([1, n_cls]), name="bias")
vars_ = [W, b]

def hyp(features):
    h = tf.nn.softmax(tf.matmul(features, W) + b)
    return h

def loss_fn(features, labels):
    h = hyp(features)
    loss = tf.reduce_mean(-tf.reduce_sum(labels * tf.math.log(h), axis=1))
    return loss

def grads_fn(features, labels):
    with tf.GradientTape() as tape:
        loss = loss_fn(features, labels)
        grads = tape.gradient(loss, vars_)
        return grads

def fit(X, Y, n_epochs=2000, verbose=100):
    optimizer =  tf.keras.optimizers.SGD(lr=0.1)
    for i in range(n_epochs):
        for features, labels in dataset:
            loss = loss_fn(X, Y)
            grads = grads_fn(features, labels)
            optimizer.apply_gradients(zip(grads, vars_))
            print("{:>4d} | {:6.4f}".format(i+1, loss))

In [6]:
iter(dataset)

<tensorflow.python.data.ops.iterator_ops.IteratorV2 at 0x1d1eb04bcc8>

In [7]:
fit(train_X, train_y)

   1 | 8.6570
   2 | 5.3045
   3 | 2.7035
   4 | 1.7814
   5 | 1.5306
   6 | 1.3813
   7 | 1.2904
   8 | 1.2120
   9 | 1.1370
  10 | 1.0688
  11 | 1.0055
  12 | 0.9492
  13 | 0.8986
  14 | 0.8556
  15 | 0.8185
  16 | 0.7898
  17 | 0.7664
  18 | 0.7541
  19 | 0.7452
  20 | 0.7575
  21 | 0.7602
  22 | 0.8052
  23 | 0.7880
  24 | 0.8507
  25 | 0.7929
  26 | 0.8616
  27 | 0.7878
  28 | 0.8604
  29 | 0.7818
  30 | 0.8569
  31 | 0.7757
  32 | 0.8523
  33 | 0.7696
  34 | 0.8470
  35 | 0.7636
  36 | 0.8412
  37 | 0.7577
  38 | 0.8351
  39 | 0.7520
  40 | 0.8290
  41 | 0.7465
  42 | 0.8228
  43 | 0.7412
  44 | 0.8166
  45 | 0.7361
  46 | 0.8106
  47 | 0.7312
  48 | 0.8048
  49 | 0.7265
  50 | 0.7991
  51 | 0.7220
  52 | 0.7936
  53 | 0.7176
  54 | 0.7884
  55 | 0.7134
  56 | 0.7834
  57 | 0.7094
  58 | 0.7786
  59 | 0.7056
  60 | 0.7740
  61 | 0.7019
  62 | 0.7696
  63 | 0.6984
  64 | 0.7654
  65 | 0.6950
  66 | 0.7614
  67 | 0.6917
  68 | 0.7576
  69 | 0.6886
  70 | 0.7540
  71 | 0.6855
  72 |

 593 | 0.4106
 594 | 0.4529
 595 | 0.4099
 596 | 0.4520
 597 | 0.4092
 598 | 0.4511
 599 | 0.4085
 600 | 0.4502
 601 | 0.4078
 602 | 0.4493
 603 | 0.4071
 604 | 0.4484
 605 | 0.4064
 606 | 0.4475
 607 | 0.4057
 608 | 0.4466
 609 | 0.4050
 610 | 0.4457
 611 | 0.4043
 612 | 0.4448
 613 | 0.4036
 614 | 0.4439
 615 | 0.4029
 616 | 0.4430
 617 | 0.4022
 618 | 0.4421
 619 | 0.4015
 620 | 0.4412
 621 | 0.4008
 622 | 0.4404
 623 | 0.4001
 624 | 0.4395
 625 | 0.3994
 626 | 0.4386
 627 | 0.3987
 628 | 0.4377
 629 | 0.3980
 630 | 0.4368
 631 | 0.3973
 632 | 0.4359
 633 | 0.3966
 634 | 0.4350
 635 | 0.3959
 636 | 0.4341
 637 | 0.3952
 638 | 0.4332
 639 | 0.3945
 640 | 0.4323
 641 | 0.3938
 642 | 0.4314
 643 | 0.3931
 644 | 0.4305
 645 | 0.3924
 646 | 0.4296
 647 | 0.3917
 648 | 0.4287
 649 | 0.3910
 650 | 0.4279
 651 | 0.3903
 652 | 0.4270
 653 | 0.3896
 654 | 0.4261
 655 | 0.3889
 656 | 0.4252
 657 | 0.3883
 658 | 0.4243
 659 | 0.3876
 660 | 0.4234
 661 | 0.3869
 662 | 0.4225
 663 | 0.3862
 664 |

1181 | 0.2270
1182 | 0.2269
1183 | 0.2268
1184 | 0.2267
1185 | 0.2265
1186 | 0.2264
1187 | 0.2263
1188 | 0.2262
1189 | 0.2261
1190 | 0.2260
1191 | 0.2259
1192 | 0.2258
1193 | 0.2256
1194 | 0.2255
1195 | 0.2254
1196 | 0.2253
1197 | 0.2252
1198 | 0.2251
1199 | 0.2250
1200 | 0.2249
1201 | 0.2248
1202 | 0.2247
1203 | 0.2245
1204 | 0.2244
1205 | 0.2243
1206 | 0.2242
1207 | 0.2241
1208 | 0.2240
1209 | 0.2239
1210 | 0.2238
1211 | 0.2237
1212 | 0.2236
1213 | 0.2234
1214 | 0.2233
1215 | 0.2232
1216 | 0.2231
1217 | 0.2230
1218 | 0.2229
1219 | 0.2228
1220 | 0.2227
1221 | 0.2226
1222 | 0.2225
1223 | 0.2224
1224 | 0.2222
1225 | 0.2221
1226 | 0.2220
1227 | 0.2219
1228 | 0.2218
1229 | 0.2217
1230 | 0.2216
1231 | 0.2215
1232 | 0.2214
1233 | 0.2213
1234 | 0.2212
1235 | 0.2211
1236 | 0.2210
1237 | 0.2209
1238 | 0.2207
1239 | 0.2206
1240 | 0.2205
1241 | 0.2204
1242 | 0.2203
1243 | 0.2202
1244 | 0.2201
1245 | 0.2200
1246 | 0.2199
1247 | 0.2198
1248 | 0.2197
1249 | 0.2196
1250 | 0.2195
1251 | 0.2194
1252 |

1768 | 0.1753
1769 | 0.1752
1770 | 0.1751
1771 | 0.1751
1772 | 0.1750
1773 | 0.1749
1774 | 0.1749
1775 | 0.1748
1776 | 0.1747
1777 | 0.1747
1778 | 0.1746
1779 | 0.1745
1780 | 0.1745
1781 | 0.1744
1782 | 0.1743
1783 | 0.1743
1784 | 0.1742
1785 | 0.1741
1786 | 0.1741
1787 | 0.1740
1788 | 0.1739
1789 | 0.1739
1790 | 0.1738
1791 | 0.1737
1792 | 0.1736
1793 | 0.1736
1794 | 0.1735
1795 | 0.1734
1796 | 0.1734
1797 | 0.1733
1798 | 0.1732
1799 | 0.1732
1800 | 0.1731
1801 | 0.1730
1802 | 0.1730
1803 | 0.1729
1804 | 0.1728
1805 | 0.1728
1806 | 0.1727
1807 | 0.1726
1808 | 0.1726
1809 | 0.1725
1810 | 0.1724
1811 | 0.1724
1812 | 0.1723
1813 | 0.1722
1814 | 0.1722
1815 | 0.1721
1816 | 0.1720
1817 | 0.1720
1818 | 0.1719
1819 | 0.1718
1820 | 0.1718
1821 | 0.1717
1822 | 0.1716
1823 | 0.1716
1824 | 0.1715
1825 | 0.1714
1826 | 0.1714
1827 | 0.1713
1828 | 0.1712
1829 | 0.1712
1830 | 0.1711
1831 | 0.1710
1832 | 0.1710
1833 | 0.1709
1834 | 0.1709
1835 | 0.1708
1836 | 0.1707
1837 | 0.1707
1838 | 0.1706
1839 |