In [106]:
from utils.dataloader import DataLoader as myDataLoader
import torch
from torch.utils.data import DataLoader as torchDataLoader
from torch.utils.data import TensorDataset

import pandas as pd
import numpy as np
import networkx as nx
import torch.nn as nn

In [32]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Params settings

In [69]:
class Settings():
    batch_size = 64
    epochs = 20

    embedding_size = 64
    learning_rate = 0.003
    
    # 100k dataset
    num_users = 943
    num_items = 1682

    # Transformer encoder
    dropout_rate = 0
    num_heads = 4
    d_ff = 4
    num_blocks = 2


    negative_num = 99
    # checkpoint_path_user_task = './Checkpoint/user_task/'
    # checkpoint_path_item_task = './Checkpoint/item_task/'
    verbose = 1

    hidden_dim = 256
    user_epoch = 5
    item_epoch = 25

    second_user_epoch = 10
    second_item_epoch = 10

    third_user_epoch = 10
    third_item_epoch = 10

    train_user_dataset = './models/gnn_embedding/ml_gnn_ebd/initial_user_ebds.csv'
    train_item_dataset = './models/gnn_embedding/ml_gnn_ebd/initial_item_ebds.csv'
    valid_user_dataset = './models/gnn_embedding/ml_gnn_ebd/target_user_ebds.csv'
    valid_item_dataset = './models/gnn_embedding/ml_gnn_ebd/target_item_ebds.csv'

    dataset_size = '100k'

    # set device
    if torch.cuda.is_available():
        print("Using CUDA (Nvidia GPU)")
        device = torch.device('cuda')
    else:
        print("CUDA not available, using CPU")
        device = torch.device('cpu')


settings = Settings()

CUDA not available, using CPU


## Data loading and searching for 1st 2nd 3rd order neighbours

In [None]:
# load the target USER embedding
# ['userid', 'embedding']
initial_user_embedding_path = "./models/gnn_embedding/ml_gnn_ebd/initial_user_ebds.csv"
initial_item_embedding_path = "./models/gnn_embedding/ml_gnn_ebd/initial_item_ebds.csv"
target_user_embedding_path = "./models/gnn_embedding/ml_gnn_ebd/target_user_ebds.csv"
target_item_embedding_path = "./models/gnn_embedding/ml_gnn_ebd/target_item_ebds.csv"

initial_user_embedding_df = pd.read_csv(initial_user_embedding_path)
initial_item_embedding_df = pd.read_csv(initial_item_embedding_path)  
target_user_embedding_df = pd.read_csv(target_user_embedding_path)
target_item_embedding_df = pd.read_csv(target_item_embedding_path)

# load the rating 
# ['userId', 'movieId', 'rating', 'timestamp']
# ratings_df = pd.read_csv("./data/ml-100k/u.data", sep="\t",header=None, names=["user", "item", "rating", "timestamp"])
from utils.dataloader import DataLoader
from utils.data_split import train_test_split
movie_data = DataLoader(size="100k")
data = movie_data.load_ratings()
ratings_df, test_set = train_test_split(data)

user_ids = list(initial_user_embedding_df['user'].unique())
item_ids = list(initial_item_embedding_df['item'].unique())
ratings_df = ratings_df[ratings_df['user'].isin(user_ids) & ratings_df['item'].isin(item_ids)]


------

In [51]:
def build_user_item_graph(df):
    G = nx.Graph()
    
    for _, row in df.iterrows():
        user_node = f'u_{row["user"]}'
        item_node = f'i_{row["item"]}'
        G.add_edge(user_node, item_node)
    
    return G

G = build_user_item_graph(ratings_df)
print(G.number_of_nodes())
print(G.number_of_edges())

1275
55275


In [None]:
def get_neighbors(graph, node):
    first_order = list(graph.neighbors(node))
    
    # 2nd order neighbors
    second_order = []
    for first_neighbor in first_order:
        second_order.extend(list(graph.neighbors(first_neighbor)))
    second_order = list(set(second_order) - set(first_order) - {node})
    
    # 3rd order neighbors
    third_order = []
    for second_neighbor in second_order:
        third_order.extend(list(graph.neighbors(second_neighbor)))
    third_order = list(set(third_order) - set(first_order) - set(second_order) - {node})
    
    first_order = [int(n.split('_')[1]) for n in first_order]
    second_order = [int(n.split('_')[1]) for n in second_order]
    third_order = [int(n.split('_')[1]) for n in third_order]
    
    return [first_order, second_order, third_order]


