In [16]:
import os
import time
import torch
import argparse
import numpy as np

from model import SASRec
from utils import *
from accelerate import Accelerator

In [17]:
args = argparse.Namespace()
args.dataset = 'ml-1m-incremental'
args.train_dir='default'
args.batch_size = 128
args.one_batch = 1
args.lr = 0.0003
args.maxlen = 200
args.hidden_units = 50
args.num_blocks = 2
args.num_epochs = 201
args.num_heads = 1
args.dropout_rate = 0.5
args.l2_emb = 0.0
args.device = 'cuda'
args.inference_only = False
args.state_dict_path = 'ml-1m_default/SASRec.epoch=201.lr=0.001.layer=2.head=1.hidden=50.maxlen=200.pth'
args.incremental_epochs = 5

In [18]:
def data_preparation(fgt_users, dataset):
    [user_train, user_valid, user_test, usernum, itemnum] = copy.deepcopy(dataset)
    fgt_dataset = [{}, {}, {}, 0 ,0]
    remain_dataset = [{}, {}, {}, 0 ,0]

    fgt_items = set()
    remain_items = set()
    for u in user_train.keys():
        if (u in fgt_users) and (u not in fgt_dataset[0].keys()):
            fgt_dataset[0][u] = user_train[u]
            fgt_dataset[1][u] = user_valid[u]
            fgt_dataset[2][u] = user_test[u]
            fgt_items.update(fgt_dataset[0][u])
        else:
            remain_dataset[0][u] = user_train[u]
            remain_dataset[1][u] = user_valid[u]
            remain_dataset[2][u] = user_test[u]
            remain_items.update(remain_dataset[0][u])

    fgt_dataset[3] = len(fgt_dataset[0].keys())
    fgt_dataset[4] = len(fgt_items)

    remain_dataset[3] = len(remain_dataset[0].keys())
    remain_dataset[4] = len(remain_items)

    cc = 0.0
    for u in dataset[0]:
        cc += len(dataset[0][u])
    print('average sequence length old_dataset: %.2f' % (cc / len(dataset[0])))

    cc = 0.0
    for u in fgt_dataset[0]:
        cc += len(fgt_dataset[0][u])
    print('average sequence length fgt_dataset: %.2f' % (cc / len(fgt_dataset[0])))

    cc = 0.0
    for u in remain_dataset[0]:
        cc += len(remain_dataset[0][u])
    print('average sequence length remain_dataset: %.2f' % (cc / len(remain_dataset[0])))

    return fgt_dataset, remain_dataset


In [19]:
def model_init(args, _dataset):
    dataset = copy.deepcopy(_dataset)
    model = SASRec(dataset[3], dataset[4], args).to(args.device) # no ReLU activation in original SASRec implementation?
    
    for name, param in model.named_parameters():
        try:
            torch.nn.init.xavier_normal_(param.data)
        except:
            pass # just ignore those failed init layers
    
    # this fails embedding init 'Embedding' object has no attribute 'dim'
    # model.apply(torch.nn.init.xavier_uniform_)
    
    model.train() # enable model training
    
    epoch_start_idx = 1
    if args.state_dict_path is not None:
        try:
            model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
            tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:]
            epoch_start_idx = int(tail[:tail.find('.')]) + 1
        except: # in case your pytorch version is not 1.6 etc., pls debug by pdb if load weights failed
            print('failed loading state_dicts, pls check file path: ', end="")
            print(args.state_dict_path)
            print('pdb enabled for your quick check, pls type exit() if you do not need it')
            import pdb; pdb.set_trace()
    
    return model

In [20]:
def evaluateeee(model, data, name):
    model.eval()
    t_test = evaluate(model, data, args)
    print(name + ' test (NDCG@10: %.4f, HR@10: %.4f)' % (t_test[0], t_test[1]))
    return

