In [1]:
# %% [code]
import os
import time
import sys
import numpy as np
import tensorflow as tf
import scipy.sparse
from pathlib import Path
import argparse

# Set random seeds for reproducibility
random_seed = 2019
np.random.seed(random_seed)
tf.compat.v1.set_random_seed(random_seed)

# Define training arguments and hyperparameters
class Args:
    dataset = 'dummy'
    verbose = 10
    batch_size = 512
    epochs = 101  # adjust epochs as needed; 101 used here for demonstration
    embed_size = 64
    lr = 0.05
    dropout = 0.9
    negative_weight = 0.05
    topK = [5, 10, 20]

args = Args()

###############################################################################
# Data Loader
###############################################################################
class LoadData(object):
    def __init__(self, DATA_ROOT):
        self.trainfile = os.path.join(DATA_ROOT, 'train.csv')
        self.testfile = os.path.join(DATA_ROOT, 'test.csv')
        self.user_field_M, self.item_field_M = self.get_length()
        print("user_field_M", self.user_field_M)
        print("item_field_M", self.item_field_M)
        print("field_M", self.user_field_M + self.item_field_M)
        self.item_bind_M = self.bind_item()  # assign a unique id for each item feature string
        self.user_bind_M = self.bind_user()  # assign a unique id for each user feature string
        print("item_bind_M", len(self.binded_items))
        print("user_bind_M", len(self.binded_users))
        self.item_map_list = []
        # Build an item feature list from the map (using the string split on '-' to get the individual features)
        for itemid in self.item_map:
            self.item_map_list.append([int(feature) for feature in self.item_map[itemid].strip().split('-')])
        # Append a dummy feature vector for padding
        self.item_map_list.append([int(feature) for feature in self.item_map[0].strip().split('-')])
        self.user_positive_list = self.get_positive_list(self.trainfile)  # map: user_id -> list of positive item ids
        self.Train_data, self.Test_data = self.construct_data()
        self.user_train, self.item_train = self.get_train_instances()
        self.user_test = self.get_test()

    def get_length(self):
        length_user = 0
        length_item = 0
        with open(self.trainfile, 'r') as f:
            line = f.readline()
            while line:
                user_features = line.strip().split(',')[0].split('-')
                item_features = line.strip().split(',')[1].split('-')
                for uf in user_features:
                    feature = int(uf)
                    if feature > length_user:
                        length_user = feature
                for itf in item_features:
                    feature = int(itf)
                    if feature > length_item:
                        length_item = feature
                line = f.readline()
        return length_user + 1, length_item + 1

    def bind_item(self):
        self.binded_items = {}  # map: item feature string -> id
        self.item_map = {}      # map: id -> item feature string
        self.bind_i(self.trainfile)
        self.bind_i(self.testfile)
        return len(self.binded_items)

    def bind_i(self, file):
        with open(file, 'r') as f:
            line = f.readline()
            i = len(self.binded_items)
            while line:
                features = line.strip().split(',')
                item_features = features[1]
                if item_features not in self.binded_items:
                    self.binded_items[item_features] = i
                    self.item_map[i] = item_features
                    i += 1
                line = f.readline()

    def bind_user(self):
        self.binded_users = {}  # map: user feature string -> id
        self.user_map = {}      # map: id -> user feature string
        self.bind_u(self.trainfile)
        self.bind_u(self.testfile)
        return len(self.binded_users)

    def bind_u(self, file):
        with open(file, 'r') as f:
            line = f.readline()
            i = len(self.binded_users)
            while line:
                features = line.strip().split(',')
                user_features = features[0]
                if user_features not in self.binded_users:
                    self.binded_users[user_features] = i
                    self.user_map[i] = user_features
                    i += 1
                line = f.readline()

    def get_positive_list(self, file):
        self.max_positive_len = 0
        user_positive_list = {}
        with open(file, 'r') as f:
            line = f.readline()
            while line:
                features = line.strip().split(',')
                user_id = self.binded_users[features[0]]
                item_id = self.binded_items[features[1]]
                if user_id in user_positive_list:
                    user_positive_list[user_id].append(item_id)
                else:
                    user_positive_list[user_id] = [item_id]
                line = f.readline()
        for u in user_positive_list:
            if len(user_positive_list[u]) > self.max_positive_len:
                self.max_positive_len = len(user_positive_list[u])
        return user_positive_list

    def get_train_instances(self):
        user_train, item_train = [], []
        for u in self.user_positive_list:
            # Transform user features into a list of ints
            u_train = [int(feature) for feature in self.user_map[u].strip().split('-')]
            user_train.append(u_train)
            temp = self.user_positive_list[u][:]
            # Pad the list with a dummy item id (which is self.item_bind_M) to reach max_positive_len
            while len(temp) < self.max_positive_len:
                temp.append(self.item_bind_M)
            item_train.append(temp)
        user_train = np.array(user_train)
        item_train = np.array(item_train)
        return user_train, item_train

    def construct_data(self):
        X_user, X_item = self.read_data(self.trainfile)
        Train_data = self.construct_dataset(X_user, X_item)
        print("# of training:", len(X_user))
        X_user, X_item = self.read_data(self.testfile)
        Test_data = self.construct_dataset(X_user, X_item)
        print("# of test:", len(X_user))
        return Train_data, Test_data

    def construct_dataset(self, X_user, X_item):
        user_id = []
        for one in X_user:
            key = "-".join([str(item) for item in one])
            user_id.append(self.binded_users[key])
        item_id = []
        for one in X_item:
            key = "-".join([str(item) for item in one])
            item_id.append(self.binded_items[key])
        count = np.ones(len(X_user))
        sparse_matrix = scipy.sparse.csr_matrix((count, (user_id, item_id)), dtype=np.int16,
                                                shape=(self.user_bind_M, self.item_bind_M))
        return sparse_matrix

    def get_test(self):
        X_user, _ = self.read_data(self.testfile)
        return X_user

    def read_data(self, file):
        X_user = []
        X_item = []
        with open(file, 'r') as f:
            line = f.readline()
            while line:
                features = line.strip().split(',')
                user_features = features[0].split('-')
                X_user.append([int(item) for item in user_features])
                item_features = features[1].split('-')
                X_item.append([int(item) for item in item_features])
                line = f.readline()
        return X_user, X_item

