In [None]:
%load_ext autoreload
%autoreload 2
import torch
import pytorch_lightning as pl
import os
import sys

In [None]:
# Note this notebook uses libs from MicKey
HOME = os.environ["HOME"]
mickey_path = os.path.join(HOME, "map_free_localization/mickey")

if os.path.exists(mickey_path):
    sys.path.append(mickey_path)


In [None]:
from torch.utils.data import DataLoader
from lib.datasets.mapfree import MapFreeDataset
from config.default import cfg


In [None]:
data_dir = os.path.join(mickey_path, "config/datasets")
config_path = os.path.join(data_dir, "mapfree.yaml")
config = cfg
config.set_new_allowed(True)
config.DEBUG = False

dataset_dir = "/media/jprincen/HD/Map Free Localization"
if os.path.exists(config_path):
    config.merge_from_file(config_path)
    # explicitely setting to None because if loading from yaml it's a string
    config.DATASET.SCENES = ['s00001']
    # config.DATASET.SCENES = None
    config.DATASET.AUGMENTATION_TYPE = None
    config.DATASET.DATA_ROOT = dataset_dir
else:
    print("Config does not exist")

dataset = MapFreeDataset(config, "val")


In [None]:
# We need model checkpoint, configs for training

In [None]:
from lib.models.MicKey.model import MicKeyTrainingModel
config_dir = os.path.join(mickey_path, "config/MicKey")
config_path = os.path.join(config_dir, "curriculum_learning.yaml")
config = cfg
config.set_new_allowed(True)
config.DEBUG = False
if os.path.exists(config_path):
    config.merge_from_file(config_path)

config.TRAINING.NUM_GPUS = 1
config.BATCH_SIZE = 4
config.NUM_WORKERS = 4
config.SAMPLER = None

tcfg = config.TRAINING
checkpoint_path = os.path.join(mickey_path, "mickey_weights/mickey.ckpt")
model = MicKeyTrainingModel(config)

In [None]:
from pytorch_lightning.loggers import TensorBoardLogger
exp_name = 'overfit_small_seq2'
logger = TensorBoardLogger(save_dir='../weights', name=exp_name)


In [None]:
trainer = pl.Trainer(devices=config.TRAINING.NUM_GPUS,
                         # log_every_n_steps=config.TRAINING.LOG_INTERVAL,
                         log_every_n_steps=1,
                         # max_epochs=config.TRAINING.EPOCHS,
                         max_epochs=30,
                         logger=logger,
                         # gradient_clip_val=config.TRAINING.GRAD_CLIP
                    )


In [None]:
dataloader = DataLoader(dataset,
                        batch_size=config.BATCH_SIZE,
                        num_workers=config.NUM_WORKERS)


In [None]:
trainer.fit(model, train_dataloaders=dataloader, ckpt_path=checkpoint_path)

In [None]:
trainer.save_checkpoint(os.path.join("../mickey_weights", "overfit.ckpt"))
