In [2]:
%load_ext autoreload
%autoreload 2
import os
import sys
# Add project root to path - adjust the number of parent dirs (..) based on where your notebook is located
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
import torch
from torch.optim import Adam
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

import src.datamodule as datamodule
import src.models as models

In [3]:
data_module = datamodule.TrajectoryDataModule(
            data_folder='../data/',
            batch_size=32,
            max_sequence_length=None,
            series_type=['position'],
            include_game_state_vars=False,
            include_timesteps=True
        )

data_module.setup()
train_loader = data_module.train_dataloader()
data = next(iter(train_loader))

len(data_module.train_dataset)

print(data['trajectory'].shape)
print(data['mask'].shape) 
print(data['game_id'].shape)

print(data['trajectory'][1])


torch.Size([32, 6320, 3])
torch.Size([32, 6320])
torch.Size([32])
tensor([[ 0.1900,  0.0000, -9.5000],
        [ 0.2300,  0.0000, -9.5000],
        [ 0.2900,  0.0000, -9.5000],
        ...,
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000]])


In [4]:
model = models.LSTMAutoencoder(
        input_dim=data_module.feature_dimension,
        latent_dim=32,  # Adjust latent dimension as needed
        sequence_length=data_module.sequence_length,
        learning_rate=1e-3
    )


X_hat, z = model(data['trajectory'])

print(X_hat.shape)

print(z.shape)


torch.Size([32, 6320, 3])
torch.Size([32, 32])


In [None]:
# Create a checkpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints',
    filename='autoencoder-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    monitor='val_loss',
    mode='min'
)

# Initialize the trainer
trainer = pl.Trainer(
    max_epochs=100,  # Adjust number of epochs as needed
    accelerator='auto',  # Will automatically detect if you have GPU available
    callbacks=[checkpoint_callback],
    # Enable progress bar
    enable_progress_bar=True,
    # Add validation every N epochs
    check_val_every_n_epoch=1
)

# Train the model
trainer.fit(
    model=model,
    train_dataloaders=data_module.train_dataloader(),
    val_dataloaders=data_module.val_dataloader()
)