In [1]:
import pandas as pd
import gzip
import json
from tqdm import tqdm

# ---- 1. 定义数据加载函数 ----
# 这个函数可以逐行读取压缩的jsonl文件，避免一次性加载到内存中导致崩溃
def load_data(file_path, limit=None):
    data = []
    with gzip.open(file_path, 'rt', encoding='utf-8') as f:
        for i, line in enumerate(tqdm(f, desc=f"Loading {file_path.split('/')[-1]}")):
            if limit and i >= limit:
                break
            data.append(json.loads(line))
    return pd.DataFrame(data)

# ---- 2. 设置文件路径 ----
# 请确保这个路径与您的项目结构一致
RAW_DATA_PATH = '../data/raw/'

# Handmade 品类的文件路径
handmade_reviews_path = RAW_DATA_PATH + 'Handmade_Products.jsonl.gz'
handmade_meta_path = RAW_DATA_PATH + 'meta_Handmade_Products.jsonl.gz'

# Health & Personal Care 品类的文件路径
health_reviews_path = RAW_DATA_PATH + 'Health_and_Personal_Care.jsonl.gz'
health_meta_path = RAW_DATA_PATH + 'meta_Health_and_Personal_Care.jsonl.gz'


# ---- 3. 加载 Handmade 数据进行初步探索 ----
# 为了快速测试，我们先只加载前10000条记录。如果运行顺利，可以去掉 limit=10000 参数来加载全部数据。
print("--- Loading Handmade Reviews ---")
df_reviews_handmade = load_data(handmade_reviews_path, limit=10000)

print("\n--- Loading Handmade Metadata ---")
df_meta_handmade = load_data(handmade_meta_path, limit=10000)


# ---- 4. 显示加载后的数据信息 ----
print("\n--- Handmade Reviews DataFrame Info ---")
df_reviews_handmade.info()
print("\n--- First 3 rows of Handmade Reviews ---")
display(df_reviews_handmade.head(3))

print("\n--- Handmade Metadata DataFrame Info ---")
df_meta_handmade.info()
print("\n--- First 3 rows of Handmade Metadata ---")
display(df_meta_handmade.head(3))

--- Loading Handmade Reviews ---


Loading Handmade_Products.jsonl.gz: 10000it [00:00, 76473.64it/s]



--- Loading Handmade Metadata ---


Loading meta_Handmade_Products.jsonl.gz: 10000it [00:00, 23189.46it/s]


