In [None]:
import sys
import os
from datetime import datetime
import warnings
from pathlib import Path
import json

# Data manipulation libraries
import pandas as pd
import numpy as np

# Deep learning framework
import torch

# NLP & Transformers
import nltk  # Useful for tokenization or post-processing if needed
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline  # Megatron GPT2-compatible

In [None]:
TOKENIZER_DIR = "../artifacts/tokenizer"
GPT_MODEL_NAME = "nvidia/megatron-gpt2-345m"
GPT_MODEL_DATAFABRIC_PATH = "/home/jovyan/datafabric/Megatron_GPT2_345M/megatron_gpt_345m.nemo"
EMBEDDINGS_OUTPUT_PATH = "../data/processed/"

In [None]:
# Check whether the gpt model file exists
is_gpt_model_available = Path(GPT_MODEL_DATAFABRIC_PATH).exists()

# Log the configuration status of the GPT model
if is_gpt_model_available:
    print("GPT model is properly configured.")
else:
    print(
        "GPT model is not properly configured. Please create and download the required assets "
        "in your project on AI Studio."
    )

In [None]:
tokenizer = AutoTokenizer.from_pretrained(GPT_MODEL_DATAFABRIC_PATH)
model = AutoModelForCausalLM.from_pretrained(GPT_MODEL_DATAFABRIC_PATH)

# Set device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Loading BERT model")

# bert_model = BERTLMModel.restore_from(BERT_MODEL_DATAFABRIC_PATH, strict=False).to(device)
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)


In [None]:
def generate_sentence(prompt="Generate a phrase", max_length=25, temperature=0.9):
    result = generator(prompt, max_length=max_length, temperature=temperature, num_return_sequences=1)
    return result[0]['generated_text']


In [None]:
# def generate_by_difficulty(level="easy"):
#     settings = {
#         "easy": {"max_length": 20, "temperature": 0.7},
#         "medium": {"max_length": 30, "temperature": 0.9},
#         "hard": {"max_length": 50, "temperature": 1.0},
#     }
#     config = settings.get(level, settings["medium"])
#     return generate_sentence(prompt="Let's practice typing", **config)