# LVLM-based Recommender Prompt Generation

This notebook demonstrates the process of generating prompts for a LVLM-based recommender system. We will follow these steps:

1. **Import Necessary Modules**: Import all the necessary modules and set up the project path.
2. **Define the `PromptGenerator` Class**: Set up the `PromptGenerator` class.
3. **Initialization**: Initialize the `PromptGenerator` class with the dataset.
4. **Image Combination**: Combine images for history and target sequences.
5. **Prompt Generation**: Generate prompts using different templates and explain their meanings.

## Step 1: Import Necessary Modules

First, we import all the necessary modules and set up the project path.


In [1]:
import os
import sys
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from dataset import MRecDataset
import json
from prompt_templates import templates
import matplotlib.pyplot as plt
from PIL import Image
import random
import csv



## Step 2: Define the `PromptGenerator` Class

Next, we define the `PromptGenerator` class with the appropriate methods.

In [12]:
class PromptGenerator(object):
    def __init__(self, dataset, dataset_name, sample_size=None, max_seq_len=10, neg_num=29):
        """
        Initialize the PromptGenerator class.

        Parameters:
        - dataset (object): The dataset object containing the data instances.
        - dataset_name (str): The name of the dataset. Must be one of ['toys', 'beauty', 'sports', 'clothing'].
        - sample_size (int, optional): The number of samples to generate. Default is None.
        - max_seq_len (int, optional): The maximum sequence length. Default is 10.
        - neg_num (int, optional): The number of negative samples. Default is 29.
        """
        super(PromptGenerator, self).__init__()
        dataset_choices = ['toys', 'beauty', 'sports', 'clothing']
        if dataset_name not in dataset_choices:
            raise ValueError(f'Invalid dataset: [{dataset_name}]. Only {dataset_choices} are supported.')
        self.max_seq_len = max_seq_len
        self.dataset_name = dataset_name
        self.dataset = dataset
        self.interaction = {str(index+1): sublist for index, sublist in enumerate(self.dataset.instances())}

        self.base_local_path = '../../datasets/Amazon_Review_Plus/photos/{}/'
        self.base_online_path = 'https://hmdataset.oss-cn-guangzhou.aliyuncs.com/{}_{}_combine/'
        self.sampled_photos_path = './sampled_photos/{}_{}/'
        self.sampled_prompts_path = './sampled_prompts/{}_{}/'
        
        self.online_image_path = self.base_online_path.format(self.dataset_name, self.max_seq_len)
        self.local_image_path = self.sampled_photos_path.format(self.dataset_name, self.max_seq_len)
        self.saved_prompt_path = self.sampled_prompts_path.format(self.dataset_name, self.max_seq_len)

        self.sample_size = sample_size
        self.neg_num = neg_num
        
        if sample_size:
            self.sample_interactions()

    def read_tsv(self, path, datamap, item_pool_full):
        """
        Read a TSV file and extract user-item interactions.

        Parameters:
        - path (str): The file path to the TSV file.
        - datamap (dict): A dictionary mapping user and item IDs.
        - item_pool_full (dict): A dictionary containing item details.

        Returns:
        - dict: A dictionary with user IDs as keys and a list of item titles as values.
        """
        with open(path, 'r', newline='') as file:
            no_title = 0
            tsv_reader = csv.reader(file, delimiter='\t')
            result = {}

            for i, row in enumerate(tsv_reader):
                if i == 0:
                    continue
                user = row[0]
                items = row[1].split(" ")
                item_titles = []
                user_id = datamap['user2id'][user]
                for item in items:
                    item_id = datamap['item2id'][item]
                    if 'title' not in item_pool_full[item_id]:
                        title = ""
                        no_title += 1
                    else:
                        title = item_pool_full[item_id]['title']
                    item_titles.append(title)
                result[user_id] = item_titles
            return result

    def is_image_available(self, item_id):
        """
        Check if an image is available for a given item ID.

        Parameters:
        - item_id (str): The item ID.

        Returns:
        - bool: True if the image is available, False otherwise.
        """
        img_path = self.base_local_path.format(self.dataset_name) + self.dataset.id2img(item_id)
        return os.path.exists(img_path) and 'No_Image_Available.jpg' not in img_path

    def sample_interactions(self, seed=42):
        """
        Sample interactions from the dataset.

        Parameters:
        - seed (int, optional): The random seed for sampling. Default is 42.
        """
        if seed is not None:
            random.seed(seed)

        original_count = len(self.interaction)
        valid_keys = [key for key in self.interaction.keys() if all(self.is_image_available(item) for item in self.interaction[key])]
        valid_count = len(valid_keys)

        print(f"Total interactions before filtering: {original_count}")
        print(f"Total valid interactions after filtering: {valid_count}")

        if valid_count < self.sample_size:
            raise ValueError(f"Only {valid_count} valid interactions available, but {self.sample_size} samples are requested.")

        sampled_keys = random.sample(valid_keys, self.sample_size)
        print(sampled_keys)
        self.interaction = {key: self.interaction[key] for key in sampled_keys}

    def save_image(self, all_images, label):
        """
        Save images to local paths.

        Parameters:
        - all_images (dict): A dictionary with image paths.
        - label (str): The label for the images.
        """
        if not os.path.exists(self.local_image_path):
            os.makedirs(self.local_image_path)

        for index, img_path_list in all_images.items():
            images = [Image.open(img_path) for img_path in img_path_list]
            num_images = len(images)

            fig, axes = plt.subplots(1, num_images, figsize=(3*num_images, 3))
            if num_images == 1:
                axes = [axes]

            for i, image in enumerate(images):
                axes[i].imshow(image)
                axes[i].axis('off')

            plt.savefig(self.local_image_path + '{}_{}.png'.format(index, label))
            plt.close(fig)

    def combine_images(self):
        """
        Combine images for history and target sequences.
        """
        history_image_dict = {}
        target_image_dict = {}
        for index, seq in self.interaction.items():
            seq_images = [self.base_local_path.format(self.dataset_name) + self.dataset.id2img(_) for _ in seq]
            history_image_dict[index] = seq_images[-self.max_seq_len-1:-1]
            target_image_dict[index] = [seq_images[-1]]
        self.save_image(history_image_dict, 'history')
        self.save_image(target_image_dict, 'target')

    def get_image_path_dict(self):
        """
        Get a dictionary of image paths for history sequences.

        Returns:
        - dict: A dictionary with image paths.
        """
        history_image_dict = {}
        for index, seq in self.interaction.items():
            seq_images = [self.base_local_path.format(self.dataset_name) + self.dataset.id2img(_) for _ in seq]
            history_image_dict[index] = seq_images[-self.max_seq_len-1:-1]
        return history_image_dict

    def generate_prompts(self, template_id, templates, tsv_path=None, lvlm="claude3"):
        """
        Generate prompts using a specific template.

        Parameters:
        - template_id (str): The ID of the template to use.
        - templates (dict): A dictionary of templates.
        - tsv_path (str, optional): The file path to the TSV file.
        - lvlm (str, optional): The language model to use. Default is "claude3".
        """
        if not os.path.exists(self.saved_prompt_path):
            os.makedirs(self.saved_prompt_path)
        
        template = templates[template_id].strip()
        
        prompts = {}

        if template_id in ['s-1-image','s-1-title-image','s-1-title', 's-2', 's-3','s-4', 's-5']:
            title_des = []

            history_image_dict = self.get_image_path_dict()
            for index, seq in self.interaction.items():
                history_titles = []
                for item in seq[-self.max_seq_len-1:-1]:
                    try:
                        title = self.dataset.item2side(item)['title']
                    except KeyError:
                        title = "no_title"
                    history_titles.append(title)
                for item_id, title in zip(seq[-self.max_seq_len-1:-1],history_titles):
                    des = self.dataset.item2side(item_id)["image"]["image_description"][lvlm].replace('\n','')
                    title_des.append("(title: {} | description: {})".format(title,des))
                target_title = self.dataset.item2side(seq[-1]).get('title', 'no_title')
                candidate_titles = [target_title]
                for item in self.dataset._user2neg[index][:self.neg_num]:
                    try:
                        title = self.dataset.item2side(item)['title']
                    except KeyError:
                        title = "no_title"
                    candidate_titles.append(title)
                random.shuffle(candidate_titles)
                history_len = len(seq[-self.max_seq_len-1:-1])
                candidate_len = self.neg_num+1
                if template_id == 's-1-image':
                    prompt = template.format(
                        history_len, 
                        candidate_len, 
                        str(candidate_titles), 
                        candidate_len, 
                        candidate_len
                        )
                elif template_id == 's-1-title-image':
                    prompt = template.format(
                        history_len, 
                        str(history_titles),
                        candidate_len, 
                        str(candidate_titles), 
                        candidate_len, 
                        candidate_len
                        )
                elif template_id == 's-1-title':
                    prompt = template.format(
                        history_len, 
                        str(history_titles),
                        candidate_len, 
                        str(candidate_titles), 
                        candidate_len, 
                        candidate_len
                        )
                elif template_id == 's-3':
                    datamap = self.dataset.datamaps
                    item_pool_full = self.dataset.item_pool_full
                    user_id2item_titles = self.read_tsv(tsv_path, datamap, item_pool_full)
                    candidate_titles = user_id2item_titles[index]
                    prompt = template.format(
                        history_len, 
                        str(history_titles),
                        str(candidate_titles), 
                        candidate_len, 
                        candidate_len
                    )
                elif template_id == 's-4':
                    prompt = template.format(
                        history_len, 
                        str(title_des),
                        candidate_len, 
                        str(candidate_titles), 
                        candidate_len, 
                        candidate_len
                    )
                elif template_id == 's-5':
                    datamap = self.dataset.datamaps
                    item_pool_full = self.dataset.item_pool_full
                    user_id2item_titles = self.read_tsv(tsv_path, datamap, item_pool_full)
                    candidate_titles = user_id2item_titles[index]
                    prompt = template.format(
                        history_len, 
                        str(title_des),
                        str(candidate_titles), 
                        candidate_len, 
                        candidate_len
                    )
                prompts[index] = {
                    'prompt': prompt,
                    'history': {
                        'local_combined_image_path': self.local_image_path + '{}_{}.png'.format(index, 'history'),
                        'online_combined_image_path': self.online_image_path + '{}_{}.png'.format(index, 'history'),
                        'id': seq[-self.max_seq_len-1:-1],
                        'titles': history_titles,
                        'original_images_path': history_image_dict[index]
                    },
                    'candidate': {
                        'titles': candidate_titles,
                    },
                    'target': {
                        'local_image_path': self.local_image_path + '{}_{}.png'.format(index, 'target'),
                        'online_image_path': self.online_image_path + '{}_{}.png'.format(index, 'target'),
                        'id': [seq[-1]],
                        'titles': [target_title]
                    }
                }
        else:
            raise ValueError("Invalid template ID provided.")
        
        if template_id in ["s-4", "s-5"]:
            json_file_path = self.saved_prompt_path + 'prompts_{}_{}.json'.format(template_id, lvlm)
        else:
            json_file_path = self.saved_prompt_path + 'prompts_{}.json'.format(template_id)
        
        with open(json_file_path, 'w', encoding='utf-8') as f:
            json.dump(prompts, f, ensure_ascii=False, indent=4)

