Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: pmixer/SASRec.pytorch
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: master
Choose a base ref
...
head repository: chorus12/SASRec.pytorch
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: master
Choose a head ref
Can’t automatically merge. Don’t worry, you can still create the pull request.
  • 10 commits
  • 18 files changed
  • 2 contributors

Commits on Jan 17, 2022

  1. Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
    Copy the full SHA
    1eabd11 View commit details
  2. Update README.md

    chorus12 authored Jan 17, 2022
    Copy the full SHA
    7a04761 View commit details

Commits on Jan 26, 2022

  1. Copy the full SHA
    d494dfb View commit details
  2. remove old stuff

    chorus12 committed Jan 26, 2022
    Copy the full SHA
    15d9930 View commit details
  3. Update README.md

    chorus12 authored Jan 26, 2022
    Copy the full SHA
    aecf376 View commit details
  4. Copy the full SHA
    9c79658 View commit details
  5. Update README.md

    chorus12 authored Jan 26, 2022
    Copy the full SHA
    2f8bb96 View commit details
  6. added docker support

    chorus12 committed Jan 26, 2022
    Copy the full SHA
    1de6a82 View commit details

Commits on Feb 3, 2022

  1. added parameters for flexible hitrate and ndcg limit calculation - ca…

    …n now specify K and how many other items to sample, or take the whole population and score, also added preprocessing movie lens data from raw csv files available at movielens.org
    chorus12 committed Feb 3, 2022
    Copy the full SHA
    e1cd898 View commit details
  2. Copy the full SHA
    670be0f View commit details
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
.DS_Store
__pycache__/
*_default/
.ipynb_checkpoints/
lightning_logs/
runs/
*.pt
284 changes: 284 additions & 0 deletions DataHelper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
'''
helper classes to process train, validation and test datasets for SASrecsys model
author: bazman, 2021
'''
import torch
import numpy as np
from collections import defaultdict
import random

class SequenceDataValidationFullLength(torch.utils.data.Dataset):
'''
Dataset for validation - similar to SequenceDataValidation calss but puts all items except validation into validation sequence
So it can be used for true NDCG HIT rate metrics rather than sampling 100 items in validation sequence
train -> **valid** -> test
dataset to produce validation data
Input:
- user_train : known sequence of items for the user (train data)
- user_valid : one item that makes up a next selection after user_train sequence
- usernum : number of users in user_train/user_valid
- itemnum : number of items in user_train/user_valid
- maxlen : max length of sequence
Returns:
- user_train is the same
- user_valid is appended with all items except for vlidation item and after that all those items are scored with model and logit
for the 0-th element(user_valid) should be somewhere in top 10 or 100 scores
'''
def __init__(self, user_train, user_valid, usernum, itemnum, maxlen):
'''
Input:
- user_train: dict of user training sequence
- user_valid: dict with one item for validation sequence
- usernum - number of users in dataset
- itemnum - number of items in dataset
- maxlen - max len of sequence for truncation
- ndcg_samples - how many random items do we sample to calculate hit rate and ndcg
Output:
self.seq - maxlen sequnce for train
self.valid - 101 len for validation
'''
from tqdm import tqdm
super(SequenceDataValidationFullLength, self).__init__()

# make a list of users to validate on
# limit users max to 10000 or to whatever we have in case less than 10000
if usernum > 10_000:
users = random.sample(range(1, usernum + 1), 10_000)
else:
users = range(1, usernum + 1)

# making a validation sequence with one element from valid and the rest random
# all elements that are in train plus padding zero
valid_seq = torch.zeros((len(users), itemnum), dtype=torch.int)

# make a matrix from train sequence (batch, maxlen)
final_seq = torch.zeros((len(users), maxlen), dtype=torch.int)

with tqdm(total=len(users)) as pbar:
for ii,_u in enumerate(users):
# truncate seq to maxlen
idx = min(maxlen, len(user_train[_u]))
final_seq[ii, -idx:] = torch.as_tensor(user_train[_u][-idx:])

all_items_set = set(range(1, itemnum+1)) # set of all possible items
validation_items_set = all_items_set - set(user_valid[_u])

