Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions opensportslib/core/trainer/localization_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def build_trainer(cfg, model=None, default_args=None, resume_from=None):
trainer.best_criterion_valid = checkpoint.get('best_criterion_valid',
0 if cfg.TRAIN.criterion_valid == "map" else float("inf"))
logging.info(f"Restored best epoch: {trainer.best_epoch}")

else:
trainer = Trainer_pl(cfg, default_args["work_dir"])


return trainer
Expand All @@ -147,6 +150,37 @@ def __init__(self):
def train(self):
pass

class Trainer_pl(Trainer):
"""Trainer class used for models that rely on lightning modules.

Args:
cfg (dict): Dict config. It should contain the key 'max_epochs' and the key 'GPU'.
"""

def __init__(self, cfg, work_dir):
from opensportslib.core.utils.lightning import CustomProgressBar, MyCallback
import pytorch_lightning as pl

self.work_dir = work_dir
call = MyCallback()
self.trainer = pl.Trainer(
max_epochs=cfg.max_epochs,
devices=[cfg.GPU],
callbacks=[call, CustomProgressBar(refresh_rate=1)],
num_sanity_val_steps=0,
)

def train(self, **kwargs):
self.trainer.fit(**kwargs)

best_model = kwargs["model"].best_state

logging.info("Done training")
logging.info("Best epoch: {}".format(best_model.get("epoch")))
torch.save(best_model, os.path.join(self.work_dir, "model.pth.tar"))

logging.info("Model saved")
logging.info(os.path.join(self.work_dir, "model.pth.tar"))


class Trainer_e2e(Trainer):
Expand Down
52 changes: 52 additions & 0 deletions opensportslib/core/utils/lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import TQDMProgressBar
import logging


class CustomProgressBar(TQDMProgressBar):
"""Override the custom progress bar used by pytorch lightning to change some attributes."""

def get_metrics(self, trainer, pl_module):
"""Override the method to don't show the version number in the progress bar."""
items = super().get_metrics(trainer, pl_module)
items.pop("v_num", None)
return items


class MyCallback(pl.Callback):
"""Override the Callback class of pl to change the behaviour on validation epoch end."""

def __init__(self):
super().__init__()

def on_validation_epoch_end(self, trainer, pl_module):
loss_validation = pl_module.losses.avg
state = {
"epoch": trainer.current_epoch + 1,
"state_dict": pl_module.model.state_dict(),
"best_loss": pl_module.best_loss,
"optimizer": pl_module.optimizer.state_dict(),
}

# remember best prec@1 and save checkpoint
is_better = loss_validation < pl_module.best_loss
pl_module.best_loss = min(loss_validation, pl_module.best_loss)

# Save the best model based on loss only if the evaluation frequency too long
if is_better:
pl_module.best_state = state
# torch.save(state, best_model_path)

# Reduce LR on Plateau after patience reached
prevLR = pl_module.optimizer.param_groups[0]["lr"]
pl_module.scheduler.step(loss_validation)
currLR = pl_module.optimizer.param_groups[0]["lr"]

if currLR is not prevLR and pl_module.scheduler.num_bad_epochs == 0:
logging.info("\nPlateau Reached!")
if (
prevLR < 2 * pl_module.scheduler.eps
and pl_module.scheduler.num_bad_epochs >= pl_module.scheduler.patience
):
logging.info("\nPlateau Reached and no more reduction -> Exiting Loop")
trainer.should_stop = True
125 changes: 125 additions & 0 deletions opensportslib/core/utils/video_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,131 @@ def oneHotToShifts(onehot, params):
Shifts[:, i] = shifts

return Shifts

def timestamps2long(output_spotting, video_size, chunk_size, receptive_field):
"""Method to transform the timestamps to vectors"""
start = 0
last = False
receptive_field = receptive_field // 2

timestamps_long = (
torch.zeros(
[video_size, output_spotting.size()[-1] - 2],
dtype=torch.float,
device=output_spotting.device,
)
- 1
)

for batch in np.arange(output_spotting.size()[0]):

tmp_timestamps = (
torch.zeros(
[chunk_size, output_spotting.size()[-1] - 2],
dtype=torch.float,
device=output_spotting.device,
)
- 1
)

for i in np.arange(output_spotting.size()[1]):
tmp_timestamps[
torch.floor(output_spotting[batch, i, 1] * (chunk_size - 1)).type(
torch.int
),
torch.argmax(output_spotting[batch, i, 2:]).type(torch.int),
] = output_spotting[batch, i, 0]

# ------------------------------------------
# Store the result of the chunk in the video
# ------------------------------------------

# For the first chunk
if start == 0:
timestamps_long[0 : chunk_size - receptive_field] = tmp_timestamps[
0 : chunk_size - receptive_field
]

# For the last chunk
elif last:
timestamps_long[start + receptive_field : start + chunk_size] = (
tmp_timestamps[receptive_field:]
)
break

# For every other chunk
else:
timestamps_long[
start + receptive_field : start + chunk_size - receptive_field
] = tmp_timestamps[receptive_field : chunk_size - receptive_field]

# ---------------
# Loop Management
# ---------------

# Update the index
start += chunk_size - 2 * receptive_field
# Check if we are at the last index of the game
if start + chunk_size >= video_size:
start = video_size - chunk_size
last = True
return timestamps_long


def batch2long(output_segmentation, video_size, chunk_size, receptive_field):
"""Method to transform the batches to vectors."""
start = 0
last = False
receptive_field = receptive_field // 2

segmentation_long = torch.zeros(
[video_size, output_segmentation.size()[-1]],
dtype=torch.float,
device=output_segmentation.device,
)

for batch in np.arange(output_segmentation.size()[0]):

tmp_segmentation = torch.nn.functional.one_hot(
torch.argmax(output_segmentation[batch], dim=-1),
num_classes=output_segmentation.size()[-1],
)

# ------------------------------------------
# Store the result of the chunk in the video
# ------------------------------------------

# For the first chunk
if start == 0:
segmentation_long[0 : chunk_size - receptive_field] = tmp_segmentation[
0 : chunk_size - receptive_field
]

# For the last chunk
elif last:
segmentation_long[start + receptive_field : start + chunk_size] = (
tmp_segmentation[receptive_field:]
)
break

# For every other chunk
else:
segmentation_long[
start + receptive_field : start + chunk_size - receptive_field
] = tmp_segmentation[receptive_field : chunk_size - receptive_field]

# ---------------
# Loop Management
# ---------------

# Update the index
start += chunk_size - 2 * receptive_field
# Check if we are at the last index of the game
if start + chunk_size >= video_size:
start = video_size - chunk_size
last = True
return segmentation_long

# import torch
# import numpy as np
# import decord
Expand Down
Loading