## Step 3: Initialization

Next, we initialize the `PromptGenerator` class with the dataset.

In [6]:
dataset_name = 'toys'
dataset = MRecDataset(root='../../datasets/Amazon_Review_Plus', dataset=dataset_name)

sample_size = 10
max_seq_len = 10
neg_num = 29
pg = PromptGenerator(dataset=dataset, dataset_name=dataset_name, sample_size=sample_size, neg_num=neg_num, max_seq_len=max_seq_len)

Total interactions before filtering: 19412
Total valid interactions after filtering: 19017
['3732', '843', '9200', '8194', '7469', '4674', '3437', '18242', '2918', '14108']


## Step 4: Image Combination

We combine images for history and target sequences.


In [7]:
pg.combine_images()

## Step 5: Prompt Generation

Finally, we generate prompts using different templates. Here is the explanation of each template ID:

- **s-1-title**: LVLMs as recommender using titles only.
- **s-1-title-image**: LVLMs as recommender using both titles and images.
- **s-1-image**: LVLMs as recommender using images only.
- **s-3**: LVLMs as reranker using both titles and images.
- **s-4**: LVLMs as item enhancer and recommender, using captions generated from images.
- **s-5**: LVLMs as item enhancer and reranker, using captions generated from images.

Note: For rerankers, tsv_path contains pre-ranked results from other recommenders, such as SASRec. 


In [11]:
pg.generate_prompts('s-1-title', templates)
pg.generate_prompts('s-1-title-image', templates)
pg.generate_prompts('s-1-image', templates)
pg.generate_prompts('s-3', templates, tsv_path="tsv_files/test_resultv20_toys_64_0.0001_VitConcatTitle-20240606-141603.tsv")
pg.generate_prompts('s-4', templates)
pg.generate_prompts('s-5', templates, tsv_path="tsv_files/test_resultv20_toys_64_0.0001_VitConcatTitle-20240606-141603.tsv")