# 👋 Session-based Recommend Using SR-GNN

I referred to following code bases while writing the code below.

- https://colab.research.google.com/drive/1X4uOWv_xkefDu_h-pbJg-fEkMfR7NGz9?usp=sharing
- https://github.com/userbehavioranalysis/SR-GNN_PyTorch-Geometric
- https://rzykov.github.io/notebooks/RetailRocketDatasetNextClick.html


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


#### Instructions on getting the dataset.

**Please follow the steps below to download and use kaggle data within Google Colab:

1. Go to your account, Scroll to API section and Click Expire API Token to remove previous tokens

2. Click on Create New API Token - It will download kaggle.json file on your machine.

3. Run the following cells.**

In [12]:
# Uncomment if first time running
! pip install kaggle
from google.colab import files
files.upload()

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"marcyane7","key":"7aa918d0d5cc83067e389a89473ce5b4"}'}

In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

In [14]:
! kaggle datasets download -d retailrocket/ecommerce-dataset

Downloading ecommerce-dataset.zip to /content
 97% 282M/291M [00:03<00:00, 122MB/s]
100% 291M/291M [00:03<00:00, 95.2MB/s]


In [15]:
# # Uncomment if first time running
# ! mkdir ecommerce-dataset
# ! unzip ecommerce-dataset.zip -d ecommerce-dataset

Archive:  ecommerce-dataset.zip
  inflating: ecommerce-dataset/category_tree.csv  
  inflating: ecommerce-dataset/events.csv  
  inflating: ecommerce-dataset/item_properties_part1.csv  
  inflating: ecommerce-dataset/item_properties_part2.csv  


# ⚙️ Set Up

### You will need to restart the runtime in order to be able to use the newly installed libraries

In [5]:
# Uncomment if you have not run this cell before
! pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
! pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
! pip install torch-geometric
! pip install -q git+https://github.com/snap-stanford/deepsnap.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu113.html
Collecting torch-scatter
  Downloading torch_scatter-2.1.0.tar.gz (106 kB)
