In [373]:
import numpy as np
import tensorflow as tf
from sklearn.datasets import make_regression

In [374]:
class DenseLayer:
    def __init__(self, input_size, output_size):
        tf.random.set_seed(42)
        np.random.seed(42)
        self.weights = tf.Variable(tf.random.normal(shape=(input_size, output_size)), trainable=True)
        self.biases = tf.Variable(tf.random.normal(shape=(output_size,)), trainable=True)
    
    def forward(self, inputs):
        self.inputs = inputs
        self.output = tf.matmul(tf.convert_to_tensor(self.inputs, dtype=tf.float32), self.weights) + self.biases
        
        return self.output
    
    def backward(self, grad_output, learning_rate):
        
        grad_weights = tf.matmul(tf.transpose(tf.convert_to_tensor(self.inputs, dtype=tf.float32)), tf.convert_to_tensor(grad_output, dtype=tf.float32))
        grad_biases = tf.reduce_sum(grad_output, axis=0)
        
        grad_input = tf.matmul(tf.convert_to_tensor(grad_output, dtype=tf.float32), tf.transpose(self.weights))

        self.weights = self.weights - learning_rate * grad_weights
        self.biases = self.biases - learning_rate * grad_biases
        
        return grad_input

In [375]:
class DenseNetwork:
    def __init__(self):
        self.layers = []
    
    def add_layer(self, layer):
        self.layers.append(layer)
    
    def forward(self, inputs):
        for layer in self.layers:
            inputs = layer.forward(inputs)
        return inputs
    
    def backward(self, grad_output, learning_rate):
        for layer in reversed(self.layers):
            grad_output = layer.backward(grad_output, learning_rate)

In [376]:
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler
import numpy as np

# Generate synthetic dataset
X, y = make_regression(n_samples=10, n_features=10, noise=0.5, random_state=42)

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

y_train = y_train.reshape(-1,1)
y_test = y_test.reshape(-1,1)


# Standardize the input features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Train scikit-learn's LinearRegression model
lr_model = LinearRegression()
lr_model.fit(X_train_scaled, y_train)

# Predict with scikit-learn's LinearRegression model
y_pred_lr = lr_model.predict(X_test_scaled)

# Train the DenseNetwork implemented from scratch
dense_net = DenseNetwork()
dense_net.add_layer(DenseLayer(10, 10))
dense_net.add_layer(DenseLayer(10, 1))

# Train the DenseNetwork using gradient descent
learning_rate = 0.00001
num_epochs = 5000
for epoch in range(num_epochs):
    # Forward pass
    y_pred = dense_net.forward(X_train_scaled)
    
    # Compute loss (mean squared error)
    loss = np.mean((y_pred - y_train) ** 2)
#     print(loss)
    print(f'epoch {epoch}:{loss}')
    # Backward pass
    grad_output = 2 * (y_pred - y_train) / len(X_train_scaled)
    dense_net.backward(grad_output, learning_rate)

# Predict with the DenseNetwork
y_pred_dense = dense_net.forward(X_test_scaled)

# Compare the results
print("Mean Squared Error (sklearn LinearRegression):", mean_squared_error(y_test, y_pred_lr))
print("Mean Squared Error (DenseNetwork implemented from scratch):", mean_squared_error(y_test, y_pred_dense))

epoch 0:56154.1171875
epoch 1:56093.50390625
epoch 2:56032.515625
epoch 3:55971.140625
epoch 4:55909.37890625
epoch 5:55847.2109375
epoch 6:55784.640625
epoch 7:55721.65625
epoch 8:55658.2578125
epoch 9:55594.4296875
epoch 10:55530.171875
epoch 11:55465.4765625
epoch 12:55400.3359375
epoch 13:55334.74609375
epoch 14:55268.69921875
epoch 15:55202.18359375
epoch 16:55135.203125
epoch 17:55067.7421875
epoch 18:54999.80078125
epoch 19:54931.37109375
epoch 20:54862.4453125
epoch 21:54793.0078125
epoch 22:54723.07421875
epoch 23:54652.609375
epoch 24:54581.63671875
epoch 25:54510.1328125
epoch 26:54438.08984375
epoch 27:54365.5078125
epoch 28:54292.37890625
epoch 29:54218.6953125
epoch 30:54144.44921875
epoch 31:54069.640625
epoch 32:53994.2578125
epoch 33:53918.2890625
epoch 34:53841.75
epoch 35:53764.609375
epoch 36:53686.87109375
epoch 37:53608.52734375
epoch 38:53529.5703125
epoch 39:53450.0078125
epoch 40:53369.8125
epoch 41:53288.9921875
epoch 42:53207.53515625
epoch 43:53125.4375
epoc