def compute_user_neighbors(user_graph, target_user_ids, target_embeddings_df):
    data = []
    for user_id in target_user_ids:
        neighbors = get_neighbors(user_graph, f"u_{user_id}")
        embedding = target_embeddings_df.loc[target_embeddings_df['user'] == user_id, 'embedding'].values[0]
        data.append({
            'userid': user_id,
            '1st_order': neighbors[0],
            '2nd_order': neighbors[1],
            '3rd_order': neighbors[2],
            'oracle_embedding': embedding
        })

    return pd.DataFrame(data)

def compute_item_neighbors(item_graph, target_item_ids, target_embeddings_df):
    data = []
    for item_id in target_item_ids:
        neighbors = get_neighbors(item_graph, f"i_{item_id}")
        embedding = target_embeddings_df.loc[target_embeddings_df['item'] == item_id, 'embedding'].values[0]
        data.append({
            'itemid': item_id,
            '1st_order': neighbors[0],
            '2nd_order': neighbors[1],
            '3rd_order': neighbors[2],
            'oracle_embedding': embedding
        })

    return pd.DataFrame(data)

In [None]:

target_user_input_df = compute_user_neighbors(G, user_ids, target_user_embedding_df)
target_item_input_df = compute_item_neighbors(G, item_ids, target_item_embedding_df)

target_user_input_df.to_csv("./models/gnn_embedding/ml_gnn_ebd/gnn_user_input.csv", index=False)
target_item_input_df.to_csv("./models/gnn_embedding/ml_gnn_ebd/gnn_item_input.csv", index=False)


## Train First Embedding with 1rd, 2nd, 3rd order user item relationship

In [90]:
from models.gnn_embedding.GeneralGNN import GeneralGNN
from models.gnn_embedding.train_helper import train_first_order_task, train_second_order_task, train_third_order_task

In [78]:

init_user_embedding_path="./models/gnn_embedding/ml_gnn_ebd/initial_user_ebds.csv"
init_item_embedding_path="./models/gnn_embedding/ml_gnn_ebd/initial_item_ebds.csv"

model = GeneralGNN(name="GraphSAGE", settings=settings, init_user_embedding_path=init_user_embedding_path,init_item_embedding_path=init_item_embedding_path)

In [112]:
from ast import literal_eval

def csv_to_dataloader(path, task='user'):

    df = pd.read_csv(path)
    column_specs = {
        "id": torch.long,
        "1st_order": torch.long,         # 填充后的 1 阶邻居列表
        "2nd_order": torch.long,         # 填充后的 2 阶邻居列表
        "3rd_order": torch.long,         # 填充后的 3 阶邻居列表
        "oracle_embedding": torch.float32,  # 目标嵌入向量
    }

    tensor_list = []
    for col, dtype in column_specs.items():
        if col == "id":
            if task =='user':
                tensor = torch.tensor(df["userid"].values, dtype=dtype)
            else:
                tensor = torch.tensor(df["itemid"].values, dtype=dtype)
        elif col in ["1st_order", "2nd_order", "3rd_order"]:
            tensor = torch.tensor(df[col].apply(literal_eval).tolist(), dtype=dtype)
        elif col == "oracle_embedding":
            tensor = torch.tensor(df[col].apply(literal_eval).tolist(), dtype=dtype)
        else:
            raise ValueError(f"未支持的列: {col}")
        
        tensor_list.append(tensor)

    dataset = TensorDataset(*tensor_list)

    return dataset

In [113]:



train_user_dataset = csv_to_dataloader("./models/gnn_embedding/ml_gnn_ebd/gnn_user_input.csv","user")
train_item_dataset = csv_to_dataloader("./models/gnn_embedding/ml_gnn_ebd/gnn_item_input.csv","item")
valid_user_dataset = train_user_dataset.copy()
valid_item_dataset = train_item_dataset.copy()

