Skip to content

Commit

Permalink
Additional code fore repros
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasGeiping committed Jun 13, 2024
1 parent 698e994 commit 2e7e57a
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 29 deletions.
23 changes: 23 additions & 0 deletions cramming/backend/optimizers/schedulers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Misc. optimizer implementations."""

import transformers
import math

Expand Down Expand Up @@ -182,6 +183,13 @@ def get_schedule_fn(initial_time, cfg_train):
base_percentage=0.25,
initial_time=initial_time,
)
elif cfg_train.scheduler == "triangle2":
scheduler_fn = partial(
get_triangle,
num_training_steps=cfg_train.steps,
falloff=0.25,
base_percentage=0.25,
)
elif cfg_train.scheduler in [
"linear",
"cosine",
Expand Down Expand Up @@ -493,3 +501,18 @@ def lr_lambda(current_step: int):
return decay / lr_init # as LambdaLR multiplies by lr_init

return LambdaLR(optimizer, lr_lambda, -1)


def get_triangle(optimizer, num_training_steps, base_percentage=0.5, falloff=0.5):
"""Linear increase from a percentage of the base learning rate, then linear decay.
plot min(0.5 + x * (1 - 0.5)/(1-0.25) / 1000, 1/0.25 - x / (1000 * 0.25)) from 0 to 1000 in the plot range 0 to 1
"""

def lr_lambda(current_step):
return min(
base_percentage + current_step * (1 - base_percentage) / (1 - falloff) / num_training_steps,
float(1 / falloff - current_step / (num_training_steps * falloff)),
)

return LambdaLR(optimizer, lr_lambda, -1)
14 changes: 2 additions & 12 deletions cramming/backend/torch_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
that were tested but ultimately discarded, so read that part only if you're interested.
"""

import torch
import torch._inductor.utils

Expand All @@ -23,7 +24,6 @@
from safetensors.torch import load_file, save_file
from transformers.utils.generic import working_or_temp_dir

from torch.distributed.optim import ZeroRedundancyOptimizer

from .utils import group_parameters, prepare_pretraining_dataloader, update_ema, updated_latest_weight_average
from .optimizers.schedulers import get_schedule_fn
Expand Down Expand Up @@ -565,17 +565,7 @@ def _load_optimizer(model, cfg_train, cfg_impl, initial_time):
if cfg_impl.foreach_optimizer and cfg_train.optim.type != "Shampoo":
optimizer_args["foreach"] = True

if torch.distributed.is_initialized() and cfg_impl.zero_redundancy_optimizer:
# The overlap option is a whole bucket of problems in itself for now...
optimizer = ZeroRedundancyOptimizer(
grouped_parameters,
optimizer_class=optimizer_class,
parameters_as_bucket_view=True,
overlap_with_ddp=False,
**optimizer_args,
)
else:
optimizer = optimizer_class(grouped_parameters, **optimizer_args)
optimizer = optimizer_class(grouped_parameters, **optimizer_args)

if cfg_train.optim_mod.name == "none":
optimizer_to_schedule = optimizer
Expand Down
2 changes: 1 addition & 1 deletion cramming/config/cfg_pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:
- arch: hf-bert-tiny
- data: sanity-check-2
- impl: torch-default
- wandb: none
- wandb: default
- train: bert-base
- _self_
- override hydra/job_logging: custom
Expand Down
2 changes: 1 addition & 1 deletion cramming/config/data/sources/ag_news.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# For sanity testing
ag_news:
provider: huggingface
partition: default
# partition: default
split: train

streaming: False
Expand Down
2 changes: 1 addition & 1 deletion cramming/config/impl/_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ path: data

# data implementation:
local_staging_dir: # Optionally copy a preprocessed dataset into this folder before loading it for training
forbid_dataset_preprocessing: True
forbid_dataset_preprocessing: False
temporary_corpus: False # Save data directly into local staging dir, forget after use
max_raw_chunk_size: 8e6

