In [9]:
import json
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset

class SocialMediaPostGenerator:
    def __init__(self, model_name='google/flan-t5-large'):
        self.model_name = model_name
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
    
    def load_data(self, json_path):
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        formatted_data = []
        
        for entry in data:
            prompt = "Generate a social media post with the following details: " \
                     f"Platform: {entry['platform']} " \
                     f"Post Heading: {entry['post_heading']} " \
                     f"Post Content: {entry['post_content']} " \
                     f"Hashtags: {', '.join(entry['hashtags'])} " \
                     f"Emojis: {', '.join(entry['emoji'])}"
            response = f"Post Heading: {entry['post_heading']}\nPost Content: {entry['post_content']}\nHashtags: {', '.join(entry['hashtags'])}\nEmojis: {', '.join(entry['emoji'])}"
            formatted_data.append({"input_text": prompt, "target_text": response})
        
        return Dataset.from_list(formatted_data)

    def train(self, train_data):
        tokenized_data = train_data.map(lambda x: {
            "input_ids": self.tokenizer(x["input_text"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")["input_ids"].squeeze(),
            "labels": self.tokenizer(x["target_text"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")["input_ids"].squeeze()
        })
        
        training_args = TrainingArguments(
            output_dir="./flan_t5_trained",
            evaluation_strategy="epoch",
            save_strategy="epoch",
            per_device_train_batch_size=4,
            per_device_eval_batch_size=4,
            num_train_epochs=3,
            weight_decay=0.01,
            logging_dir="./logs",
            logging_steps=10,
            save_total_limit=2
        )
        
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_data,
            tokenizer=self.tokenizer,
            data_collator=DataCollatorForSeq2Seq(self.tokenizer, model=self.model)
        )
        
        trainer.train()

    def generate_post(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
        output_ids = self.model.generate(**inputs, max_length=512, num_beams=5, early_stopping=True)
        return self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

if __name__ == "__main__":
    generator = SocialMediaPostGenerator()
    train_dataset = generator.load_data(".data/curated_data/curated_data.json")
    generator.train(train_dataset)
    
    prompt = """Generate a high-quality, engaging social media post for a business in a descriptive format. Follow the example structure and ensure clarity, creativity, and audience engagement.

    Context:
    - Platform: Facebook  
    - Theme: Product Launch of a fitness software called FitTune
    - Target Audience: fitness enthusiasts  
    - Tone: Friendly
    - Language: English  
    - Word Limit: 500 words  

    Requirements:
    - Craft a compelling opening that grabs attention.  
    - Highlight key details about the business or occasion.  
    - Maintain a consistent and engaging tone throughout.  
    - Use persuasive language and storytelling where applicable.  
    - Include a strong call to action (CTA) to encourage engagement.  
    
    Now, generate a social media post using the provided context.
    """
    
    generated_post = generator.generate_post(prompt)
    print("Generated Post:\n", generated_post)


ImportError: 
T5Tokenizer requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