###############################################################################
# ENSFM Model
###############################################################################
class ENSFM:
    def __init__(self, item_attribute, user_field_M, item_field_M, embedding_size, max_item_pu, args):
        self.embedding_size = embedding_size
        self.max_item_pu = max_item_pu
        self.user_field_M = user_field_M
        self.item_field_M = item_field_M
        self.weight1 = args.negative_weight
        self.item_attribute = item_attribute  # a list of item feature vectors
        self.lambda_bilinear = [0.0, 0.0]

    def _create_placeholders(self):
        self.input_u = tf.compat.v1.placeholder(tf.int32, [None, None], name="input_u_feature")
        self.input_ur = tf.compat.v1.placeholder(tf.int32, [None, self.max_item_pu], name="input_ur")
        self.dropout_keep_prob = tf.compat.v1.placeholder(tf.float32, name="dropout_keep_prob")

    def _create_variables(self):
        self.uidW = tf.Variable(tf.random.truncated_normal([self.user_field_M, self.embedding_size],
                                                             mean=0.0, stddev=0.01), name="uidW")
        self.iidW = tf.Variable(tf.random.truncated_normal([self.item_field_M+1, self.embedding_size],
                                                             mean=0.0, stddev=0.01), name="iidW")
        self.H_i = tf.Variable(tf.constant(0.01, shape=[self.embedding_size, 1]), name="H_i")
        self.H_s = tf.Variable(tf.constant(0.01, shape=[self.embedding_size, 1]), name="H_s")
        self.u_bias = tf.Variable(tf.random.truncated_normal([self.user_field_M, 1],
                                                               mean=0.0, stddev=0.01), name="u_bias")
        self.i_bias = tf.Variable(tf.random.truncated_normal([self.item_field_M, 1],
                                                               mean=0.0, stddev=0.01), name="i_bias")
        self.bias = tf.Variable(tf.constant(0.0), name='bias')

    def _create_vectors(self):
        self.user_feature_emb = tf.nn.embedding_lookup(self.uidW, self.input_u)
        self.summed_user_emb = tf.reduce_sum(self.user_feature_emb, axis=1)
        # Apply dropout to H_i and H_s
        self.H_i = tf.nn.dropout(self.H_i, rate=1 - self.dropout_keep_prob)
        self.H_s = tf.nn.dropout(self.H_s, rate=1 - self.dropout_keep_prob)
        self.all_item_feature_emb = tf.nn.embedding_lookup(self.iidW, self.item_attribute)
        self.summed_all_item_emb = tf.reduce_sum(self.all_item_feature_emb, axis=1)
        # Compute cross terms for user and item
        self.user_cross = 0.5 * (tf.square(self.summed_user_emb) - tf.reduce_sum(tf.square(self.user_feature_emb), axis=1))
        self.item_cross = 0.5 * (tf.square(self.summed_all_item_emb) - tf.reduce_sum(tf.square(self.all_item_feature_emb), axis=1))
        self.user_cross_score = tf.matmul(self.user_cross, self.H_s)
        self.item_cross_score = tf.matmul(self.item_cross, self.H_s)
        self.user_bias = tf.reduce_sum(tf.nn.embedding_lookup(self.u_bias, self.input_u), axis=1)
        self.item_bias = tf.reduce_sum(tf.nn.embedding_lookup(self.i_bias, self.item_attribute), axis=1)
        self.I = tf.ones(shape=(tf.shape(self.input_u)[0], 1))
        self.p_emb = tf.concat([self.summed_user_emb, self.user_cross_score + self.user_bias + self.bias, self.I], axis=1)
        self.I = tf.ones(shape=(tf.shape(self.summed_all_item_emb)[0], 1))
        self.q_emb = tf.concat([self.summed_all_item_emb, self.I, self.item_cross_score + self.item_bias], axis=1)
        self.H_i_emb = tf.concat([self.H_i, [[1.0]], [[1.0]]], axis=0)

    def _create_inference(self):
        self.pos_item = tf.nn.embedding_lookup(self.q_emb, self.input_ur)
        # Filter out padding values (assumed to equal self.item_bind_M)
        self.pos_num_r = tf.cast(tf.not_equal(self.input_ur, data.item_bind_M), tf.float32)
        self.pos_item = tf.einsum('ab,abc->abc', self.pos_num_r, self.pos_item)
        self.pos_r = tf.einsum('ac,abc->abc', self.p_emb, self.pos_item)
        self.pos_r = tf.einsum('ajk,kl->ajl', self.pos_r, self.H_i_emb)
        self.pos_r = tf.reshape(self.pos_r, [-1, self.max_item_pu])

    def _pre(self):
        dot = tf.einsum('ac,bc->abc', self.p_emb, self.q_emb)
        pre = tf.einsum('ajk,kl->aj', dot, self.H_i_emb)
        return pre

    def _create_loss(self):
        self.loss1 = self.weight1 * tf.reduce_sum(
            tf.reduce_sum(tf.reduce_sum(tf.einsum('ab,ac->abc', self.q_emb, self.q_emb), axis=0)
                          * tf.reduce_sum(tf.einsum('ab,ac->abc', self.p_emb, self.p_emb), axis=0)
                          * tf.matmul(self.H_i_emb, self.H_i_emb, transpose_b=True), axis=0), axis=0)
        self.loss1 += tf.reduce_sum((1.0 - self.weight1) * tf.square(self.pos_r) - 2.0 * self.pos_r)
        self.l2_loss0 = tf.nn.l2_loss(self.uidW)
        self.l2_loss1 = tf.nn.l2_loss(self.iidW)
        self.loss = self.loss1 + self.lambda_bilinear[0] * self.l2_loss0 + self.lambda_bilinear[1] * self.l2_loss1
        self.reg_loss = self.lambda_bilinear[0] * self.l2_loss0 + self.lambda_bilinear[1] * self.l2_loss1

    def _build_graph(self):
        self._create_placeholders()
        self._create_variables()
        self._create_vectors()
        self._create_inference()
        self._create_loss()
        self.pre = self._pre()

