## Script for training and saving models

#### Layout of files

Model.py
- Defines the entire model
    - Layers
    - How it trainsforms data from input to output

Trainer.py
- Defines how the model is trained 
- Defines curriculum learning
    - How it's trained initially
        - X steps just to get similar output as input
        - make it stable
    - How it's trained when we assume it's getting smarter
        - Learn it to move in the direction of the reward

Utils.py
- Contains all helper functions

States.py
- How to get new random states
    - Output x, y and food
        - x being the random initial state
            - Should be of size in range of (x1, x2)
            - Should be one entity complying with certain rules
        - y being the target output after e epochs
        - food being the desired target location determine the direction of the CA

#### Imports

In [1]:
import torch
import numpy as np
from Trainer import Trainer
from Model import Complex_CA

#### Setup

In [2]:
#device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
#device = torch.device('mps:0' if torch.backends.mps.is_available else 'cpu')
device = torch.device('cpu')
batch_size = 16
model = Complex_CA(device, batch_size)
model = model.to(device)
trainer = Trainer(model, device)
print(device)

cpu


#### Training

In [3]:
seed = 2
torch.manual_seed(seed)
np.random.seed(seed)
model, losses = trainer.train()

0it [00:00, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

0.17085902392864227
0.005538702942430973
0.0002948479959741235
0.0004460290656425059
0.0005173355457372963
0.00046597723849117756
0.00023246303317137063
0.00023489298473577946
0.0004651088092941791
0.0001947474229382351
0.00021267245756462216
0.0001278896233998239
9.112533007282764e-05
0.0001390504330629483
0.00011855008779093623
9.964736818801612e-05
4.906527465209365e-05
7.867500244174153e-05
6.0570666391868144e-05
9.998514724429697e-05


  1%|          | 1/100 [00:17<28:10, 17.08s/it]

0.00012187294487375766
0.00010282566654495895
2.7120184313389473e-05
3.4563010558485985e-05
3.3998865546891466e-05
4.550459561869502e-05
3.244389154133387e-05
5.281270568957552e-05
3.490227391012013e-05
2.1857389583601616e-05
2.2251248083193786e-05
1.1460462701506913e-05
3.6717574403155595e-05
9.622160177968908e-06
1.3937358744442463e-05
2.1623394786729477e-05
3.6631925468100235e-05
1.2027639058942441e-05
5.827923814649694e-06
6.558156201208476e-06


  2%|▏         | 2/100 [00:33<27:28, 16.82s/it]

0.0004010238917544484
0.00015967636136338115
4.63645119452849e-05
0.00010460718476679176
3.904565892298706e-05
5.9232377680018544e-05
5.896153743378818e-05
6.403513543773443e-05
7.503660162910819e-05
6.425953324651346e-05
6.157977622933686e-05
1.9757300833589397e-05
2.6597026590025052e-05
4.368318695924245e-05
3.9701837522443384e-05
7.485983951482922e-05
7.551995076937601e-05
3.3907086617546156e-05
3.122706766589545e-05
2.3092617993825115e-05


  3%|▎         | 3/100 [01:09<41:32, 25.69s/it]

9.41440011956729e-05
3.155552622047253e-05
3.1790175853529945e-05
5.57936800760217e-05
1.9901492123608477e-05
4.182446718914434e-05
4.5947810576763004e-05
2.8971782739972696e-05
2.460997347952798e-05
2.3667689674766734e-05
2.9989687391207553e-05
2.1290028598741628e-05
4.664575681090355e-05
2.6689602236729115e-05
4.17849951190874e-05
1.9273658836027607e-05
4.2897710955003276e-05
4.578607331495732e-05
4.177471419097856e-05
1.996428909478709e-05


  4%|▍         | 4/100 [01:46<48:06, 30.06s/it]

0.002356124110519886
0.0017371225403621793
0.002640234539285302
0.002319033956155181
0.0017955637304112315
0.0019667509477585554
0.001942234463058412
0.0017200157744809985
0.001982334302738309
0.0019282294670119882
0.002320273779332638
0.0020269625820219517
0.002012609736993909
0.002177812159061432
0.0019213373307138681
0.00210069352760911
0.0019571813754737377
0.001779530430212617
0.0022768101189285517
0.0020880787633359432


  5%|▌         | 5/100 [02:42<1:01:59, 39.15s/it]

0.0017164071323350072
0.0020439757499843836
0.0018817669479176402
0.001618660637177527
0.00200754776597023
0.0023321672342717648
0.0021174962166696787
0.002024024026468396
0.0020814863964915276
0.002061933046206832
0.0015410991618409753
0.0017755022272467613
0.0020846922416239977
0.0020759417675435543
0.002147792838513851
0.001816916512325406
0.0018418682739138603
0.0019984054379165173
0.0019486790988594294
0.0023062494583427906


  6%|▌         | 6/100 [03:38<1:10:31, 45.01s/it]

0.006330989301204681
0.006041217595338821
0.004503147676587105
0.0048981462605297565
0.006334742531180382
0.005720558576285839
0.005008941516280174
0.004961512982845306
0.004383762367069721
0.005919408518821001
0.004806213080883026
0.00556267611682415
0.005636135581880808
0.005513615440577269
0.005255354102700949
0.005145779810845852
0.005917354021221399
0.005764767527580261
0.005877942778170109
0.004814246669411659


  7%|▋         | 7/100 [05:04<1:30:24, 58.33s/it]

0.005952142644673586
0.005629259627312422
0.005820804741233587
0.00553083885461092
0.005268787033855915
0.0049942233599722385
0.00491924025118351
0.005567058455199003
0.005063354037702084
0.005168563686311245
0.005746917333453894
0.004499031230807304
0.005004174076020718
0.005155465565621853
0.00594848208129406
0.005877602845430374
0.005141152068972588
0.005117219872772694
0.005404818803071976
0.005130499601364136


  8%|▊         | 8/100 [06:19<1:37:54, 63.86s/it]

0.013982010073959827
0.013096218928694725
0.015318388119339943
0.01614946313202381
0.014477583579719067
0.016050590202212334
0.016017328947782516
0.014450524933636189
0.015273837372660637
0.013537510298192501
0.014272989705204964
0.015088189393281937
0.015017971396446228
0.014397327788174152
0.014077164232730865
0.01358153484761715
0.013729616068303585
0.01596572995185852
0.014514744281768799
0.013824141584336758


  9%|▉         | 9/100 [08:24<2:05:37, 82.83s/it]

0.008934112265706062
0.010498345829546452
0.008567421697080135
0.010685217566788197
0.010579741559922695
0.010043156333267689
0.008009717799723148
0.008730941452085972
0.009223327971994877
0.010704792104661465
0.010053927078843117
0.009846516884863377
0.009837497025728226
0.010738605633378029
0.008043406531214714
0.009499702602624893
0.009840619750320911
0.009099938906729221
0.0099642938002944
0.009214977733790874


 10%|█         | 10/100 [10:00<2:10:26, 86.96s/it]

0.009768632240593433
0.008765934035182
0.00817436445504427
0.009674781933426857
0.009876000694930553
0.008966193534433842
0.009497454389929771
0.00933120772242546
0.009669926017522812
0.008913997560739517
0.008789838291704655
0.010314510203897953
0.009384841658174992
0.008568949066102505
0.009976658970117569
0.01031837984919548
0.008819800801575184
0.00892132893204689
0.010668688453733921
0.011297708377242088


 11%|█         | 11/100 [11:45<2:17:19, 92.58s/it]

0.03557036817073822
0.021309897303581238
0.02272595465183258
0.024540644139051437
0.02738209255039692
0.026856662705540657
0.02750425599515438
0.026375796645879745
0.021743742749094963
0.021333003416657448
0.022856079041957855
0.026030711829662323
0.02516426146030426
0.023024627938866615
0.022703120484948158
0.024274423718452454
0.022452088072896004
0.024655543267726898
0.024523627012968063
0.026234187185764313


 12%|█▏        | 12/100 [14:31<2:48:18, 114.76s/it]

0.006653875112533569
0.005327739752829075
0.005058904644101858
0.005130913574248552
0.0057405391708016396
0.00519480649381876
0.0064508020877838135
0.004841102287173271
0.005131332669407129
0.0061632730066776276
0.0051983557641506195
0.004546290263533592
0.005061428062617779
0.005615870468318462
0.0050707850605249405
0.004674960393458605
0.0056371972896158695
0.00527638616040349
0.004105372820049524
0.005496107041835785


 13%|█▎        | 13/100 [15:56<2:33:16, 105.71s/it]

0.022056862711906433
0.016462119296193123
0.016210440546274185
0.019888507202267647
0.015037376433610916
0.016108406707644463
0.0176394023001194
0.019157705828547478
0.019771598279476166
0.01997322216629982
0.02039501443505287
0.01857498101890087
0.01884431578218937
0.01979130692780018
0.01757182739675045
0.019775526598095894
0.019952429458498955
0.01906837895512581
0.0167890265583992
0.018484458327293396


 14%|█▍        | 14/100 [18:15<2:46:03, 115.85s/it]

0.013141129165887833
0.01377611793577671
0.015402569435536861
0.01651257835328579
0.015117739327251911
0.013280957005918026
0.014551079832017422
0.013991307467222214
0.013689612969756126
0.011476457118988037
0.014434730634093285
0.015661446377635002
0.015831321477890015
0.014774687588214874
0.014066467061638832
0.013767234049737453
0.013014712370932102
0.01409541442990303
0.014620696194469929
0.013314052484929562


 15%|█▌        | 15/100 [20:20<2:48:11, 118.73s/it]

0.2919636368751526
0.05271366983652115
0.052471015602350235
0.0455625019967556
0.04428451880812645
0.04046912118792534
0.04383637383580208
0.043658364564180374
0.04507200047373772
0.03415410965681076
0.04025890678167343
0.03784225136041641
0.03696912154555321
0.04056970775127411
0.04493197053670883
0.040880266577005386
0.03622020408511162
0.04584580659866333
0.03321649879217148
0.03443396836519241


 16%|█▌        | 16/100 [24:17<3:35:56, 154.25s/it]

0.03194671869277954
0.03210625797510147
0.03681671619415283
0.039170704782009125
0.03812430426478386
0.03972329571843147
0.029858702793717384
0.030383041128516197
0.034652452915906906
0.0369427464902401
0.04203972965478897
0.03408914804458618
0.030836960300803185
0.03407521918416023
0.03180397301912308
0.03528757020831108
0.03305690363049507
0.036372337490320206
0.03457316756248474
0.03445916250348091


 17%|█▋        | 17/100 [28:04<4:03:24, 175.95s/it]

0.04547371715307236
0.022692399099469185
0.021799415349960327
0.02299334853887558
0.022835100069642067
0.020458092913031578
0.02301783859729767
0.025274403393268585
0.02166237123310566
0.021467765793204308
0.020042425021529198
0.024727798998355865
0.02130405604839325
0.021519389003515244
0.02089294232428074
0.021948451176285744
0.021356383338570595


#### Save model and losses graph

In [None]:
#save model
torch.save(model.state_dict(), 'models/complex_ca5.pth')

#save graph
print(losses.shape)
np.save('losses', losses)