In [1]:
# Install required packages and setup
!pip install torch numpy tqdm

import os
import time
import torch
import argparse
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from scipy.spatial.distance import pdist, squareform
from collections import Counter

# Create data directory
!mkdir -p data

# Download MovieLens-1M dataset from the GitHub repository
!wget -O data/ml-1m.txt https://raw.githubusercontent.com/pmixer/SASRec.pytorch/main/python/data/ml-1m.txt

print("Dataset downloaded successfully!")

# Verify the dataset
if os.path.exists('data/ml-1m.txt'):
    with open('data/ml-1m.txt', 'r') as f:
        lines = f.readlines()
    print(f"Dataset loaded with {len(lines)} interactions")
    print("First few lines:")
    for i in range(min(5, len(lines))):
        print(lines[i].strip())
else:
    print("Error: Dataset not found!")

--2025-06-07 02:00:29--  https://raw.githubusercontent.com/pmixer/SASRec.pytorch/main/python/data/ml-1m.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9053831 (8.6M) [text/plain]
Saving to: ‘data/ml-1m.txt’


2025-06-07 02:00:30 (61.3 MB/s) - ‘data/ml-1m.txt’ saved [9053831/9053831]

Dataset downloaded successfully!
Dataset loaded with 999611 interactions
First few lines:
1 1
1 2
1 3
1 4
1 5


In [2]:
import requests

def prepare_and_download_dataset(dataset_name):
    """
    Downloads and prepares amazon-book or yelp2018 datasets for SASRec.
    """
    final_path = f'data/{dataset_name}.txt'
    if os.path.exists(final_path):
        print(f"Dataset '{dataset_name}' already prepared.")
        return

    print(f"Preparing dataset: {dataset_name}...")
    base_url = f"https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/master/data/{dataset_name}"

    all_interactions = []

    # Download and process train.txt and test.txt
    for file_part in ['train', 'test']:
        url = f"{base_url}/{file_part}.txt"
        try:
            res = requests.get(url)
            res.raise_for_status()

            lines = res.text.strip().split('\n')
            for line in lines:
                parts = line.strip().split()
                if not parts: continue

                # Convert to 1-based indexing for SASRec compatibility
                user_id = int(parts[0]) + 1
                item_ids = [int(i) + 1 for i in parts[1:]]

                for item_id in item_ids:
                    all_interactions.append((user_id, item_id))
        except requests.exceptions.RequestException as e:
            print(f"Error downloading {url}: {e}")
            return

    # Write to the final format required by SASRec
    with open(final_path, 'w') as f:
        for u, i in all_interactions:
            f.write(f'{u} {i}\n')

    print(f"Successfully prepared and saved dataset '{dataset_name}' to {final_path}")

In [3]:
# utils.py
import sys
import copy
import torch
import random
import numpy as np
from collections import defaultdict
from multiprocessing import Process, Queue

def build_index(dataset_name):
    ui_mat = np.loadtxt('data/%s.txt' % dataset_name, dtype=np.int32)

    n_users = ui_mat[:, 0].max()
    n_items = ui_mat[:, 1].max()

    u2i_index = [[] for _ in range(n_users + 1)]
    i2u_index = [[] for _ in range(n_items + 1)]

    for ui_pair in ui_mat:
        u2i_index[ui_pair[0]].append(ui_pair[1])
        i2u_index[ui_pair[1]].append(ui_pair[0])

    return u2i_index, i2u_index

# sampler for batch generation
def random_neq(l, r, s):
    t = np.random.randint(l, r)
    while t in s:
        t = np.random.randint(l, r)
    return t

def sample_function(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED):
    def sample(uid):
        # uid = np.random.randint(1, usernum + 1)
        while len(user_train[uid]) <= 1: uid = np.random.randint(1, usernum + 1)

        seq = np.zeros([maxlen], dtype=np.int32)
        pos = np.zeros([maxlen], dtype=np.int32)
        neg = np.zeros([maxlen], dtype=np.int32)
        nxt = user_train[uid][-1]
        idx = maxlen - 1

        ts = set(user_train[uid])
        for i in reversed(user_train[uid][:-1]):
            seq[idx] = i
            pos[idx] = nxt
            if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
            nxt = i
            idx -= 1
            if idx == -1: break

        return (uid, seq, pos, neg)

    np.random.seed(SEED)
    uids = np.arange(1, usernum+1, dtype=np.int32)
    counter = 0
    while True:
        if counter % usernum == 0:
            np.random.shuffle(uids)
        one_batch = []
        for i in range(batch_size):
            one_batch.append(sample(uids[counter % usernum]))
            counter += 1
        result_queue.put(zip(*one_batch))

