In [1]:
import os

import torch
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger

from common.file_paths import BASE_DIR
from modules import (GenerateSchematicCallback,
                     LightningTransformerMinecraftStructureGenerator,
                     MinecraftDataModule)

torch.set_float32_matmul_precision('medium')

seed_everything(0, workers=True)

experiment_name = "real_run"
experiment_version = 0
checkpoint_dir = "lightning_logs"

# Load the checkpoint and create the logger
logger = TensorBoardLogger(checkpoint_dir, name=experiment_name, version=experiment_version)

# Load the model and create the trainer
latest_checkpoint = os.path.join(checkpoint_dir, experiment_name, f'version_{experiment_version}', 'checkpoints', 'last.ckpt')
if os.path.exists(latest_checkpoint):
    print(f"Loading checkpoint {latest_checkpoint}")
    lightning_model = LightningTransformerMinecraftStructureGenerator.load_from_checkpoint(
        latest_checkpoint)
else:
    print("Creating new model")
    lightning_model = LightningTransformerMinecraftStructureGenerator(
        num_classes=10000,
        max_sequence_length=1331,
        embedding_dropout=0.2,
        model_dim=768,
        num_heads=12,
        num_layers=12,
        decoder_dropout=0.2,
        learning_rate=1e-4
    )

hdf5_file = os.path.join(BASE_DIR, 'data.h5')
data_module = MinecraftDataModule(
    file_path=hdf5_file,
    batch_size=1,
    # num_workers=4
)

# Callback to save the latest checkpoint for resuming training
latest_checkpoint_callback = ModelCheckpoint(
    save_last=True,
    save_weights_only=False
)
# Callback to save the best model (weights only)
best_checkpoint_callback = ModelCheckpoint(
    filename='best-{epoch}-{val_loss:.2f}',
    save_top_k=2,
    monitor='val_loss',
    mode='min',
    save_weights_only=True
)
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=50,
    mode='min'
)
generate_schematic_callback = GenerateSchematicCallback(
    masked_path='schematic_viewer/public/schematics/masked/',
    filled_path='schematic_viewer/public/schematics/filled/',
    data_module=data_module,
    generate_train=False,
    generate_val=True,
    generate_all_datasets=False,
    generate_every_n_epochs=1,
    temperature=0.7
)

trainer = Trainer(
    max_epochs=5000,
    # profiler='simple',
    gradient_clip_val=1.0,
    log_every_n_steps=5,
    val_check_interval=0.1,
    callbacks=[
        latest_checkpoint_callback,
        best_checkpoint_callback,
        early_stop_callback,
        generate_schematic_callback
    ]
)

trainer.fit(lightning_model, datamodule=data_module)

  from .autonotebook import tqdm as notebook_tqdm
Seed set to 0


Creating new model


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs





LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                                   | Params
-----------------------------------------------------------------
0 | model | TransformerMinecraftStructureGenerator | 110 M 
-----------------------------------------------------------------
110 M     Trainable params
0         Non-trainable params
110 M     Total params
443.707   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\mmmfr\Documents\Repositories\minecraft-schematic-generator\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


                                                                           

c:\Users\mmmfr\Documents\Repositories\minecraft-schematic-generator\.venv\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 0:   8%|▊         | 4/48 [00:02<00:22,  1.91it/s, v_num=3]