###############################################################################
# List of dataset directories
###############################################################################
# Change the base directory to where your datasets are stored.
base_data_dir = "/media/leo/Huy/Project/CARS/AMZ"
dataset_dirs = [
    os.path.join(base_data_dir, "Arts & Photography"),
    os.path.join(base_data_dir, "Genre Fiction"),
    os.path.join(base_data_dir, "History")
]

###############################################################################
# Training Loop
###############################################################################
for DATA_ROOT in dataset_dirs:
    print("==============================================")
    print("Training on dataset:", DATA_ROOT)
    
    # Load data for the current dataset
    data = LoadData(DATA_ROOT)
    
    # Start a new TensorFlow graph for each dataset
    with tf.Graph().as_default():
        session_conf = tf.compat.v1.ConfigProto()
        session_conf.gpu_options.allow_growth = True
        with tf.compat.v1.Session(config=session_conf) as sess:
            # Build the model
            deep = ENSFM(data.item_map_list, data.user_field_M, data.item_field_M,
                         args.embed_size, data.max_positive_len, args)
            deep._build_graph()
            train_op1 = tf.compat.v1.train.AdagradOptimizer(
                learning_rate=args.lr, initial_accumulator_value=1e-8
            ).minimize(deep.loss)
            sess.run(tf.compat.v1.global_variables_initializer())

            # Define inner helper functions that use the current session and model
            def train_step1(u_batch, y_batch):
                feed_dict = {
                    deep.input_u: u_batch,
                    deep.input_ur: y_batch,
                    deep.dropout_keep_prob: args.dropout,
                }
                _, loss, loss1, loss2, _ = sess.run(
                    [train_op1, deep.loss, deep.loss1, deep.reg_loss, deep.p_emb],
                    feed_dict)
                return loss, loss1, loss2

            def evaluate():
                eva_batch = 128
                recall_all, ndcg_all = [[] for _ in range(3)], [[] for _ in range(3)]
                user_features = data.user_test
                ll = int(len(user_features) / eva_batch) + 1
                for batch_num in range(ll):
                    start_index = batch_num * eva_batch
                    end_index = min((batch_num + 1) * eva_batch, len(user_features))
                    u_batch_eval = user_features[start_index:end_index]
                    batch_users = end_index - start_index
                    feed_dict_eval = { deep.input_u: u_batch_eval, deep.dropout_keep_prob: 1.0 }
                    pre = sess.run(deep.pre, feed_dict_eval)
                    pre = np.array(pre)
                    # Remove the dummy padding dimension
                    pre = np.delete(pre, -1, axis=1)
                    user_ids = [data.binded_users["-".join(map(str, one))] for one in u_batch_eval]
                    
                    idx = np.zeros_like(pre, dtype=bool)
                    idx[data.Train_data[user_ids].nonzero()] = True
                    pre[idx] = -np.inf

                    for i, kj in enumerate(args.topK):
                        idx_topk_part = np.argpartition(-pre, kj, axis=1)
                        pre_bin = np.zeros_like(pre, dtype=bool)
                        pre_bin[np.arange(batch_users)[:, None], idx_topk_part[:, :kj]] = True
                        true_bin = np.zeros_like(pre, dtype=bool)
                        true_bin[data.Test_data[user_ids].nonzero()] = True
                        tmp = (np.logical_and(true_bin, pre_bin).sum(axis=1)).astype(np.float32)
                        recall_all[i].append(tmp / np.minimum(kj, true_bin.sum(axis=1)))
                        topk_part = pre[np.arange(batch_users)[:, None], idx_topk_part[:, :kj]]
                        idx_part = np.argsort(-topk_part, axis=1)
                        idx_topk = idx_topk_part[np.arange(batch_users)[:, None], idx_part]
                        tp = np.log(2) / np.log(np.arange(2, kj + 2))
                        test_batch = data.Test_data[user_ids]
                        DCG = (test_batch[np.arange(batch_users)[:, None], idx_topk].toarray() * tp).sum(axis=1)
                        IDCG = np.array([(tp[:min(n, kj)]).sum() for n in test_batch.getnnz(axis=1)])
                        ndcg_all[i].append(DCG / IDCG)
                for i in range(3):
                    recall_all[i] = np.hstack(recall_all[i])
                    ndcg_all[i] = np.hstack(ndcg_all[i])
                    print(f"Top {args.topK[i]} Recall: {np.mean(recall_all[i]):.4f}, NDCG: {np.mean(ndcg_all[i]):.4f}")

            # Initial evaluation
            print("Initial evaluation:")
            evaluate()

            # Begin training epochs
            for epoch in range(args.epochs):
                print(f"\nEpoch {epoch} for dataset {DATA_ROOT}")
                # Shuffle training data at each epoch
                shuffle_indices = np.random.permutation(len(data.user_train))
                data.user_train = data.user_train[shuffle_indices]
                data.item_train = data.item_train[shuffle_indices]
                ll = int(len(data.user_train) / args.batch_size)
                loss_sum = np.zeros(3)
                for batch_num in range(ll):
                    start_index = batch_num * args.batch_size
                    end_index = min((batch_num + 1) * args.batch_size, len(data.user_train))
                    u_batch = data.user_train[start_index:end_index]
                    i_batch = data.item_train[start_index:end_index]
                    loss, loss1, loss2 = train_step1(u_batch, i_batch)
                    loss_sum += np.array([loss, loss1, loss2])
                print(f"Epoch {epoch} losses: loss={loss_sum[0]/ll:.4f}, loss1={loss_sum[1]/ll:.4f}, reg={loss_sum[2]/ll:.4f}")

                if epoch % args.verbose == 0:
                    print("Evaluation:")
                    evaluate()