In [21]:
def kl_loss(pretrained_model, current_model, sampler):
    """
    sampler = Norlmal Sampler
    """
    u, seq, pos, neg = sampler.next_batch() # tuples to ndarray
    u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg)
    cur_pos_logits, cur_neg_logits = current_model(u, seq, pos, neg)
    pre_pos_logits, pre_neg_logits = pretrained_model(u, seq, pos, neg)


    # P: pretrained model; Q: current model.
    prob_p = torch.nn.functional.softmax(pre_pos_logits, dim=-1)
    prob_q = torch.nn.functional.softmax(cur_pos_logits, dim=-1)

    log_prob_q = torch.log(prob_q + 1e-12)
    kl_loss = torch.nn.functional.kl_div(log_prob_q, prob_p, reduction='batchmean')
    
    return kl_loss

In [22]:
def normal_loss(current_model, sampler, optimizer, lossfunc):
    """
    sampler = Unlearning Sampler
    """
    u, seq, pos, neg = sampler.next_batch() # tuples to ndarray
    u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg)
    pos_logits, neg_logits = current_model(u, seq, pos, neg)
    pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)
    # print("\neye ball check raw_logits:"); print(pos_logits); print(neg_logits) # check pos_logits > 0, neg_logits < 0
    optimizer.zero_grad()
    indices = np.where(pos != 0)
    loss = lossfunc(pos_logits[indices], pos_labels[indices])
    loss += lossfunc(neg_logits[indices], neg_labels[indices])

    return loss

In [23]:
def random_neq(l, r, s):  
    t = np.random.randint(l, r)
    while t in s:
        t = np.random.randint(l, r)
    return t

class tSimpleWarpSampler(object):
    def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10):
        self.User = User
        self.usernum = usernum
        self.itemnum = itemnum
        self.batch_size = batch_size
        self.maxlen = maxlen
        self.UserList = list(User.keys())

    def tsample(self):
        user = np.random.choice(self.UserList)
        while len(self.User[user]) <= 1: user = np.random.choice(self.UserList)
        
        seq = np.zeros([self.maxlen], dtype=np.int32)
        pos = np.zeros([self.maxlen], dtype=np.int32)
        neg = np.zeros([self.maxlen], dtype=np.int32)
        nxt = self.User[user][-1]
        idx = self.maxlen - 1

        ts = set(self.User[user])
        for i in reversed(self.User[user][:-1]):
            seq[idx] = i
            pos[idx] = nxt
            if nxt != 0: neg[idx] = random_neq(1, self.itemnum + 1, ts)
            nxt = i
            idx -= 1
            if idx == -1: break

        return (user, seq, pos, neg)

    def next_batch(self):
        one_batch = []
        for i in range(self.batch_size):
            one_batch.append(self.tsample())
        
        return zip(*one_batch)

def get_sampler(dataset, n_batch):
    [train, valid, test, usernum, itemnum] = dataset
    num_batch = len(train) // n_batch # tail? + ((len(user_train) % args.batch_size) != 0)
    sampler = tSimpleWarpSampler(train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen)
    return num_batch, sampler

In [24]:
def unlearning(_max_steps, _unlerning_sampler, _remain_sampler, _pre_model, _cur_model):

    epoch_start_idx = 1
    bce_criterion = torch.nn.BCEWithLogitsLoss() # torch.nn.BCELoss()
    adam_optimizer = torch.optim.Adam(_cur_model.parameters(), lr=args.lr, betas=(0.9, 0.98))

    T = 0.0
    t0 = time.time()

    max_steps = _max_steps
    _normal_loss = 0.0

    bad_weight = 0.5

    # for epoch in range(epoch_start_idx, args.incremental_epochs + 1):
    move_step = 0
    while move_step < max_steps: # tqdm(range(num_batch), total=num_batch, ncols=70, leave=False, unit='b'):
        _kl_loss = kl_loss(_pre_model, _cur_model, _remain_sampler)
        _normal_loss = normal_loss(_cur_model, _unlerning_sampler, adam_optimizer, bce_criterion)
        loss = bad_weight * _kl_loss - bad_weight * _normal_loss
        # for param in model.item_emb.parameters(): loss += args.l2_emb * torch.norm(param)
        loss.backward()
        
        adam_optimizer.step()
        move_step += 1
        # print("loss in epoch {} iteration {}: {}".format(epoch_start_idx, move_step, loss.item())) # expected 0.4~0.6 after init few epochs
            
    print("loss in step {}: {}".format(move_step, -loss.item())) # expected 0.4~0.6 after init few epochs
    print("Done")
    return _pre_model, _cur_model