class WarpSampler(object):
    def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1):
        self.result_queue = Queue(maxsize=n_workers * 10)
        self.processors = []
        for i in range(n_workers):
            self.processors.append(
                Process(target=sample_function, args=(User,
                                                      usernum,
                                                      itemnum,
                                                      batch_size,
                                                      maxlen,
                                                      self.result_queue,
                                                      np.random.randint(2e9)
                                                      )))
            self.processors[-1].daemon = True
            self.processors[-1].start()

    def next_batch(self):
        return self.result_queue.get()

    def close(self):
        for p in self.processors:
            p.terminate()
            p.join()

# train/val/test data generation
def data_partition(fname):
    usernum = 0
    itemnum = 0
    User = defaultdict(list)
    user_train = {}
    user_valid = {}
    user_test = {}
    # assume user/item index starting from 1
    f = open('data/%s.txt' % fname, 'r')
    for line in f:
        u, i = line.rstrip().split(' ')
        u = int(u)
        i = int(i)
        usernum = max(u, usernum)
        itemnum = max(i, itemnum)
        User[u].append(i)

    for user in User:
        nfeedback = len(User[user])
        if nfeedback < 3:
            user_train[user] = User[user]
            user_valid[user] = []
            user_test[user] = []
        else:
            user_train[user] = User[user][:-2]
            user_valid[user] = []
            user_valid[user].append(User[user][-2])
            user_test[user] = []
            user_test[user].append(User[user][-1])
    return [user_train, user_valid, user_test, usernum, itemnum]

# TODO: merge evaluate functions for test and val set
# evaluate on test set
def evaluate(model, dataset, args):
    [train, valid, test, usernum, itemnum] = copy.deepcopy(dataset)

    NDCG = 0.0
    HT = 0.0
    valid_user = 0.0

    if usernum>10000:
        users = random.sample(range(1, usernum + 1), 10000)
    else:
        users = range(1, usernum + 1)
    for u in users:

        if len(train[u]) < 1 or len(test[u]) < 1: continue

        seq = np.zeros([args.maxlen], dtype=np.int32)
        idx = args.maxlen - 1
        seq[idx] = valid[u][0]
        idx -= 1
        for i in reversed(train[u]):
            seq[idx] = i
            idx -= 1
            if idx == -1: break
        rated = set(train[u])
        rated.add(0)
        item_idx = [test[u][0]]
        for _ in range(100):
            t = np.random.randint(1, itemnum + 1)
            while t in rated: t = np.random.randint(1, itemnum + 1)
            item_idx.append(t)

        predictions = -model.predict(*[np.array(l) for l in [[u], [seq], item_idx]])
        predictions = predictions[0] # - for 1st argsort DESC

        rank = predictions.argsort().argsort()[0].item()

        valid_user += 1

        if rank < 10:
            NDCG += 1 / np.log2(rank + 2)
            HT += 1
        if valid_user % 100 == 0:
            print('.', end="")
            sys.stdout.flush()

    return NDCG / valid_user, HT / valid_user

# evaluate on val set
def evaluate_valid(model, dataset, args):
    [train, valid, test, usernum, itemnum] = copy.deepcopy(dataset)

    NDCG = 0.0
    valid_user = 0.0
    HT = 0.0
    if usernum>10000:
        users = random.sample(range(1, usernum + 1), 10000)
    else:
        users = range(1, usernum + 1)
    for u in users:
        if len(train[u]) < 1 or len(valid[u]) < 1: continue

        seq = np.zeros([args.maxlen], dtype=np.int32)
        idx = args.maxlen - 1
        for i in reversed(train[u]):
            seq[idx] = i
            idx -= 1
            if idx == -1: break

        rated = set(train[u])
        rated.add(0)
        item_idx = [valid[u][0]]
        for _ in range(100):
            t = np.random.randint(1, itemnum + 1)
            while t in rated: t = np.random.randint(1, itemnum + 1)
            item_idx.append(t)

        predictions = -model.predict(*[np.array(l) for l in [[u], [seq], item_idx]])
        predictions = predictions[0]

        rank = predictions.argsort().argsort()[0].item()

        valid_user += 1

        if rank < 10:
            NDCG += 1 / np.log2(rank + 2)
            HT += 1
        if valid_user % 100 == 0:
            print('.', end="")
            sys.stdout.flush()

    return NDCG / valid_user, HT / valid_user