epoch 352:7150.68896484375
epoch 353:7070.7763671875
epoch 354:6991.765625
epoch 355:6913.64697265625
epoch 356:6836.41357421875
epoch 357:6760.05908203125
epoch 358:6684.5732421875
epoch 359:6609.95263671875
epoch 360:6536.1904296875
epoch 361:6463.2763671875
epoch 362:6391.2041015625
epoch 363:6319.9677734375
epoch 364:6249.55712890625
epoch 365:6179.9697265625
epoch 366:6111.19384765625
epoch 367:6043.220703125
epoch 368:5976.046875
epoch 369:5909.6630859375
epoch 370:5844.0654296875
epoch 371:5779.2412109375
epoch 372:5715.18310546875
epoch 373:5651.8876953125
epoch 374:5589.34326171875
epoch 375:5527.5498046875
epoch 376:5466.4892578125
epoch 377:5406.162109375
epoch 378:5346.55908203125
epoch 379:5287.6708984375
epoch 380:5229.49169921875
epoch 381:5172.013671875
epoch 382:5115.23095703125
epoch 383:5059.1328125
epoch 384:5003.71435546875
epoch 385:4948.96826171875
epoch 386:4894.8876953125
epoch 387:4841.4658203125
epoch 388:4788.6923828125
epoch 389:4736.56201171875
epoch 390:4

epoch 707:362.6109924316406
epoch 708:360.49615478515625
epoch 709:358.3985595703125
epoch 710:356.317138671875
epoch 711:354.252685546875
epoch 712:352.2031555175781
epoch 713:350.1702575683594
epoch 714:348.1524353027344
epoch 715:346.15008544921875
epoch 716:344.1637268066406
epoch 717:342.1931457519531
epoch 718:340.2370300292969
epoch 719:338.2968444824219
epoch 720:336.3705749511719
epoch 721:334.46014404296875
epoch 722:332.5640869140625
epoch 723:330.6826477050781
epoch 724:328.8149719238281
epoch 725:326.9620666503906
epoch 726:325.12353515625
epoch 727:323.29901123046875
epoch 728:321.48822021484375
epoch 729:319.69158935546875
epoch 730:317.90838623046875
epoch 731:316.1383056640625
epoch 732:314.3828430175781
epoch 733:312.639892578125
epoch 734:310.9105529785156
epoch 735:309.19366455078125
epoch 736:307.4903564453125
epoch 737:305.7994384765625
epoch 738:304.12213134765625
epoch 739:302.4568786621094
epoch 740:300.8045654296875
epoch 741:299.1637878417969
epoch 742:297.53

epoch 1015:96.8136215209961
epoch 1016:96.51634216308594
epoch 1017:96.22034454345703
epoch 1018:95.92560577392578
epoch 1019:95.63221740722656
epoch 1020:95.34019470214844
epoch 1021:95.04983520507812
epoch 1022:94.76074981689453
epoch 1023:94.47290802001953
epoch 1024:94.18614959716797
epoch 1025:93.90117645263672
epoch 1026:93.6173324584961
epoch 1027:93.33463287353516
epoch 1028:93.0536880493164
epoch 1029:92.77367401123047
epoch 1030:92.49519348144531
epoch 1031:92.21770477294922
epoch 1032:91.94151306152344
epoch 1033:91.66650390625
epoch 1034:91.39328002929688
epoch 1035:91.1211166381836
epoch 1036:90.8498764038086
epoch 1037:90.5802230834961
epoch 1038:90.31127166748047
epoch 1039:90.0440902709961
epoch 1040:89.77775573730469
epoch 1041:89.51288604736328
epoch 1042:89.24878692626953
epoch 1043:88.98603820800781
epoch 1044:88.72471618652344
epoch 1045:88.46446228027344
epoch 1046:88.20521545410156
epoch 1047:87.94742584228516
epoch 1048:87.6905517578125
epoch 1049:87.43459320068

