Skip to content

Commit

Permalink
Fix metadata.min.json
Browse files Browse the repository at this point in the history
  • Loading branch information
ControlNet committed Mar 28, 2024
1 parent c4b4d32 commit 7d84e6d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
7 changes: 6 additions & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os

import toml
import torch
Expand All @@ -8,7 +9,7 @@
from metrics import AP, AR
from model import Batfd, BatfdPlus
from post_process import post_process
from utils import read_json
from utils import generate_metadata_min, read_json

parser = argparse.ArgumentParser(description="BATFD evaluation")
parser.add_argument("--config", type=str)
Expand Down Expand Up @@ -125,6 +126,10 @@ def evaluate_lavdf(config, args):

if __name__ == '__main__':
args = parser.parse_args()

if os.path.exists(os.path.join(args.data_root, "metadata.min.json")):
generate_metadata_min(args.data_root)

config = toml.load(args.config)
torch.backends.cudnn.benchmark = True
if config["dataset"] == "lavdf":
Expand Down
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import argparse
import os

import toml
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from dataset.lavdf import LavdfDataModule
from model import Batfd, BatfdPlus
from utils import LrLogger, EarlyStoppingLR
from utils import LrLogger, EarlyStoppingLR, generate_metadata_min

parser = argparse.ArgumentParser(description="BATFD training")
parser.add_argument("--config", type=str)
Expand All @@ -24,6 +25,9 @@
args = parser.parse_args()
config = toml.load(args.config)

if os.path.exists(os.path.join(args.data_root, "metadata.min.json")):
generate_metadata_min(args.data_root)

learning_rate = config["optimizer"]["learning_rate"]
gpus = args.gpus
total_batch_size = args.batch_size * gpus
Expand Down
11 changes: 11 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from importlib import metadata
import json
import os
import re
from abc import ABC
from typing import List, Tuple, Optional
Expand Down Expand Up @@ -219,3 +221,12 @@ def _run_early_stop_checking(self, trainer: Trainer) -> None:
elif self.mode == "any":
if any(lr <= self.lr_threshold for lr in all_lr):
trainer.should_stop = True


def generate_metadata_min(data_root: str):
metadata_full = read_json(os.path.join(data_root, "metadata.json"))
metadata_min = []
for meta in metadata_full:
del meta["timestamps"]
with open(os.path.join(data_root, "metadata.min.json"), "w") as f:
json.dump(metadata_min, f)

0 comments on commit 7d84e6d

Please sign in to comment.