In [4]:
# model.py
import numpy as np
import torch

class PointWiseFeedForward(torch.nn.Module):
    def __init__(self, hidden_units, dropout_rate):

        super(PointWiseFeedForward, self).__init__()

        self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout2 = torch.nn.Dropout(p=dropout_rate)

    def forward(self, inputs):
        outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
        outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length)
        return outputs

# pls use the following self-made multihead attention layer
# in case your pytorch version is below 1.16 or for other reasons
# https://github.com/pmixer/TiSASRec.pytorch/blob/master/model.py

class SASRec(torch.nn.Module):
    def __init__(self, user_num, item_num, args):
        super(SASRec, self).__init__()

        self.user_num = user_num
        self.item_num = item_num
        self.dev = args.device
        self.norm_first = args.norm_first

        # TODO: loss += args.l2_emb for regularizing embedding vectors during training
        # https://stackoverflow.com/questions/42704283/adding-l1-l2-regularization-in-pytorch
        self.item_emb = torch.nn.Embedding(self.item_num+1, args.hidden_units, padding_idx=0)
        self.pos_emb = torch.nn.Embedding(args.maxlen+1, args.hidden_units, padding_idx=0)
        self.emb_dropout = torch.nn.Dropout(p=args.dropout_rate)

        self.attention_layernorms = torch.nn.ModuleList() # to be Q for self-attention
        self.attention_layers = torch.nn.ModuleList()
        self.forward_layernorms = torch.nn.ModuleList()
        self.forward_layers = torch.nn.ModuleList()

        self.last_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)

        for _ in range(args.num_blocks):
            new_attn_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
            self.attention_layernorms.append(new_attn_layernorm)

            new_attn_layer =  torch.nn.MultiheadAttention(args.hidden_units,
                                                            args.num_heads,
                                                            args.dropout_rate)
            self.attention_layers.append(new_attn_layer)

            new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)
            self.forward_layernorms.append(new_fwd_layernorm)

            new_fwd_layer = PointWiseFeedForward(args.hidden_units, args.dropout_rate)
            self.forward_layers.append(new_fwd_layer)

            # self.pos_sigmoid = torch.nn.Sigmoid()
            # self.neg_sigmoid = torch.nn.Sigmoid()

    def log2feats(self, log_seqs): # TODO: fp64 and int64 as default in python, trim?
        seqs = self.item_emb(torch.LongTensor(log_seqs).to(self.dev))
        seqs *= self.item_emb.embedding_dim ** 0.5
        poss = np.tile(np.arange(1, log_seqs.shape[1] + 1), [log_seqs.shape[0], 1])
        # TODO: directly do tensor = torch.arange(1, xxx, device='cuda') to save extra overheads
        poss *= (log_seqs != 0)
        seqs += self.pos_emb(torch.LongTensor(poss).to(self.dev))
        seqs = self.emb_dropout(seqs)

        tl = seqs.shape[1] # time dim len for enforce causality
        attention_mask = ~torch.tril(torch.ones((tl, tl), dtype=torch.bool, device=self.dev))

        for i in range(len(self.attention_layers)):
            seqs = torch.transpose(seqs, 0, 1)
            if self.norm_first:
                x = self.attention_layernorms[i](seqs)
                mha_outputs, _ = self.attention_layers[i](x, x, x,
                                                attn_mask=attention_mask)
                seqs = seqs + mha_outputs
                seqs = torch.transpose(seqs, 0, 1)
                seqs = seqs + self.forward_layers[i](self.forward_layernorms[i](seqs))
            else:
                mha_outputs, _ = self.attention_layers[i](seqs, seqs, seqs,
                                                attn_mask=attention_mask)
                seqs = self.attention_layernorms[i](seqs + mha_outputs)
                seqs = torch.transpose(seqs, 0, 1)
                seqs = self.forward_layernorms[i](seqs + self.forward_layers[i](seqs))

        log_feats = self.last_layernorm(seqs) # (U, T, C) -> (U, -1, C)

        return log_feats

    def forward(self, user_ids, log_seqs, pos_seqs, neg_seqs): # for training
        log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet

        pos_embs = self.item_emb(torch.LongTensor(pos_seqs).to(self.dev))
        neg_embs = self.item_emb(torch.LongTensor(neg_seqs).to(self.dev))

        pos_logits = (log_feats * pos_embs).sum(dim=-1)
        neg_logits = (log_feats * neg_embs).sum(dim=-1)

        # pos_pred = self.pos_sigmoid(pos_logits)
        # neg_pred = self.neg_sigmoid(neg_logits)

        return pos_logits, neg_logits # pos_pred, neg_pred

    def predict(self, user_ids, log_seqs, item_indices): # for inference
        log_feats = self.log2feats(log_seqs) # user_ids hasn't been used yet

        final_feat = log_feats[:, -1, :] # only use last QKV classifier, a waste

        item_embs = self.item_emb(torch.LongTensor(item_indices).to(self.dev)) # (U, I, C)

        logits = item_embs.matmul(final_feat.unsqueeze(-1)).squeeze(-1)

        # preds = self.pos_sigmoid(logits) # rank same item list for different users

        return logits # preds # (U, I)

