## Script for training and saving models

#### Layout of files

Model.py
- Defines the entire model
    - Layers
    - How it trainsforms data from input to output

Trainer.py
- Defines how the model is trained 
- Defines curriculum learning
    - How it's trained initially
        - X steps just to get similar output as input
        - make it stable
    - How it's trained when we assume it's getting smarter
        - Learn it to move in the direction of the reward

Utils.py
- Contains all helper functions

States.py
- How to get new random states
    - Output x, y and food
        - x being the random initial state
            - Should be of size in range of (x1, x2)
            - Should be one entity complying with certain rules
        - y being the target output after e epochs
        - food being the desired target location determine the direction of the CA

#### Imports

In [1]:
import torch
import numpy as np
from Trainer import Trainer
from Model import Complex_CA

#### Setup

In [4]:
#device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
#device = torch.device('mps:0' if torch.backends.mps.is_available else 'cpu')
device = torch.device('cpu')
batch_size = 16
model = Complex_CA(device, batch_size)
model = model.to(device)
trainer = Trainer(model, device)
print(device)

cpu


#### Training

In [5]:
seed = 2
torch.manual_seed(seed)
np.random.seed(seed)
model, losses = trainer.train()

0it [00:00, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

0.004258122760802507
0.00011605492181843147
0.0002304360386915505
4.5191551180323586e-05
5.424371192930266e-05
2.081611455651e-05
2.331527321075555e-05
6.050237061572261e-06
8.589723620389123e-06
3.643297532107681e-06
1.233820785273565e-05
3.820583970082225e-06
2.5968506633944344e-06
4.731088665721472e-06
2.478422175045125e-06
1.41071518555691e-06
3.1317326829594094e-06
1.7399679563823156e-06
3.665698841359699e-06
6.92609148700285e-07


  1%|          | 1/100 [00:46<1:16:11, 46.18s/it]

1.8040705072053242e-06
1.0888463748415234e-06
2.2048702703614254e-06
8.296150326714269e-07
1.6619694633845938e-06
4.262553545686387e-07
1.801119196898071e-06
3.6748198795066855e-07
1.1429391406636569e-06
2.0802421829557716e-07
1.026347831611929e-06
3.0850526400172384e-07
1.6859576135175303e-07


  1%|          | 1/100 [01:16<2:06:22, 76.59s/it]


KeyboardInterrupt: 

#### Save model and losses graph

In [None]:
#save model
torch.save(model.state_dict(), 'models/complex_ca5.pth')

#save graph
print(losses.shape)
np.save('losses', losses)