Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
yuh-yang committed May 14, 2023
1 parent 4ae6879 commit 2338233
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 18 deletions.
2 changes: 1 addition & 1 deletion recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def _fill_nan(self):
elif ftype == FeatureType.FLOAT:
feat[field].fillna(value=feat[field].mean(), inplace=True)
else:
dtype = np.int64 if ftype == FeatureType.TOKEN_SEQ else np.float
dtype = np.int64 if ftype == FeatureType.TOKEN_SEQ else np.float64
feat[field] = feat[field].apply(lambda x: np.array([], dtype=dtype) if isinstance(x, float) else x)

def _normalize(self):
Expand Down
2 changes: 1 addition & 1 deletion recbole/data/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _convert_to_tensor(data):
torch.Tensor: Converted tensor from `data`.
"""
elem = data[0]
if isinstance(elem, (float, int, np.float, np.int64)):
if isinstance(elem, (float, int, np.float64, np.int64)):
new_data = torch.as_tensor(data)
elif isinstance(elem, (list, tuple, pd.Series, np.ndarray, torch.Tensor)):
seq_data = [torch.as_tensor(d) for d in data]
Expand Down
10 changes: 5 additions & 5 deletions recbole/evaluator/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def calculate_metric(self, dataobject):

def metric_info(self, pos_index):
idxs = pos_index.argmax(axis=1)
result = np.zeros_like(pos_index, dtype=np.float)
result = np.zeros_like(pos_index, dtype=np.float64)
for row, idx in enumerate(idxs):
if pos_index[row, idx] > 0:
result[row, idx:] = 1 / (idx + 1)
Expand Down Expand Up @@ -125,10 +125,10 @@ def calculate_metric(self, dataobject):

def metric_info(self, pos_index, pos_len):
pre = pos_index.cumsum(axis=1) / np.arange(1, pos_index.shape[1] + 1)
sum_pre = np.cumsum(pre * pos_index.astype(np.float), axis=1)
sum_pre = np.cumsum(pre * pos_index.astype(np.float64), axis=1)
len_rank = np.full_like(pos_len, pos_index.shape[1])
actual_len = np.where(pos_len > len_rank, len_rank, pos_len)
result = np.zeros_like(pos_index, dtype=np.float)
result = np.zeros_like(pos_index, dtype=np.float64)
for row, lens in enumerate(actual_len):
ranges = np.arange(1, pos_index.shape[1] + 1)
ranges[lens:] = ranges[lens - 1]
Expand Down Expand Up @@ -187,13 +187,13 @@ def metric_info(self, pos_index, pos_len):
len_rank = np.full_like(pos_len, pos_index.shape[1])
idcg_len = np.where(pos_len > len_rank, len_rank, pos_len)

iranks = np.zeros_like(pos_index, dtype=np.float)
iranks = np.zeros_like(pos_index, dtype=np.float64)
iranks[:, :] = np.arange(1, pos_index.shape[1] + 1)
idcg = np.cumsum(1.0 / np.log2(iranks + 1), axis=1)
for row, idx in enumerate(idcg_len):
idcg[row, idx:] = idcg[row, idx - 1]

ranks = np.zeros_like(pos_index, dtype=np.float)
ranks = np.zeros_like(pos_index, dtype=np.float64)
ranks[:, :] = np.arange(1, pos_index.shape[1] + 1)
dcg = 1.0 / np.log2(ranks + 1)
dcg = np.cumsum(np.where(pos_index, dcg, 0), axis=1)
Expand Down
11 changes: 11 additions & 0 deletions recbole/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def sim(self, z1: torch.Tensor, z2: torch.Tensor):
z1 = fn.normalize(z1)
z2 = fn.normalize(z2)
return torch.mm(z1, z2.t())

def pair_sim(self, z1, z2):
z1 = fn.normalize(z1)
z2 = fn.normalize(z2)
return torch.sum(z1*z2, dim=1)

def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor):
f = lambda x: torch.exp(x / self.tau)
Expand Down Expand Up @@ -78,6 +83,12 @@ def vanilla_loss(self, z1: torch.Tensor, z2: torch.Tensor):
f = lambda x: torch.exp(x / self.tau)
pos_pairs = f(self.sim(z1, z2)).diag()
neg_pairs = f(self.sim(z1, z2)).sum(1)
return -torch.log(1e-8 + pos_pairs / neg_pairs)

def vanilla_loss_overall(self, z1, z2, z_2_all):
f = lambda x: torch.exp(x / self.tau)
pos_pairs = f(self.pair_sim(z1, z2))
neg_pairs = f(self.sim(z1, z_2_all)).sum(1)
return -torch.log(pos_pairs / neg_pairs)

def vanilla_loss_with_one_negative(self, z1: torch.Tensor, z2: torch.Tensor):
Expand Down
57 changes: 49 additions & 8 deletions recbole/model/sequential_recommender/dcrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@
import dgl
from dgl.nn.pytorch import GraphConv

def cal_kl(target, input):
### log with sigmoid
target = torch.sigmoid(target)
input = torch.sigmoid(input)
target = torch.log(target + 1e-8)
input = torch.log(input + 1e-8)
return F.kl_div(input, target, reduction='batchmean', log_target=True)

def cal_kl_1(target, input):
target[target<1e-8] = 1e-8
target = torch.log(target + 1e-8)
input = torch.log_softmax(input + 1e-8, dim=0)
return F.kl_div(input, target, reduction='batchmean', log_target=True)


def graph_dual_neighbor_readout(g: dgl.DGLGraph, aug_g: dgl.DGLGraph, node_ids, features):
_, all_neighbors = g.out_edges(node_ids)
Expand Down Expand Up @@ -98,9 +112,10 @@ def __init__(self, in_dim, out_dim, dropout_prob=0.7):
super(GCN, self).__init__()
self.dropout_prob = dropout_prob
self.layer = GraphConv(in_dim, out_dim, weight=False,
bias=False, allow_zero_in_degree=True)
bias=False, allow_zero_in_degree=False)

def forward(self, graph, feature):
graph = dgl.add_self_loop(graph)
origin_w, graph = graph_dropout(graph, 1-self.dropout_prob)
embs = [feature]
for i in range(2):
Expand Down Expand Up @@ -351,20 +366,45 @@ def calculate_loss_graphcl(self, interaction):
# filtering those len=1, set weight=0.5
mainstream_weights[item_seq_len == 1] = 0.5

expected_weights_distribution = torch.normal(
self.config["weight_mean"], 0.1, size=mainstream_weights.size()).sort()[0].to(self.device)
kl_loss = self.config["kl_weight"] * F.kl_div(F.log_softmax(
mainstream_weights, dim=0).sort()[0], expected_weights_distribution, reduction="batchmean")
expected_weights_distribution = torch.normal(self.config["weight_mean"], 0.1, size=mainstream_weights.size()).to(self.device)
# kl_loss = self.config["kl_weight"] * cal_kl(expected_weights_distribution.sort()[0], mainstream_weights.sort()[0])

# apply log_softmax for input
kl_loss = self.config["kl_weight"] * cal_kl_1(expected_weights_distribution.sort()[0], mainstream_weights.sort()[0])

if torch.isnan(kl_loss):
logging.info("kl_loss: {}".format(kl_loss))
logging.info("mainstream_weights: {}".format(
mainstream_weights.cpu().tolist()))
logging.info("expected_weights_distribution: {}".format(
expected_weights_distribution.cpu().tolist()))
raise ValueError("kl loss is nan")

personlization_weights = mainstream_weights.max() - mainstream_weights

# contrastive learning
if self.config["cl_ablation"] == "full":
cl_loss_adj = self.contrastive_learning_layer.grace_loss(
# cl_loss_adj = self.contrastive_learning_layer.grace_loss(
# aug_seq_output, iadj_graph_output_seq)
# cl_loss_a2s = self.contrastive_learning_layer.grace_loss(
# iadj_graph_output_seq, isim_graph_output_seq)
# cl_loss_adj = self.contrastive_learning_layer.vanilla_loss_overall(
# aug_seq_output, iadj_graph_output_seq, iadj_graph_output_raw)
# cl_loss_a2s = self.contrastive_learning_layer.vanilla_loss_overall(
# iadj_graph_output_seq, isim_graph_output_seq, isim_graph_output_raw)
cl_loss_adj = self.contrastive_learning_layer.vanilla_loss(
aug_seq_output, iadj_graph_output_seq)
cl_loss_a2s = self.contrastive_learning_layer.grace_loss(
cl_loss_a2s = self.contrastive_learning_layer.vanilla_loss(
iadj_graph_output_seq, isim_graph_output_seq)
cl_loss = (self.config["graphcl_coefficient"] * (mainstream_weights *
cl_loss_adj + personlization_weights * cl_loss_a2s)).mean()
if torch.isnan(cl_loss):
logging.error("cl_loss_adj: {}".format(cl_loss_adj.cpu().tolist()))
logging.error("cl_loss_a2s: {}".format(cl_loss_a2s.cpu().tolist()))
logging.error("mainstream_weights: {}".format(mainstream_weights.cpu().tolist()))
logging.error("personlization_weights: {}".format(personlization_weights.cpu().tolist()))
logging.error("cl loss is nan")
raise ValueError("cl loss is nan")
# Fusion After CL
if self.config["graph_view_fusion"]:
# 3, N_mask, dim
Expand All @@ -378,9 +418,10 @@ def calculate_loss_graphcl(self, interaction):
# [item_num, H]
test_item_emb = self.item_embedding.weight
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)
loss = self.loss_fct(logits+1e-8, pos_items)

if torch.isnan(loss):
logging.error("cl_loss: {}".format(cl_loss))
logging.error("loss is nan")
return loss, cl_loss, kl_loss

Expand Down
8 changes: 6 additions & 2 deletions recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ def __init__(self, config, model):
self.best_valid_result = None
self.train_loss_dict = dict()
self.optimizer = self._build_optimizer(self.model.parameters())
if 'schedule_step' in config:
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=config['schedule_step'], gamma=0.2)

self.eval_type = config['eval_type']
self.eval_collector = Collector(config)
self.evaluator = Evaluator(config)
Expand Down Expand Up @@ -162,8 +165,6 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=Fals
tuple which includes the sum of loss in each part.
"""
self.model.train()
# 每个epoch都清零global encoding
# nn.init.zeros_(self.model.global_encodings)
loss_func = loss_func or self.model.calculate_loss
total_loss = None
iter_data = (
Expand Down Expand Up @@ -387,6 +388,9 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre

valid_step+=1

if hasattr(self, 'scheduler'):
self.scheduler.step()

# self._add_hparam_to_tensorboard(self.best_valid_score)
return self.best_valid_score, self.best_valid_result

Expand Down
3 changes: 2 additions & 1 deletion run_DCRec.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,8 @@ def get_args():
config_dict['train_batch_size'] = 512

if args.dataset == "beauty":
config_dict["attn_dropout_prob"] = 0.1
config_dict["schedule_step"] = 30
config_dict["attn_dropout_prob"] = 0.1
config_dict['train_batch_size'] = 2048
config_dict["sim_group"] = 4
config_dict["weight_mean"] = 0.5
Expand Down

0 comments on commit 2338233

Please sign in to comment.