In [5]:
# main.py
import os
import time
import torch
import argparse

# Import our modules
# from model import SASRec
# from utils import *

def str2bool(s):
    if s not in {'false', 'true'}:
        raise ValueError('Not a valid boolean string')
    return s == 'true'

# For Colab, we'll set arguments directly instead of using argparse
class Args:
    def __init__(self):
        self.dataset = 'ml-1m'
        self.train_dir = 'default'
        self.batch_size = 128
        self.lr = 0.001
        self.maxlen = 200
        self.hidden_units = 50
        self.num_blocks = 2
        self.num_epochs = 201  # Reduced for demo
        self.num_heads = 1
        self.dropout_rate = 0.2
        self.l2_emb = 0.0
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.inference_only = False
        self.state_dict_path = None
        self.norm_first = False

args = Args()

# Create output directory
if not os.path.isdir(args.dataset + '_' + args.train_dir):
    os.makedirs(args.dataset + '_' + args.train_dir)

with open(os.path.join(args.dataset + '_' + args.train_dir, 'args.txt'), 'w') as f:
    f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
f.close()

if __name__ == '__main__':
    # Define all datasets to run
    datasets_to_run = ['ml-1m', 'amazon-book', 'yelp2018']

    for dataset_name in datasets_to_run:
        print(f"\n\n{'='*60}")
        print(f"RUNNING SASREC ON: {dataset_name.upper()}")
        print(f"{'='*60}\n")

        # Set arguments for the current run
        args = Args()
        args.dataset = dataset_name

        # Adjust maxlen for larger datasets if necessary
        if dataset_name in ['amazon-book', 'yelp2018']:
            args.maxlen = 50 # These datasets have shorter average sequences
        else:
            args.maxlen = 200 # Original value for ml-1m

        # Prepare the dataset if it's not ml-1m
        if dataset_name != 'ml-1m':
            prepare_and_download_dataset(dataset_name)

        # Create output directory for the current dataset
        log_dir = args.dataset + '_' + args.train_dir
        if not os.path.isdir(log_dir):
            os.makedirs(log_dir)
        with open(os.path.join(log_dir, 'args.txt'), 'w') as f:
            f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
        f.close()

        # Load and partition the dataset
        dataset = data_partition(args.dataset)
        [user_train, user_valid, user_test, usernum, itemnum] = dataset
        num_batch = (len(user_train) - 1) // args.batch_size + 1

        cc = sum(len(user_train[u]) for u in user_train)
        print('average sequence length: %.2f' % (cc / len(user_train)))

        # Setup logging
        log_file = open(os.path.join(log_dir, 'log.txt'), 'w')
        log_file.write('epoch (val_ndcg, val_hr) (test_ndcg, test_hr)\n')

        # Initialize sampler and model
        sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3)
        model = SASRec(usernum, itemnum, args).to(args.device)

        # Initialize weights
        for name, param in model.named_parameters():
            try:
                torch.nn.init.xavier_normal_(param.data)
            except:
                pass # ignore failed init layers
        model.train()

        bce_criterion = torch.nn.BCEWithLogitsLoss()
        adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))

        T = 0.0
        t0 = time.time()

        # Start training loop for the current dataset
        for epoch in range(1, args.num_epochs + 1):
            total_loss = 0.0
            for step in range(num_batch):
                u, seq, pos, neg = sampler.next_batch()
                u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg)

                pos_logits, neg_logits = model(u, seq, pos, neg)
                pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)

                adam_optimizer.zero_grad()
                indices = np.where(pos != 0)
                loss = bce_criterion(pos_logits[indices], pos_labels[indices])
                loss += bce_criterion(neg_logits[indices], neg_labels[indices])
                for param in model.item_emb.parameters(): loss += args.l2_emb * torch.norm(param)
                loss.backward()
                adam_optimizer.step()
                total_loss += loss.item()

            print(f"Epoch: {epoch}, Average Loss: {total_loss / num_batch:.4f}")

            if epoch % 20 == 0:
                model.eval()
                t1 = time.time() - t0
                T += t1
                print('Evaluating', end='')

                # NOTE: Using a single merged evaluation function is recommended
                # For now, we use the original ones
                t_test = evaluate(model, dataset, args)
                t_valid = evaluate_valid(model, dataset, args)

                print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)'
                        % (epoch, T, t_valid[0], t_valid[1], t_test[0], t_test[1]))

                log_file.write(str(epoch) + ' ' + str(t_valid) + ' ' + str(t_test) + '\n')
                log_file.flush()
                t0 = time.time()
                model.train()

        log_file.close()
        sampler.close()
        print(f"Done training on {dataset_name}")

    print("\n\nAll dataset runs completed.")