Expand Down
2 changes: 1 addition & 1 deletion cramming/config/wandb/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
enabled: True
entity: YOURNAMEHERE
entity: jonasgeiping # change this obviously ;>
project: cramming-pretrain
tags: []
59 changes: 46 additions & 13 deletions sanity_checks2024.sh
Original file line number Diff line number Diff line change
@@ -1,17 +1,50 @@


# Sanity checks for pytorch issue https://github.com/pytorch/pytorch/issues/96693
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_base arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 seed=233
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_simplecomp arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null seed=233
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_gemm arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null +impl._inductor_vars.max_autotune_gemm=True seed=233
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_pw arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null +impl._inductor_vars.max_autotune_pointwise=True seed=233
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_default arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune seed=233
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_no_cudagraphs arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune-no-cudagraphs seed=233

python pretrain.py name=DA6000amp_b8192_cb_o4_premade_base_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl.tf32_allowed=False seed=233
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_simplecomp_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.tf32_allowed=False seed=233
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_gemm_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null +impl._inductor_vars.max_autotune_gemm=True impl.tf32_allowed=False seed=233
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_pw_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null +impl._inductor_vars.max_autotune_pointwise=True impl.tf32_allowed=False seed=233
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_default_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune impl.tf32_allowed=False seed=233
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_no_cudagraphs_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune-no-cudagraphs impl.tf32_allowed=False seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_base arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_simplecomp arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_gemm arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null +impl._inductor_vars.max_autotune_gemm=True seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_pw arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null +impl._inductor_vars.max_autotune_pointwise=True seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_default arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_no_cudagraphs arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune-no-cudagraphs seed=233

# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_base_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl.tf32_allowed=False seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_simplecomp_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.tf32_allowed=False seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_gemm_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null +impl._inductor_vars.max_autotune_gemm=True impl.tf32_allowed=False seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_pw_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null +impl._inductor_vars.max_autotune_pointwise=True impl.tf32_allowed=False seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_default_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune impl.tf32_allowed=False seed=233
# python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_no_cudagraphs_tfoff arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune-no-cudagraphs impl.tf32_allowed=False seed=233

# all follow the same curve:
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune seed=233 train.steps=40000 budget=2.4 impl.deterministic=True train.batch_size_ramp=0 train.scheduler=triangle2
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_max_autotune_no_cudagraphs_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune-no-cudagraphs seed=233 train.steps=40000 budget=2.4 impl.deterministic=True train.batch_size_ramp=0 train.scheduler=triangle2
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_default_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=default seed=233 train.steps=40000 budget=2.4 impl.deterministic=True train.batch_size_ramp=0 train.scheduler=triangle2
python pretrain.py name=DA6000amp_b8192_cb_o4_premade_reduce_overhead_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=reduce-overhead seed=233 train.steps=40000 budget=2.4 impl.deterministic=True train.batch_size_ramp=0 train.scheduler=triangle2


CUDA_VISIBLE_DEVICES=3 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python pretrain.py name=DA6000amp_b8192_cb_o4_with_nondet_max_autotune_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune seed=233 train.steps=40000 budget=2.4 impl.deterministic=False train.batch_size_ramp=0 train.scheduler=triangle2
CUDA_VISIBLE_DEVICES=4 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python pretrain.py name=DA6000amp_b8192_cb_o4_with_nondet_max_autotune_no_cudagraphs_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune-no-cudagraphs seed=233 train.steps=40000 budget=2.4 impl.deterministic=False train.batch_size_ramp=0 train.scheduler=triangle2
CUDA_VISIBLE_DEVICES=5 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python pretrain.py name=DA6000amp_b8192_cb_o4_with_nondet_default_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=default seed=233 train.steps=40000 budget=2.4 impl.deterministic=False train.batch_size_ramp=0 train.scheduler=triangle2
CUDA_VISIBLE_DEVICES=6 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python pretrain.py name=DA6000amp_b8192_cb_o4_with_nondet_reduce_overhead_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=reduce-overhead seed=233 train.steps=40000 budget=2.4 impl.deterministic=False train.batch_size_ramp=0 train.scheduler=triangle2