# Create DataLoaders
train_user_loader = torchDataLoader(train_user_dataset, batch_size=settings.batch_size, shuffle=True)
valid_user_loader = torchDataLoader(valid_user_dataset, batch_size=settings.batch_size, shuffle=False)

train_item_loader = torchDataLoader(train_item_dataset, batch_size=settings.batch_size, shuffle=True)
valid_item_loader = torchDataLoader(valid_item_dataset, batch_size=settings.batch_size, shuffle=False)

ValueError: expected sequence of length 344 at dim 1 (got 373)

In [105]:
for batch in train_user_loader:
    target_ids, support_1st, support_2nd, support_3rd, oracle_embeddings = batch
    print(support_1st)


RuntimeError: stack expects each tensor to be equal size, but got [92] at entry 0 and [94] at entry 1

In [None]:
# Train for user tasks
num_epochs = settings.epochs
device = settings.device

print("Training user tasks...")
print(" -> Training 1st order user tasks...")
train_first_order_task(
    model=model,
    train_loader=train_user_loader,
    valid_loader=valid_user_loader,
    epochs=num_epochs,
    device=device,
    task="user",
)

print(" -> Training 2nd order user tasks...")
train_second_order_task(
    model=model,
    train_loader=train_user_loader,
    valid_loader=valid_user_loader,
    epochs=num_epochs,
    device=device,
    task="user",
)

print(" -> Training 3rd order user tasks...")
train_third_order_task(
    model=model,
    train_loader=train_user_loader,
    valid_loader=valid_user_loader,
    epochs=num_epochs,
    device=device,
    task="user",
)

# Train for item tasks
print("Training item tasks...")
print(" -> Training 1st order item tasks...")
train_first_order_task(
    model=model,
    train_loader=train_item_loader,
    valid_loader=valid_item_loader,
    epochs=num_epochs,
    device=device,
    task="item",
)

print(" -> Training 2nd order item tasks...")
train_second_order_task(
    model=model,
    train_loader=train_item_loader,
    valid_loader=valid_item_loader,
    epochs=num_epochs,
    device=device,
    task="item",
)

print(" -> Training 3rd order item tasks...")
train_third_order_task(
    model=model,
    train_loader=train_item_loader,
    valid_loader=valid_item_loader,
    epochs=num_epochs,
    device=device,
    task="item",
)

print("Training completed.")

Training user tasks...
 -> Training 1st order user tasks...


ValueError: malformed node or string: [274, 126, 198, 670, 736, 53, 66, 52, 224, 344, 14, 179, 945, 185, 347, 550, 303, 58, 51, 727, 744, 509, 959, 685, 476, 27, 517, 628, 707, 280, 576, 252, 134, 159, 1016, 208, 1008, 77, 792, 284, 936, 481, 672, 213, 1017, 728, 97, 117, 271, 1067, 572, 1197, 393, 321, 291, 286, 183, 923, 25, 402, 20, 172, 875, 129, 330, 38, 187, 60, 228, 610, 629, 717, 778, 739, 806, 880, 433, 64, 740, 239, 92, 86, 354, 44, 234, 529, 294, 1221, 293, 167, 346, 654, 475, 319, 772, 19, 326, 612, 979, 50, 547, 694, 831, 216, 161, 221, 604, 518, 684, 800, 190, 345, 385, 81, 306, 87, 324, 249, 750, 165, 257, 387, 215, 265, 207, 152, 302, 919, 275, 89, 770, 921, 188, 405, 887, 974, 425, 248, 1, 483, 396, 195, 65, 332, 1170, 662, 125, 258, 813, 371, 855, 311, 582, 882, 845, 328, 282, 403, 181, 203, 212, 591, 135, 724, 729, 233, 270, 325, 789, 55, 209, 520, 655, 287, 238, 251, 1267, 644, 301, 79, 171, 356, 709, 250, 781, 28, 642, 214, 939, 9, 204, 5, 458, 11, 467, 240, 636, 1022, 428, 578, 178, 1014, 735, 423, 511, 315, 211, 276, 896, 751, 796, 144, 242, 581, 82, 186, 340, 128, 255, 23, 900, 47, 312, 1063, 902, 118, 566, 658, 942, 31, 657, 874, 514, 411, 742, 12, 127, 640, 160, 785, 262, 761, 723, 469, 1101, 313, 200, 660, 1135, 464, 121, 246, 963, 223, 220, 741, 218, 650, 7, 219, 503, 995, 568, 584, 638, 480, 686, 731, 722, 504, 676, 462, 451, 76, 43, 844, 15, 176, 775, 674, 607, 295, 304, 1018, 699, 132, 435, 507, 191, 502, 226, 972, 327, 712, 1044, 59, 479, 378, 448, 515, 1090, 197, 1136, 639, 944, 471, 372, 111, 357, 603, 137, 298, 1118, 196, 559, 716, 22, 762, 57, 222, 1011, 673, 847, 961, 956, 153, 45, 528, 382, 367, 690, 1142, 815, 307, 410, 192, 4, 155, 558, 143, 21, 69, 692, 872]