## Preprocessing

In [25]:
old_dataset = data_partition('ml-1m') # {train}, {valid}, {test}, usernum, itemnum

In [26]:
pre_model = model_init(args, old_dataset)
cur_model = model_init(args, old_dataset)
cur_model.train()

SASRec(
  (item_emb): Embedding(3417, 50, padding_idx=0)
  (pos_emb): Embedding(200, 50)
  (emb_dropout): Dropout(p=0.5, inplace=False)
  (attention_layernorms): ModuleList(
    (0-1): 2 x LayerNorm((50,), eps=1e-08, elementwise_affine=True)
  )
  (attention_layers): ModuleList(
    (0-1): 2 x MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=50, out_features=50, bias=True)
    )
  )
  (forward_layernorms): ModuleList(
    (0-1): 2 x LayerNorm((50,), eps=1e-08, elementwise_affine=True)
  )
  (forward_layers): ModuleList(
    (0-1): 2 x PointWiseFeedForward(
      (conv1): Conv1d(50, 50, kernel_size=(1,), stride=(1,))
      (dropout1): Dropout(p=0.5, inplace=False)
      (relu): ReLU()
      (conv2): Conv1d(50, 50, kernel_size=(1,), stride=(1,))
      (dropout2): Dropout(p=0.5, inplace=False)
    )
  )
  (last_layernorm): LayerNorm((50,), eps=1e-08, elementwise_affine=True)
)

## Unlearning

In [27]:
fgt_users = [5,100,500]
fgt_dataset, remain_dataset = data_preparation(fgt_users, old_dataset)
fgt_num_batch, fgt_sampler = get_sampler(fgt_dataset, args.one_batch)
remain_num_batch, remain_sampler = get_sampler(remain_dataset, args.batch_size)

average sequence length old_dataset: 163.50
average sequence length fgt_dataset: 162.00
average sequence length remain_dataset: 163.50


In [28]:
evaluateeee(cur_model, old_dataset, '\n cur_old ')
evaluateeee(cur_model, fgt_dataset, '\n cur_fgt ')
evaluateeee(cur_model, remain_dataset, '\n cur_remain ')

............................................................
 cur_old  test (NDCG@10: 0.5888, HR@10: 0.8219)

 cur_fgt  test (NDCG@10: 0.8333, HR@10: 1.0000)
............................................................
 cur_remain  test (NDCG@10: 0.5859, HR@10: 0.8196)


In [29]:
pre_model, cur_model = unlearning(60, fgt_sampler, remain_sampler, pre_model, cur_model)

loss in step 60: 2.124324321746826
Done


In [30]:
evaluateeee(cur_model, old_dataset, '\n cur_old ')
evaluateeee(cur_model, fgt_dataset, '\n cur_fgt ')
evaluateeee(cur_model, remain_dataset, '\n cur_remain ')
evaluateeee(pre_model, old_dataset, '\n pre_old ')
evaluateeee(pre_model, fgt_dataset, '\n pre_fgt ')
evaluateeee(pre_model, remain_dataset, '\n pre_remain ')

............................................................
 cur_old  test (NDCG@10: 0.4688, HR@10: 0.7124)

 cur_fgt  test (NDCG@10: 0.0000, HR@10: 0.0000)
............................................................
 cur_remain  test (NDCG@10: 0.4716, HR@10: 0.7207)
............................................................
 pre_old  test (NDCG@10: 0.5932, HR@10: 0.8263)

 pre_fgt  test (NDCG@10: 1.0000, HR@10: 1.0000)
............................................................
 pre_remain  test (NDCG@10: 0.5909, HR@10: 0.8219)
