## Imports

In [1]:
# Enable automatic extension autoreloading
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import torch

from accelerate import Accelerator

from trl import set_seed, PPOTrainer

from watermark import watermark

import sys
sys.path.insert(0, str(Path.cwd().parent.resolve()))
from configs import get_script_args, get_ppo_config
from dataset import get_dataset, collator
from reward_model import get_reward_model
from model import get_model
from utils import get_tokenizer
from trainer import train

ELK_PATH = Path("/fsx/home-augustas/elk/")
modules = [
    ELK_PATH,
    ELK_PATH / "elk" / "promptsource",
]
for module in modules:
    if not str(module) in sys.path:
        sys.path.insert(0, str(module.resolve()))

print(sys.path[:2])

from templates import DatasetTemplates

['/fsx/home-augustas/elk/elk/promptsource', '/fsx/home-augustas/elk/elk/training']


In [3]:
# info = watermark(
#     packages="torch,transformers,datasets,peft,trl,tensorboard,accelerate",
#     python=True, conda=True, gpu=True,
#     current_date=True, current_time=True,
# )
# print(info)

In [4]:
# Get script args
script_args = get_script_args([
    # "--model_name=databricks/dolly-v2-3b",
    # "--tokenizer_name=databricks/dolly-v2-3b",
    "--model_name=gpt2",
    "--tokenizer_name=gpt2",
    "--reward_model_output_path=/fsx/home-augustas/logs/unifiedqa-v2-t5-3b-1363200_custom_data_v4_all_20230629_120158_21789",
    "--dataset_name=AugustasM/burns-datasets-VINC-ppo-training-v4",
    "--remove_unused_columns=False",
    "--log_with=tensorboard",
    "--logging_dir=/fsx/home-augustas/ppo_logs",
    "--learning_rate=1.4e-5",
    "--batch_size=1",
    "--mini_batch_size=1",
    "--gradient_accumulation_steps=64",
    "--steps=192",
    "--ppo_epochs=4",
    "--early_stopping=True",
    "--reward_baseline=0.0",
    "--target_kl=0.1",
    "--init_kl_coef=0.2",
    "--adap_kl_ctrl=True",
    "--seed=0",
    "--save_freq=2",
    "--output_dir=/fsx/home-augustas/ppo_runs/test",
])
print(script_args.dataset_name)

# Get PPO config
config = get_ppo_config(script_args)
print(config.total_ppo_epochs)
print(config.project_kwargs)

AugustasM/burns-datasets-VINC-ppo-training-v4
192
{'logging_dir': '/fsx/home-augustas/ppo_logs'}


In [5]:
# Tokenizer
tokenizer = get_tokenizer(script_args.tokenizer_name)

# Dataset for PPO training
train_dataset, prompt_max_len, response_max_len = get_dataset(
    script_args.dataset_name, tokenizer, subsets_to_delete=["piqa"]
)

Loading tokenizer gpt2...


Using pad_token, but it is not set yet.


Loaded tokenizer.

Loading dataset...



Found cached dataset parquet (/admin/home-augustas/.cache/huggingface/datasets/AugustasM___parquet/AugustasM--burns-datasets-VINC-ppo-training-v4-278eadae0cef7ee6/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
Loading cached processed dataset at /admin/home-augustas/.cache/huggingface/datasets/AugustasM___parquet/AugustasM--burns-datasets-VINC-ppo-training-v4-278eadae0cef7ee6/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-bc44c211aecff264.arrow


Deleting subset: piqa


Loading cached processed dataset at /admin/home-augustas/.cache/huggingface/datasets/AugustasM___parquet/AugustasM--burns-datasets-VINC-ppo-training-v4-278eadae0cef7ee6/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-87b3b92eba672a07_*_of_00012.arrow
Loading cached processed dataset at /admin/home-augustas/.cache/huggingface/datasets/AugustasM___parquet/AugustasM--burns-datasets-VINC-ppo-training-v4-278eadae0cef7ee6/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-565481452045ac3d_*_of_00012.arrow



Max prompt length: 550

Max response length: 11

Remaining columns: ['prompt', 'input_ids', 'attention_mask', 'response_len']

Total number of examples: 8615

Processing finished.



In [6]:
# set seed before initializing value head for deterministic eval
set_seed(config.seed)

# Now let's build the model, the reference model, and the tokenizer.
current_device = Accelerator().local_process_index
print(f"Current device: {current_device}\n")

# Get the reward model
reward_model, reward_model_name = get_reward_model(
    script_args.reward_model_output_path, current_device,
)
reward_model_tokenizer = get_tokenizer(reward_model_name)

Current device: 0

The current device is 0.

Loading reward model from allenai/unifiedqa-v2-t5-3b-1363200.
is_bf16_possible=True
Loaded reward model with 2,851,598,336 parameters.
Reward model dtype: torch.bfloat16

Loading reporter from /fsx/home-augustas/VINC-logs/allenai/unifiedqa-v2-t5-3b-1363200/AugustasM/burns-datasets-VINC-v4/festive-feistel/reporters/layer_19.pt
Loaded reporter.

Loading tokenizer allenai/unifiedqa-v2-t5-3b-1363200...
Falling back to slow tokenizer; fast one failed: 'No such file or directory (os error 2)'
Loaded tokenizer.



In [7]:
# Model
model = get_model(script_args.model_name, current_device)

Loading subject model...

is_bf16_possible=True


Loaded subject model with 124,440,577 parameters.
Model dtype: torch.bfloat16



In [8]:
dataset_template_path = "AugustasM/burns-datasets-VINC/first"

dataset_templates = DatasetTemplates(dataset_template_path)
dataset_templates.templates = {
    x.name: x for x in dataset_templates.templates.values()
}
print(len(dataset_templates.templates))

template = list(dataset_templates.templates.values())[0]
template

1


<elk.promptsource.templates.Template at 0x7efa8805f9a0>

In [9]:
# Optimizer
# TODO: consider whether adding Adafactor back in is a good idea
optimizer = None

# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(
    config,
    model,
    ref_model=None,
    tokenizer=tokenizer,
    dataset=train_dataset,
    data_collator=collator,
    optimizer=optimizer,
)