ValueError: malformed node or string: [274, 126, 198, 670, 736, 53, 66, 52, 224, 344, 14, 179, 945, 185, 347, 550, 303, 58, 51, 727, 744, 509, 959, 685, 476, 27, 517, 628, 707, 280, 576, 252, 134, 159, 1016, 208, 1008, 77, 792, 284, 936, 481, 672, 213, 1017, 728, 97, 117, 271, 1067, 572, 1197, 393, 321, 291, 286, 183, 923, 25, 402, 20, 172, 875, 129, 330, 38, 187, 60, 228, 610, 629, 717, 778, 739, 806, 880, 433, 64, 740, 239, 92, 86, 354, 44, 234, 529, 294, 1221, 293, 167, 346, 654, 475, 319, 772, 19, 326, 612, 979, 50, 547, 694, 831, 216, 161, 221, 604, 518, 684, 800, 190, 345, 385, 81, 306, 87, 324, 249, 750, 165, 257, 387, 215, 265, 207, 152, 302, 919, 275, 89, 770, 921, 188, 405, 887, 974, 425, 248, 1, 483, 396, 195, 65, 332, 1170, 662, 125, 258, 813, 371, 855, 311, 582, 882, 845, 328, 282, 403, 181, 203, 212, 591, 135, 724, 729, 233, 270, 325, 789, 55, 209, 520, 655, 287, 238, 251, 1267, 644, 301, 79, 171, 356, 709, 250, 781, 28, 642, 214, 939, 9, 204, 5, 458, 11, 467, 240, 636, 1022, 428, 578, 178, 1014, 735, 423, 511, 315, 211, 276, 896, 751, 796, 144, 242, 581, 82, 186, 340, 128, 255, 23, 900, 47, 312, 1063, 902, 118, 566, 658, 942, 31, 657, 874, 514, 411, 742, 12, 127, 640, 160, 785, 262, 761, 723, 469, 1101, 313, 200, 660, 1135, 464, 121, 246, 963, 223, 220, 741, 218, 650, 7, 219, 503, 995, 568, 584, 638, 480, 686, 731, 722, 504, 676, 462, 451, 76, 43, 844, 15, 176, 775, 674, 607, 295, 304, 1018, 699, 132, 435, 507, 191, 502, 226, 972, 327, 712, 1044, 59, 479, 378, 448, 515, 1090, 197, 1136, 639, 944, 471, 372, 111, 357, 603, 137, 298, 1118, 196, 559, 716, 22, 762, 57, 222, 1011, 673, 847, 961, 956, 153, 45, 528, 382, 367, 690, 1142, 815, 307, 410, 192, 4, 155, 558, 143, 21, 69, 692, 872]


## Movie Data aggragation

# Inference Part

In [17]:
full_user_embedding_init_path = "./models/gnn_embedding/ml_gnn_ebd/full_user_init_ebds.csv"
full_item_embedding_init_path = "./models/gnn_embedding/ml_gnn_ebd/full_user_item_ebds.csv"

full_user_init_embedding = pd.read_csv(full_user_embedding_init_path)
full_item_init_embedding = pd.read_csv(full_item_embedding_init_path)  

In [21]:
full_user_init_embedding.head(),full_item_init_embedding.head()

