### Imports

In [1]:
import networkx as nx
import pandas as pd
from transformers import AutoModel,AutoTokenizer
import torch
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv,RGATConv
from torch_geometric.loader import NeighborLoader
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sklearn.preprocessing import LabelEncoder
import heapq
from collections import defaultdict
from nltk.metrics.distance import edit_distance
from pickle import load as pkl_load,dump as pkl_dump
import numpy as np
from random import sample as random_sample
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm


### Data preparation

In [2]:
mov_dataset = pd.read_csv("../datasets/final-ds.csv").map(lambda s: s.lower() if type(s) == str else s)
ogl_enc = LabelEncoder()
mov_dataset["original_language"] = ogl_enc.fit_transform(mov_dataset["original_language"])

In [3]:
with open("embeds.pkl","rb") as embfile: embs = pkl_load(embfile)
for k in embs: mov_dataset[k] = embs[k]

### Definition

In [4]:
class SageRecNet(torch.nn.Module):
    def __init__(self,input_dim,hidden_dim,output_dim,num_relations):
        super(SageRecNet,self).__init__()
        self.conv1 = SAGEConv(input_dim,hidden_dim)
        self.conv2 = SAGEConv(hidden_dim,output_dim)
    def forward(self,x,edge_index,edge_type=None):
        x = self.conv1(x,edge_index)
        x = torch.relu(x)
        return self.conv2(x,edge_index)
class RLRecommendationAgent(torch.nn.Module):
    def __init__(self,user_dim,input_dim,num_actions,device):
        super().__init__()
        self.device = device
        self.policy_net = torch.nn.Sequential(
            torch.nn.Linear(input_dim + user_dim,128),
            torch.nn.ReLU(),
            torch.nn.Linear(128,1)
        ).to(device)
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(),lr=0.004)
        self.loss_fn = torch.nn.MSELoss()

    def forward(self,user_context,item_embeddings):
        expanded_context = user_context.unsqueeze(0).expand(item_embeddings.size(0),-1)
        combined = torch.cat([expanded_context,item_embeddings],dim=1)
        return self.policy_net(combined)
    def state_dict(self):
        return {
            "policy_net": self.policy_net.state_dict(),
            "optimizer": self.optimizer.state_dict()
        }
    def load_state_dict(self,state_dict):
        self.policy_net.load_state_dict(state_dict["policy_net"])
        self.optimizer.load_state_dict(state_dict["optimizer"])