RUNNING SASREC ON: ML-1M

average sequence length: 163.50
Epoch: 1, Average Loss: 1.1686
Epoch: 2, Average Loss: 1.0023
Epoch: 3, Average Loss: 0.9566
Epoch: 4, Average Loss: 0.8812
Epoch: 5, Average Loss: 0.8251
Epoch: 6, Average Loss: 0.7888
Epoch: 7, Average Loss: 0.7465
Epoch: 8, Average Loss: 0.7072
Epoch: 9, Average Loss: 0.6878
Epoch: 10, Average Loss: 0.6652
Epoch: 11, Average Loss: 0.6445
Epoch: 12, Average Loss: 0.6371
Epoch: 13, Average Loss: 0.6212
Epoch: 14, Average Loss: 0.6054
Epoch: 15, Average Loss: 0.5977
Epoch: 16, Average Loss: 0.5873
Epoch: 17, Average Loss: 0.5746
Epoch: 18, Average Loss: 0.5696
Epoch: 19, Average Loss: 0.5605
Epoch: 20, Average Loss: 0.5508
Evaluating........................................................................................................................epoch:20, time: 61.861300(s), valid (NDCG@10: 0.5382, HR@10: 0.7967), test (NDCG@10: 0.5197, HR@10: 0.7687)
Epoch: 21, Average Loss: 0.5489
Epoch: 22, Average Loss: 0.5369
Epoch: 