(   user                                          embedding
 0   196  [-0.013786688446998596, 0.03217056393623352, -...
 1   186  [0.05683184415102005, 0.03737185150384903, 0.0...
 2    22  [-0.05237126722931862, -0.026388995349407196, ...
 3   244  [-0.04823482781648636, -0.06648404896259308, 0...
 4   166  [-0.05344872921705246, -0.0514201745390892, 0....,
    item                                          embedding
 0   242  [0.030565787106752396, 0.025797907263040543, -...
 1   302  [-0.04676782712340355, -0.029111390933394432, ...
 2   377  [0.03179219737648964, 0.026234377175569534, -0...
 3    51  [-0.04509979113936424, 0.054844241589307785, -...
 4   346  [0.03948422148823738, 0.024974819272756577, 0....)

In [23]:
full_user_init_embedding['embedding']

0      [-0.013786688446998596, 0.03217056393623352, -...
1      [0.05683184415102005, 0.03737185150384903, 0.0...
2      [-0.05237126722931862, -0.026388995349407196, ...
3      [-0.04823482781648636, -0.06648404896259308, 0...
4      [-0.05344872921705246, -0.0514201745390892, 0....
                             ...                        
938    [0.037584856152534485, 0.060177914798259735, -...
939    [0.040449537336826324, -0.03403152897953987, 0...
940    [-0.02798338606953621, 0.07018252462148666, -0...
941    [0.06474433094263077, 0.003956958651542664, -0...
942    [0.06258333474397659, -0.021894250065088272, -...
Name: embedding, Length: 943, dtype: object

In [56]:
full_user_ids = ratings_df['user'].unique()
full_item_ids = ratings_df['item'].unique()

In [57]:
full_user_graph = nx.Graph()
full_user_graph.add_nodes_from(full_user_ids, bipartite="user")
full_user_graph.add_nodes_from(ratings_df['item'].unique(), bipartite="item")
full_user_graph.add_edges_from(zip(ratings_df['user'], ratings_df['item']))

full_item_graph = nx.Graph()
full_item_graph.add_nodes_from(full_item_ids, bipartite="item")
full_item_graph.add_nodes_from(ratings_df['user'].unique(), bipartite="user")
full_item_graph.add_edges_from(zip(ratings_df['item'], ratings_df['user']))


print("Nodes in user_graph:", full_user_graph.nodes())
print("Edges in user_graph:", full_user_graph.edges())
print("Nodes in item_graph:", full_item_graph.nodes())
print("Edges in item_graph:", full_item_graph.edges())

Nodes in user_graph: [196, 186, 22, 244, 166, 298, 115, 253, 305, 6, 62, 286, 200, 210, 224, 303, 122, 194, 291, 234, 119, 167, 299, 308, 95, 38, 102, 63, 160, 50, 301, 225, 290, 97, 157, 181, 278, 276, 7, 10, 284, 201, 287, 246, 242, 249, 99, 178, 251, 81, 260, 25, 59, 72, 87, 42, 292, 20, 13, 138, 60, 57, 223, 189, 243, 92, 241, 254, 293, 127, 222, 267, 11, 8, 162, 279, 145, 28, 135, 32, 90, 216, 250, 271, 265, 198, 168, 110, 58, 237, 94, 128, 44, 264, 41, 82, 262, 174, 43, 84, 269, 259, 85, 213, 121, 49, 155, 68, 172, 19, 268, 5, 80, 66, 18, 26, 130, 256, 1, 56, 15, 207, 232, 52, 161, 148, 125, 83, 272, 151, 54, 16, 91, 294, 229, 36, 70, 14, 295, 233, 214, 192, 100, 307, 297, 193, 113, 275, 219, 218, 123, 158, 302, 23, 296, 33, 154, 77, 270, 187, 170, 101, 184, 112, 133, 215, 69, 104, 240, 144, 191, 61, 142, 177, 203, 21, 197, 134, 180, 236, 263, 109, 64, 114, 239, 117, 65, 137, 257, 111, 285, 96, 116, 73, 221, 235, 164, 281, 182, 129, 45, 131, 230, 126, 231, 280, 288, 152, 217, 79,

In [58]:
len(full_user_ids),len(full_item_ids)

(943, 1682)

In [59]:
inf_user_input_df = compute_user_neighbors(full_user_graph, full_user_ids, full_user_init_embedding)
inf_item_input_df = compute_item_neighbors(full_item_graph, full_item_ids, full_item_init_embedding)

In [78]:
inf_user_input_df.head(),inf_item_input_df.head()

(   userid                                          1st_order  \
 0     196  [1, 514, 7, 8, 11, 12, 13, 18, 530, 533, 23, 5...   
 1     186  [1, 5, 6, 7, 10, 12, 13, 14, 18, 20, 22, 25, 3...   
 2      22  [1, 2, 4, 6, 7, 8, 10, 11, 13, 14, 16, 17, 18,...   
 3     244  [1, 3, 1028, 7, 521, 9, 13, 15, 1039, 17, 527,...   
 4     166  [1, 6, 7, 523, 13, 18, 535, 557, 565, 566, 60,...   
 
                                            2nd_order  \
 0  [1, 2, 3, 5, 6, 7, 8, 10, 11, 13, 14, 15, 16, ...   
 1  [1, 2, 3, 5, 6, 7, 8, 10, 11, 13, 14, 15, 16, ...   
 2  [1, 2, 3, 5, 6, 7, 8, 10, 11, 13, 14, 15, 16, ...   
 3  [1, 2, 3, 5, 6, 7, 8, 10, 11, 13, 14, 15, 16, ...   
 4  [1, 2, 3, 5, 6, 7, 8, 10, 11, 13, 14, 15, 16, ...   
 
                                            3rd_order  \
 0  [1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 1...   
 1  [1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 1...   
 2  [1, 2, 3, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 1...   
 3  [1, 2, 3, 4, 5, 7, 8, 9, 10, 11,

In [61]:
from utils.dataloader import DataLoader
from utils.data_split import train_test_split

movie_data = DataLoader(size="100k")
data = movie_data.load_ratings()
train_list, test_list = train_test_split(data)
ratings = pd.concat([train_list, test_list], axis=0, ignore_index=True)

user_list = ratings['user'].unique().tolist()
item_list = ratings['item'].unique().tolist()

user2idx = {user: idx for idx, user in enumerate(user_list)}
idx2user = {idx: user for user, idx in user2idx.items()}

item2idx = {item:idx for idx, item in enumerate(item_list)}
idx2item = {idx: item for item, idx in item2idx.items()}

In [62]:
import pandas as pd

def reorder_dataframe(df, user2idx,column):
    """
    Reorder a DataFrame based on the user2idx mapping.

    Parameters:
        df (pd.DataFrame): DataFrame with columns ['user', 'embedding'].
        user2idx (dict): Dictionary mapping users to their new indices.

    Returns:
        pd.DataFrame: Reordered DataFrame with new indices based on user2idx.
    """
    mapping = user2idx if column == 'user' else item2idx
    # Shuffle the rows of the DataFrame
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    
    # Map the 'user' column to the new index using user2idx
    df['new_index'] = df[column].map(mapping)
    
    # Sort the DataFrame by the new index
    df = df.sort_values(by='new_index').set_index('new_index')
    
    # Drop the 'new_index' column if you want only ['user', 'embedding']
    df = df[[column, 'embedding']]
    
    return df


In [63]:
full_user_init_embedding

Unnamed: 0,user,embedding
0,196,"[-0.013786688446998596, 0.03217056393623352, -..."
1,186,"[0.05683184415102005, 0.03737185150384903, 0.0..."
2,22,"[-0.05237126722931862, -0.026388995349407196, ..."
3,244,"[-0.04823482781648636, -0.06648404896259308, 0..."
4,166,"[-0.05344872921705246, -0.0514201745390892, 0...."
...,...,...
938,939,"[0.037584856152534485, 0.060177914798259735, -..."
939,936,"[0.040449537336826324, -0.03403152897953987, 0..."
940,930,"[-0.02798338606953621, 0.07018252462148666, -0..."
941,920,"[0.06474433094263077, 0.003956958651542664, -0..."


In [65]:
reordered_user_df = reorder_dataframe(full_user_init_embedding,user2idx,'user')
reordered_item_df = reorder_dataframe(full_item_init_embedding,user2idx,'item')

In [64]:
import os

output_folder = "ml_gnn_ebd"
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

reordered_user_df.to_csv(os.path.join(output_folder, "full_initial_user_ebds.csv"), index=False)
reordered_item_df.to_csv(os.path.join(output_folder, "full_initial_item_ebds.csv"), index=False)

In [69]:
import ast
# Assuming your DataFrame is named `df` and has columns "user" and "embedding" for users and "item" and "embedding" for items

# Extract user embeddings and convert to tensor
user_embeddings = torch.tensor(reordered_user_df['embedding'].apply(ast.literal_eval).tolist(), dtype=torch.float32)

# Save the user embeddings tensor
# torch.save(user_embeddings, 'full_user_embeddings.pt')
# print("User embeddings saved to 'user_embeddings.pt'")


item_embeddings = torch.tensor(reordered_item_df['embedding'].apply(ast.literal_eval).tolist(), dtype=torch.float32)

# Save the item embeddings tensor
# torch.save(item_embeddings, 'full_item_embeddings.pt')
# print("Item embeddings saved to 'item_embeddings.pt'")


In [72]:
reordered_user_df

Unnamed: 0_level_0,user,embedding
new_index,Unnamed: 1_level_1,Unnamed: 2_level_1
0,97,"[0.013621360063552856, 0.006962865591049194, -..."
1,266,"[0.04026229679584503, -0.03364080563187599, 0...."
2,811,"[0.05681542307138443, -0.0004775002598762512, ..."
3,24,"[-0.011574454605579376, -0.024325575679540634,..."
4,31,"[-0.05590250715613365, -0.02796020731329918, -..."
...,...,...
938,107,"[0.002338327467441559, 0.027677029371261597, 0..."
939,271,"[-0.0360712856054306, -0.04051988571882248, -0..."
940,861,"[-0.018500830978155136, 0.01041443645954132, -..."
941,436,"[-0.05177813768386841, -0.04470661282539368, -..."


In [75]:
new_idx2user = reordered_user_df['user'].to_dict()
new_user2idx = {user:idx for idx,user in new_idx2user.items()}
new_user2idx

{97: 0,
 266: 1,
 811: 2,
 24: 3,
 31: 4,
 281: 5,
 569: 6,
 260: 7,
 332: 8,
 324: 9,
 423: 10,
 468: 11,
 287: 12,
 894: 13,
 869: 14,
 639: 15,
 539: 16,
 500: 17,
 482: 18,
 335: 19,
 849: 20,
 771: 21,
 926: 22,
 40: 23,
 364: 24,
 765: 25,
 308: 26,
 445: 27,
 714: 28,
 705: 29,
 77: 30,
 818: 31,
 596: 32,
 372: 33,
 166: 34,
 756: 35,
 251: 36,
 883: 37,
 437: 38,
 240: 39,
 108: 40,
 68: 41,
 175: 42,
 159: 43,
 140: 44,
 755: 45,
 732: 46,
 307: 47,
 533: 48,
 695: 49,
 803: 50,
 64: 51,
 236: 52,
 766: 53,
 900: 54,
 87: 55,
 693: 56,
 724: 57,
 779: 58,
 517: 59,
 661: 60,
 588: 61,
 327: 62,
 219: 63,
 210: 64,
 657: 65,
 71: 66,
 929: 67,
 650: 68,
 469: 69,
 847: 70,
 200: 71,
 480: 72,
 448: 73,
 522: 74,
 676: 75,
 941: 76,
 137: 77,
 209: 78,
 626: 79,
 214: 80,
 573: 81,
 356: 82,
 622: 83,
 73: 84,
 478: 85,
 45: 86,
 892: 87,
 591: 88,
 614: 89,
 507: 90,
 267: 91,
 169: 92,
 333: 93,
 452: 94,
 310: 95,
 261: 96,
 607: 97,
 939: 98,
 823: 99,
 434: 100,
 328: 101,

In [67]:
model.reload_embedding()
infered_embeddings = model.forward()

NameError: name 'model' is not defined