In [10]:
# TODO: put this into config
# TODO: check whether there are better settings for this
# We then define the arguments to pass to the `generate` function. These arguments
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
# the `generate` function of the trained model.
generation_kwargs = {
    "top_k": 0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id,
    "eos_token_id": 100_000, # why is this value like this?
    "pad_to_multiple_of": 8, # TODO: double-check, but this seems to work and to be faster
}

train(
    ppo_trainer=ppo_trainer,
    tokenizer=tokenizer,
    generation_kwargs=generation_kwargs,
    reward_model=reward_model,
    reward_model_tokenizer=reward_model_tokenizer,
    template=template,
    script_args=script_args,
    config=config,
    device=current_device,
)

0it [00:00, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

1it [00:06,  6.72s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

2it [00:09,  4.28s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

3it [00:13,  4.08s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

4it [00:16,  3.64s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

5it [00:19,  3.45s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

6it [00:22,  3.31s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

7it [00:25,  3.26s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

8it [00:27,  3.03s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

9it [00:31,  3.05s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

10it [00:33,  2.98s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

11it [00:37,  3.22s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

12it [00:40,  3.10s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

13it [00:44,  3.28s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

14it [00:46,  3.05s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

15it [00:50,  3.19s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

16it [00:52,  3.02s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

17it [00:56,  3.11s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

18it [00:59,  3.07s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

19it [01:02,  3.07s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (606 > 512). Running this sequence through the model will result in indexing errors
20it [01:04,  2.94s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

21it [01:08,  3.04s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

22it [01:11,  3.11s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

23it [01:14,  3.10s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

24it [01:16,  2.95s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

25it [01:21,  3.32s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

26it [01:24,  3.21s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

27it [01:27,  3.36s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

28it [01:30,  3.28s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

29it [01:34,  3.37s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

30it [01:37,  3.21s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

31it [01:40,  3.29s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

32it [01:43,  3.18s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

33it [01:47,  3.26s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

34it [01:50,  3.24s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

35it [01:53,  3.34s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

36it [01:57,  3.26s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

37it [02:00,  3.42s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

38it [02:03,  3.30s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

39it [02:07,  3.43s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

40it [02:10,  3.21s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

41it [02:13,  3.22s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

42it [02:16,  3.12s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

43it [02:19,  3.17s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

44it [02:22,  3.10s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

45it [02:25,  3.13s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

46it [02:29,  3.14s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

47it [02:32,  3.22s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

48it [02:35,  3.18s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

49it [02:38,  3.20s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

50it [02:41,  3.16s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

51it [02:45,  3.24s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

52it [02:48,  3.27s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

53it [02:52,  3.36s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

54it [02:55,  3.24s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

55it [02:58,  3.36s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

56it [03:01,  3.29s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

57it [03:05,  3.47s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

58it [03:08,  3.33s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

59it [03:12,  3.48s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

60it [03:15,  3.25s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

61it [03:18,  3.36s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

62it [03:21,  3.19s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

63it [03:25,  3.26s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

64it [03:28,  3.15s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

65it [03:31,  3.33s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

66it [03:35,  3.31s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

67it [03:38,  3.43s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

68it [03:41,  3.31s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

69it [03:45,  3.46s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

70it [03:48,  3.34s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

71it [03:52,  3.49s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

72it [03:55,  3.36s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

73it [03:59,  3.48s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

74it [04:02,  3.42s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

75it [04:06,  3.44s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

76it [04:08,  3.27s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

77it [04:12,  3.30s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

78it [04:15,  3.20s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

79it [04:19,  3.35s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

80it [04:22,  3.28s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

81it [04:25,  3.28s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

82it [04:28,  3.22s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

83it [04:32,  3.33s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

84it [04:35,  3.23s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

85it [04:38,  3.39s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

86it [04:41,  3.27s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

87it [04:45,  3.44s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

88it [04:48,  3.30s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

89it [04:51,  3.30s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

90it [04:54,  3.17s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

91it [04:58,  3.32s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

92it [05:01,  3.18s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

93it [05:05,  3.37s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

94it [05:07,  3.20s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

95it [05:11,  3.23s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

96it [05:14,  3.14s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

97it [05:17,  3.25s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

98it [05:21,  3.33s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

99it [05:24,  3.36s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

100it [05:27,  3.31s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

101it [05:31,  3.39s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

102it [05:34,  3.26s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

103it [05:38,  3.42s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

104it [05:41,  3.35s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

105it [05:44,  3.42s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

106it [05:47,  3.25s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

107it [05:51,  3.32s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

108it [05:54,  3.27s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

109it [05:58,  3.38s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

110it [06:01,  3.26s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

111it [06:04,  3.27s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

112it [06:07,  3.12s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

113it [06:11,  3.55s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

114it [06:14,  3.47s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

115it [06:18,  3.56s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

116it [06:21,  3.41s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

117it [06:25,  3.47s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

118it [06:28,  3.37s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

119it [06:32,  3.58s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

120it [06:35,  3.39s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

121it [06:38,  3.36s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

122it [06:42,  3.34s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

123it [06:45,  3.51s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

124it [06:49,  3.45s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

125it [06:53,  3.56s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

126it [06:56,  3.49s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

127it [07:00,  3.56s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

128it [07:03,  3.48s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

129it [07:07,  3.58s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

130it [07:10,  3.52s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

131it [07:14,  3.62s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

132it [07:17,  3.43s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

133it [07:21,  3.49s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

134it [07:24,  3.43s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

135it [07:28,  3.52s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

136it [07:31,  3.48s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

137it [07:35,  3.58s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

138it [07:38,  3.41s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

139it [07:41,  3.44s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

140it [07:44,  3.27s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

141it [07:48,  3.28s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

142it [07:50,  3.12s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

143it [07:54,  3.18s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

144it [07:56,  3.06s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

145it [08:00,  3.15s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

146it [08:03,  3.03s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

147it [08:06,  3.12s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

148it [08:09,  3.01s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

149it [08:12,  3.12s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

150it [08:15,  3.02s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

151it [08:18,  3.10s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

152it [08:21,  3.01s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

153it [08:24,  3.13s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

154it [08:27,  3.05s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

155it [08:30,  3.13s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

156it [08:33,  3.02s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

157it [08:37,  3.12s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

158it [08:39,  3.04s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

159it [08:43,  3.16s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

160it [08:46,  3.05s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

161it [08:49,  3.16s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

162it [08:52,  3.06s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

163it [08:56,  3.25s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

164it [08:59,  3.16s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

165it [09:02,  3.23s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

166it [09:05,  3.10s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

167it [09:08,  3.19s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

168it [09:11,  3.06s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

169it [09:14,  3.16s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

170it [09:17,  3.09s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

171it [09:21,  3.32s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

172it [09:24,  3.29s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

173it [09:28,  3.38s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

174it [09:31,  3.23s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

175it [09:34,  3.31s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

176it [09:38,  3.39s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

177it [09:42,  3.54s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

178it [09:45,  3.43s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

179it [09:49,  3.58s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

180it [09:52,  3.43s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

181it [09:56,  3.51s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

182it [09:59,  3.37s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

183it [10:03,  3.57s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

184it [10:06,  3.43s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

185it [10:10,  3.57s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

186it [10:13,  3.46s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

187it [10:17,  3.62s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

188it [10:20,  3.50s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

189it [10:24,  3.58s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

190it [10:27,  3.41s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

191it [10:31,  3.52s/it]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

192it [10:34,  3.30s/it]


In [11]:
device = torch.cuda.current_device()
print(f"Device: {device}")

gpu_properties = torch.cuda.get_device_properties(device)
print(f"GPU Name: {gpu_properties.name}")
print(f"Total GPU Memory: {gpu_properties.total_memory / 1024**3} GB")
print(f"Allocated GPU Memory: {torch.cuda.memory_allocated(device) / 1024**3} GB")
print(f"Cached GPU Memory: {torch.cuda.memory_reserved(device) / 1024**3} GB")

Device: 0
GPU Name: NVIDIA A100-SXM4-40GB
Total GPU Memory: 39.56402587890625 GB
Allocated GPU Memory: 6.586625099182129 GB
Cached GPU Memory: 6.896484375 GB


In [20]:
import time
from datetime import datetime

def your_function():
    # Your function code here
    time.sleep(3)  # Simulating a delay of 5 seconds

start_time = datetime.now()
your_function()
elapsed_time = datetime.now() - start_time
print(elapsed_time.total_seconds())
elapsed_time = datetime.utcfromtimestamp(
    elapsed_time.total_seconds()
).strftime("%H:%M:%S")

print(f"Duration: {elapsed_time}")

3.003301
Duration: 00:00:03