In [5]:
class GraphRecommender:
    def __init__(self,dataset,gnn_params=None,use_clusters=True,device=None):
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dataset = dataset
        self.size = len(dataset["title"])
        cn = 12
        self.weights = {
            "is_adult": 1.0,"original_language": 2.0,
            "overview": 3.0,"tagline": 2.0,
            "genres": 4.0,"year": 1.5,
            "keywords":3.5
        }
        self.gnn_params = {
            "hidden_dim": 128,
            "num_classes": cn,
            "num_relations": 4,
            "learning_rate": 0.006,
            "weight_decay": 5e-4,
            "epochs": 64
        }
        self.title_to_index = {self.dataset["title"][idx].lower(): idx for idx in range(self.size)}
        self.le = LabelEncoder()
        scalar_features = np.array([self._get_scalar_features(i) for i in range(self.size)])
        text_features = []
        for feature in ["overview","tagline","genres"]:
            feat_array = np.stack(self.dataset[feature])
            if feat_array.ndim == 3: feat_array = feat_array.squeeze(axis=1)
            text_features.append(feat_array)
        text_features = np.hstack(text_features)

        kmeans = KMeans(n_clusters=cn)
        self.dataset.loc[:,"labels"] = kmeans.fit_predict(
            np.hstack([scalar_features,text_features])
        )
        self.gnn_params["num_classes"] = cn
        self._build_graph()
        self._init_gnn()
        if gnn_params: self.gnn_params.update(gnn_params)
        self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
        self.embmodel = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2").to(self.device)
        with torch.no_grad(): samp_out = self.model(self.graph_data.x,self.graph_data.edge_index)
        self.rl_agent = RLRecommendationAgent(
            user_dim=self.gnn_params["hidden_dim"],
            input_dim=samp_out.size(1),
            num_actions=cn,
            device=self.device
        )
        self.user_feedback = defaultdict(list)
        self.reward_history = []
        self.epsilon = 0.2
    def _get_feature_embeddings(self,feature,batch_size=128):
        sentences = self.dataset[feature]
        embs = []
        for i in range(0,len(sentences),batch_size):
            batch = sentences[i:i + batch_size]
            inputs = self.tokenizer(
                batch,return_tensors="pt",
                padding=True,truncation=True,max_length=256
            ).to(self.device)
            with torch.no_grad():
                outputs = self.embmodel(**inputs)
                batch_embs = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
            embs.extend(batch_embs)
        return np.array(embs)
    def _build_graph(self,thres=0.35,neighbors=70):
        self.graph = nx.Graph()
        scalar_features = np.array([self._get_scalar_features(i) for i in range(self.size)])
        text_features = []
        for feature in ["overview","tagline","genres"]:
            feat_array = np.stack(self.dataset[feature])
            if feat_array.ndim == 3: feat_array = feat_array.squeeze(axis=1)
            text_features.append(feat_array)
        text_features = np.hstack(text_features)
        all_features = np.hstack([scalar_features,text_features])
        all_features = all_features / np.linalg.norm(all_features,axis=1,keepdims=True)
        nbrs = NearestNeighbors(n_neighbors=neighbors,metric="cosine",n_jobs=4).fit(all_features)
        distances,indices = nbrs.kneighbors(all_features)
        similarities = 1 - distances
        edges = []
        for i in range(self.size):
            for j,sim in zip(indices[i],similarities[i]):
                if i != j and sim > thres: edges.append((i,j,sim))
        self.graph.add_weighted_edges_from(edges)
        edge_index = torch.tensor(list(self.graph.edges)).t().contiguous().to(self.device)
        self.graph_data = Data(
            x=self._get_node_features().clone().detach().to(self.device),
            edge_index=edge_index,
            y=torch.tensor(self.dataset["labels"].values,dtype=torch.long).to(self.device)
        )
        train_size = int(1.0 * self.size)
        train_mask = torch.zeros(self.size,dtype=torch.bool)
        train_mask[torch.randperm(self.size)[:train_size]] = True
        self.graph_data.train_mask = train_mask
    def update_cluster_labels(self):
        self._node_clustering()
        self.dataset["labels"] = self.le.fit_transform(self.cluster_labels)
        self.gnn_params["num_classes"] = len(self.le.classes_)
    def _get_scalar_features(self,idx):
        return np.array([
            self.dataset["is_adult"][idx] * self.weights["is_adult"],
            self.dataset["original_language"][idx] * self.weights["original_language"],
            self.dataset["year"][idx] * self.weights["year"]
        ])
    def _get_node_features(self):
        features = []
        for i in range(self.size):
            scalar = np.expand_dims(self._get_scalar_features(i),axis=0)
            text = np.concatenate([
                self.dataset[feature][i].reshape(-1,1) * self.weights[feature]
                for feature in ["overview","tagline","genres"]
            ])
            combined = np.hstack([scalar,text.reshape(1,-1)])
            features.append(combined)
        return torch.nn.functional.normalize(
            torch.tensor(np.vstack(features),dtype=torch.float).to(self.device),
            p=2,dim=1
        )
    def _init_gnn(self):
        self.model = SageRecNet(
            input_dim=self.graph_data.x.size(1),
            hidden_dim=self.gnn_params["hidden_dim"],
            output_dim=self.gnn_params["num_classes"],
            num_relations=self.gnn_params["num_relations"]
        ).to(self.device)
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.gnn_params["learning_rate"],
            weight_decay=self.gnn_params["weight_decay"]
        )
    def _node_clustering(self,clusters=12,use_gnn_embeddings=True):
        if use_gnn_embeddings and not hasattr(self.model,"parameters"): raise ValueError("Train model before GNN-based clustering")
        if use_gnn_embeddings:
            with torch.no_grad():
                embeddings = self.model(
                    self.graph_data.x,
                    self.graph_data.edge_index,
                    # self.graph_data.edge_type
                ).cpu().numpy()
        else: embeddings = self.graph_data.x.cpu().numpy()
        kmeans = KMeans(n_clusters=clusters)
        self.cluster_labels = kmeans.fit_predict(embeddings)
    def _compute_heuristics(self,kn=7,max_weight=0.7,avg_weight=0.3):
        self.heuristics = {}
        for node in self.graph.nodes():
            neighbors = list(self.graph[node].items())
            if not neighbors:
                self.heuristics[node] = 0
                continue
            similarities = [data["weight"] for _,data in neighbors]
            max_sim = max(similarities)
            top_similarities = sorted(similarities,reverse=True)[:kn]
            avg_sim = sum(top_similarities) / len(top_similarities)
            self.heuristics[node] = (max_sim * max_weight) + (avg_sim * avg_weight)
    def train(self,batch_sz=256):
        train_loader = NeighborLoader(
            self.graph_data,
            num_neighbors=[20,10],
            batch_size=batch_sz,
            shuffle=True,
            input_nodes=self.graph_data.train_mask
        )
        loss_hist = []
        for epoch in range(self.gnn_params["epochs"]):
            total_loss = 0
            self.model.train()
            for batch in train_loader:
                self.optimizer.zero_grad()
                out = self.model(batch.x,batch.edge_index)
                loss = torch.nn.functional.cross_entropy(out[batch.train_mask],batch.y[batch.train_mask])
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
            avg_loss = total_loss / len(train_loader)
            print(f"Epoch {epoch+1}/{self.gnn_params['epochs']} | Loss: {avg_loss:.4f}")
            if loss_hist and abs(loss_hist[-1] - loss.item()) < 0.001:
                print(f"Converged at epoch {epoch+1}")
                break
            loss_hist.append(loss.item())
        self.update_cluster_labels()
    def _title_match(self,qtitle,thres=4):
            query = qtitle.strip().lower()
            min_dist = float("inf")
            best_idx = -1
            for idx,title in enumerate(self.dataset["title"]):
                title_clean = str(title).strip().lower()
                dist = edit_distance(query,title_clean)
                if dist < min_dist:
                    min_dist = dist
                    best_idx = idx
            return best_idx if min_dist <= thres else -1
    def gnn_recommendations(self,prompt_titles,k=4):
        self.model.eval()
        if not hasattr(self,"cluster_labels"): self.update_cluster_labels()
        input_indices = []
        for title in prompt_titles:
            idx = self._title_match(title)
            if idx == -1:
                print(f"'{title}' not found in dataset")
                continue
            input_indices.append(idx)
        if not input_indices: return []
        with torch.no_grad(): embeddings = self.model(self.graph_data.x,self.graph_data.edge_index)
        combined_sims = torch.zeros(len(embeddings)).to(self.device)
        for idx in input_indices:
            sims = torch.nn.functional.cosine_similarity(embeddings[idx].unsqueeze(0),embeddings)
            cluster_mask = self.cluster_labels == self.cluster_labels[idx]
            sims[~cluster_mask] *= 0.3
            combined_sims += sims
        combined_sims /= len(input_indices)
        combined_sims[input_indices] = float("-inf")
        top_indices = combined_sims.argsort(descending=True)[:k + len(input_indices)]
        final_recs = [int(i.item()) for i in top_indices if i not in input_indices][:k]
        return [self.dataset["title"].iloc[i] for i in final_recs[:k]]
    def a_star_recommendations(self,prompt_titles,k=4,max_depth=4,score_threshold=0.5,max_exploration=500):
        if hasattr(self,"cluster_labels"): self._node_clustering()  
        start_indices = [self._title_match(title.lower()) for title in prompt_titles if title.lower() in self.title_to_index]
        if not start_indices: return []
        if not hasattr(self,"heuristics"): self._compute_heuristics()
        node_scores = defaultdict(dict)
        exploration_count = 0
        queue = []
        for start_idx in start_indices:
            heapq.heappush(queue,(-self.heuristics[start_idx],start_idx,0.0,0,start_idx))
            node_scores[start_idx][start_idx] = 0.0
        while queue and exploration_count < max_exploration:
            if len(queue) > 1000:
                queue = heapq.nsmallest(1000//2,queue)
                heapq.heapify(queue)   
            current_f_neg,current_node,current_g,depth,origin_idx = heapq.heappop(queue)
            current_f = -current_f_neg
            if current_f < score_threshold or depth > max_depth:
                continue
            exploration_count += 1
            if current_node not in node_scores or \
            current_g > node_scores[current_node].get(origin_idx,-1):
                node_scores[current_node][origin_idx] = current_g
            for neighbor in self.graph.neighbors(current_node):
                edge_weight = self.graph[current_node][neighbor]["weight"]
                new_g = current_g + edge_weight
                new_f = new_g + self.heuristics[neighbor]

                if new_f >= score_threshold:
                    if (neighbor not in node_scores) or \
                    (new_g > node_scores[neighbor].get(origin_idx,-1)):
                        heapq.heappush(queue,(-new_f,neighbor,new_g,depth+1,origin_idx))
        scored_nodes = defaultdict(float)
        for node,origins in node_scores.items():
            if node in start_indices: continue
            avg_score = sum(origins.values()) / len(origins)
            scored_nodes[node] = avg_score
        sorted_nodes = sorted(scored_nodes.items(),key=lambda x: -x[1])
        final_recs = [node for node,score in sorted_nodes[:k + len(start_indices)] if node not in start_indices][:k]
        return [self.dataset["title"].iloc[idx] for idx in final_recs[:k]]
    def get_recommendations(self,titles,k=4,user_id="default"):
        gnn_recs = self.gnn_recommendations(titles,k*2)
        a_star_recs = self.a_star_recommendations(titles,k*2)
        candidates = list(set(gnn_recs + a_star_recs))
        user_context = self._get_user_context(user_id)
        final_recs = self._rl_select_action(user_context,candidates,k)
        self._update_rl_agent(user_id)
        return final_recs
    def update_user_feedback(self,user_id,movies,scores,feedback_type="explicit"):
        weights = {
            "explicit": 1.0,
            "implicit": 0.7,
            "inferred": 0.5
        }
        weighted_scores = [s * weights[feedback_type] for s in scores]
        feedback = {
            "movies": movies,
            "scores": weighted_scores,
            "timestamp": datetime.now(),
            "type": feedback_type
        }
        self.user_feedback[user_id].append(feedback)
        
    def _get_user_context(self,user_id):
        if user_id not in self.user_feedback: return torch.zeros(self.gnn_params["hidden_dim"]).to(self.device)
        all_embs = []
        with torch.no_grad():
            all_embeddings = self.model(self.graph_data.x,self.graph_data.edge_index)
            
        for feedback in self.user_feedback[user_id]:
            movie_indices = [self.title_to_index[m] for m in feedback["movies"]]
            movie_embeddings = all_embeddings[movie_indices]
            weighted_embs = movie_embeddings * torch.tensor(
                feedback["scores"],device=self.device).unsqueeze(1)
            all_embs.append(weighted_embs.mean(dim=0)) 
        return torch.stack(all_embs).mean(dim=0)
    def _rl_select_action(self,user_context,candidates,k=4):
        candidate_indices = [self.title_to_index[m] for m in candidates]
        with torch.no_grad():
            candidate_embeddings = self.model(
                self.graph_data.x,
                self.graph_data.edge_index
            )[candidate_indices]
        k = min(k,candidate_embeddings.shape[0])
        if np.random.random() < self.epsilon:
            selected = np.random.choice(candidate_embeddings.shape[0],k,replace=False)
        else:
            with torch.no_grad():
                q_values = self.rl_agent(user_context,candidate_embeddings).squeeze(1)
            selected = q_values.topk(k).indices.cpu().numpy().flatten()
        return [candidates[i] for i in selected]
    def _update_rl_agent(self,user_id):
        if user_id not in self.user_feedback: return
        batch_size = 32
        feedbacks = self.user_feedback[user_id][-1000:]
        if len(feedbacks) < batch_size:
            return
        batch = random_sample(feedbacks,batch_size)
        user_context = self._get_user_context(user_id)
        for feedback in batch:
            self.rl_agent.optimizer.zero_grad()
            movie_indices = [self.title_to_index[m] for m in feedback["movies"]]
            movie_embeddings = self.graph_data.x[movie_indices]
            predicted_q = self.rl_agent(user_context,movie_embeddings)
            
            target_q = torch.tensor(feedback["scores"],device=self.device,dtype=torch.float32)
            
            loss = self.rl_agent.loss_fn(predicted_q,target_q)
            loss.backward()
            self.rl_agent.optimizer.step()
        self.epsilon = max(0.05,self.epsilon * 0.995)
    def collect_feedback(self,recommendations,user_id):
        print("\nHow relevant are these recommendations? (0.0-1.0)")
        scores = []
        for movie in recommendations:
            while True:
                try:
                    score = float(input(f"{movie}: "))
                    if 0.0 <= score <= 1.0:
                        scores.append(score)
                        break
                    print("Please enter between 0.0 (bad) and 1.0 (perfect)")
                except ValueError:
                    print("Invalid number format")
        self.update_user_feedback(user_id,recommendations,scores)
    def save_model(self,flpath,user_data_path="../model/user_data.pkl"):
        if not flpath.endswith(".pkl"):
            flpath += ".pkl"
        with open(flpath,"wb") as mdlfile:
            pkl_dump({
                "gnn_state": self.model.state_dict(),
                "rl_state": self.rl_agent.state_dict(),
                "cluster_labels": self.cluster_labels,
                "epsilon": self.epsilon,
                "gnn_params": self.gnn_params
            },mdlfile)

        with open(user_data_path,"wb") as userfile:
            pkl_dump({
                "user_feedback": dict(self.user_feedback),
                "label_encoder": self.le
            },userfile)
    @classmethod
    def load_model(cls,flpath,dataset,user_data_path="../model/user_data.pkl",device=None):
        if not flpath.endswith(".pkl"):
            flpath += ".pkl"
        with open(flpath,"rb") as mdlfile: saved_data = pkl_load(mdlfile)
        recommender = cls(dataset,device=device)
        recommender.model.load_state_dict(saved_data["gnn_state"])
        recommender.rl_agent.load_state_dict(saved_data["rl_state"])
        recommender.cluster_labels = saved_data["cluster_labels"]
        recommender.epsilon = saved_data["epsilon"]
        recommender.gnn_params.update(saved_data["gnn_params"])
        try:
            with open(user_data_path,"rb") as userfile:
                user_data = pkl_load(userfile)
                recommender.user_feedback = defaultdict(list,user_data["user_feedback"])
                recommender.le = user_data["label_encoder"]
        except FileNotFoundError: print("No user data found,starting fresh")
        return recommender

### Training

In [6]:
acc_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
recommender = GraphRecommender(mov_dataset.iloc[:1000],device=acc_device)
recommender.train()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.dataset.loc[:,"labels"] = kmeans.fit_predict(
2025-02-02 15:25:04.275224: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738490104.310198   47624 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738490104.320837   47624 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-02 15:25:04.393596: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to u

Epoch 1/64 | Loss: 2.4505
Epoch 2/64 | Loss: 2.3172
Epoch 3/64 | Loss: 2.2303
Epoch 4/64 | Loss: 2.2262
Epoch 5/64 | Loss: 2.2156
Epoch 6/64 | Loss: 2.2088
Epoch 7/64 | Loss: 2.2059
Epoch 8/64 | Loss: 2.2056
Epoch 9/64 | Loss: 2.2039
Epoch 10/64 | Loss: 2.2044
Epoch 11/64 | Loss: 2.2020
Epoch 12/64 | Loss: 2.2053
Epoch 13/64 | Loss: 2.2014
Epoch 14/64 | Loss: 2.2021
Epoch 15/64 | Loss: 2.1987
Epoch 16/64 | Loss: 2.2015
Epoch 17/64 | Loss: 2.2003
Converged at epoch 17


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self.dataset["labels"] = self.le.fit_transform(self.cluster_labels)


### Execution

In [7]:
inp_titles = []
inp = input("Enter a title (leave blank to stop) >>> ")
inp_titles.append(inp)
while inp.strip() != "":
    inp = input("Enter a title >>> ").strip()
    if inp: inp_titles.append(inp)
print(inp_titles)
recs = recommender.get_recommendations(inp_titles)
recommender.collect_feedback(recs,user_id="current_user")
pers_recs = recommender.get_recommendations(inp_titles)
print(pers_recs)

['inception', 'the avengers']

How relevant are these recommendations? (0.0-1.0)
['godzilla: king of the monsters', 'black panther', 'the empire strikes back', 'the matrix resurrections']