valid_seq[ii,0] = user_valid[_u][0] # get true next element from validation set
valid_seq[ii,1:] = torch.from_numpy(np.array(list(validation_items_set))) # all items except validation one
pbar.update(1)

self.seq = final_seq # store training seq
self.valid = valid_seq # store validation seq
self.users = users # store validation users

def __getitem__(self, index):
return self.seq[index], self.valid[index]

def __len__(self):
return len(self.seq)


class SequenceDataValidation(torch.utils.data.Dataset):
'''
Dataset for validation
train -> **valid** -> test
dataset to produce validation data
Input:
- user_train : known sequence of items for the user (train data)
- user_valid : one item that makes up a next selection after user_train sequence
Returns:
- user_train is the same
- user_valid is appended with 100 random items that are not in user_trian after that 101 items are scored with model and logit
for the 0-th element(user_valid) should be somewhere in top 10 scores
'''
def __init__(self, user_train, user_valid, usernum, itemnum, maxlen, ndcg_samples=100):
'''
Input:
- user_train: dict of user training sequence
- user_valid: dict with one item for validation sequence
- usernum - number of users in dataset
- itemnum - number of items in dataset
- maxlen - max len of sequence for truncation
- ndcg_samples - how many random items do we sample to calculate hit rate and ndcg
Output:
self.seq - maxlen sequnce for train
self.valid - 101 len for validation
'''
from tqdm import tqdm
super(SequenceDataValidation, self).__init__()

# make a list of users to validate on
# limit users max to 10000 or to whatever we have in case less than 10000
if usernum > 10_000:
users = random.sample(range(1, usernum + 1), 10_000)
else:
users = range(1, usernum + 1)

# making a validation sequence with one element from valid and the rest random
# all elements that are in train plus padding zero
valid_seq = torch.zeros((len(users), ndcg_samples+1), dtype=torch.int)

# make a matrix from train sequence (batch, maxlen)
final_seq = torch.zeros((len(users), maxlen), dtype=torch.int)

with tqdm(total=len(users)) as pbar:
for ii,_u in enumerate(users):
# truncate seq to maxlen
idx = min(maxlen, len(user_train[_u]))
final_seq[ii, -idx:] = torch.as_tensor(user_train[_u][-idx:])

items_not_in_seq = np.array(list(set(range(1,itemnum+1)) - set(final_seq[ii].numpy().flatten()))) # random stuff not in final_seq
valid_seq[ii,0] = user_valid[_u][0] # get true next element from validation set
valid_seq[ii,1:] = torch.from_numpy(items_not_in_seq[np.random.randint(0, len(items_not_in_seq), ndcg_samples)]) # fill the rest with random stuff
pbar.update(1)

self.seq = final_seq # store training seq
self.valid = valid_seq # store validation seq
self.users = users # store validation users

def __getitem__(self, index):
return self.seq[index], self.valid[index]

def __len__(self):
return len(self.seq)

class SequenceDataTest(SequenceDataValidation):
'''
Dataset for test
train -> valid -> **test**
dataset to produce test data set
same as SequenceDataValidation class but uses one element from test_seq to make a test_seq
alse adds up validation item to train sequence
'''
def __init__(self, user_train, user_valid, user_test, usernum, itemnum, maxlen, ndcg_samples):
super().__init__(user_train, user_test, usernum, itemnum, maxlen, ndcg_samples)
# now we need to shift self.seq one item back
self.seq[:,:-1] = self.seq[:,1:]
# this is an extra item that will be the last in training seq
extra_valid_item = torch.as_tensor([user_valid[_u][0] for _u in self.users])
self.seq[:,-1] = extra_valid_item