2025-04-11 22:49:32.935485: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744400972.947025   16632 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744400972.950490   16632 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744400972.960507   16632 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744400972.960518   16632 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744400972.960519   16632 computation_placer.cc:177] computation placer alr

Training on dataset: /media/leo/Huy/Project/CARS/AMZ/Arts & Photography
user_field_M 689607
item_field_M 201822
field_M 891429
item_bind_M 143815
user_bind_M 147825
# of training: 346226
# of test: 147825


I0000 00:00:1744400988.487532   16632 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 5678 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4060, pci bus id: 0000:08:00.0, compute capability: 8.9


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


I0000 00:00:1744400989.005670   16632 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled


Initial evaluation:


2025-04-11 22:49:59.995804: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.53GiB (rounded to 4859830272)requested by op einsum_5/Einsum
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation summary follows.
Current allocation summary follows.
2025-04-11 22:49:59.995818: I external/local_xla/xla/tsl/framework/bfc_allocator.cc:1058] BFCAllocator dump for GPU_0_bfc
2025-04-11 22:49:59.995823: I external/local_xla/xla/tsl/framework/bfc_allocator.cc:1065] Bin (256): 	Total Chunks: 10, Chunks in use: 10. 2.5KiB allocated for chunks. 2.5KiB in use in bin. 800B client-requested in use in bin.
2025-04-11 22:49:59.995825: I external/local_xla/xla/tsl/framework/bfc_allocator.cc:1065] Bin (512): 	Total Chunks: 2, Chunks in use: 1. 1.0KiB allocated for chunks. 512B in use in bin. 264B client-requested in use in bin.
2025-04-11 

