### Imports

In [None]:
import networkx as nx
import pandas as pd
from transformers import AutoModel,AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
from nltk.metrics.distance import edit_distance,jaccard_distance
import torch
from torch_geometric.data import Data
from torch_geometric.nn import RGATConv,SAGEConv
from sklearn.cluster import KMeans
import numpy as np
from sklearn.preprocessing import normalize,LabelEncoder
import heapq
from collections import defaultdict
from pickle import load as pkl_load

### Data loading

https://www.kaggle.com/datasets/asaniczka/tmdb-movies-dataset-2023-930k-movies

In [2]:
mov_dataset = pd.read_csv("./datasets/final-ds.csv").map(lambda s: s.lower() if type(s) == str else s)

In [3]:
enc = LabelEncoder()
mov_dataset["original_language"] = enc.fit_transform(mov_dataset["original_language"])

In [4]:
with open("embeds.pkl","rb") as embfile: embs = pkl_load(embfile)
for k in embs: mov_dataset[k] = embs[k]

### Training

In [5]:
class GraphRecommender:
    def __init__(self, dataset, gnn_params=None, device=None):
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.dataset = dataset
        self.size = len(dataset["title"])
        self.weights = {
            "is_adult": 1.0, "original_language": 2.0,
            "overview": 3.0, "tagline": 2.0,
            "genres": 4.0, "year": 1.5,
            "keywords":3.5
        }
        self.gnn_params = {
            "hidden_dim": 32,
            "num_classes": 12,
            "num_relations": 4,
            "learning_rate": 0.006,
            "weight_decay": 5e-4,
            "epochs": 32
        }
        if gnn_params: self.gnn_params.update(gnn_params)
        self._build_graph()
        self._init_gnn()
        self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
        self.embmodel = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2").to(self.device)
    def _get_feature_embeddings(self, feature, batch_size=128):
        sentences = self.dataset[feature]
        embs = []
        for i in range(0, len(sentences), batch_size):
            batch = sentences[i:i + batch_size]
            inputs = self.tokenizer(
                batch, return_tensors="pt", 
                padding=True, truncation=True, max_length=256
            ).to(self.device)
            with torch.no_grad():
                outputs = self.embmodel(**inputs)
                batch_embs = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
            embs.extend(batch_embs)
        return np.array(embs)
    def _build_graph(self, thres=0.35):
        self.graph = nx.Graph()
        for i in range(self.size): self.graph.add_node(i, title=self.dataset["title"][i])
        for i in range(self.size):
            for j in range(i + 1, self.size):
                scalar_sim = cosine_similarity(
                    [self._get_scalar_features(i)],
                    [self._get_scalar_features(j)]
                )[0][0]
                text_sim = sum(
                    cosine_similarity(
                        [self.dataset[feature][i].flatten() * self.weights[feature]],
                        [self.dataset[feature][j].flatten() * self.weights[feature]]
                    )[0][0]
                    for feature in ["overview", "tagline", "genres"]
                )
                if (scalar_sim + text_sim) > thres: self.graph.add_edge(i, j, weight=scalar_sim + text_sim)
        edge_index = torch.tensor(list(self.graph.edges)).t().contiguous().to(self.device)
        edge_weight = torch.tensor(
            [self.graph[u][v]["weight"] for u, v in self.graph.edges],
            dtype=torch.float
        ).to(self.device)
        self.graph_data = Data(
            x=self._get_node_features(),
            edge_index=edge_index,
            edge_attr=torch.nn.functional.normalize(edge_weight, p=2, dim=0),
            y=torch.randint(0, self.gnn_params["num_classes"], (self.size,), dtype=torch.long).to(self.device),
            train_mask=torch.ones(self.size, dtype=torch.bool).to(self.device)
        )
        self.graph_data.edge_type = torch.randint(
            0, self.gnn_params["num_relations"], 
            (edge_index.size(1),), dtype=torch.long
        ).to(self.device)
    def _get_scalar_features(self, idx):
        return np.array([
            self.dataset["is_adult"][idx] * self.weights["is_adult"],
            self.dataset["original_language"][idx] * self.weights["original_language"],
            # self.dataset["year"][idx] * self.weights["year"]
        ])
    def _get_node_features(self):
        features = []
        for i in range(self.size):
            scalar = np.expand_dims(self._get_scalar_features(i), axis=0)
            text = np.concatenate([
                self.embs[feature][i] * self.weights[feature]
                for feature in ["overview", "tagline", "genres"]
            ])
            combined = np.hstack([scalar, text.reshape(1, -1)])
            features.append(combined)
        return torch.nn.functional.normalize(
            torch.tensor(np.vstack(features), dtype=torch.float).to(self.device),
            p=2, dim=1
        )
    def _init_gnn(self):
        self.model = SageRecNet(
            input_dim=self.graph_data.x.size(1),
            hidden_dim=self.gnn_params["hidden_dim"],
            output_dim=self.gnn_params["num_classes"],
            num_relations=self.gnn_params["num_relations"]
        ).to(self.device)
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.gnn_params["learning_rate"],
            weight_decay=self.gnn_params["weight_decay"]
        )
    def _node_clustering(self,clusters=8,use_gnn_embeddings=True):
        if use_gnn_embeddings:
            with torch.no_grad():
                embeddings = self.model(
                    self.graph_data.x,
                    self.graph_data.edge_index,
                    self.graph_data.edge_type
                ).cpu().numpy()
        else: embeddings = self.graph_data.x.cpu().numpy()
        kmeans = KMeans(n_clusters=clusters)
        self.cluster_labels = kmeans.fit_predict(embeddings)
    def _compute_heuristics(self,kn=7):
        # Use the maximum possible similarity as a base and then add a weighted factor based on the average similarity of neighbors
        self.heuristics = {}
        for node in self.graph.nodes():
            neighbors = list(self.graph[node].items())
            similarities = [data['weight'] for _, data in neighbors]
            top_similarities = sorted(similarities, reverse=True)[:kn]
            self.heuristics[node] = sum(top_similarities)/len(top_similarities) if neighbors else 0
    def train(self):
        loss_hist = []
        for epoch in range(self.gnn_params["epochs"]):
            self.model.train()
            self.optimizer.zero_grad()
            out = self.model(
                self.graph_data.x, 
                self.graph_data.edge_index,
                self.graph_data.edge_type
            )
            loss = torch.nn.functional.cross_entropy(
                out[self.graph_data.train_mask],
                self.graph_data.y[self.graph_data.train_mask]
            )
            loss.backward()
            self.optimizer.step()
            print(f'Epoch {epoch+1}/{self.gnn_params["epochs"]}, Loss: {loss.item():.4f}')
            if loss_hist and abs(loss_hist[-1] - loss.item()) < 0.01:
                print(f"Converged at epoch {epoch+1}")
                break
            loss_hist.append(loss.item())
    def gnn_recommendations(self, prompt_titles, k=4):
        self.model.eval()
        title_to_index = {self.dataset["title"][idx].lower(): idx for idx in range(self.size)}
        with torch.no_grad():
            embeddings = self.model(
                self.graph_data.x,
                self.graph_data.edge_index,
                self.graph_data.edge_type
            )
        results = {}
        for title in prompt_titles:
            idx = title_to_index.get(title.lower())
            if idx is None:
                continue
                
            sims = torch.nn.functional.cosine_similarity(
                embeddings[idx].unsqueeze(0),
                embeddings
            )
            top_k = sims.argsort(descending=True)[1:k+1]
            results[title] = [self.dataset["title"][i] for i in top_k.cpu().numpy()]
        return results
    def a_star_recommendations(self, prompt_titles, k=4, max_depth=3,score_threshold=0.5, max_exploration=500):
        title_to_index = {self.dataset["title"][idx].lower(): idx for idx in range(self.size)}
        start_indices = [title_to_index.get(title.lower()) for title in prompt_titles if title in title_to_index]
        if not start_indices: return []
        if not hasattr(self, 'heuristics'): self._compute_heuristics()
        node_scores = defaultdict(dict)
        exploration_count = 0
        queue = []
        for start_idx in start_indices:
            heapq.heappush(queue, (-self.heuristics[start_idx],start_idx,0.0,0,start_idx))
            node_scores[start_idx][start_idx] = 0.0
        while queue and exploration_count < max_exploration:
            current_f_neg, current_node, current_g, depth, origin_idx = heapq.heappop(queue)
            current_f = -current_f_neg
            exploration_count += 1
            if current_f < score_threshold or depth > max_depth: continue
            if current_node not in node_scores or \
            current_g > node_scores[current_node].get(origin_idx, -1): node_scores[current_node][origin_idx] = current_g
            for neighbor in self.graph.neighbors(current_node):
                edge_weight = self.graph[current_node][neighbor]['weight']
                new_g = current_g + edge_weight
                new_f = new_g + self.heuristics[neighbor]
                if new_f >= score_threshold:
                    if (neighbor not in node_scores) or \
                    (new_g > node_scores[neighbor].get(origin_idx, -1)):
                        heapq.heappush(queue, (-new_f, neighbor, new_g, depth+1, origin_idx))
        scored_nodes = []
        for node, origins in node_scores.items():
            if len(origins) < len(start_indices): continue
            min_score = min(origins.values())
            avg_score = sum(origins.values()) / len(origins)
            combined_score = avg_score * min_score
            scored_nodes.append((node, combined_score, min_score, avg_score))
        scored_nodes.sort(key=lambda x: (-x[1], -x[2], -x[3]))
        final_recs = [
            n for n, *scores in scored_nodes 
            if n not in start_indices
        ][:k]
        return [self.dataset['title'][idx] for idx in final_recs]
class SageRecNet(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_relations):
        super(SageRecNet, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, output_dim)

    def forward(self, x, edge_index, edge_type=None):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        return self.conv2(x, edge_index)

### Execution

In [None]:
recommender = GraphRecommender(mov_dataset)
recommender.train()