epoch 1339:42.29218673706055
epoch 1340:42.19747543334961
epoch 1341:42.1029052734375
epoch 1342:42.00871276855469
epoch 1343:41.91470718383789
epoch 1344:41.82117462158203
epoch 1345:41.72772216796875
epoch 1346:41.63443374633789
epoch 1347:41.54140853881836
epoch 1348:41.44860076904297
epoch 1349:41.35636520385742
epoch 1350:41.26408386230469
epoch 1351:41.17216491699219
epoch 1352:41.08034133911133
epoch 1353:40.988929748535156
epoch 1354:40.89763259887695
epoch 1355:40.80680465698242
epoch 1356:40.71623229980469
epoch 1357:40.62571716308594
epoch 1358:40.53539276123047
epoch 1359:40.445335388183594
epoch 1360:40.355751037597656
epoch 1361:40.2661247253418
epoch 1362:40.1769905090332
epoch 1363:40.087867736816406
epoch 1364:39.999000549316406
epoch 1365:39.910560607910156
epoch 1366:39.822208404541016
epoch 1367:39.734214782714844
epoch 1368:39.646427154541016
epoch 1369:39.558692932128906
epoch 1370:39.471282958984375
epoch 1371:39.384185791015625
epoch 1372:39.29743194580078
epoch

epoch 1628:22.898277282714844
epoch 1629:22.851947784423828
epoch 1630:22.805561065673828
epoch 1631:22.759353637695312
epoch 1632:22.71330451965332
epoch 1633:22.667327880859375
epoch 1634:22.621383666992188
epoch 1635:22.575557708740234
epoch 1636:22.52996253967285
epoch 1637:22.484283447265625
epoch 1638:22.438852310180664
epoch 1639:22.393421173095703
epoch 1640:22.348134994506836
epoch 1641:22.302997589111328
epoch 1642:22.257904052734375
epoch 1643:22.213001251220703
epoch 1644:22.16814422607422
epoch 1645:22.123363494873047
epoch 1646:22.078731536865234
epoch 1647:22.03418731689453
epoch 1648:21.989749908447266
epoch 1649:21.9453125
epoch 1650:21.900936126708984
epoch 1651:21.856998443603516
epoch 1652:21.81284523010254
epoch 1653:21.76889419555664
epoch 1654:21.725011825561523
epoch 1655:21.68128204345703
epoch 1656:21.63759422302246
epoch 1657:21.594045639038086
epoch 1658:21.550546646118164
epoch 1659:21.50716781616211
epoch 1660:21.463871002197266
epoch 1661:21.4208030700683

epoch 1964:11.874593734741211
epoch 1965:11.852241516113281
epoch 1966:11.829841613769531
epoch 1967:11.807562828063965
epoch 1968:11.7853364944458
epoch 1969:11.763110160827637
epoch 1970:11.740955352783203
epoch 1971:11.7186861038208
epoch 1972:11.696707725524902
epoch 1973:11.674654006958008
epoch 1974:11.652624130249023
epoch 1975:11.63070297241211
epoch 1976:11.608816146850586
epoch 1977:11.586977005004883
epoch 1978:11.565193176269531
epoch 1979:11.543439865112305
epoch 1980:11.521727561950684
epoch 1981:11.500041961669922
epoch 1982:11.478384971618652
epoch 1983:11.456833839416504
epoch 1984:11.435245513916016
epoch 1985:11.413796424865723
epoch 1986:11.392341613769531
epoch 1987:11.370943069458008
epoch 1988:11.349504470825195
epoch 1989:11.3281831741333
epoch 1990:11.307022094726562
epoch 1991:11.285669326782227
epoch 1992:11.264591217041016
epoch 1993:11.243396759033203
epoch 1994:11.222371101379395
epoch 1995:11.201240539550781
epoch 1996:11.180257797241211
epoch 1997:11.159