ResourceExhaustedError: Graph execution error:

Detected at node 'einsum_5/Einsum' defined at (most recent call last):
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/runpy.py", line 87, in _run_code
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/asyncio/events.py", line 80, in _run
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 534, in process_one
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    File "/tmp/ipykernel_16632/2218588336.py", line 304, in <module>
    File "/tmp/ipykernel_16632/2218588336.py", line 273, in _build_graph
    File "/tmp/ipykernel_16632/2218588336.py", line 252, in _pre
Node: 'einsum_5/Einsum'
Detected at node 'einsum_5/Einsum' defined at (most recent call last):
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/runpy.py", line 87, in _run_code
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/asyncio/events.py", line 80, in _run
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 534, in process_one
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
    File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    File "/tmp/ipykernel_16632/2218588336.py", line 304, in <module>
    File "/tmp/ipykernel_16632/2218588336.py", line 273, in _build_graph
    File "/tmp/ipykernel_16632/2218588336.py", line 252, in _pre
Node: 'einsum_5/Einsum'
2 root error(s) found.
  (0) RESOURCE_EXHAUSTED: OOM when allocating tensor with shape[66,128,143816] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node einsum_5/Einsum}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.

	 [[einsum_6/Einsum/_9]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.

  (1) RESOURCE_EXHAUSTED: OOM when allocating tensor with shape[66,128,143816] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node einsum_5/Einsum}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.

0 successful operations.
0 derived errors ignored.

Original stack trace for 'einsum_5/Einsum':
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/runpy.py", line 197, in _run_module_as_main
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/runpy.py", line 87, in _run_code
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tornado/platform/asyncio.py", line 205, in start
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/asyncio/base_events.py", line 601, in run_forever
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/asyncio/base_events.py", line 1905, in _run_once
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/asyncio/events.py", line 80, in _run
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 534, in process_one
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3048, in run_cell
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3103, in _run_cell
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3308, in run_cell_async
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3490, in run_ast_nodes
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
  File "/tmp/ipykernel_16632/2218588336.py", line 304, in <module>
  File "/tmp/ipykernel_16632/2218588336.py", line 273, in _build_graph
  File "/tmp/ipykernel_16632/2218588336.py", line 252, in _pre
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 150, in error_handler
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tensorflow/python/util/dispatch.py", line 1260, in op_dispatch_handler
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tensorflow/python/ops/special_math_ops.py", line 763, in einsum
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tensorflow/python/ops/special_math_ops.py", line 1200, in _einsum_v2
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tensorflow/python/ops/gen_linalg_ops.py", line 1134, in einsum
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tensorflow/python/framework/op_def_library.py", line 796, in _apply_op_helper
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 2705, in _create_op_internal
  File "/home/leo/miniconda3/envs/ENSFM/lib/python3.9/site-packages/tensorflow/python/framework/ops.py", line 1200, in from_node_def