class SequenceData(torch.utils.data.Dataset):
'''
dataset for training the network
'''
def __init__(self, user_seq, usernum, itemnum):
'''
user_seq is a dict with keys = userid - sequential from 1 to number of users(usernum)
itemnum - number of items in vocabulary of selected movies
Sets up the following props in the object:
seq - all elements of user seq without last element
pos - all elements without first element (shift one time item ahead)
neg - the same length but with all different elements
Resulting data looks like this:
seq = [250, 13, 251, 70, 252, 81, 237, 150, 253, 27, 143, 254, 236,
196, 229, 255, 256, 179, 167, 172, 157, 257, 39, 199, 258]
pos = [ 13, 251, 70, 252, 81, 237, 150, 253, 27, 143, 254, 236, 196,
229, 255, 256, 179, 167, 172, 157, 257, 39, 199, 258, 29]
neg = [928, 3404, 821, 2505, 1931, 2588, 1365, 527, 3140, 1615, 1649,
1981, 450, 1175, 1576, 1787, 1425, 2698, 1916, 729, 3390, 2503,
2751, 1481, 2422]
'''
from tqdm import tqdm
super(SequenceData, self).__init__()
self.usernum = usernum
self.userids = np.array(list(user_seq.keys())) # store userids in a property
self.seq, self.pos, self.neg = dict(), dict(), dict()
with tqdm(total=len(user_seq)) as pbar:
for userid, _user_seq in user_seq.items():
self.seq[userid] = np.array(_user_seq[:-1]) # all but last element
self.pos[userid] = np.array(_user_seq[1:]) # shifted one time slot ahead
# negative sequence
items_not_in_seq = np.array(list(set(range(1,itemnum+1)) - set(_user_seq))) # all items from vocab that are out of user_seq
self.neg[userid] = items_not_in_seq[np.random.randint(0, len(items_not_in_seq), len(self.seq[userid]))] # select random items from above array
pbar.update(1)

def __getitem__(self, index):
userid = self.userids[index]
return userid, self.seq[userid], self.pos[userid], self.neg[userid]

def __len__(self):
return len(self.seq)

def tokenize_batch(batch, max_len=200):
'''
use tokenizer to cast dict type to tensors and shrink the data to maxlen - nothing else
could have made it in dataset directly but anyway...
'''
u = []
seq_list = []
pos_list = []
neg_list = []

# torch.zeros(max_len, dtype=torch.int)
# torch.zeros_like(seq_batch)
# torch.zeros_like(seq_batch)

for _u, seq, pos, neg in batch:
# fixed size tensor of max_len
seq_holder = torch.zeros(max_len, dtype=torch.int)
pos_holder = torch.zeros_like(seq_holder)
neg_holder = torch.zeros_like(seq_holder)

idx = min(max_len, len(seq))
seq_holder[-idx:] = torch.from_numpy(seq[-idx:])
pos_holder[-idx:] = torch.from_numpy(pos[-idx:])
neg_holder[-idx:] = torch.from_numpy(neg[-idx:])

seq_list.append(seq_holder.unsqueeze(dim=0))
pos_list.append(pos_holder.unsqueeze(dim=0))
neg_list.append(neg_holder.unsqueeze(dim=0))
u.append(_u)
return u, torch.cat(seq_list, dim=0), torch.cat(pos_list, dim=0), torch.cat(neg_list, dim=0)


# train/val/test data generation
def data_partition(fname):
'''
Partition the data into train, valid and test sets.
Input : file in format
user_id<space>item_selected
...
user_id<space>item_selected
All items appear according to time order
Returns:
user_train - dict with key = userid and value = list of all items selected in respected time order
user_valid - dict with the same structure as above but with penulitimate item (just one item)
user_test - same as above but with ultimate item selected
i.e. you have user 5 with items 1,29,34,15,8 there will be
user_train[5] = [1,29,34], user_valid = [15], user_test=[8]
usernum - number of users
itemnum - number of items
'''
usernum = 0
itemnum = 0
User = defaultdict(list)
user_train = {}
user_valid = {}
user_test = {}
# assume user/item index starting from 1
f = open('data/%s.txt' % fname, 'r')
for line in f:
u, i = line.rstrip().split(' ')
u = int(u)
i = int(i)
usernum = max(u, usernum)
itemnum = max(i, itemnum)
User[u].append(i)