[K     |████████████████████████████████| 106 kB 5.0 MB/s 
[?25hBuilding wheels for collected packages: torch-scatter
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch-scatter: filename=torch_scatter-2.1.0-cp38-cp38-linux_x86_64.whl size=3398139 sha256=f0422cf701903c848dabe4b25c49d8e7391e2eaf959a165270fa8fca409c31be
  Stored in directory: /root/.cache/pip/wheels/41/7f/4f/cf072bea3b6efe4561de2db3603ebbd8718c134c24caab8281
Successfully built torch-scatter
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu113.html
Co

  Building wheel for deepsnap (setup.py) ... [?25l[?25hdone


In [32]:
! mkdir raw

In [16]:
# Import Python built-in libraries
import copy
import pickle
import random
import time

In [17]:
# Import pip libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm, trange

# Import torch packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data

# Import PyG packages
import torch_geometric as pyg
import torch_geometric.data as pyg_data
from torch_geometric.typing import Adj, OptTensor
import torch_sparse

# ⚗️ Data Preprocessing

## Download

You can download the data from 🔗[this kaggle competition](https://www.kaggle.com/retailrocket/ecommerce-dataset). We will only use the `events.csv` file.

## Inspection

Let's have a look at the dataset.

In [18]:
CSV_PATH = "ecommerce-dataset/events.csv"

In [19]:
# Load and have a peek into the dataset
events_df = pd.read_csv(CSV_PATH)
print(events_df.head())
print(f'There are {len(events_df)} rows in the raw data.')

       timestamp  visitorid event  itemid  transactionid
0  1433221332117     257597  view  355908            NaN
1  1433224214164     992329  view  248676            NaN
2  1433221999827     111016  view  318965            NaN
3  1433221955914     483717  view  253185            NaN
4  1433221337106     951259  view  367447            NaN
There are 2756101 rows in the raw data.


In [20]:
# Print the summarized statistics of the dataset
events_df.describe()

Unnamed: 0,timestamp,visitorid,itemid,transactionid
count,2756101.0,2756101.0,2756101.0,22457.0
mean,1436424000000.0,701922.9,234922.5,8826.497796
std,3366312000.0,405687.5,134195.4,5098.99629
min,1430622000000.0,0.0,3.0,0.0
25%,1433478000000.0,350566.0,118120.0,4411.0
50%,1436453000000.0,702060.0,236067.0,8813.0
75%,1439225000000.0,1053437.0,350715.0,13224.0
max,1442545000000.0,1407579.0,466867.0,17671.0


In [21]:
# Check the maximum unique id of items.
# We will use this as the total number of items to create item embeddings.
max(events_df['itemid'])

466867

In [15]:
display(events_df)

Unnamed: 0,timestamp,visitorid,event,itemid,transactionid
0,1433221332117,257597,view,355908,
1,1433224214164,992329,view,248676,
2,1433221999827,111016,view,318965,
3,1433221955914,483717,view,253185,
4,1433221337106,951259,view,367447,
...,...,...,...,...,...
2756096,1438398785939,591435,view,261427,
2756097,1438399813142,762376,view,115946,
2756098,1438397820527,1251746,view,78144,
2756099,1438398530703,1184451,view,283392,


## Separating log data into sessions

Let's load and break all log data into sessions for all users. Note that these preprocessing logics were taken from [here](https://rzykov.github.io/notebooks/RetailRocketDatasetNextClick.html) where the code is written in Scala.

Summary of what we do in the cell below:

1. Filter only the 'view' events. Among the three types of events provided by the dataset, we will only use the 'view' events to predict and recommend the next 'viewed' item.
2. Filter out visitors with single clicks.
3. Group events by visitor id.
4. Within the grouped events from a visitor, break and generate sessions

In [22]:
# Filter only the 'view' events.
events_df_filtered = events_df[events_df['event'] == 'view']
print(f'There are {len(events_df_filtered)} `view` events in the raw data.')

There are 2664312 `view` events in the raw data.


In [23]:
events_df_filtered

Unnamed: 0,timestamp,visitorid,event,itemid,transactionid
0,1433221332117,257597,view,355908,
1,1433224214164,992329,view,248676,
2,1433221999827,111016,view,318965,
3,1433221955914,483717,view,253185,
4,1433221337106,951259,view,367447,
...,...,...,...,...,...
2756096,1438398785939,591435,view,261427,
2756097,1438399813142,762376,view,115946,
2756098,1438397820527,1251746,view,78144,
2756099,1438398530703,1184451,view,283392,


In [24]:
# Filter out visitors with single clicks.
visit_counts_per_visitor = events_df_filtered['visitorid'].value_counts(dropna=False)
display(visit_counts_per_visitor)

1150086    6479
530559     3623
895999     2368
152963     2304
163561     2194
           ... 
908147        1
258979        1
551045        1
218233        1
1184451       1
Name: visitorid, Length: 1404179, dtype: int64

In [25]:
visitors_with_significant_visits = visit_counts_per_visitor[visit_counts_per_visitor > 1].index
events_df_filtered = events_df_filtered[events_df_filtered['visitorid'].isin(visitors_with_significant_visits)]
display(events_df_filtered)

Unnamed: 0,timestamp,visitorid,event,itemid,transactionid
0,1433221332117,257597,view,355908,
1,1433224214164,992329,view,248676,
2,1433221999827,111016,view,318965,
3,1433221955914,483717,view,253185,
5,1433224086234,972639,view,22556,
...,...,...,...,...,...
2756092,1438398473572,709520,view,104512,
2756094,1438399289446,701750,view,296172,
2756095,1438400574346,289041,view,156947,
2756098,1438397820527,1251746,view,78144,


In [26]:
# Let's group events and their timing data by their visitorid.
visits_by_visitors = {}
for _, row in enumerate(tqdm(events_df_filtered.iterrows())):
    timestamp, visitorid, event, itemid, transactionid = row[1].values

    if visitorid not in visits_by_visitors:
        visits_by_visitors[visitorid] = {'itemids': [], 'timestamps': []}
    visits_by_visitors[visitorid]['itemids'].append(itemid)
    visits_by_visitors[visitorid]['timestamps'].append(timestamp)

print()
assert len(visits_by_visitors) == events_df_filtered['visitorid'].nunique()
print(f'There are {len(visits_by_visitors)} visitors left.')

1656582it [01:10, 23358.70it/s]



There are 396449 visitors left.


In [27]:
# We will separate sessions by 2 hours.
delay = 2 * 3600 * 1000 # since timestamps are in seconds, we have 2(hours) * 3600(seconds) * 3600(miliseconds) 

# Let's group events from visitors into sessions.
sessions_by_visitors = {}
for visitorid, visitor_dict in visits_by_visitors.items():
    sessions = [[]]
    events_sorted = sorted(zip(visitor_dict['timestamps'],
                               visitor_dict['itemids']))
    for i in range(len(events_sorted) - 1):
        sessions[-1].append(events_sorted[i][1])
        if (events_sorted[i+1][0] - events_sorted[i][0]) > delay:
            sessions.append([])
    sessions[-1].append(events_sorted[len(events_sorted) - 1][1])
    sessions_by_visitors[visitorid] = sessions

print()
print(f'There are {len(sessions_by_visitors)} sessions. That is the same number of visitors.')


There are 396449 sessions. That is the same number of visitors.


## Splitting train and test dataset

Now let's split the train, validation, and test dataset. We will split the dataset by user ids. Each user will only be found in one of the three splits! We split like this because the model may 'cheat' from partial sessions of a user during training and use that information during test time.

In [28]:
# Adjsut sampling rate ([0, 1]) to generate smaller datasets.
# Setting `sampling_rate` to 1 will lead to a full dataset split.
sampling_rate = 1

# We use random seed for reproducibility.
seed = 42
all_visitors = list(sessions_by_visitors.keys())
random.Random(seed).shuffle(all_visitors)

num_train = int(len(all_visitors) * 0.8 * sampling_rate)
num_val = int(len(all_visitors) * 0.1 * sampling_rate)
num_test = int(len(all_visitors) * 0.1 * sampling_rate)

train_visitors = all_visitors[:num_train]
val_visitors = all_visitors[num_train : num_train+num_val]
test_visitors = all_visitors[num_train+num_val:num_train+num_val+num_test]

Let's check the size of each split and pickle the preprocessed data.

In [29]:
def extract_subsessions(sessions):
    """Extracts all partial sessions from the sessions given.

    For example, a session (1, 2, 3) should be augemnted to produce two
    separate sessions (1, 2) and (1, 2, 3).
    """
    all_sessions = []
    for session in sessions:
        for i in range(1, len(session)):
            all_sessions.append(session[:i+1])
    return all_sessions

In [30]:
# Check the number of visitors in each split
print(f'train, val, and test visitors: {len(train_visitors), len(val_visitors), len(test_visitors)}')

# Get sessions of each visitor, generate subsessions of each session, and put
# all the generated subsessions into right splits. We generate subsessions
# according to the dataset generation policy suggested by the original SR-GNN
# paper.
train_sessions, val_sessions, test_sessions = [], [], []
for visitor in train_visitors:
    train_sessions.extend(extract_subsessions(sessions_by_visitors[visitor]))
for visitor in val_visitors:
    val_sessions.extend(extract_subsessions(sessions_by_visitors[visitor]))
for visitor in test_visitors:
    test_sessions.extend(extract_subsessions(sessions_by_visitors[visitor]))

train, val, and test visitors: (317159, 39644, 39644)


In [33]:
# Check the number of (sub)sessions in each split
print(f'train, val, and test sessions: {len(train_sessions), len(val_sessions), len(test_sessions)}')

# Save the processed files.
with open('raw/train.txt', 'wb') as f:
    pickle.dump(train_sessions, f)
with open('raw/val.txt', 'wb') as f:
    pickle.dump(val_sessions, f)
with open('raw/test.txt', 'wb') as f:
    pickle.dump(test_sessions, f)

train, val, and test sessions: (781928, 91317, 96896)


# 📦 Data Pipeline

For data ingestion, we use PyTorch's `dataloader` and PyG's `Data` class. To learn more about the `Data` class, check out the documentation [here](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#module-torch_geometric.data).

In [34]:
class GraphDataset(pyg_data.InMemoryDataset):
    def __init__(self, root, file_name, transform=None, pre_transform=None):
        self.file_name = file_name
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [f'{self.file_name}.txt']

    @property
    def processed_file_names(self):
        return [f'{self.file_name}.pt']

    def download(self):
        pass

    def process(self):
        raw_data_file = f'{self.raw_dir}/{self.raw_file_names[0]}'
        with open(raw_data_file, 'rb') as f:
            sessions = pickle.load(f)
        data_list = []

        for session in sessions:
            session, y = session[:-1], session[-1]
            codes, uniques = pd.factorize(session)
            senders, receivers = codes[:-1], codes[1:]

            # Build Data instance
            edge_index = torch.tensor([senders, receivers], dtype=torch.long)
            x = torch.tensor(uniques, dtype=torch.long).unsqueeze(1)
            y = torch.tensor([y], dtype=torch.long)
            data_list.append(pyg_data.Data(x=x, edge_index=edge_index, y=y))

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

# 🔮 Model Design

Our gated session graph layer has two main parts: (1) message propagation to create an adjacency matrix (`self.propagate`) and (2) the GRU cell (`self.gru`). We will put these inside the `forward()` function.

We only use one layer for our `GatedSessionGraphConv` implementation for simplicity. Also, our sessions have average length < 5, so we do not need a large receptive field.

In [35]:
class GatedSessionGraphConv(pyg.nn.conv.MessagePassing):
    def __init__(self, out_channels, aggr: str = 'add', **kwargs):
        super().__init__(aggr=aggr, **kwargs)

        self.out_channels = out_channels

        self.gru = torch.nn.GRUCell(out_channels, out_channels, bias=False)

    def forward(self, x, edge_index):
        m = self.propagate(edge_index, x=x, size=None)
        x = self.gru(m, x)
        return x

    def message(self, x_j):
        return x_j

    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x, reduce=self.aggr)

In [36]:
class SRGNN(nn.Module):
    def __init__(self, hidden_size, n_items):
        super(SRGNN, self).__init__()
        self.hidden_size = hidden_size
        self.n_items = n_items

        self.embedding = nn.Embedding(self.n_items, self.hidden_size)
        self.gated = GatedSessionGraphConv(self.hidden_size)

        self.q = nn.Linear(self.hidden_size, 1)
        self.W_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.W_2 = nn.Linear(self.hidden_size, self.hidden_size)
        self.W_3 = nn.Linear(2 * self.hidden_size, self.hidden_size, bias=False)

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, data):
        x, edge_index, batch_map = data.x, data.edge_index, data.batch

        # (0)
        embedding = self.embedding(x).squeeze()

        # (1)-(5)
        v_i = self.gated(embedding, edge_index)

        # Divide nodes by session
        # For the detailed explanation of what is happening below, please refer
        # to the Medium blog post.
        sections = list(torch.bincount(batch_map).cpu())
        v_i_split = torch.split(v_i, sections)

        v_n, v_n_repeat = [], []
        for session in v_i_split:
            v_n.append(session[-1])
            v_n_repeat.append(
                session[-1].view(1, -1).repeat(session.shape[0], 1))
        v_n, v_n_repeat = torch.stack(v_n), torch.cat(v_n_repeat, dim=0)

        q1 = self.W_1(v_n_repeat)
        q2 = self.W_2(v_i)

        # (6)
        alpha = self.q(F.sigmoid(q1 + q2))
        s_g_split = torch.split(alpha * v_i, sections)

        s_g = []
        for session in s_g_split:
            s_g_session = torch.sum(session, dim=0)
            s_g.append(s_g_session)
        s_g = torch.stack(s_g)

        # (7)
        s_l = v_n
        s_h = self.W_3(torch.cat([s_l, s_g], dim=-1))

        # (8)
        z = torch.mm(self.embedding.weight, s_h.T).T
        return z

# 🚂 Model Training

We can now start model training.

In [37]:
# Define the hyperparameters.
args = {
    'batch_size': 100,
    'hidden_dim': 32,
    'epochs': 100,
    'l2_penalty': 0.00001,
    'weight_decay': 0.1,
    'step': 30,
    'lr': 0.001,
    'num_items': 466868}

class objectview(object):
    def __init__(self, d): 
        self.__dict__ = d

args = objectview(args)

In [38]:
def train(args):
    # Prepare data pipeline
    train_dataset = GraphDataset('./', 'train')
    train_loader = pyg_data.DataLoader(train_dataset,
                                       batch_size=args.batch_size,
                                       shuffle=False,
                                       drop_last=True)
    val_dataset = GraphDataset('./', 'val')
    val_loader = pyg_data.DataLoader(val_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     drop_last=True)

    # Build model
    model = SRGNN(args.hidden_dim, args.num_items).to('cuda')

    # Get training components
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.l2_penalty)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.step,
                                          gamma=args.weight_decay)
    criterion = nn.CrossEntropyLoss()

    # Train
    losses = []
    test_accs = []
    top_k_accs = []

    best_acc = 0
    best_model = None

    for epoch in range(args.epochs):
        total_loss = 0
        model.train()
        for _, batch in enumerate(tqdm(train_loader)):
            batch.to('cuda')
            optimizer.zero_grad()

            pred = model(batch)
            label = batch.y
            loss = criterion(pred, label)

            loss.backward()
            optimizer.step()
            total_loss += loss.item() * batch.num_graphs

        total_loss /= len(train_loader.dataset)
        losses.append(total_loss)

        scheduler.step()

        if epoch % 1 == 0:
          test_acc, top_k_acc = test(val_loader, model, is_validation=True)
          print(test_acc)
          test_accs.append(test_acc)
          top_k_accs.append(top_k_acc)
          if test_acc > best_acc:
            best_acc = test_acc
            best_model = copy.deepcopy(model)
        else:
          test_accs.append(test_accs[-1])
  
    return test_accs, top_k_accs, losses, best_model, best_acc, val_loader

In [39]:
def test(loader, test_model, is_validation=False, save_model_preds=False):
    test_model.eval()

    # Define K for Hit@K metrics.
    k = 20
    correct = 0
    top_k_correct = 0

    for _, data in enumerate(tqdm(loader)):
        data.to('cuda')
        with torch.no_grad():
            # max(dim=1) returns values, indices tuple; only need indices
            score = test_model(data)
            pred = score.max(dim=1)[1]
            label = data.y

        if save_model_preds:
          data = {}
          data['pred'] = pred.view(-1).cpu().detach().numpy()
          data['label'] = label.view(-1).cpu().detach().numpy()

          df = pd.DataFrame(data=data)
          # Save locally as csv
          df.to_csv('pred.csv', sep=',', index=False)
            
        correct += pred.eq(label).sum().item()

        # We calculate Hit@K accuracy only at test time.
        if not is_validation:
            score = score.cpu().detach().numpy()
            for row in range(pred.size(0)):
                top_k_pred = np.argpartition(score[row], -k)[-k:]
                if label[row].item() in top_k_pred:
                    top_k_correct += 1
    
    if not is_validation:
        return correct / len(loader), top_k_correct / len(loader)
    else:
        return correct / len(loader), 0

In [None]:
test_accs, top_k_accs, losses, best_model, best_acc, test_loader = train(args) 

print(test_accs, top_k_accs)
print("Maximum test set accuracy: {0}".format(max(test_accs)))
print("Minimum loss: {0}".format(min(losses)))

# plt.title(dataset.name)
plt.plot(losses, label="training loss" + " - ")
plt.plot(test_accs, label="test accuracy" + " - ")
plt.legend()
plt.show()

Processing...
  edge_index = torch.tensor([senders, receivers], dtype=torch.long)
Done!
Processing...
Done!
100%|██████████| 7819/7819 [07:32<00:00, 17.29it/s]
100%|██████████| 913/913 [00:15<00:00, 58.99it/s]


12.812705366922234


100%|██████████| 7819/7819 [06:27<00:00, 20.18it/s]
100%|██████████| 913/913 [00:10<00:00, 87.90it/s]


13.99014238773275


100%|██████████| 7819/7819 [06:32<00:00, 19.94it/s]
100%|██████████| 913/913 [00:10<00:00, 85.79it/s]


15.462212486308871


100%|██████████| 7819/7819 [06:28<00:00, 20.13it/s]
100%|██████████| 913/913 [00:11<00:00, 78.52it/s]


16.305585980284775


100%|██████████| 7819/7819 [06:28<00:00, 20.14it/s]
100%|██████████| 913/913 [00:10<00:00, 87.82it/s]


16.806133625410734


100%|██████████| 7819/7819 [06:29<00:00, 20.07it/s]
100%|██████████| 913/913 [00:10<00:00, 86.93it/s]


17.11500547645126


100%|██████████| 7819/7819 [06:32<00:00, 19.91it/s]
100%|██████████| 913/913 [00:10<00:00, 85.63it/s]


17.304490690032857


100%|██████████| 7819/7819 [06:31<00:00, 19.99it/s]
100%|██████████| 913/913 [00:10<00:00, 86.83it/s]


17.562979189485212


100%|██████████| 7819/7819 [06:31<00:00, 19.96it/s]
100%|██████████| 913/913 [00:10<00:00, 86.56it/s]


17.710843373493976


100%|██████████| 7819/7819 [06:34<00:00, 19.82it/s]
100%|██████████| 913/913 [00:10<00:00, 86.49it/s]


17.79408543263965


100%|██████████| 7819/7819 [06:31<00:00, 19.97it/s]
100%|██████████| 913/913 [00:10<00:00, 87.50it/s]


17.92552026286966


100%|██████████| 7819/7819 [06:32<00:00, 19.94it/s]
100%|██████████| 913/913 [00:10<00:00, 87.15it/s]


18.052573932092006


100%|██████████| 7819/7819 [06:29<00:00, 20.06it/s]
100%|██████████| 913/913 [00:10<00:00, 88.42it/s]


18.146768893756846


100%|██████████| 7819/7819 [06:29<00:00, 20.07it/s]
100%|██████████| 913/913 [00:10<00:00, 87.90it/s]


18.219058050383353


100%|██████████| 7819/7819 [06:28<00:00, 20.11it/s]
100%|██████████| 913/913 [00:10<00:00, 89.80it/s]


18.31434830230011


100%|██████████| 7819/7819 [06:28<00:00, 20.14it/s]
100%|██████████| 913/913 [00:10<00:00, 89.42it/s]


18.343921139101862


100%|██████████| 7819/7819 [06:31<00:00, 19.96it/s]
100%|██████████| 913/913 [00:10<00:00, 88.35it/s]


18.384446878422782


 28%|██▊       | 2218/7819 [01:50<04:37, 20.19it/s]

# 🧪 Evaluation

In [None]:
# Save the best model
torch.save(best_model.state_dict(), 'model')

In [None]:
# Run test for our best model to save the predictions!
test_dataset = GraphDataset('./', 'test')
test_loader = pyg_data.DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  drop_last=True)

test(test_loader, best_model, is_validation=False, save_model_preds=True)