## Environment Setup

In [None]:
! pip install transformers diffusers fastapi uvicorn torch torchvision Pillow accelerate datasets lora diffusers

## Dependencies

In [None]:
import os
import re
import json
import torch
import torch.nn as nn
from PIL import Image
from transformers import Trainer
from datasets import load_dataset
from IPython.display import Image
from IPython.display import display
from diffusers import DDPMScheduler
from torchvision.models import vgg19
from transformers import BlipProcessor
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from torchvision.transforms import Resize
from transformers import TrainingArguments
from torchvision.transforms import Normalize
from diffusers.utils import enable_full_lora
from transformers import AutoModelForCausalLM
from diffusers import StableDiffusionPipeline
from torchvision.transforms import functional as TF
from transformers import BlipForConditionalGeneration

print(torch.backends.mps.is_available())

## Global Variables

In [None]:
llm_model           = "decapoda-research/llama-7b-hf"
image_model         = "stabilityai/stable-diffusion-xl"
blip_processor_path = "Salesforce/blip-image-captioning-base"
blip_model_path     = "Salesforce/blip-image-captioning-base"
device              = "mps" if torch.backends.mps.is_available() else "cpu"


## Defining Data Paths

In [None]:
linkedin_images = '../data/linkedin_data/linkedin_images/'
linkedin_texts  = '../data/linkedin_data/post_data.json'


## Loading Text Data

In [None]:
dataset = load_dataset("json", data_files = "processed_dataset.json")["train"]


## Load BLIP model

In [None]:
blip_processor = BlipProcessor.from_pretrained(blip_processor_path)
blip_model     = BlipForConditionalGeneration.from_pretrained(blip_model_path)

# Move BLIP model to MPS
blip_model.to(device)

## Image Caption Extraction with BLIP model

In [None]:
def extract_blip_features(image_path:str, prompt_text:str):
    """
    Extracts image features using BLIP model

    Arguments:
    ----------

    Raises:
    -------

    Returns:
    --------
    """
    try:
        image           = Image.open(image_path).convert("RGB")

        # Use BLIP for feature extraction (forward pass only)
        inputs          = blip_processor(text           = [prompt_text],
                                         images         = image, 
                                         return_tensors = "pt",
                                         padding        = True)
        
        outputs         = blip_model(**inputs)
        
        text_embedding  = outputs.text_embeds.detach().cpu().numpy()
        image_embedding = outputs.image_embeds.detach().cpu().numpy()
        return text_embedding, image_embedding
        
    except Exception as e:
        raise
        

## Pre-process to make Combined Features

In [None]:
def preprocess_function(examples):
    """
    
    """
    image_path              = examples["image_paths"][0]
    prompt_text             = examples["post_heading"]
    
    text_embed, image_embed = extract_blip_features(image_path, prompt_text)

    combined_input          = (f"Platform: {examples['platform_name']}. "
                               f"Post Heading: {examples['post_heading']}. "
                               f"Post Content: {examples['post_content']}. "
                               f"BLIP Text Embedding: {text_embed} "
                               f"BLIP Image Embedding: {image_embed}"
                              )

    preprocess_output       = llm_tokenizer(combined_input, 
                                            truncation = True, 
                                            padding    = True, 
                                            max_length = 512)

    return preprocess_output


## Load the pre-trained LLaMA model and tokenizer

In [None]:
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model)
llm_model     = AutoModelForCausalLM.from_pretrained(llm_model)

# Move LLaMA model to MPS
llm_model.to(device)


## Fine-tuning configuration for LLM

In [None]:
training_args = TrainingArguments(output_dir                  = "./llama_finetuned",
                                  evaluation_strategy         = "epoch",
                                  logging_dir                 = "./logs",
                                  per_device_train_batch_size = 2,
                                  num_train_epochs            = 3,
                                 )


## LLM Trainer

In [None]:
trainer = Trainer(model         = llm_model,
                  args          = training_args,
                  train_dataset = tokenized_dataset
                 )

trainer.train()
trainer.save_model("./llama_finetuned")


## Load Stable Diffusion XL

In [None]:
pipe = StableDiffusionPipeline.from_pretrained(image_model)
pipe.enable_attention_slicing()

# Move pipe to MPS
pipe.to(device)
# Enable LoRA fine-tuning
enable_full_lora(pipe.unet)


## Perceptual Loss (Feature Loss)

In [None]:
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = vgg19(pretrained=True).features[:16].eval()  # Use first few layers
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.criterion = nn.MSELoss()

    def forward(self, generated_image, target_image):
        generated_features = self.vgg(TF.normalize(generated_image, mean=[0.5]*3, std=[0.5]*3))
        target_features = self.vgg(TF.normalize(target_image, mean=[0.5]*3, std=[0.5]*3))
        return self.criterion(generated_features, target_features)


## Cross-Modal Alignment Loss (BLIP Encoder)

