# Outline
- Load validation dataset - Movielens-1M
- Load recommendation model that was finetuned on specific embedding spce (TBC)
- Create a self-reflection mechanism on the validation dataset
  - Predict the next sequence for a specific user by prompting the llm with the data about the user
    - Generate prompt to encode specific features
  - reflect on the prediction vs the ground truth
  - suggest a better feature and reflect on why the mistake happened
  - continue to the next sequence

## Imports

In [73]:
import sys
import re
sys.path.append(r'C:\Projects\TAU\DeepLearning\Open-World-Knowledge-Augmented-Recommendation\knowledge_encoding')
from lm_encoding import get_text_data_loader 

In [74]:
import numpy as np
import pandas as pd
import torch
import torch.utils.data as Data

sys.path.append(r'C:\Projects\TAU\DeepLearning\Open-World-Knowledge-Augmented-Recommendation\RS')
from dataset import AmzDataset
from main_ctr import eval
from utils import load_json

In [75]:
class ShortDataLoader:
    def __init__(self, dataloader, num_batches):
        self.dataloader = dataloader
        self.num_batches = num_batches

    def __iter__(self):
        count = 0
        for batch in self.dataloader:
            if count >= self.num_batches:
                break
            yield batch
            count += 1

    def __len__(self):
        return min(len(self.dataloader), self.num_batches)

