## Script for training and saving models

#### Layout of files

Model.py
- Defines the entire model
    - Layers
    - How it trainsforms data from input to output

Trainer.py
- Defines how the model is trained 
- Defines curriculum learning
    - How it's trained initially
        - X steps just to get similar output as input
        - make it stable
    - How it's trained when we assume it's getting smarter
        - Learn it to move in the direction of the reward

Utils.py
- Contains all helper functions

States.py
- How to get new random states
    - Output x, y and food
        - x being the random initial state
            - Should be of size in range of (x1, x2)
            - Should be one entity complying with certain rules
        - y being the target output after e epochs
        - food being the desired target location determine the direction of the CA

#### Imports

In [1]:
import torch
import numpy as np
from Trainer import Trainer
from Model import Complex_CA

#### Setup

In [2]:
#device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
#device = torch.device('mps:0' if torch.backends.mps.is_available else 'cpu')
device = torch.device('cpu')
batch_size = 16
model = Complex_CA(device, batch_size)
model = model.to(device)
trainer = Trainer(model, device)
print(device)

cpu


#### Training

In [3]:
seed = 2
torch.manual_seed(seed)
np.random.seed(seed)
model, losses = trainer.train()

0it [00:00, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

0.06907821446657181
0.001110689714550972
0.0004709625500254333
0.0003999376785941422
0.000303095206618309
0.0004710940702352673
0.0003977002343162894
0.00034802925074473023
0.00048760566278360784
0.0003964136412832886
0.0004222887509968132
0.0005446759751066566
0.00027731238515116274
0.0002596256381366402
0.0003525568754412234
0.00033268469269387424
0.0001891341235022992
0.0001903492520796135
0.00012597520253621042
0.00013063231017440557


  1%|          | 1/100 [00:16<27:34, 16.71s/it]

0.00024957209825515747
0.00015755306230857968
0.0001028290280373767
5.9294434322509915e-05
5.411235906649381e-05
6.098274752730504e-05
3.5601846320787445e-05
2.796073749777861e-05
3.0485098250210285e-05
8.984930900624022e-06
1.3706646313949022e-05
2.7068870622315444e-05
1.619232352823019e-05
1.1863063264172524e-05
1.9529423298081383e-05
9.90628086583456e-06
1.0963049135170877e-05
8.454186172457412e-06
4.9843483793665655e-06
8.261963557743002e-06


  2%|▏         | 2/100 [00:33<27:10, 16.64s/it]

0.00045908265747129917
0.00015808352327439934
9.336251969216391e-05
4.143644036957994e-05
4.719496791949496e-05
3.4796179534168914e-05
3.084863783442415e-05
3.9147973438957706e-05
5.627542486763559e-05
2.4609509637230076e-05
1.981198147404939e-05
4.2541381844785064e-05
5.1347735279705375e-05
4.397736847749911e-05
2.037683589151129e-05
9.133087587542832e-05
5.9305381000740454e-05
3.9189228118630126e-05
4.0099796024151146e-05
1.6984027752187103e-05


  3%|▎         | 3/100 [01:08<40:44, 25.20s/it]

3.402208676561713e-05
3.0255745514295995e-05
2.2063912183512002e-05
2.2154545149533078e-05
4.804220225196332e-05
3.7970727134961635e-05
2.8611524612642825e-05
5.0684018788160756e-05
1.605963370820973e-05
1.465569766878616e-05
2.9222146622487344e-05
2.5427971195313148e-05
2.2360920411301777e-05
1.5352838090620935e-05
3.654447209555656e-05
1.0299065252183937e-05
2.206204953836277e-05
2.1909019778831862e-05
9.564486390445381e-06
2.0346049495856278e-05


  4%|▍         | 4/100 [01:44<46:49, 29.27s/it]

0.005501160863786936
0.0019939763005822897
0.0024124188348650932
0.0022055329754948616
0.0018377293599769473
0.0019466785015538335
0.0019151524174958467
0.0017221489688381553
0.002028411952778697
0.00186790875159204
0.0023490088060498238
0.0019234907813370228
0.001979649066925049
0.002114793984219432
0.0018684010719880462
0.0020811629947274923
0.0019152277382090688
0.001862656557932496
0.0022814555559307337
0.0021338490769267082


  5%|▌         | 5/100 [02:42<1:02:50, 39.69s/it]

0.0018104681512340903
0.0020765960216522217
0.001904208562336862
0.0015751896426081657
0.0020337533205747604
0.002294218400493264
0.0021406651940196753
0.0019950950518250465
0.0021062318701297045
0.002025752095505595
0.001517349504865706
0.0016739919083192945
0.0020100371912121773
0.002050605369731784
0.002128081861883402
0.001798024633899331
0.0018015870591625571
0.0019692028872668743
0.0019456229638308287
0.0023164262529462576


  6%|▌         | 6/100 [03:41<1:12:47, 46.46s/it]

0.023054232820868492
0.007142051123082638
0.006280178669840097
0.005315849091857672
0.006840042769908905
0.005986063275486231
0.004933252930641174
0.005091633182018995
0.0042870501056313515
0.006092655472457409
0.004774353466928005
0.005658967420458794
0.005573849193751812
0.005582842044532299
0.005493054166436195
0.0053717694245278835
0.0059796292334795
0.005790463648736477
0.005980882793664932
0.004944181535393


  7%|▋         | 7/100 [05:12<1:34:06, 60.72s/it]

0.0071549564599990845
0.0055794548243284225
0.005844946019351482
0.005598665215075016
0.0053162043914198875
0.004950679838657379
0.004838288761675358
0.0056084999814629555
0.00509205786511302
0.005207043141126633
0.005747062619775534
0.00450622383505106
0.0049562701024115086
0.005239211954176426
0.006014835089445114
0.005907974671572447
0.005107834003865719
0.0050434693694114685
0.005546415690332651
0.005129710305482149


  8%|▊         | 8/100 [06:34<1:43:33, 67.53s/it]

0.13885878026485443
0.019334685057401657
0.023429738357663155
0.019633550196886063
0.020452331751585007
0.017903022468090057
0.019149789586663246
0.016075558960437775
0.015850352123379707
0.014700856991112232
0.014726277440786362
0.01557092647999525
0.015554478392004967
0.014917691238224506
0.013984083198010921
0.014064366929233074
0.014166682958602905
0.01631757616996765
0.015476829372346401
0.01526717934757471


  9%|▉         | 9/100 [08:49<2:14:29, 88.67s/it]

0.02825339324772358
0.01076938584446907


  9%|▉         | 9/100 [08:56<1:30:28, 59.66s/it]


KeyboardInterrupt: 

#### Save model and losses graph

In [None]:
#save model
torch.save(model.state_dict(), 'models/complex_ca5.pth')

#save graph
print(losses.shape)
np.save('losses', losses)