In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from src.training import *
from src.pipeline.loaders import collate_numpy_matrices_without_conditions
from src.pipeline.datasets import RotationalDataset, prepare_and_save_seed_dataset, MatrixTupleDataset
import warnings
warnings.filterwarnings("ignore", message="Unknown entity 'ee-infinity-loader'")

filename = 'data/processed/paths.npz'
dims = (20,20)

if not os.path.exists(filename):
    prepare_and_save_seed_dataset(filename, dims, repr_version=5, seed_paths=10)

matrix_dataset = MatrixTupleDataset(filename)
rotational_datalist = RotationalDataset(matrix_dataset)

dataloader = DataLoader(
    rotational_datalist,
    batch_size=256, 
    collate_fn=collate_numpy_matrices,
    num_workers=0,
    pin_memory=True
)
train_dataloader, test_dataloader = split_dataloader(dataloader, val_split=0.2)
train_dataloader, val_dataloader = split_dataloader(train_dataloader, val_split=0.1)

input_size = dims[0]

print(f"Dataset size is: {len(rotational_datalist)}")

In [None]:
from src.model import BinaryMatrixTransformCNN, AttentiveBinaryMatrixTransformCNN, FactorioCNN_PixelOutput
from src.training import train_model
import itertools

input_size = dims[0]
model1 = BinaryMatrixTransformCNN(matrix_size=input_size)
model2 = AttentiveBinaryMatrixTransformCNN(matrix_size=input_size)

device = 'cuda'

models = [model1, model2]
for moodel in models:
    log_dir = f'runs/{datetime.now().strftime("%Y%m%d-%H%M%S")}_{moodel.filename}'
    moodel = moodel.to(device=device)
    train_model(moodel, train_loader=train_dataloader,
                val_loader=val_dataloader, log_dir=log_dir,
                integrity_weight=0.0, num_epochs=50, device=device)
    test_model(moodel, test_dataloader, device=device, log_dir=log_dir)
    del(moodel)

# underscore as to not shadow str
for bias in [-1.5, -0.5]:
    str_ = 0.5
    moodel = FactorioCNN_PixelOutput(4, 4, 5, 3, 3, 2, 32, 21, 
                                    presence_gate_strength=str_,
                                    presence_gate_bias=bias)
    moodel.filename += f"_pgst{str_}_bias{bias}"
    log_dir = f'runs/{datetime.now().strftime("%Y%m%d-%H%M%S")}_{moodel.filename}'
    train_model(moodel, train_loader=train_dataloader,
                val_loader=val_dataloader, log_dir=log_dir,
                integrity_weight=0.0, num_epochs=50, device=device)
    test_model(moodel, test_dataloader, device=device, log_dir=log_dir)
    del(moodel)  # great farm in the sky