--- Handmade Reviews DataFrame Info ---
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Data columns (total 10 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   rating             10000 non-null  float64
 1   title              10000 non-null  object 
 2   text               10000 non-null  object 
 3   images             10000 non-null  object 
 4   asin               10000 non-null  object 
 5   parent_asin        10000 non-null  object 
 6   user_id            10000 non-null  object 
 7   timestamp          10000 non-null  int64  
 8   helpful_vote       10000 non-null  int64  
 9   verified_purchase  10000 non-null  bool   
dtypes: bool(1), float64(1), int64(2), object(6)
memory usage: 713.0+ KB

--- First 3 rows of Handmade Reviews ---





Unnamed: 0,rating,title,text,images,asin,parent_asin,user_id,timestamp,helpful_vote,verified_purchase
0,5.0,Beautiful colors,I bought one for myself and one for my grandda...,[],B08GPJ1MSN,B08GPJ1MSN,AF7OANMNHQJC3PD4HRPX2FATECPA,1621607495111,1,True
1,5.0,You simply must order order more than one!,I’ve ordered three bows so far. Have not been ...,[],B084TWHS7W,B084TWHS7W,AGMJ3EMDVL6OWBJF7CA5RGJLXN5A,1587762946965,0,True
2,5.0,Great,As pictured. Used a frame from the dollar stor...,[],B07V3NRQC4,B07V3NRQC4,AEYORY2AVPMCPDV57CE337YU5LXA,1591448951297,0,True



--- Handmade Metadata DataFrame Info ---
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Data columns (total 14 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   main_category    10000 non-null  object 
 1   title            10000 non-null  object 
 2   average_rating   10000 non-null  float64
 3   rating_number    10000 non-null  int64  
 4   features         10000 non-null  object 
 5   description      10000 non-null  object 
 6   price            6514 non-null   float64
 7   images           10000 non-null  object 
 8   videos           10000 non-null  object 
 9   store            9930 non-null   object 
 10  categories       10000 non-null  object 
 11  details          10000 non-null  object 
 12  parent_asin      10000 non-null  object 
 13  bought_together  0 non-null      object 
dtypes: float64(2), int64(1), object(11)
memory usage: 1.1+ MB

--- First 3 rows of Handmade Metadata ---


Unnamed: 0,main_category,title,average_rating,rating_number,features,description,price,images,videos,store,categories,details,parent_asin,bought_together
0,Handmade,Daisy Keychain Wristlet Gray Fabric Key fob La...,4.5,12,"[High Quality Fabrics, Antique Brass Metallic ...",[This charming Daisy Fabric Keychain wristlet ...,,[{'thumb': 'https://m.media-amazon.com/images/...,[],Generic,"[Handmade Products, Clothing, Shoes & Accessor...",{'Package Dimensions': '8 x 4 x 0.85 inches; 0...,B07NTK7T5P,
1,Handmade,Anemone Jewelry Beauteous November Birthstone ...,4.1,10,"[Stunning gemstone and detailed design, Bands ...",[Anemone brings this November birthstone ring ...,69.0,[{'thumb': 'https://m.media-amazon.com/images/...,[],Anemone Jewelry,"[Handmade Products, Jewelry, Rings, Statement]","{'Department': 'womens', 'Date First Available...",B0751M85FV,
2,Handmade,Silver Triangle Earrings with Chevron Pattern,5.0,1,[],[These large silver triangles are stamped with...,,[{'thumb': 'https://m.media-amazon.com/images/...,[],Zoë Noelle Designs,"[Handmade Products, Jewelry, Earrings, Drop & ...","{'Department': 'Women', 'Date First Available'...",B01HYNE114,


In [3]:
# ---- 5. 筛选核心列 ----
# 我们不需要所有列，只选择对模型有用的，可以节省大量内存
cols_reviews = ['user_id', 'parent_asin', 'text', 'timestamp']
cols_meta = ['parent_asin', 'title', 'description', 'images']

df_reviews_handmade_slim = df_reviews_handmade[cols_reviews]
df_meta_handmade_slim = df_meta_handmade[cols_meta]

# ---- 6. 合并评论和元数据 ----
# 使用 'inner' 合并，确保我们只保留同时拥有评论和元数据的商品交互记录
df_handmade_merged = pd.merge(
    df_reviews_handmade_slim,
    df_meta_handmade_slim,
    on='parent_asin',
    how='inner'
)

print(f"--- Merged DataFrame for Handmade ---")
print(f"Original reviews count: {len(df_reviews_handmade_slim)}")
print(f"Merged count: {len(df_handmade_merged)}")
df_handmade_merged.info()


# ---- 7. 数据清洗与特征提取 ----

# 7.1) 定义一个函数来安全地提取高清图片URL
def extract_image_url(image_list):
    if not isinstance(image_list, list) or len(image_list) == 0:
        return None
    for img in image_list:
        if isinstance(img, dict) and img.get('hi_res'):
            return img['hi_res']
    if isinstance(image_list[0], dict) and image_list[0].get('large'):
        return image_list[0]['large']
    return None

# 7.2) 应用函数，创建新的 'image_url' 列
df_handmade_merged['image_url'] = df_handmade_merged['images'].apply(extract_image_url)

# 【--- 错误修正处 ---】
# 在合并文本前，将 description 列中的列表转换为字符串
def join_if_list(entry):
    if isinstance(entry, list):
        return ' '.join(str(item) for item in entry) # 将列表元素用空格连接成一个字符串
    return entry

# 将这个转换函数应用到 description 列
df_handmade_merged['description'] = df_handmade_merged['description'].apply(join_if_list)
# 【--- 修正结束 ---】


# 7.3) 组合文本特征 (标题 + 描述 + 评论)
# fillna('') 确保即使某部分文本缺失，也不会导致错误
df_handmade_merged['combined_text'] = (
    df_handmade_merged['title'].fillna('') + ' . ' +
    df_handmade_merged['description'].fillna('') + ' . ' + # 现在这一步是安全的
    df_handmade_merged['text'].fillna('')
)

# 7.4) 移除空值和不再需要的列
# 确保所有核心信息都存在，特别是 image_url
df_handmade_final = df_handmade_merged.dropna(subset=['user_id', 'parent_asin', 'timestamp', 'image_url', 'combined_text'])

# 只保留最终需要的列
final_cols = ['user_id', 'parent_asin', 'timestamp', 'combined_text', 'image_url']
df_handmade_final = df_handmade_final[final_cols]


# ---- 8. 显示最终处理好的数据 ----
print("\n--- Final Processed DataFrame for Handmade (Top 5 rows) ---")
display(df_handmade_final.head())

print(f"\nTotal clean records for Handmade: {len(df_handmade_final)}")

--- Merged DataFrame for Handmade ---
Original reviews count: 10000
Merged count: 1021
<class 'pandas.core.frame.DataFrame'>
Int64Index: 1021 entries, 0 to 1020
Data columns (total 7 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   user_id      1021 non-null   object
 1   parent_asin  1021 non-null   object
 2   text         1021 non-null   object
 3   timestamp    1021 non-null   int64 
 4   title        1021 non-null   object
 5   description  1021 non-null   object
 6   images       1021 non-null   object
dtypes: int64(1), object(6)
memory usage: 63.8+ KB

--- Final Processed DataFrame for Handmade (Top 5 rows) ---


Unnamed: 0,user_id,parent_asin,timestamp,combined_text,image_url
0,AF2FTNKCY6XY67BKBUO4BNJRZ4XQ,B09GK2JJDZ,1669860506507,Cats Dogs ID Tags Personalized Lovely Symbols ...,https://m.media-amazon.com/images/I/61pFszChat...
1,AHPBIP5JNVD4ZDTULYEAAX2PBDGQ,B09GK2JJDZ,1658635895538,Cats Dogs ID Tags Personalized Lovely Symbols ...,https://m.media-amazon.com/images/I/61pFszChat...
2,AEUKSETY42JHTLQY734DRPDGMSXA,B09GK2JJDZ,1663898449664,Cats Dogs ID Tags Personalized Lovely Symbols ...,https://m.media-amazon.com/images/I/61pFszChat...
3,AEAE2ZYPWPDUJHAAMGQIL7LNVTPA,B09GK2JJDZ,1677866085425,Cats Dogs ID Tags Personalized Lovely Symbols ...,https://m.media-amazon.com/images/I/61pFszChat...
4,AGQOAL23M4GPIYK2WKSEMW3Q27EQ,B09GK2JJDZ,1644892809092,Cats Dogs ID Tags Personalized Lovely Symbols ...,https://m.media-amazon.com/images/I/61pFszChat...



Total clean records for Handmade: 1021


In [4]:
# ---- 9. 将处理流程封装成一个函数 ----
def process_category(review_path, meta_path, limit=None):
    """加载、合并和清洗指定品类的数据"""
    print(f"\n--- Processing category: {review_path.split('/')[-1]} ---")
    
    # 加载数据
    df_reviews = load_data(review_path, limit=limit)
    df_meta = load_data(meta_path, limit=limit)
    
    # 筛选核心列
    cols_reviews = ['user_id', 'parent_asin', 'text', 'timestamp']
    cols_meta = ['parent_asin', 'title', 'description', 'images']
    df_reviews_slim = df_reviews[cols_reviews]
    df_meta_slim = df_meta[cols_meta]
    
    # 合并
    df_merged = pd.merge(df_reviews_slim, df_meta_slim, on='parent_asin', how='inner')
    
    # 清洗与特征提取
    df_merged['image_url'] = df_merged['images'].apply(extract_image_url)
    
    def join_if_list(entry):
        if isinstance(entry, list):
            return ' '.join(str(item) for item in entry)
        return entry
    df_merged['description'] = df_merged['description'].apply(join_if_list)
    
    df_merged['combined_text'] = (
        df_merged['title'].fillna('') + ' . ' +
        df_merged['description'].fillna('') + ' . ' +
        df_merged['text'].fillna('')
    )
    
    final_cols = ['user_id', 'parent_asin', 'timestamp', 'combined_text', 'image_url']
    df_final = df_merged.dropna(subset=final_cols)[final_cols]
    
    print(f"Finished processing. Found {len(df_final)} clean records.")
    return df_final

# ---- 10. 处理所有品类并合并 ----
# 注意：为了节省时间和内存，我们继续使用 limit=10000。
# 在您准备好进行完整训练时，可以移除 limit 参数。
df_handmade_final = process_category(handmade_reviews_path, handmade_meta_path, limit=10000)
df_health_final = process_category(health_reviews_path, health_meta_path, limit=10000)

# 合并两个品类的 DataFrame
df_all = pd.concat([df_handmade_final, df_health_final], ignore_index=True)
print(f"\n--- Combined DataFrame ---")
print(f"Total records from all categories: {len(df_all)}")


# ---- 11. 构建用户行为序列 ----
print("\n--- Building User Sequences ---")

# 11.1) 按用户ID和时间戳排序，这是构建序列最关键的一步
df_all_sorted = df_all.sort_values(by=['user_id', 'timestamp'])

# 11.2) 按 user_id 分组，并将每个用户的交互信息聚合为列表
df_sequences = df_all_sorted.groupby('user_id').agg({
    'parent_asin': list,
    'combined_text': list,
    'image_url': list,
    'timestamp': list
}).reset_index()

# 11.3) 筛选掉只有一个交互的用户（无法形成序列进行预测）
df_sequences['sequence_length'] = df_sequences['parent_asin'].apply(len)
df_sequences = df_sequences[df_sequences['sequence_length'] > 1].reset_index(drop=True)


# ---- 12. 显示最终的序列化数据 ----
print(f"\n--- Final Sequential DataFrame ---")
print(f"Total number of users with sequences: {len(df_sequences)}")
print("Each row represents a user, with their historical interactions aggregated into lists.")
display(df_sequences.head())


# ---- 13. 保存处理好的数据 ----
# 使用 pickle 格式可以完整地保存 DataFrame 的结构，包括列表
PROCESSED_DATA_PATH = '../data/processed/'
output_filename = PROCESSED_DATA_PATH + 'sequential_data_sample.pkl'
df_sequences.to_pickle(output_filename)

print(f"\nSuccessfully saved processed sequential data to: {output_filename}")


--- Processing category: Handmade_Products.jsonl.gz ---


Loading Handmade_Products.jsonl.gz: 10000it [00:00, 66793.92it/s]
Loading meta_Handmade_Products.jsonl.gz: 10000it [00:00, 22048.47it/s]


Finished processing. Found 1021 clean records.

--- Processing category: Health_and_Personal_Care.jsonl.gz ---


Loading Health_and_Personal_Care.jsonl.gz: 10000it [00:00, 30797.90it/s]
Loading meta_Health_and_Personal_Care.jsonl.gz: 10000it [00:00, 23895.31it/s]


Finished processing. Found 2546 clean records.

--- Combined DataFrame ---
Total records from all categories: 3567

--- Building User Sequences ---

--- Final Sequential DataFrame ---
Total number of users with sequences: 238
Each row represents a user, with their historical interactions aggregated into lists.


Unnamed: 0,user_id,parent_asin,combined_text,image_url,timestamp,sequence_length
0,AE2YT5NODO5H7VNHTO6CYQAE4VQA,"[B0C5BPS25X, B0C5BPS25X]",[Cascade Complete Powder Dishwasher Detergent ...,[https://m.media-amazon.com/images/I/71Ryg2c2J...,"[1490230355000, 1490230355000]",2
1,AE3FC3VW5FA434GIZ22WTTOFTH5A,"[B083T9VD7J, B01M4NB3OJ]",[Connoisseurs Silver Jewelry Cleaner with Clea...,[https://m.media-amazon.com/images/I/81gfjJ2+n...,"[1595800056992, 1614091185943]",2
2,AE3KLVXGZPANXE5XLXYKHTVAZ3FQ,"[B075J1D3QD, B07XVVVB8W, B08M5BX9VK, B08R9DFCG...","[7 pcs Ear Picks Earwax Removal Care Kit, Ear ...",[https://m.media-amazon.com/images/I/610t4QvtQ...,"[1511899508596, 1572203039766, 1613918829108, ...",5
3,AE3VBP5V3YQVBMHQLV5C3AX4JBTA,"[B07D6NVJCX, B0759GBX47]",[Navya Craft Labradorite Round Silver Ring 925...,[https://m.media-amazon.com/images/I/61B5KngPW...,"[1599573787239, 1600900435210]",2
4,AE4JZ7OUABFFH2GV7Z7DJVOICS7A,"[B0C1HHXQ94, B00RKNVZFE]",[ASUTRA Natural & Organic Yoga Mat Cleaner (Mi...,[https://m.media-amazon.com/images/I/81qEnhizz...,"[1522340932338, 1536925603760]",2



Successfully saved processed sequential data to: ../data/processed/sequential_data_sample.pkl


In [5]:
# ---- 测试我们自定义的 Dataset 和 DataLoader ----

# 由于 data_loader.py 在 src 目录下，我们需要将 src 目录添加到系统路径中
import sys
sys.path.append('../src')

# 从 data_loader.py 文件中导入 RecSysDataset 类
from data_loader import RecSysDataset
from torch.utils.data import DataLoader

# 1. 创建 Dataset 实例
print("--- Initializing Dataset ---")
PROCESSED_DATA_PATH = '../data/processed/sequential_data_sample.pkl'
dataset = RecSysDataset(data_path=PROCESSED_DATA_PATH, max_seq_len=50)

# 2. 验证 __len__ 方法
print(f"\nTotal users in dataset: {len(dataset)}")

# 3. 验证 __getitem__ 方法
print("\n--- Testing __getitem__ for the first user (index 0) ---")
sample_user_data = dataset[0]

# 打印样本数据的结构和内容
for key, value in sample_user_data.items():
    # 为了美观，只打印序列的前5个和后5个元素
    if isinstance(value, list):
        print(f"\n- {key} (length: {len(value)}):")
        if len(value) > 10:
            print(f"  First 5: {value[:5]}")
            print(f"  Last 5:  {value[-5:]}")
        else:
            print(f"  Value: {value}")
    else:
        print(f"\n- {key}:\n  Value: {value}")

# 4. 测试 DataLoader
# DataLoader 会自动将多条数据打包成一个批次(batch)
print("\n--- Testing DataLoader ---")
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
first_batch = next(iter(dataloader))

print("Successfully created a batch of size 2.")
print("Keys in the batch:", first_batch.keys())
print("Shape of input_texts in the batch:", len(first_batch['input_texts']), "users x", len(first_batch['input_texts'][0]), "items")

--- Initializing Dataset ---
Loading data from: ../data/processed/sequential_data_sample.pkl
Data loaded. Total users: 238

Total users in dataset: 238

--- Testing __getitem__ for the first user (index 0) ---

- input_item_ids (length: 50):
  First 5: ['<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
  Last 5:  ['<PAD>', '<PAD>', '<PAD>', '<PAD>', 'B0C5BPS25X']

- input_texts (length: 50):
  First 5: ['', '', '', '', '']
  Last 5:  ['', '', '', '', 'Cascade Complete Powder Dishwasher Detergent - Fresh Scent - 75oz (Pack of 2) .  . I received Cascade Complete as listed in the product description.  Good product.']

- input_images (length: 50):
  First 5: ['', '', '', '', '']
  Last 5:  ['', '', '', '', 'https://m.media-amazon.com/images/I/71Ryg2c2JmL._AC_SL1280_.jpg']

- target_item_id:
  Value: B0C5BPS25X

- target_text:
  Value: Cascade Complete Powder Dishwasher Detergent - Fresh Scent - 75oz (Pack of 2) .  . I received Cascade Complete as listed in the product description.  Good product

In [6]:
# ---- 测试我们自定义的 MultimodalEncoder ----

# 确保 src 目录在系统路径中
import sys
if '../src' not in sys.path:
    sys.path.append('../src')

# 从 models.py 文件中导入 MultimodalEncoder 类
from models import MultimodalEncoder
import torch

# 1. 确定运行设备
# 检查服务器是否有可用的CUDA GPU，如果有，就使用GPU，否则使用CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Using device: {device} ---")


# 2. 初始化多模态编码器
# 将模型实例化的同时，直接移动到我们确定的设备上
try:
    encoder = MultimodalEncoder(device=device)
except Exception as e:
    print(f"Error initializing MultimodalEncoder: {e}")
    # 如果出错，可能是网络问题或Hugging Face Hub访问问题，这里先假设能成功


# 3. 从 DataLoader 中获取一个真实的数据批次
# 我们之前已经创建了 dataloader 实例
first_batch = next(iter(dataloader))

# 我们只需要目标物品（target item）的文本和图片URL来测试编码器
sample_texts = list(first_batch['target_text'])
sample_image_urls = list(first_batch['target_image'])

print(f"\n--- Testing Encoder with a batch of size: {len(sample_texts)} ---")
print("Sample Text:", sample_texts[0][:80] + "...") # 打印第一个文本的前80个字符
print("Sample Image URL:", sample_image_urls[0])


# 4. 执行前向传播
# 将数据送入编码器，得到输出的嵌入向量
# 使用 torch.no_grad() 是一个好习惯，因为我们只是在测试（推理），不需要计算梯度
with torch.no_grad():
    # 将编码器设置为评估模式
    encoder.eval() 
    
    # 得到输出
    output_embeddings = encoder(texts=sample_texts, image_urls=sample_image_urls)


# 5. 验证输出
print("\n--- Verifying Output ---")
print(f"Shape of the output tensor: {output_embeddings.shape}")
print(f"Data type of the output tensor: {output_embeddings.dtype}")
print(f"Output tensor is on device: {output_embeddings.device}")

# 检查输出向量的范数（长度），因为我们代码里做了L2归一化，所以它的范数应该约等于1
# 我们只检查第一个向量
first_vector_norm = torch.linalg.norm(output_embeddings[0])
print(f"Norm of the first output vector: {first_vector_norm.item():.4f} (should be close to 1.0)")

--- Using device: cuda ---
Loading CLIP model and processor...


preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

CLIP model loaded successfully.

--- Testing Encoder with a batch of size: 2 ---
Sample Text: UCARI Intolerance & Food Sensitivity Test Kit for Adults & Kids | 1500+ Food, En...
Sample Image URL: https://m.media-amazon.com/images/I/61edIYIp8UL._AC_SL1500_.jpg

--- Verifying Output ---
Shape of the output tensor: torch.Size([2, 512])
Data type of the output tensor: torch.float32
Output tensor is on device: cuda:0
Norm of the first output vector: 1.0000 (should be close to 1.0)


In [11]:
# ---- 测试完整的 GenerativeRecSysModel (修正版) ----
import pandas as pd
import torch
import importlib

# 确保 src 目录在系统路径中并强制重载
import sys
if '../src' not in sys.path:
    sys.path.append('../src')
import models
importlib.reload(models)
from models import GenerativeRecSysModel

# 1. 确定运行设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"--- Using device: {device} ---")

# 2. 初始化完整的模型
print("\n--- Initializing the full GenerativeRecSysModel ---")
model = GenerativeRecSysModel(embed_dim=512, nhead=8, num_layers=3, device=device)
model.to(device)

# 3. 初始化模型的 Item ID 嵌入层
print("\n--- Initializing Item ID Embeddings ---")
full_data_df = pd.read_pickle('../data/processed/sequential_data_sample.pkl')
all_unique_item_ids = full_data_df['parent_asin'].explode().unique()
model.create_item_embeddings(all_unique_item_ids)
print(f"Item ID embedding layer created for {len(all_unique_item_ids)} unique items.")

# 4. 从 DataLoader 中获取一个真实的数据批次
first_batch = next(iter(dataloader))

# 【--- 错误修正处 ---】
# DataLoader 输出的批次是 (seq_len, batch_size)，我们需要将其转置为 (batch_size, seq_len)
# 使用 zip(*) 可以非常优雅地实现列表的转置
batch_input_ids = list(zip(*first_batch['input_item_ids']))
batch_input_texts = list(zip(*first_batch['input_texts']))
batch_input_images = list(zip(*first_batch['input_images']))
print(f"\nCorrected batch shape: {len(batch_input_ids)} users x {len(batch_input_ids[0])} sequence length")
# 【--- 修正结束 ---】

# 5. 执行完整模型的前向传播
print("\n--- Performing a forward pass through the full model ---")
with torch.no_grad():
    model.eval()
    user_interest_vector = model(
        input_item_ids=batch_input_ids,
        input_texts=batch_input_texts,
        input_images=batch_input_images
    )

# 6. 验证最终输出
print("\n--- Verifying Final Output ---")
print(f"Shape of the final user interest vector: {user_interest_vector.shape}")
print(f"Output vector is on device: {user_interest_vector.device}")
has_nan = torch.isnan(user_interest_vector).any()
print(f"Does the output contain NaN values? {has_nan.item()}")

--- Using device: cuda ---

--- Initializing the full GenerativeRecSysModel ---

--- Initializing Item ID Embeddings ---
Item ID embedding layer created for 415 unique items.

Corrected batch shape: 2 users x 50 sequence length

--- Performing a forward pass through the full model ---

--- Verifying Final Output ---
Shape of the final user interest vector: torch.Size([2, 512])
Output vector is on device: cuda:0
Does the output contain NaN values? False


In [13]:
import pandas as pd

# 加载我们处理好的序列数据
df_sequences = pd.read_pickle('../data/processed/sequential_data_sample.pkl')

# --- 智能寻找一个有效的、非重复的测试历史 ---
found_user = None
for index, user in df_sequences.iterrows():
    # 检查该用户的购买历史中，是否包含至少2个【不重复】的商品ID
    if len(set(user['parent_asin'])) >= 2:
        found_user = user
        break # 找到第一个就停止

if found_user is not None:
    print("--- Found a valid and meaningful user history for testing ---")
    
    # 提取前两个【不重复】的商品作为我们的测试输入
    distinct_items = []
    seen_ids = set()
    
    # 将用户的历史记录打包在一起遍历
    for item_id, text, image in zip(found_user['parent_asin'], found_user['combined_text'], found_user['image_url']):
        if item_id not in seen_ids:
            distinct_items.append({'parent_asin': item_id, 'combined_text': text, 'image_url': image})
            seen_ids.add(item_id)
        if len(distinct_items) == 2:
            break
            
    # 打印成我们可以直接复制的 jsonl 格式
    print("\n=== Please copy the following content into test_user_history.txt ===\n")
    for item in distinct_items:
        # 使用 json.dumps 来处理文本中的特殊字符，避免格式错误
        import json
        print(json.dumps(item))

else:
    print("Could not find a user with at least two different items in their history in the current data sample.")

--- Found a valid and meaningful user history for testing ---

=== Please copy the following content into test_user_history.txt ===

{"parent_asin": "B083T9VD7J", "combined_text": "Connoisseurs Silver Jewelry Cleaner with Cleaning Basket and Polishing Cloth 8 oz. .  . I had a cross made of silver and it was completely black from tarnishing, I follow the directions and dip it in the solution for just a few seconds when It came out I wiped it and wonderful. Great product highly recommend .", "image_url": "https://m.media-amazon.com/images/I/81gfjJ2+n3L._AC_SL1500_.jpg"}
{"parent_asin": "B01M4NB3OJ", "combined_text": "Microwave Dish Cozies, Set of 3, 1 Small Bowl Cozy, 1 Medium Bowl Cozy, and 1 Dinner Plate Cozy, Kitchen Motif . Microwave cozy set includes 3 dish cozies. 1 small bowl cozy (About 6.5\" square, pictured with 5\u201d diameter bowl and a mug), 1 medium bowl cozy (About 8\" square, pictured with 6\u201d diameter bowl), 1 dinner plate cozy (About 10\" square, pictured with 10.5