for user in User:
nfeedback = len(User[user])
if nfeedback < 3:
user_train[user] = User[user]
user_valid[user] = []
user_test[user] = []
else:
user_train[user] = User[user][:-2]
user_valid[user] = []
user_valid[user].append(User[user][-2])
user_test[user] = []
user_test[user].append(User[user][-1])
return [user_train, user_valid, user_test, usernum, itemnum]
32 changes: 20 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
update, with few lines of manually initialization code added, it converges as fast as tf version. BTW, I strongly recommend checking issues of the repo from time to time for knowing new updates and details :)
Implementation of SASRec model via pytorch/lightning.
Originally based on [this code](https://github.com/pmixer/SASRec.pytorch) but rewritten completely to achieve same values for metrics as in paper.
![NDCG@10 on Movie Lens 1M](./ndcg.png)
[Implementation by authors of paper](https://github.com/kang205/SASRec)

---

update: a pretrained model added, pls run the command as below to test its performance(current perf still not as good as paper's reported results after trained more epochs, maybe due to leaky causual attention weights issue got fixed by using PyTorch 1.6's MultiHeadAttention, pls help identifying the root cause if you are interested):

Code for running multiple GPU training:
```
python main.py --device=cuda --dataset=ml-1m --train_dir=default --state_dict_path='ml-1m_default/SASRec.epoch=601.lr=0.001.layer=2.head=1.hidden=50.maxlen=200.pth' --inference_only=true --maxlen=200
PL_TORCH_DISTRIBUTED_BACKEND=nccl python SASRecMain.py --dataset=ml-1m --maxlen=200 --dropout_rate=0.2 --d_model=50 --num_blocks=2 --num_heads=1 --ndcg_samples=100 --top_k=10 --opt=AdamW --lr=0.001 --weight_decay=1 --batch_size=1024 --num_epochs=300 --use_swa=True --swa_epoch_start=0.65 --swa_annealing_epochs=10 --xavier_init=True --strategy=ddp_spawn --precision=16 --accelerator=auto --devices=auto --l2_pe_reg=1
```
Don't forget to run tensorboard as well
```
tensorboard --logdir ./lightning_logs/ --host 0.0.0.0
```
To use in inference mode run
```
python SASRecMain.py --dataset=ml-1m --inference_only=True --checkpoint_path=./sasrec.ckpt --accelerator=auto
```
This will produce metrics on validation dataset similar to those:
```
DATALOADER:0 VALIDATE RESULTS
{'hr_val': 0.8273178935050964, 'ndcg_val': 0.5920551419258118}
```

---

modified based on [paper author's tensorflow implementation](https://github.com/kang205/SASRec), switching to PyTorch(v1.6) for simplicity, executable by:

```python main.py --dataset=ml-1m --train_dir=default --maxlen=200 --dropout_rate=0.2 --device=cuda```
To run interactive version use [notebook](./SASRec_interactive.ipynb)

pls check paper author's [repo](https://github.com/kang205/SASRec) for detailed intro and more complete README, and here's paper bib FYI :)

```
@inproceedings{kang2018self,
234 changes: 234 additions & 0 deletions SASRecMain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
"""
Implementation of Self-attentive sequential recommendation paper:
@inproceedings{kang2018self,
title={Self-attentive sequential recommendation},
author={Kang, Wang-Cheng and McAuley, Julian},
booktitle={2018 IEEE International Conference on Data Mining (ICDM)},
pages={197--206},
year={2018},
organization={IEEE}
}
Originally taken [this code](https://github.com/pmixer/SASRec.pytorchhttps://github.com/pmixer/SASRec.pytorch) and rewritten model class plus used lightning.
on multiple GPU run with command:
PL_TORCH_DISTRIBUTED_BACKEND=nccl python SASRecMain.py --dataset=ml-1m --maxlen=200 --dropout_rate=0.2 --d_model=50 --num_blocks=2 --num_heads=1 --ndcg_samples=100 --top_k=10 --opt=AdamW --lr=0.001 --weight_decay=1 --batch_size=1024 --num_epochs=300 --use_swa=True --swa_epoch_start=0.65 --swa_annealing_epochs=10 --xavier_init=True --strategy=ddp_spawn --precision=16 --accelerator=auto --devices=auto --l2_pe_reg=1
to calc validation metrics run with:
python SASRecMain.py --dataset=ml-1m --inference_only=True --checkpoint_path=./sasrec.ckpt --accelerator=auto
don't forget to launch tensorboard with:
tensorboard --logdir ./lightning_logs/ --host 0.0.0.0
Author: Sergei Bazhin
Date: 2021-DEC - JAN-2022
"""
import os
import numpy as np
import torch
import pytorch_lightning as pl
import argparse
# module with datasets definition = train, validation and test
import DataHelper as DH
import SASRecModel as SASRec
import torch.optim as optim
import torch.nn.functional as F
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging

from torch.nn import MultiheadAttention, LayerNorm, Dropout, Conv1d, Embedding, BCEWithLogitsLoss
from SASRecModel import PointWiseFF, SASRecEncoderLayer, PositinalEncoder, SASRecEncoder


# setup command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='ml-1m',
required=True,
help="dataset to use : Beauty, ml-1m(default), Steam or Video")

parser.add_argument('--maxlen', default=50, type=int,
help="truncate input sequence to last maxlen items, default 50")
parser.add_argument('--hidden_units', default=50, type=int, help="synonym for d_model") # synonym for d_model
parser.add_argument('--d_model', default=50, type=int,
help="Transformer internal dimention") # same as hidden_units
parser.add_argument('--num_blocks', default=2, type=int, help="Number of blocks in Transformer")
parser.add_argument('--num_heads', default=1, type=int, help="Number of heads in self-attention")
parser.add_argument('--dropout_rate', default=0.5, type=float, help="Dropout rate for Transformer")
parser.add_argument('--l2_pe_reg', default=0.1, type=float, help="Regularization for positional embedding")


parser.add_argument('--ndcg_samples', default=100, type=int,
help="How many random items to pick up in hit-rate and ndcg calculation, default 100, if set to -1 then use all items on validation along with inference_only flag set to true")
parser.add_argument('--top_k', default=10, type=int,
help="How many items with high scores to pick for hit-rate and ndcg calculation, default 10")
parser.add_argument('--opt', default='Adam', type=str, help="Oplimizer to use: Adam(default), AdmaW, FusedAdam(requires apex library)")
parser.add_argument('--lr', default=0.001, type=float,
help="learning rate, default 0.001")
parser.add_argument('--weight_decay', default=0.001, type=float, help="Weight decay for AdmaW")
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--warmup_proportion', default=0.2, type=float, help="Fraction of total optimization steps to increase learning rate from zero to max value")
# for different optimizers - regular Adam uses num_epochs and LAMB uses max_iters
parser.add_argument('--max_iters', default=10000, type=int, help="Optimization budget in update iterations")
parser.add_argument('--num_epochs', default=201, type=int, help="Number of epochs to train")
# swa parameters
parser.add_argument('--use_swa', default=False, type=bool, help="Use Stochastic Weights Ageraging algorythm")
parser.add_argument('--swa_epoch_start', default=0.8, type=float, help="Start SWA after that part of total epochs")
parser.add_argument('--swa_annealing_epochs', default=10, type=int, help="Number of epochs in the annealing phase of SWA")

# xavier init
parser.add_argument('--xavier_init', default=True, type=bool, help="Use xavier normal to init the model")

parser.add_argument('--inference_only', default=False, type=bool)
parser.add_argument('--checkpoint_path', default=None, type=str, help="Path to lightning checkpoint file")

# Torch Lightning settings
# https://pytorch-lightning.readthedocs.io/en/stable/advanced/multi_gpu.html
# Data Parallel (strategy='dp') (multiple-gpus, 1 machine)
# DistributedDataParallel (strategy='ddp') (multiple-gpus across many machines (python script based)).
# DistributedDataParallel (strategy='ddp_spawn') (multiple-gpus across many machines (spawn based)).
# DistributedDataParallel 2 (strategy='ddp2') (DP in a machine, DDP across machines).
# Horovod (strategy='horovod') (multi-machine, multi-gpu, configured at runtime)
# TPUs (tpu_cores=8|x) (tpu or TPU pod)
parser.add_argument('--strategy', default='ddp_spawn', type=str, help="Lightning parallel training strategy dp, ddp, ddp_spawn(default), ddp2, etc ")
parser.add_argument('--precision', default=16, type=int, help="Lightning precision for model data during trining 16(default) or 32")
parser.add_argument('--accelerator', default="auto", type=str, help="Lightning accelerator auto(defaut), cpu, gpu, tpu")
parser.add_argument('--devices', default="auto", type=str,
help="Lightning devices to use - see https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#devices")

# args = parser.parse_args(['--dataset=ml-1m', '--train_dir=default',
# '--maxlen=200', '--dropout_rate=0.2', '--device=cuda'])

args = parser.parse_args()
args = vars(args)


if __name__ == '__main__':
# read dataset
dataset = DH.data_partition(args['dataset'])

print('\nRuntime parameters\n',*[(k, v) for (k, v) in args.items()], sep="\n")

[user_train, user_valid, user_test, usernum, itemnum] = dataset

# batches got sliced by users, i.e. batch accumulate BATCH_SIZE user sequences of items selected/bought
BATCH_SIZE = args['batch_size']
num_batch = len(user_train) // BATCH_SIZE # number of batches

user_train_lens = list(map(len, [v for k, v in user_train.items()]))
print(
f'average sequence length: {sum(user_train_lens)/len(user_train):.1f}')

print(f"\nBatch size is - {BATCH_SIZE}\n")



callbacks_list = []
# save checkpoints
callbacks_list.append(ModelCheckpoint(monitor="ndcg_val", mode="max",
filename="sasrec_{epoch:05d}_{step}_{ndgc_val:.4f}"))

# use SWA
if args['use_swa']:
callbacks_list.append(StochasticWeightAveraging(swa_epoch_start=args['swa_epoch_start'],
annealing_epochs=args['swa_annealing_epochs']))

# https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
trainer = pl.Trainer(strategy=args['strategy'],
accelerator=args['accelerator'],
devices=args['devices'],
max_epochs=args['num_epochs'],
reload_dataloaders_every_n_epochs=1,
accumulate_grad_batches=4,
val_check_interval=1.0,
callbacks=callbacks_list,
# log 4 times per epoch
log_every_n_steps=int(
len(user_train) / args['batch_size'] / 3),
# log_every_n_steps=1,
# limit_val_batches=0, How much of validation dataset to check. Useful when debugging or testing something that happens at the end of an epoch.
num_sanity_val_steps=1)

# no training but only validation metrics
if args['inference_only']:
model = SASRecEncoder.load_from_checkpoint(args['checkpoint_path'])
model.hparams.top_k = args['top_k']
if args['ndcg_samples'] == -1 :
val_loader = torch.utils.data.DataLoader(dataset=DH.SequenceDataValidationFullLength(user_train,
user_valid,
usernum,
itemnum,
model.hparams.maxlen),
batch_size=128,
shuffle=False,
drop_last=False)
else:
val_loader = torch.utils.data.DataLoader(dataset=DH.SequenceDataValidation(user_train,
user_valid,
usernum,
itemnum,
model.hparams.maxlen,
model.hparams.ndcg_samples),
batch_size=128,
shuffle=False,
drop_last=False)
trainer.validate(model, dataloaders=val_loader)
else: # start training routine
if args['ndcg_samples'] == -1 :
val_loader = torch.utils.data.DataLoader(dataset=DH.SequenceDataValidationFullLength(user_train,
user_valid,
usernum,
itemnum,
args['maxlen']),
batch_size=args['batch_size'],
shuffle=False,
drop_last=False)
else:
val_loader = torch.utils.data.DataLoader(dataset=DH.SequenceDataValidation(user_train,
user_valid,
usernum,
itemnum,
args['maxlen'],
args['ndcg_samples']),
batch_size=args['batch_size'],
shuffle=True,
drop_last=True)

# test_loader = torch.utils.data.DataLoader(dataset=DH.SequenceDataTest(user_train,
# user_valid,
# user_test,
# usernum,
# itemnum,
# args['maxlen'],
# args['ndcg_samples']),
# batch_size=args['batch_size'], shuffle=False,
# drop_last=True)

train_loader = torch.utils.data.DataLoader(dataset=DH.SequenceData(user_train, usernum, itemnum),
batch_size=args['batch_size'],
shuffle=True,
collate_fn=DH.tokenize_batch)


if args['opt'] == 'FusedAdam':
try:
import apex
except ModuleNotFoundError:
print("\n >>>No apex installed - switching to simple Adam<<<\n")
args['opt'] = 'Adam'
model = SASRecEncoder(itemnum, **args)

if args['xavier_init']:
# weight initialization
print("\nRunning weights initialization with xavier normal...\n")
for name, param in model.named_parameters():
try:
torch.nn.init.xavier_normal_(param.data)
print(f"{name:<40} sucess")
except:
print(f"{name:<40} failure")

trainer.fit(model, train_loader, val_loader)

torch.save(model.state_dict(), f"sasrec_{trainer.logger.version}.pt")

# metrics on test dataset
# trainer.test(model, test_loader)
315 changes: 315 additions & 0 deletions SASRecModel.py

Large diffs are not rendered by default.

4,382 changes: 4,382 additions & 0 deletions SASRec_interactive.ipynb

Large diffs are not rendered by default.

38 changes: 38 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
FROM pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel

ENV TZ 'Europe/Moscow'
RUN echo $TZ > /etc/timezone
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y locales language-pack-ru sudo nano unzip wget git
RUN sed -i -e 's/# ru_RU.UTF-8 UTF-8/ru_RU.UTF-8 UTF-8/' /etc/locale.gen && \
dpkg-reconfigure --frontend=noninteractive locales && \
update-locale LANG=ru_RU.UTF-8
ENV LANG ru_RU.UTF-8
ENV LANGUAGE ru_RU
ENV LC_ALL ru_RU.UTF-8
RUN apt-get clean

RUN useradd -ms /bin/bash testuser && \
echo "testuser:testuser" | chpasswd && \
usermod -aG sudo testuser && \
chmod 777 -R /root

WORKDIR /home/testuser

USER testuser
ENV PATH="/home/testuser/.local/bin:${PATH}"

COPY requirements.txt .
RUN pip install -r requirements.txt

RUN mkdir -p /home/testuser/.jupyter/lab/user-settings/@jupyterlab/terminal-extension
COPY --chown=testuser plugin.jupyterlab-settings /home/testuser/.jupyter/lab/user-settings/@jupyterlab/terminal-extension/
COPY --chown=testuser jupyter_lab_config.py /home/testuser/.jupyter/
ENV JUPYTER_TOKEN="<your token here>"

RUN openssl req -x509 -sha256 -nodes -days 3650 -newkey rsa:4096 -keyout .jupyter/jupyter.key -out .jupyter/jupyter.pem \
-subj "/C=RU/ST=Uranopolis/L=SkyCity/O=Space/OU=DS/CN=heaven.is"

ENV SHELL="/bin/bash"

RUN git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
CMD jupyter lab -e JUPYTER_TOKEN=$JUPYTER_TOKEN
996 changes: 996 additions & 0 deletions docker/jupyter_lab_config.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions docker/plugin.jupyterlab-settings
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"fontFamily": "Ubuntu mono",
"fontSize": 16}
16 changes: 16 additions & 0 deletions docker/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# How to build an image for this project
This docker image contains a bit more that is needed for the model to run but it will be handy for all your pytorch/lightning projects.
Before building an image you might want to customize it:
- change locale of operating system (it is set to Russian)
- setup token to access jupyter lab in line 30 of Dockerfile `ENV JUPYTER_TOKEN="<your token here>"`
- in case you ever need a username and password for the image it is `testuser\testuser`
- install on your client PC a ubuntu font - jupyter lab is set to use it or change font in `plugin.jupyterlab-settings`

```sh
docker build . -t sasrec_torch
```
This will make an image `sasrec_torch`
Run image with command
```
docker run -it --name torch --gpus all --rm --privileged -p 8888:8888 -p 6006:6006 -v <host path>:/home/testuser/shared_folder sasrec_torch jupyter lab
```
21 changes: 21 additions & 0 deletions docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
matplotlib
pandas
Pillow
scikit-learn
scipy
statsmodels
tornado
tqdm
jupyterhub==0.9.3
seaborn
tables
h5py
jupytext
jupyterlab
ipywidgets
transformers
tensorboard
awscli
pytorch-lightning==1.5.9
pyarrow
plotly
130 changes: 0 additions & 130 deletions main.py

This file was deleted.

Binary file not shown.
14 changes: 0 additions & 14 deletions ml-1m_default/args.txt

This file was deleted.

121 changes: 0 additions & 121 deletions model.py

This file was deleted.

Binary file added ndcg.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added sasrec-loss-function.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
200 changes: 0 additions & 200 deletions utils.py

This file was deleted.