In [6]:
class SASRecAnalyzer:
    def __init__(self, model, item_num):
        self.model = model.to(torch.device('cpu')) # Move to CPU for analysis
        self.item_num = item_num
        self.item_embeddings = self.model.item_emb.weight.data.numpy()[1:] # Exclude padding item 0

    def _calculate_metrics(self, high_dim_embeds, low_dim_embeds, k=10):
        # This helper function is the same as the one used for KGAT
        nn_high = NearestNeighbors(n_neighbors=k+1).fit(high_dim_embeds)
        high_dim_neighbors = nn_high.kneighbors(high_dim_embeds, return_distance=False)[:, 1:]
        nn_low = NearestNeighbors(n_neighbors=k+1).fit(low_dim_embeds)
        low_dim_neighbors = nn_low.kneighbors(low_dim_embeds, return_distance=False)[:, 1:]
        np_scores = [len(set(high_dim_neighbors[i]) & set(low_dim_neighbors[i])) / k for i in range(len(high_dim_embeds))]
        return np.mean(np_scores)

    def plot_tsne_and_metrics(self, n_samples=1000, k=10, n_clusters=10):
        """
        Generates and plots a t-SNE visualization of item embeddings and calculates metrics.
        """
        print(f"Running t-SNE on a sample of {n_samples} item embeddings...")

        # Sample a subset of items for efficient analysis
        sample_indices = np.random.choice(self.item_embeddings.shape[0], min(n_samples, self.item_embeddings.shape[0]), replace=False)
        sampled_embeds_high_dim = self.item_embeddings[sample_indices]

        # Perform t-SNE
        tsne = TSNE(n_components=2, perplexity=40, n_iter=1000, random_state=42)
        embeds_low_dim = tsne.fit_transform(sampled_embeds_high_dim)

        # Calculate Neighborhood Preservation
        np_metric = self._calculate_metrics(sampled_embeds_high_dim, embeds_low_dim, k=k)
        print("\n--- Embedding Quality Metrics ---")
        print(f"Neighborhood Preservation (NP@{k}): {np_metric:.4f}")
        print("---------------------------------\n")

        # Perform K-Means clustering to find item groups
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10).fit(sampled_embeds_high_dim)
        cluster_labels = kmeans.labels_

        # Plotting
        plt.figure(figsize=(12, 10))
        sns.scatterplot(
            x=embeds_low_dim[:, 0], y=embeds_low_dim[:, 1],
            hue=cluster_labels,
            palette=sns.color_palette("viridis", n_clusters),
            legend='full',
            alpha=0.7
        )
        plt.title(f't-SNE Visualization of SASRec Item Embeddings ({n_clusters} Clusters)')
        plt.xlabel('t-SNE Dimension 1')
        plt.ylabel('t-SNE Dimension 2')
        plt.show()

In [7]:
def analyze_sasrec_models():
    """
    Loads trained SASRec models and runs the t-SNE and metrics analysis.
    """
    # Assuming your Args class and data_partition function are available
    datasets_to_analyze = ['ml-1m', 'amazon-book', 'yelp2018']

    for dataset_name in datasets_to_analyze:
        print(f"\n\n{'='*60}")
        print(f"ANALYZING SASREC MODEL FOR: {dataset_name.upper()}")
        print(f"{'='*60}")

        # Setup args and get dataset info
        args = Args()
        args.dataset = dataset_name
        if dataset_name in ['amazon-book', 'yelp2018']:
            args.maxlen = 50

        _, _, _, usernum, itemnum = data_partition(args.dataset)

        # Find the best saved model for this dataset
        model_dir = args.dataset + '_' + args.train_dir
        if not os.path.isdir(model_dir):
            print(f"Model directory not found for {dataset_name}. Skipping.")
            continue

        # A simple way to find a saved model file - you can make this more specific
        saved_models = [f for f in os.listdir(model_dir) if f.endswith('.pth')]
        if not saved_models:
            print(f"No saved .pth model found for {dataset_name}. Skipping.")
            continue

        model_path = os.path.join(model_dir, saved_models[-1]) # Load the last saved model
        print(f"Loading model from: {model_path}")

        # Load model and state
        model = SASRec(usernum, itemnum, args)
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        model.eval()

        # Create analyzer and run the analysis
        analyzer = SASRecAnalyzer(model, itemnum)
        analyzer.plot_tsne_and_metrics()

# After the main training block, call this new function
if __name__ == '__main__':
    # ... your existing training loop ...

    # After all training is done, run the analysis
    analyze_sasrec_models()



ANALYZING SASREC MODEL FOR: ML-1M
No saved .pth model found for ml-1m. Skipping.


ANALYZING SASREC MODEL FOR: AMAZON-BOOK
No saved .pth model found for amazon-book. Skipping.


ANALYZING SASREC MODEL FOR: YELP2018
No saved .pth model found for yelp2018. Skipping.
