In [None]:
from src.training import *
from src.pipeline.datasets import ChunkedDiskCachedDatasetWrapper
import warnings
warnings.filterwarnings("ignore", message="Unknown entity 'ee-infinity-loader'")


force_rebuild = False  # Used to wipe out the cache.
has_been_cached = False  # Used to avoid reprocessing our dataset once we've done it once.

# Added something to speed up subsequent training.
# 
if not has_been_cached:
    dims = (20,20)
    rv = 4
    datalist = prepare_dataset(dims=dims, repr_version=rv)
    rotational_datalist = AugmentedListDataset(*datalist)
    datalist[0] = [d.get_matrix(dims, rv) for d in datalist[0]]
    datalist[2] = [d.get_matrix(dims, rv) for d in datalist[2]]
    cached_dataset = ChunkedDiskCachedDatasetWrapper(rotational_datalist,
                                                     force_rebuild=force_rebuild,
                                                     cache_dir='dataset_cache')
else:
    cached_dataset = ChunkedDiskCachedDatasetWrapper.from_cache(cache_dir='dataset_cache')

dataloader = DataLoader(
    cached_dataset,
    batch_size=32, 
    collate_fn=collate_numpy_matrices
)
train_dataloader, val_dataloader = split_dataloader(dataloader)
model = DeepQCNN(input_channels=8, output_channels=8).float()

trainer = QCNNTrainer(
        model=model,
        device='mps',
        learning_rate=0.001,
        log_dir='runs/qcnn',
    )

# Train model
trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    num_epochs=15,
    save_path='models/qcnn_best.pth'
)

Epoch 1/15
  Train: base_loss: 0.5413
  Val: base_loss: 0.5364
  Saved best model to models/qcnn_best.pth
Epoch 2/15
  Train: base_loss: 0.5374
  Val: base_loss: 0.5364
Epoch 3/15
  Train: base_loss: 0.5371
  Val: base_loss: 0.5364


KeyboardInterrupt: 

In [None]:
model = DeepQCNN(input_channels=8, output_channels=8)

for name, param in model.named_parameters():
    print(f"Parameter {name}: {param.dtype}")