In [None]:
class BLIPAlignmentLoss(nn.Module):
    def __init__(self):
        """

        """
        super().__init__()
        
        self.blip_model        = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
        self.processor         = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        self.cosine_similarity = nn.CosineSimilarity(dim=1)

    def forward(self, generated_image, text_embedding, image_embedding):
        # Extract embeddings using BLIP
        inputs                    = blip_processor(text           = ["generated text"], 
                                                   images         = generated_image, 
                                                   return_tensors = "pt", 
                                                   padding        = True)
        
        outputs                   = blip_model(**inputs)
        
        generated_text_embedding  = outputs.text_embeds
        generated_image_embedding = outputs.image_embeds

        # Compute cosine similarities
        text_loss                 = 1 - nn.functional.cosine_similarity(text_embedding, generated_text_embedding).mean()
        image_loss                = 1 - nn.functional.cosine_similarity(image_embedding, generated_image_embedding).mean()

        # Compute average blip alignment loss
        blip_alignment_loss       = (text_loss + image_loss) / 2
        
        return blip_alignment_loss

## Total Loss Function with BLIP

In [None]:
class TotalBLIPLoss(nn.Module):
    def __init__(self, perceptual_loss_weight:float=0.5, blip_loss_weight:float=0.5):
        """
        
        """
        super(TotalBLIPLoss, self).__init__()
        
        self.blip_loss_fn           = BLIPAlignmentLoss()
        self.perceptual_loss_fn     = PerceptualLoss()
        self.perceptual_loss_weight = perceptual_loss_weight
        self.blip_loss_weight       = blip_loss_weight
        self.resize_transform       = Resize((224, 224))
        self.normalize_transform    = Normalize(mean = [0.485, 0.456, 0.406], 
                                                std  = [0.229, 0.224, 0.225])

    def preprocess_image_for_vgg(self, image):
        """
        
        """
        image = self.resize_transform(image)
        return self.normalize_transform(image)

    def forward(self, generated_image, target_image, text_embedding, image_embedding):
        generated_image = self.preprocess_image_for_vgg(generated_image)
        target_image    = self.preprocess_image_for_vgg(target_image)
        
        blip_loss       = self.blip_loss_fn(generated_image, 
                                            text_embedding, 
                                            image_embedding)
        perceptual_loss = self.perceptual_loss_fn(generated_image, target_image)

        total_loss = self.perceptual_loss_weight * perceptual_loss + self.blip_loss_weight * blip_loss
        return total_loss


## Prepare optimizer and scheduler for LoRA fine-tuning

In [None]:
optimizer = torch.optim.Adam(params = pipe.unet.parameters(), 
                             lr     = 5e-5)

scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

loss_fn   = TotalBLIPLoss()

## Input image features for fine-tuning

In [None]:
dataloader = DataLoader(training_dataset, 
                        batch_size = 1, 
                        shuffle    = True)

for step, batch in enumerate(dataloader):
    optimizer.zero_grad()
    
    input_image_path        = batch["image_paths"][0]
    target_image            = batch["target_image"].to(device)
    prompt_text             = batch["prompt_text"]

    # Extract embeddings from BLIP
    text_embed, image_embed = extract_blip_features(input_image_path, prompt_text)
    text_embed              = torch.tensor(text_embed).to(device)
    image_embed             = torch.tensor(image_embed).to(device)

    # Generate image on MPS
    generated_image         = pipe(prompt_text).images[0].to(device)

    # Compute loss
    loss                    = loss_fn(generated_image = generated_image, 
                                      target_image    = target_image, 
                                      text_embedding  = text_embed, 
                                      image_embedding = image_embed)
    loss.backward()
    optimizer.step()

    if step % 10 == 0:
        print(f"Step {step}, Loss: {loss.item()}")

    
print("Image fine-tuning complete!")


# Final Check

## Generate Social Media Posts with Integrated Models

In [None]:
def extract_post_details(text):
    hashtags = " ".join(re.findall(r"#\w+", text))
    emojis   = "🎉🔥"  # You can use an emoji extractor if needed
    caption  = text.split(".")[0] if text else "Generated Caption"
    return caption, hashtags, emojis
    

In [None]:
def generate_social_media_post(occasion: str, subject: str, platform: str, text_length: int, image_size: str, num_images: int):
    """
    Generate a social media post with text and related images
    """
    # Generate Post Text
    prompt         = (f"Occasion: {occasion}. Subject: {subject}. "
                      f"Platform: {platform}. Desired length: {text_length} words.")
    
    inputs         = llm_tokenizer(prompt, return_tensors="pt")
    text_outputs   = llm_model.generate(**inputs, max_new_tokens=150)
    generated_text = llm_tokenizer.decode(text_outputs[0], skip_special_tokens=True)
    
    # Extract caption, hashtags, and emojis from the generated text
    caption        = "Generated Caption Placeholder"  
    hashtags       = "#GeneratedPlaceholder"        
    emojis         = "🎉🔥"
    
    # Generate Related Images
    generated_images = list()

    for i in range(num_images):
        image_prompt    = f"{subject} for {platform} on {occasion} in {image_size} resolution"
        generated_image = pipe(image_prompt).images[0]
        image_path      = f"./generated_images/post_image_{i+1}.jpg"
        
        generated_image.save(image_path)
        generated_images.append(image_path)
    
    output_dict = {"post_text" : generated_text,
                   "caption"   : caption,
                   "hashtags"  : hashtags,
                   "emojis"    : emojis,
                   "images"    : generated_images
                  }

    return output_dict


## Example Usage

In [None]:
# Call the function to generate a post
result = generate_social_media_post(occasion    = "Valentine's Day", 
                                    subject     = "Romantic Getaway Packages", 
                                    platform    = "Instagram", 
                                    text_length = 100, 
                                    image_size  = "1080x1080", 
                                    num_images  = 2
                                   )