epoch 2276:6.687979221343994
epoch 2277:6.676034927368164
epoch 2278:6.664028644561768
epoch 2279:6.652049541473389
epoch 2280:6.640134811401367
epoch 2281:6.628169536590576
epoch 2282:6.616355895996094
epoch 2283:6.6044816970825195
epoch 2284:6.592587471008301
epoch 2285:6.58079719543457
epoch 2286:6.569033622741699
epoch 2287:6.557214736938477
epoch 2288:6.54547643661499
epoch 2289:6.53372859954834
epoch 2290:6.522067546844482
epoch 2291:6.510306358337402
epoch 2292:6.498701095581055
epoch 2293:6.487079620361328
epoch 2294:6.475424766540527
epoch 2295:6.463841915130615
epoch 2296:6.452306270599365
epoch 2297:6.440723896026611
epoch 2298:6.429203987121582
epoch 2299:6.417736053466797
epoch 2300:6.4062180519104
epoch 2301:6.394782066345215
epoch 2302:6.383334159851074
epoch 2303:6.371867656707764
epoch 2304:6.36051082611084
epoch 2305:6.349109649658203
epoch 2306:6.337827682495117
epoch 2307:6.326476573944092
epoch 2308:6.315171241760254
epoch 2309:6.303837776184082
epoch 2310:6.292641

epoch 2568:3.9956326484680176
epoch 2569:3.988726854324341
epoch 2570:3.9817962646484375
epoch 2571:3.974911689758301
epoch 2572:3.967978000640869
epoch 2573:3.9611570835113525
epoch 2574:3.954293727874756
epoch 2575:3.9474167823791504
epoch 2576:3.940610885620117
epoch 2577:3.9337825775146484
epoch 2578:3.9269113540649414
epoch 2579:3.9201598167419434
epoch 2580:3.913348913192749
epoch 2581:3.906599521636963
epoch 2582:3.8998117446899414
epoch 2583:3.8930766582489014
epoch 2584:3.886368989944458
epoch 2585:3.879585027694702
epoch 2586:3.872899293899536
epoch 2587:3.866194725036621
epoch 2588:3.8595099449157715
epoch 2589:3.8527989387512207
epoch 2590:3.846174478530884
epoch 2591:3.839524507522583
epoch 2592:3.8329129219055176
epoch 2593:3.8262381553649902
epoch 2594:3.8196563720703125
epoch 2595:3.812994956970215
epoch 2596:3.8064465522766113
epoch 2597:3.7999141216278076
epoch 2598:3.793257236480713
epoch 2599:3.786755323410034
epoch 2600:3.7802295684814453
epoch 2601:3.7736477851867

epoch 2886:2.3194034099578857
epoch 2887:2.3154592514038086
epoch 2888:2.3115389347076416
epoch 2889:2.307661294937134
epoch 2890:2.303783893585205
epoch 2891:2.299863338470459
epoch 2892:2.2960000038146973
epoch 2893:2.2921102046966553
epoch 2894:2.2882471084594727
epoch 2895:2.284419536590576
epoch 2896:2.280559778213501
epoch 2897:2.2767386436462402
epoch 2898:2.2728800773620605
epoch 2899:2.269052505493164
epoch 2900:2.265244483947754
epoch 2901:2.2614126205444336
epoch 2902:2.257571220397949
epoch 2903:2.2537856101989746
epoch 2904:2.2499892711639404
epoch 2905:2.2461891174316406
epoch 2906:2.2424421310424805
epoch 2907:2.238636016845703
epoch 2908:2.2348849773406982
epoch 2909:2.231106996536255
epoch 2910:2.227332592010498
epoch 2911:2.2235991954803467
epoch 2912:2.2198379039764404
epoch 2913:2.2161331176757812
epoch 2914:2.212347984313965
epoch 2915:2.208674192428589
epoch 2916:2.2048850059509277
epoch 2917:2.2012133598327637
epoch 2918:2.197521448135376
epoch 2919:2.19382524490