from torch.utils.data import Dataset
class DictDataset(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __getitem__(self, index):
        return self.data_list[index]

    def __len__(self):
        return len(self.data_list)

## Loading Test Set
- Identifying Classification mistakes

In [76]:
inference_model = torch.load(r'C:\Projects\TAU\DeepLearning\Open-World-Knowledge-Augmented-Recommendation\RS\model\ml-1m\ctr\DIN\WDA_Emb32_epoch20_bs256_lr1e-4_cosine_cnvt_arch_128,32_cnvt_type_HEA_eprt_2_wd0_drop0.0_hl200,80_cl3_augment_True\DIN.pt'
                   ,map_location=torch.device('cpu'))
test_set = AmzDataset(r'C:\Projects\TAU\DeepLearning\Open-World-Knowledge-Augmented-Recommendation\data\ml-1m\proc_data', 'test', 'ctr', 5, True, 'bert_avg')
test_loader = Data.DataLoader(dataset=test_set, batch_size=2, shuffle=False)
metric_scope = [1, 3, 5, 7]

In [77]:
test_set

<dataset.AmzDataset at 0x210b3c00750>

In [78]:
short_test = ShortDataLoader(test_loader, num_batches=100)
len(short_test)

100

In [79]:
auc, ll, loss, eval_time, labels, preds = eval(inference_model, short_test)
print("test loss: %.5f, test time: %.5f, auc: %.5f, logloss: %.5f" % (loss, eval_time, auc, ll))

test loss: 0.55809, test time: 3.09757, auc: 0.78708, logloss: 0.55809


In [80]:
preds_r = [1 if pred[0] > 0.5 else 0 for pred in preds]
mistake_indexes = [i for i, (label, pred) in enumerate(zip(labels, preds_r)) if label != pred]
mistake_indexes[:10]

[4, 9, 10, 14, 15, 17, 18, 22, 25, 26]

In [81]:
preds[4] , labels[4]

([0.18554836511611938], 1)

In [82]:
user_vec_dict = load_json(r'C:\Projects\TAU\DeepLearning\Open-World-Knowledge-Augmented-Recommendation\data\ml-1m\proc_data\bert_avg_augment.hist')
user_vec_dict.keys()

dict_keys(['2179', '3813', '2108', '778', '4232', '2903', '2449', '5286', '4235', '3020', '4866', '3172', '2516', '1989', '1878', '2673', '3497', '1061', '647', '1143', '3517', '231', '258', '5831', '982', '3908', '1216', '2971', '398', '3126', '415', '1848', '4661', '3427', '1053', '5374', '2311', '1042', '1737', '5192', '5647', '5545', '2476', '4616', '2507', '2872', '4229', '4215', '511', '2193', '2684', '5937', '1650', '5450', '4682', '2364', '1295', '3148', '5582', '903', '5263', '1821', '5227', '4339', '3456', '5062', '905', '1113', '1728', '254', '1781', '3376', '5270', '395', '1466', '1613', '1518', '623', '365', '5211', '1889', '879', '3401', '1347', '4010', '5168', '1354', '1100', '1572', '859', '2422', '609', '5959', '520', '4192', '5990', '83', '2197', '4850', '1408', '1327', '2744', '4382', '5982', '1304', '1108', '5393', '734', '6015', '4802', '4064', '5597', '440', '1366', '3822', '2282', '2114', '3768', '3218', '5922', '1623', '1780', '2539', '4298', '3029', '4177', '30

In [83]:
item_vec_dict = load_json(r'C:\Projects\TAU\DeepLearning\Open-World-Knowledge-Augmented-Recommendation\data\ml-1m\proc_data\bert_avg_augment.item')
item_vec_dict.keys()

dict_keys(['1', '2', '4', '3', '5', '6', '7', '9', '8', '11', '10', '13', '12', '15', '14', '16', '17', '18', '20', '19', '22', '21', '23', '24', '25', '26', '27', '29', '30', '32', '31', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '52', '54', '55', '56', '57', '58', '59', '61', '62', '63', '65', '66', '67', '68', '69', '70', '71', '74', '73', '76', '75', '78', '77', '79', '81', '80', '83', '82', '85', '84', '86', '87', '88', '90', '89', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '112', '111', '114', '113', '115', '116', '117', '118', '120', '119', '121', '122', '123', '124', '126', '125', '128', '127', '130', '129', '131', '132', '133', '134', '136', '135', '137', '139', '138', '141', '140', '142', '144', '143', '145', '146', '147', '148', '149', '150', '151', '152', '153', '154', '156', '155', '157', '158', '159', '160', '161', '163', '162'

## Extracting metadata for LLM 

In [84]:
klg_path = r'C:\Projects\TAU\DeepLearning\Open-World-Knowledge-Augmented-Recommendation\data\ml-1m\knowledge'
hist_loader, hist_idxes, item_loader, item_idxes = get_text_data_loader(klg_path, 2)

chatgpt.hist 1 Given a male user who is aged 35-44 and an executive/managerial, this user's movie viewing history over time is listed below. Bridge on the River Kwai, The (1957), 5 stars; Chinatown (1974), 4 stars; Duck Soup (1933), 5 stars; Toy Story (1995), 5 stars; Arsenic and Old Lace (1944), 4 stars; Dances with Wolves (1990), 5 stars; Cool Hand Luke (1967), 5 stars; Young Frankenstein (1974), 5 stars; High Noon (1952), 5 stars; Rear Window (1954), 5 stars; Being There (1979), 5 stars; Some Like It Hot (1959), 5 stars; Casablanca (1942), 5 stars; Wizard of Oz, The (1939), 5 stars; Gone with the Wind (1939), 5 stars; Moonstruck (1987), 5 stars; It's a Wonderful Life (1946), 5 stars; Mr. Smith Goes to Washington (1939), 5 stars; 39 Steps, The (1935), 5 stars; Malcolm X (1992), 5 stars; Witness (1985), 4 stars; Animal House (1978), 4 stars; Do the Right Thing (1989), 5 stars; Frankenstein (1931), 5 stars; Modern Times (1936), 4 stars; War of the Worlds, The (1953), 5 stars; Hollywood

In [85]:
key = '4810'
index = hist_idxes.index(key)
print(index)

hist_loader.dataset[index]

4160


"Given a male user who is aged 18-24 and a programmer, this user's movie viewing history over time is listed below. Honey, I Shrunk the Kids (1989), 1 stars; Negotiator, The (1998), 4 stars; Terminator 2: Judgment Day (1991), 3 stars; Jumanji (1995), 2 stars; Heat (1995), 5 stars; Insider, The (1999), 5 stars; Last of the Mohicans, The (1992), 3 stars; Michael (1996), 3 stars; Batman (1989), 3 stars; Jackie Chan's First Strike (1996), 4 stars; Rocketeer, The (1991), 3 stars; Maltese Falcon, The (1941), 4 stars; Messenger: The Story of Joan of Arc, The (1999), 3 stars; 2001: A Space Odyssey (1968), 3 stars; Jewel of the Nile, The (1985), 2 stars; Romancing the Stone (1984), 3 stars; Saint, The (1997), 5 stars; Sleepy Hollow (1999), 2 stars; World Is Not Enough, The (1999), 2 stars; Anaconda (1997), 1 stars; Titanic (1953), 3 stars; Waterworld (1995), 2 stars; Lost in Space (1998), 3 stars; Teenage Mutant Ninja Turtles (1990), 1 stars; Payback (1999), 3 stars; Red Dawn (1984), 3 stars; R

In [86]:
key = '3489'
index = item_idxes.index(key)
print(index)

item_loader.dataset[index]

3423


"Hook is a 1991 fantasy adventure film directed by Steven Spielberg and starring Robin Williams, Dustin Hoffman, and Julia Roberts. The movie was produced by Amblin Entertainment and TriStar Pictures and was primarily filmed in California, USA.\n\nThe film is a loose adaptation of J.M. Barrie's classic play and novel, Peter Pan, and follows a grown-up Peter Pan, who has forgotten his past as the boy who never grew up, and is now a successful but unhappy corporate lawyer in London. When his children are kidnapped by the villainous Captain Hook, Peter must return to Neverland and reclaim his lost identity as a hero in order to save his family.\n\nThe film is classified as a fantasy adventure and family movie. The director, Steven Spielberg, is a renowned filmmaker known for his works in the adventure and sci-fi genres. The cast features some of the biggest names in Hollywood, including Robin Williams as Peter Pan, Dustin Hoffman as Captain Hook, and Julia Roberts as Tinkerbell.\n\nThe mo

In [87]:
def get_user_hist(user_vec,hist_idxes):
    user_vec = user_vec.tolist()
    for key, v in user_vec_dict.items():
        if v == user_vec:
            user_index = hist_idxes.index(key)
            return key, hist_loader.dataset[user_index]
    return None

def get_item_desc(item_vec, item_idxes):
    item_vec = item_vec.tolist()
    for key,v in item_vec_dict.items():
        if v == item_vec:
            item_index = item_idxes.index(key)
            return key, item_loader.dataset[item_index]
    return None

In [88]:
data_list = []
for i in mistake_indexes:
    data = test_set[i]
    user_idx, user_hist = get_user_hist(data['hist_aug_vec'], hist_idxes)
    item_idx, item_desc = get_item_desc(data['item_aug_vec'], item_idxes)
    label = labels[i]
    pred = preds[i][0]
    data_list.append({'user_idx': user_idx, 'user_hist': user_hist, 'item_idx': item_idx, 'item_desc': item_desc, 'label': label, 'pred': pred})
    
df_restuls = pd.DataFrame(data_list)
df_restuls


Unnamed: 0,user_idx,user_hist,item_idx,item_desc,label,pred
0,4810,Given a male user who is aged 18-24 and a prog...,3489,Hook is a 1991 fantasy adventure film directed...,1,0.185548
1,4810,Given a male user who is aged 18-24 and a prog...,1127,"""The Abyss"" is a 1989 science fiction film dir...",1,0.422078
2,4810,Given a male user who is aged 18-24 and a prog...,2115,Indiana Jones and the Temple of Doom is a 1984...,1,0.355027
3,4810,Given a male user who is aged 18-24 and a prog...,2140,The Dark Crystal is a classic fantasy adventur...,1,0.460649
4,4810,Given a male user who is aged 18-24 and a prog...,2143,Legend is a fantasy adventure film released in...,1,0.305864
5,4810,Given a male user who is aged 18-24 and a prog...,2161,The NeverEnding Story is a fantasy film direct...,1,0.432104
6,4810,Given a male user who is aged 18-24 and a prog...,2167,Blade is a 1998 action/horror film directed by...,1,0.349888
7,4810,Given a male user who is aged 18-24 and a prog...,1527,The Fifth Element is a science-fiction action ...,1,0.417244
8,4810,Given a male user who is aged 18-24 and a prog...,1552,Con Air is an American action-thriller film re...,1,0.199968
9,4810,Given a male user who is aged 18-24 and a prog...,3174,Man on the Moon is a biographical comedy-drama...,1,0.385676


## Encoding the text to vectors with BERT

In [89]:
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
from lm_encoding import inference

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased',  trust_remote_code=True)
encoding_model = AutoModel.from_pretrained('bert-base-uncased',  trust_remote_code=True).half() #.cuda()

### Validating the encoder

In [90]:
items = df_restuls['item_desc'].tolist()
item_loader = DataLoader(items[:4], 1, shuffle=False)

In [91]:
users = df_restuls['user_hist'].tolist()
user_loader = DataLoader(users[:1], 1, shuffle=False)

In [92]:
new_item_vec = inference(encoding_model, tokenizer, item_loader, 'bert', 'avg')
len(new_item_vec)

100%|██████████| 4/4 [06:10<00:00, 92.66s/it] 


4

In [93]:
new_user_vec = inference(encoding_model, tokenizer, user_loader, 'bert', 'avg')
len(new_user_vec)

100%|██████████| 1/1 [02:22<00:00, 142.92s/it]


1

In [94]:
n_item_vec = [i for i in new_item_vec]
len(n_item_vec)

4

In [95]:
for i in range(4):
    a = np.array(item_vec_dict[str(df_restuls['item_idx'][i])])
    b = np.array(new_item_vec[i])
    print(np.linalg.norm(a - b))

0.0038722404570977653
0.003397571428196932
0.004200493061843059
0.003428613497758212


In [96]:
for i in range(1):
    a = np.array(user_vec_dict[str(df_restuls['user_idx'][i])])
    b = np.array(new_user_vec[i])
    print(np.linalg.norm(a - b))

0.003666466025117689


## Creating a Simple Reflexion Mechanism

In [97]:
import datetime
import os
from langchain_groq import ChatGroq

from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_openai import ChatOpenAI
from langsmith import traceable

from collections import defaultdict
from typing import List

from langchain.output_parsers.openai_tools import (
    JsonOutputToolsParser,
    PydanticToolsParser,
)
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.prebuilt.tool_executor import ToolExecutor, ToolInvocation

from dotenv import load_dotenv
load_dotenv(r'C:\Projects\TAU\DeepLearning\Relfexion_explore\.env')


True

In [98]:
parser = JsonOutputToolsParser(return_id=True)

In [99]:
actor_prompt_template = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """You are expert researcher.
            Current time: {time}

            1. {first_instruction}
            2. Reflect and critique your answer. Be severe to maximize improvement.
            3. Recommend search queries to research information and improve your answer.""",
        ),
            MessagesPlaceholder(variable_name="messages"),
            ("system", "Answer the user's question above using the required format."),
        ]
).partial(
    time=lambda: datetime.datetime.now().isoformat(),
)


class Reflection(BaseModel):
    missing: str = Field(description="Critique of what is missing.")
    superfluous: str = Field(description="Critique of what is superfluous")


class AnswerQuestion(BaseModel):
    """Answer the question."""

    answer: str = Field(description="~250 word detailed answer to the question.")
    reflection: Reflection = Field(description="Your reflection on the initial answer.")
    search_queries: List[str] = Field(
        description="1-3 search queries for researching improvements to address the critique of your current answer."
    )

In [100]:
llm = ChatGroq(groq_api_key = os.getenv('GROQ_API_KEY'),model = 'llama3-8b-8192')
# llm = ChatOpenAI(api_key=os.getenv('OPENAI_API_KEY'), model='gpt-4-turbo-2024-04-09')

initial_answer_chain = actor_prompt_template.partial(
    first_instruction="Provide a detailed answer."
) | llm.bind_tools(tools=[AnswerQuestion], tool_choice="AnswerQuestion")
validator = PydanticToolsParser(tools=[AnswerQuestion])

In [101]:
class ResponderWithRetries:
    def __init__(self, runnable, validator):
        self.runnable = runnable
        self.validator = validator

    @traceable
    def respond(self, state: List[BaseMessage]):
        response = []
        for attempt in range(3):
            try:
                response = self.runnable.invoke({"messages": state})
                self.validator.invoke(response)
                return response
            except ValidationError as e:
                state = state + [HumanMessage(content=repr(e))]
        return response
    
first_responder = ResponderWithRetries(
    runnable=initial_answer_chain, validator=validator
)

In [102]:
df_restuls.iloc[0]['user_hist']

"Given a male user who is aged 18-24 and a programmer, this user's movie viewing history over time is listed below. Honey, I Shrunk the Kids (1989), 1 stars; Negotiator, The (1998), 4 stars; Terminator 2: Judgment Day (1991), 3 stars; Jumanji (1995), 2 stars; Heat (1995), 5 stars; Insider, The (1999), 5 stars; Last of the Mohicans, The (1992), 3 stars; Michael (1996), 3 stars; Batman (1989), 3 stars; Jackie Chan's First Strike (1996), 4 stars; Rocketeer, The (1991), 3 stars; Maltese Falcon, The (1941), 4 stars; Messenger: The Story of Joan of Arc, The (1999), 3 stars; 2001: A Space Odyssey (1968), 3 stars; Jewel of the Nile, The (1985), 2 stars; Romancing the Stone (1984), 3 stars; Saint, The (1997), 5 stars; Sleepy Hollow (1999), 2 stars; World Is Not Enough, The (1999), 2 stars; Anaconda (1997), 1 stars; Titanic (1953), 3 stars; Waterworld (1995), 2 stars; Lost in Space (1998), 3 stars; Teenage Mutant Ninja Turtles (1990), 1 stars; Payback (1999), 3 stars; Red Dawn (1984), 3 stars; R

In [103]:
def generate_prompt_from_df(df_restuls,idx):
    user_hist = df_restuls['user_hist'][idx]
    item_desc = df_restuls['item_desc'][idx]
    real_label = df_restuls['label'][idx]
    pred = df_restuls['pred'][idx]
    prompt = f"""   User {df_restuls['user_idx'][idx]} has the following history: {user_hist}.
                    Item {df_restuls['item_idx'][idx]} has the following description: {item_desc}.
                    A mistake was made in the model's prediction.
                    The real label is {real_label} and the model predicted {pred}.
                    How would you rephrase the user history and item description to improve the model's prediction?
                    use two lines at the end of user history and at the end of item description.
                    Remember that you new rephrasing should help the model to predict better next time but without overfitting"""
    return prompt

In [104]:
def generate_prompt_from_df(df_restuls,idx):
    user_hist = df_restuls['user_hist'][idx]
    item_desc = df_restuls['item_desc'][idx]
    real_label = df_restuls['label'][idx]
    pred = df_restuls['pred'][idx]

    prompt =    f"""Task:Rephrase the user history and item description to better match the prediction to the actual label. Assume that the model's prediction is either a match or a mismatch to the label, and modify the descriptions to improve the fit between user preferences (as inferred from the user history) and the characteristics of the movie (as described in the item description).

                Output Format:
                Revised User History:
                Try to rephrase the user history to better reflect the user's preferences. 
                You can adjust the length, tone, and content of the user history to better align with the user's likely interests.
                Finish the Revised User History with ;
                Revised Item Description:
                The item description should be rephrased to highlight aspects of the movie that are more aligned with the user's adjusted preferences.
                Key elements to focus on might include genre, notable performances, thematic elements, and any particular production features.
                Finish the Revised User Item Description with ;

                Here is the user history: {user_hist}.
                Here is the item description: {item_desc}.
                A mistake was made in the model's prediction.
                The real label is {real_label} and the model predicted {pred}.
                How would you rephrase the user history and item description to improve the model's prediction?

                """
    return prompt

In [105]:
example_question = generate_prompt_from_df(df_restuls,0)

In [106]:
example_question

"Task:Rephrase the user history and item description to better match the prediction to the actual label. Assume that the model's prediction is either a match or a mismatch to the label, and modify the descriptions to improve the fit between user preferences (as inferred from the user history) and the characteristics of the movie (as described in the item description).\n\n                Output Format:\n                Revised User History:\n                Try to rephrase the user history to better reflect the user's preferences. \n                You can adjust the length, tone, and content of the user history to better align with the user's likely interests.\n                Finish the Revised User History with ;\n                Revised Item Description:\n                The item description should be rephrased to highlight aspects of the movie that are more aligned with the user's adjusted preferences.\n                Key elements to focus on might include genre, notable performan

In [107]:
system = "You are a helpful assistant."
human = "{text}"
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])

chain = prompt | llm
initial = chain.invoke({"text": example_question})
initial.content

"Revised User History:\nAs a young programmer with a passion for action and adventure movies, I've developed a taste for high-octane films that transport me to new worlds. My movie history reveals a fondness for 80s and 90s classics, particularly those with memorable heroes, thrilling plots, and impressive production values. I appreciate strong performances, intricate costumes, and memorable soundtracks. Despite some weaker efforts, I crave movies that deliver on action, suspense, and emotional resonance. As a fan of fantasy and adventure, I'm always on the lookout for films that spark my imagination and keep me on the edge of my seat;\n\nRevised Item Description:\nHook is a 1991 fantasy adventure film that whisks viewers away to the magical world of Neverland. Directed by the legendary Steven Spielberg, this beloved classic follows Peter Pan, a successful but unhappy corporate lawyer who must reclaim his lost identity as a hero to save his family from the clutches of the villainous Ca

In [108]:
import re
revised_prompt = initial.content + ";"
# String containing the revised User History and Revised Item Description
# Extracting the revised User History
user_history = re.search(r"User History:\n(.*?)\n\n", revised_prompt, re.DOTALL)
if user_history:
    user_history = user_history.group(1).strip()

# Extracting the Revised Item Description
item_description = re.search(r"Item Description:\n(.*?)\n\n", revised_prompt, re.DOTALL)
if item_description:
    item_description = item_description.group(1).strip()

# Printing the extracted information
print("Revised User History:", user_history)
print("Revised Item Description:", item_description)


Revised User History: As a young programmer with a passion for action and adventure movies, I've developed a taste for high-octane films that transport me to new worlds. My movie history reveals a fondness for 80s and 90s classics, particularly those with memorable heroes, thrilling plots, and impressive production values. I appreciate strong performances, intricate costumes, and memorable soundtracks. Despite some weaker efforts, I crave movies that deliver on action, suspense, and emotional resonance. As a fan of fantasy and adventure, I'm always on the lookout for films that spark my imagination and keep me on the edge of my seat;
Revised Item Description: Hook is a 1991 fantasy adventure film that whisks viewers away to the magical world of Neverland. Directed by the legendary Steven Spielberg, this beloved classic follows Peter Pan, a successful but unhappy corporate lawyer who must reclaim his lost identity as a hero to save his family from the clutches of the villainous Captain 

In [109]:
# user_hist = "This user is a young adult male who enjoys action-packed movies from the 80s and 90s. He has a penchant for films with strong male protagonists and high production quality. He tends to dislike movies with weak plot or character development, romance-heavy films, and those with poor visual effects. His favorite genres include action, adventure, and fantasy, with a soft spot for movies with historical or fantastical elements."
# item_desc = "In this beloved fantasy adventure film, a grown-up Peter Pan must rediscover his forgotten past as the boy who never grew up. With stunning visual effects, intricate costume designs, and a memorable soundtrack, Hook is a timeless classic that has captivated audiences worldwide. The movie features an all-star cast, including Robin Williams, Dustin Hoffman, and Julia Roberts, and boasts impressive production values. While it received mixed reviews upon release, Hook remains a beloved classic that explores themes of identity, nostalgia, and the power of imagination."
data_l = DataLoader([user_history,item_description],21, shuffle=False)
new_llm_vec = inference(encoding_model, tokenizer, data_l, 'bert', 'avg')

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [01:12<00:00, 72.59s/it]


In [110]:
df_restuls

Unnamed: 0,user_idx,user_hist,item_idx,item_desc,label,pred
0,4810,Given a male user who is aged 18-24 and a prog...,3489,Hook is a 1991 fantasy adventure film directed...,1,0.185548
1,4810,Given a male user who is aged 18-24 and a prog...,1127,"""The Abyss"" is a 1989 science fiction film dir...",1,0.422078
2,4810,Given a male user who is aged 18-24 and a prog...,2115,Indiana Jones and the Temple of Doom is a 1984...,1,0.355027
3,4810,Given a male user who is aged 18-24 and a prog...,2140,The Dark Crystal is a classic fantasy adventur...,1,0.460649
4,4810,Given a male user who is aged 18-24 and a prog...,2143,Legend is a fantasy adventure film released in...,1,0.305864
5,4810,Given a male user who is aged 18-24 and a prog...,2161,The NeverEnding Story is a fantasy film direct...,1,0.432104
6,4810,Given a male user who is aged 18-24 and a prog...,2167,Blade is a 1998 action/horror film directed by...,1,0.349888
7,4810,Given a male user who is aged 18-24 and a prog...,1527,The Fifth Element is a science-fiction action ...,1,0.417244
8,4810,Given a male user who is aged 18-24 and a prog...,1552,Con Air is an American action-thriller film re...,1,0.199968
9,4810,Given a male user who is aged 18-24 and a prog...,3174,Man on the Moon is a biographical comedy-drama...,1,0.385676


In [111]:
l = []
cnt= 0
for item in test_set:
    l.append(item)
    cnt+=1
    if cnt>200:
        break
test_dataset = DictDataset(l)

In [112]:
test_dataset[4]['hist_aug_vec'][-10:]

tensor([-0.2448, -0.0048, -0.0591, -0.1279, -0.2581, -0.2496, -0.0857,  0.0883,
         0.3794,  0.0489])

In [113]:
test_dataset[4]['hist_aug_vec'] = 0
test_dataset[4]['hist_aug_vec'] = torch.tensor(new_llm_vec[0])
test_dataset[4]['item_aug_vec'] = 0
test_dataset[4]['item_aug_vec'] = torch.tensor(new_llm_vec[1])

In [114]:
test_dataset[4]['hist_aug_vec'][-10:]

tensor([-0.3096, -0.2128,  0.0754, -0.1779, -0.0845, -0.3159,  0.3977, -0.1024,
         0.0948, -0.1135])

In [115]:
test_loader = Data.DataLoader(dataset=test_dataset, batch_size=2, shuffle=False)
short_test = ShortDataLoader(test_loader, num_batches=100)
auc, ll, loss, eval_time, labels, preds = eval(inference_model, short_test)

In [116]:
for batch,data in enumerate(short_test):
    if batch==2:
        print(data)
        break

{'iid': tensor([398, 246]), 'aid': tensor([[2],
        [4]]), 'lb': tensor([1, 0]), 'hist_iid_seq': tensor([[394, 395, 173, 500, 447],
        [395, 173, 500, 447, 398]]), 'hist_aid_seq': tensor([[[5],
         [4],
         [4],
         [3],
         [4]],

        [[4],
         [4],
         [3],
         [4],
         [2]]]), 'hist_rate_seq': tensor([[4, 3, 4, 5, 1],
        [3, 4, 5, 1, 5]]), 'hist_seq_len': tensor([5, 5]), 'item_aug_vec': tensor([[-0.0222,  0.1312,  0.3264,  ...,  0.0266,  0.2008, -0.0326],
        [-0.1589,  0.1256,  0.0196,  ..., -0.0551,  0.1310, -0.0424]]), 'hist_aug_vec': tensor([[ 0.0822,  0.0660,  0.2781,  ..., -0.1024,  0.0948, -0.1135],
        [-0.4312, -0.0414,  0.1912,  ...,  0.0883,  0.3794,  0.0489]])}


In [117]:
labels[:5], preds[:5]

([0, 1, 1, 0, 1],
 [[0.16122116148471832],
  [0.7664511799812317],
  [0.6553716659545898],
  [0.09310935437679291],
  [0.46843433380126953]])