In [None]:
import torch
from omegaconf.omegaconf import OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector

In [None]:
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank
from nemo.collections.nlp.parts.nlp_overrides import GradScaler, NLPDDPPlugin, NLPSaveRestoreConnector
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import StatelessTimer, exp_manager
from nemo.utils.config_utils import update_model_config

In [None]:
cfg = OmegaConf.load("conf/megatron_gpt_config.yaml")

cfg.trainer.gpus = 1

# Set current model params
cfg.model.encoder_seq_length = 2048

# Set prompt tuning params
cfg.model.optim.lr = 2e-4
cfg.model.optim.sched.min_lr = 2e-6
cfg.model.use_soft_prompts = True
cfg.model.prompt_length = 10
cfg.model.data.train_ds = 'prompt_tuning_ner_train.json'
cfg.model.data.valid_ds = 'prompt_tuning_ner_val.json'
cfg.model.data.test_ds = 'prompt_tuning_ner_test.json'
cfg.model.data.batch_size = 32
cfg.model.data.data_prefix = None
cfg.model.optim.sched.warmup_steps = 100
cfg.model.optim.sched.constant_steps = 1000
cfg.trainer.max_steps = 3000
cfg.restore_from_path = 'megatron_gpt.nemo'

In [None]:
plugins = [NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes)]

if cfg.get('cluster_type', None) == 'BCP':
    plugins.append(TorchElasticEnvironment())

trainer = Trainer(plugins=plugins, **cfg.trainer)

exp_manager(trainer, cfg.exp_manager)

model = MegatronGPTModel.restore_from(cfg.restore_from_path, cfg.model, trainer=trainer)

In [None]:
model.get_prompt_table()

In [None]:
model.init_prompt_from_text("NER-Yes-No", "named entities yes or no")

In [None]:
model.get_prompt_table()

In [None]:
model.init_prompt_from_text("NER-Complete", "name chemicals entities in context")

In [None]:
model.get_prompt_table()

In [None]:
model.prompt_tuning_freeze()

In [None]:
for param in model.parameters():
    if param.requires_grad:
        print(param)

In [None]:
trainer.fit(model)