epoch 3200:1.3718059062957764
epoch 3201:1.3695234060287476
epoch 3202:1.3672401905059814
epoch 3203:1.364974856376648
epoch 3204:1.3627474308013916
epoch 3205:1.360451340675354
epoch 3206:1.358219861984253
epoch 3207:1.3559513092041016
epoch 3208:1.3537030220031738
epoch 3209:1.3514649868011475
epoch 3210:1.3492214679718018
epoch 3211:1.3469717502593994
epoch 3212:1.3447504043579102
epoch 3213:1.3425309658050537
epoch 3214:1.3403100967407227
epoch 3215:1.3380868434906006
epoch 3216:1.335845708847046
epoch 3217:1.3336634635925293
epoch 3218:1.3314757347106934
epoch 3219:1.3292423486709595
epoch 3220:1.3270821571350098
epoch 3221:1.3248438835144043
epoch 3222:1.3226557970046997
epoch 3223:1.3204624652862549
epoch 3224:1.3182713985443115
epoch 3225:1.3160731792449951
epoch 3226:1.3139073848724365
epoch 3227:1.3117576837539673
epoch 3228:1.3095784187316895
epoch 3229:1.307410478591919
epoch 3230:1.3052481412887573
epoch 3231:1.3030837774276733
epoch 3232:1.300925850868225
epoch 3233:1.298

epoch 3535:0.7899409532546997
epoch 3536:0.7886587381362915
epoch 3537:0.7873635292053223
epoch 3538:0.7860822677612305
epoch 3539:0.7847974300384521
epoch 3540:0.7835143804550171
epoch 3541:0.782259464263916
epoch 3542:0.7809499502182007
epoch 3543:0.7796534299850464
epoch 3544:0.7783796787261963
epoch 3545:0.7771199941635132
epoch 3546:0.7758417129516602
epoch 3547:0.7745700478553772
epoch 3548:0.7732954025268555
epoch 3549:0.7720526456832886
epoch 3550:0.770771861076355
epoch 3551:0.7695131301879883
epoch 3552:0.7682713270187378
epoch 3553:0.7669944763183594
epoch 3554:0.7657495737075806
epoch 3555:0.7644914388656616
epoch 3556:0.7632424235343933
epoch 3557:0.7620128393173218
epoch 3558:0.7607417106628418
epoch 3559:0.7594990730285645
epoch 3560:0.7582579255104065
epoch 3561:0.75701904296875
epoch 3562:0.755791187286377
epoch 3563:0.7545530796051025
epoch 3564:0.7533082962036133
epoch 3565:0.7520924806594849
epoch 3566:0.7508504390716553
epoch 3567:0.7496163249015808
epoch 3568:0.74

epoch 3840:0.48024308681488037
epoch 3841:0.4794784188270569
epoch 3842:0.4786907434463501
epoch 3843:0.4779086410999298
epoch 3844:0.47712230682373047
epoch 3845:0.47635146975517273
epoch 3846:0.4755781888961792
epoch 3847:0.47480523586273193
epoch 3848:0.4740349054336548
epoch 3849:0.47326740622520447
epoch 3850:0.47249534726142883
epoch 3851:0.47173523902893066
epoch 3852:0.47095954418182373
epoch 3853:0.47018587589263916
epoch 3854:0.469429075717926
epoch 3855:0.4686780571937561
epoch 3856:0.46791672706604004
epoch 3857:0.4671488106250763
epoch 3858:0.4663963317871094
epoch 3859:0.46563321352005005
epoch 3860:0.46487194299697876
epoch 3861:0.46412456035614014
epoch 3862:0.4633725583553314
epoch 3863:0.4626283049583435
epoch 3864:0.4618738293647766
epoch 3865:0.46111443638801575
epoch 3866:0.46036040782928467
epoch 3867:0.45961499214172363
epoch 3868:0.45887404680252075
epoch 3869:0.45812949538230896
epoch 3870:0.45738428831100464
epoch 3871:0.4566435217857361
epoch 3872:0.455901563

