In [1]:
from dataclasses import dataclass
import pandas as pd

In [2]:
df = pd.read_csv("./data/AG_News.csv")
df.shape
df.head()

Unnamed: 0,Class Index,Title,Description
0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli..."
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...
4,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco..."


In [3]:
df["Class Index"].value_counts()

Class Index
3    30000
4    30000
2    30000
1    30000
Name: count, dtype: int64

In [4]:
sdf = df.sample(100)

In [5]:
sdf["Class Index"].value_counts()

Class Index
2    28
3    26
4    25
1    21
Name: count, dtype: int64

In [6]:
from typing import List
@dataclass
class User:
    
    user_name: str

@dataclass
class UserPrefer:
    
    user_name: str
    prefers: List[int]


@dataclass
class Item:
    
    item_id: str
    item_props: dict


@dataclass
class Action:
    
    action_type: str
    action_props: dict


@dataclass
class UserAction:
    
    user: User
    item: Item
    action: Action
    action_time: str

In [7]:
u1 = User("u1")
up1 = UserPrefer("u1", [1, 2])
# sdf.iloc[1] 正好是sport（类别为2）
i1 = Item("i1", {
    "id": 1, 
    "catetory": "sport",
    "title": "Swimming: Shibata Joins Japanese Gold Rush", 
    "description": "\
    ATHENS (Reuters) - Ai Shibata wore down French teen-ager  Laure Manaudou to win the women's 800 meters \
    freestyle gold  medal at the Athens Olympics Friday and provide Japan with  their first female swimming \
    champion in 12 years.", 
    "content": "content"
})
a1 = Action("浏览", {
    "open_time": "2023-04-01 12:00:00", 
    "leave_time": "2023-04-01 14:00:00",
    "type": "close",
    "duration": "2hour"
})
ua1 = UserAction(u1, i1, a1, "2023-04-01 12:00:00")

In [8]:
import os
import zhipuai
import sys
os.environ['ZHIPUAI_API_KEY'] = 'b441f069335a92a7f8679736ec1be3b9.1pnH2AoIw04JHMV1'
zhipuai.api_key = os.environ['ZHIPUAI_API_KEY']


In [10]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [12]:
from zhipuai_embedding import ZhipuAIEmbeddings

# 初始化智谱嵌入实例
embedding = ZhipuAIEmbeddings()

# 定义一个函数来获取文本的嵌入向量
def get_zhipu_embedding(text):
    return embedding.embed_query(text)

# 使用apply函数来为DataFrame中每一行的标题和描述获取嵌入向量
# 假设'sdf'是你的DataFrame，并且它有'Title'和'Description'这两列
sdf["embedding"] = sdf.apply(lambda x: get_zhipu_embedding(x.Title + x.Description), axis=1)


In [16]:
import random
class Recall:
    
    def __init__(self, df: pd.DataFrame):
        self.data = df
    
    def user_prefer_recall(self, user, n):
        up = self.get_user_prefers(user)
        idx = random.randrange(0, len(up.prefers))
        return self.pick_by_idx(idx, n)
    
    def hot_recall(self, n):
        # 随机进行示例
        df = self.data.sample(n)
        return df
    
    def user_action_recall(self, user, n):
        actions = self.get_user_actions(user)
        interest = self.get_most_interested_item(actions)
        recoms = self.recommend_by_interest(interest, n)
        return recoms
    
    def get_most_interested_item(self, user_action):
        """
        可以选近一段时间内用户交互时间、次数、评论（相关属性）过的Item
        """
        # 就是sdf的第2行，idx为1的那条作为最喜欢（假设）
        # 是一条游泳相关的Item
        idx = user_action.item.item_props["id"]
        im = self.data.iloc[idx]
        return im
    
    def recommend_by_interest(self, interest, n):
        cate_id = interest["Class Index"]
        q_emb = interest["embedding"]
        # 确定类别
        base = self.data[self.data["Class Index"] == cate_id]
        # 此处可以复用QA那一段代码，用给定embedding计算base中embedding的相似度
        base_arr = np.array(
            [v.embedding for v in base.itertuples()]
        )
        q_arr = np.expand_dims(q_emb, 0)
        sims = cosine_similarity(base_arr, q_arr)
        # 排除掉自己
        idxes = sims.argsort(0).squeeze()[-(n+1):-1]
        return base.iloc[reversed(idxes.tolist())]
    
    def pick_by_idx(self, category, n):
        df = self.data[self.data["Class Index"] == category]
        return df.sample(n)
    
    def get_user_actions(self, user):
        dct = {"u1": ua1}
        return dct[user.user_name]
    
    def get_user_prefers(self, user):
        dct = {"u1": up1}
        return dct[user.user_name]
    
    def run(self, user):
        ur = self.user_action_recall(user, 5)
        if len(ur) == 0:
            ur = self.user_prefer_recall(user, 5)
        hr = self.hot_recall(3)
        return pd.concat([ur, hr], axis=0)

In [17]:
r = Recall(sdf)

In [18]:
rd = r.run(u1)

In [19]:
rd

Unnamed: 0,Class Index,Title,Description,embedding
4107,3,Stocks Sag as Oil Marches Higher,NEW YORK (Reuters) - U.S. stocks tumbled on T...,"[-0.2420201599597931, 0.10054507851600647, 0.0..."
23777,3,Software Stocks Bounce After Oracle Ruling,Software stocks rose on Friday on speculation ...,"[-0.2412630319595337, -0.045222070068120956, 0..."
5272,3,Treasuries Edge Down as Oil Retreat,NEW YORK (Reuters) - U.S. Treasury prices edg...,"[-0.3532193601131439, 0.04000478982925415, 0.0..."
67154,3,"Cingular: subscriber, revenue rise, but operat...",ATLANTA Cingular Wireless reports a rise in su...,"[-0.21263732016086578, 0.18610377609729767, 0...."
29570,3,"Amex, Nasdaq in talks to move QQQ",NEW YORK (CNN/Money) - The Nasdaq 100 index tr...,"[-0.21873332560062408, 0.14688748121261597, -0..."
4869,1,Militants reportedly promise to free kidnapped...,"BAGHDAD, Iraq (AP) A top aide to firebrand Shi...","[-0.40075698494911194, -0.08407636731863022, -..."
71858,4,Red Hat Users Under Phishing Attack (NewsFactor),"NewsFactor - Red Hat (Nasdaq: RHAT), provider ...","[-0.2588844895362854, 0.11933600157499313, 0.0..."
86129,2,Fans Calling for Tickets Get Sex Line (AP),AP - Charlotte Bobcats fans were in for a surp...,"[-0.3724042475223541, 0.2509610056877136, 0.12..."