CUDA_VISIBLE_DEVICES=5 python pretrain.py name=DA6000amp_b8192_cb_o4_with_ramp_max_autotune_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune seed=233 train.steps=40000 budget=2.4 impl.deterministic=True train.batch_size_ramp=20000 train.scheduler=triangle2
CUDA_VISIBLE_DEVICES=5 python pretrain.py name=DA6000amp_b8192_cb_o4_with_ramp_max_autotune_no_cudagraphs_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=max-autotune-no-cudagraphs seed=233 train.steps=40000 budget=2.4 impl.deterministic=True train.batch_size_ramp=20000 train.scheduler=triangle2
CUDA_VISIBLE_DEVICES=6 python pretrain.py name=DA6000amp_b8192_cb_o4_with_ramp_default_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=default seed=233 train.steps=40000 budget=2.4 impl.deterministic=True train.batch_size_ramp=20000 train.scheduler=triangle2
CUDA_VISIBLE_DEVICES=6 python pretrain.py name=DA6000amp_b8192_cb_o4_with_ramp_reduce_overhead_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null impl.mode=reduce-overhead seed=233 train.steps=40000 budget=2.4 impl.deterministic=True train.batch_size_ramp=20000 train.scheduler=triangle2



# invoke cache skip + cuidagraphs: +impl._inductor_vars.autotune_local_cache=False +impl._inductor_vars.triton.cudagraphs=True

CUDA_VISIBLE_DEVICES=3 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python pretrain.py name=DA6000amp_b8192_cb_o4_with_det_reduce_overhead_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null seed=233 train.steps=40000 budget=100 impl.deterministic=True train.batch_size_ramp=0 train.scheduler=triangle2 +impl._inductor_vars.autotune_local_cache=False +impl._inductor_vars.triton.cudagraphs=True
CUDA_VISIBLE_DEVICES=4 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python pretrain.py name=DA6000amp_b8192_cb_o4_with_nondet_reduce_overhead_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null seed=233 train.steps=40000 budget=100 impl.deterministic=False train.batch_size_ramp=0 train.scheduler=triangle2 +impl._inductor_vars.autotune_local_cache=False +impl._inductor_vars.triton.cudagraphs=True


CUDA_VISIBLE_DEVICES=5 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python pretrain.py name=DA6000amp_b8192_const_with_det_default_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null seed=233 train.steps=40000 budget=100 impl.deterministic=True train.batch_size_ramp=0 train.scheduler=constant +impl._inductor_vars.autotune_local_cache=False
CUDA_VISIBLE_DEVICES=5 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python pretrain.py name=DA6000amp_b8192_const_with_det_reduce_overhead_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null seed=233 train.steps=40000 budget=100 impl.deterministic=True train.batch_size_ramp=0 train.scheduler=constant +impl._inductor_vars.autotune_local_cache=False +impl._inductor_vars.triton.cudagraphs=True
CUDA_VISIBLE_DEVICES=6 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python pretrain.py name=DA6000amp_b8192_const_with_nondet_default_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null seed=233 train.steps=40000 budget=100 impl.deterministic=False train.batch_size_ramp=0 train.scheduler=constant +impl._inductor_vars.autotune_local_cache=False
CUDA_VISIBLE_DEVICES=6 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 python pretrain.py name=DA6000amp_b8192_const_with_nondet_reduce_overhead_40k arch=crammed-bert train=bert-o4 data=pile-readymade data.hf_location=JonasGeiping/the_pile_WordPiecex32768_53b28db05413b6497e702f178268e1e2 impl.microbatch_size=512 impl._inductor_vars=null seed=233 train.steps=40000 budget=100 impl.deterministic=False train.batch_size_ramp=0 train.scheduler=constant +impl._inductor_vars.autotune_local_cache=False +impl._inductor_vars.triton.cudagraphs=True

# torch._dynamo.reset()?
# torch.compiler.reset?

0 comments on commit 2e7e57a

Please sign in to comment.