epoch 4143:0.29376220703125
epoch 4144:0.29329851269721985
epoch 4145:0.29280954599380493
epoch 4146:0.2923489809036255
epoch 4147:0.2918747067451477
epoch 4148:0.29140862822532654
epoch 4149:0.2909296154975891
epoch 4150:0.2904558479785919
epoch 4151:0.2899799644947052
epoch 4152:0.2895166873931885
epoch 4153:0.2890421748161316
epoch 4154:0.28857266902923584
epoch 4155:0.288112998008728
epoch 4156:0.2876511216163635
epoch 4157:0.2871825397014618
epoch 4158:0.2867216467857361
epoch 4159:0.28626376390457153
epoch 4160:0.2858009338378906
epoch 4161:0.2853345572948456
epoch 4162:0.2848706841468811
epoch 4163:0.2844099700450897
epoch 4164:0.2839525640010834
epoch 4165:0.28349459171295166
epoch 4166:0.28303831815719604
epoch 4167:0.2825683653354645
epoch 4168:0.2821202576160431
epoch 4169:0.2816507816314697
epoch 4170:0.2812004089355469
epoch 4171:0.28075653314590454
epoch 4172:0.28031009435653687
epoch 4173:0.27984943985939026
epoch 4174:0.2793952226638794
epoch 4175:0.27894070744514465
ep

epoch 4415:0.18926602602005005
epoch 4416:0.188946932554245
epoch 4417:0.18865706026554108
epoch 4418:0.1883464753627777
epoch 4419:0.1880398392677307
epoch 4420:0.18774911761283875
epoch 4421:0.18744415044784546
epoch 4422:0.18714284896850586
epoch 4423:0.18683281540870667
epoch 4424:0.1865374743938446
epoch 4425:0.18623998761177063
epoch 4426:0.18593262135982513
epoch 4427:0.18563750386238098
epoch 4428:0.18533316254615784
epoch 4429:0.18503345549106598
epoch 4430:0.18474039435386658
epoch 4431:0.18443886935710907
epoch 4432:0.18414533138275146
epoch 4433:0.1838466078042984
epoch 4434:0.1835538148880005
epoch 4435:0.18326625227928162
epoch 4436:0.1829601228237152
epoch 4437:0.1826661080121994
epoch 4438:0.1823708862066269
epoch 4439:0.1820748746395111
epoch 4440:0.1817762553691864
epoch 4441:0.18148455023765564
epoch 4442:0.18120050430297852
epoch 4443:0.1809036135673523
epoch 4444:0.18061107397079468
epoch 4445:0.18031829595565796
epoch 4446:0.18002541363239288
epoch 4447:0.17974165

epoch 4723:0.11517034471035004
epoch 4724:0.11498789489269257
epoch 4725:0.11480240523815155
epoch 4726:0.11461327970027924
epoch 4727:0.11442780494689941
epoch 4728:0.11424607038497925
epoch 4729:0.11405741423368454
epoch 4730:0.11388230323791504
epoch 4731:0.1136959046125412
epoch 4732:0.11351527273654938
epoch 4733:0.11333083361387253
epoch 4734:0.11314967274665833
epoch 4735:0.11296574771404266
epoch 4736:0.1127820834517479
epoch 4737:0.11260366439819336
epoch 4738:0.11241667717695236
epoch 4739:0.11224610358476639
epoch 4740:0.11205708235502243
epoch 4741:0.11188097298145294
epoch 4742:0.11169640719890594
epoch 4743:0.11152151226997375
epoch 4744:0.11134248971939087
epoch 4745:0.11116550117731094
epoch 4746:0.1109815314412117
epoch 4747:0.11080017685890198
epoch 4748:0.11062674969434738
epoch 4749:0.11044846475124359
epoch 4750:0.11027126014232635
epoch 4751:0.11009305715560913
epoch 4752:0.10991667211055756
epoch 4753:0.10973696410655975
epoch 4754:0.10956362634897232
epoch 4755: