# GNN Recommendation System Example

This example is based on the [RecSys Challenge 2015](https://www.kaggle.com/datasets/chadgostopp/recsys-challenge-2015/download?datasetVersionNumber=1). The dataset has been constructed by YOOCHOOSE GmbH, and contains a collection of sessions from a retailer, where each session contains clicks performed by a user in the session. For some sessions there are also buy events, meaning that the session ended with the user buying something from the web shop. Available datasets are as follows:

* Clicks dataset
  * File `yoochoose-clicks.dat` contains the clicks performed by the users. Data is as follows:
    * Session ID - ID of the session
    * Timestamp - date and time when the click took place
    * Item ID - unique identifier of the item that has been clicked
    * Category - context of the click, so that S indicates a special offer, 0 indicates a missing value and any other value represents a brand
* Buys dataset
  * File `yoochoose-buys.dat` contains the buy events of the users. Data is as follows:
    * Sessions ID - ID of the session
    * Timestamp - date and time when the buy event took place
    * Item ID - unique identifier of the item that has been bought
    * Price - price of the item that has been bought
    * Quantity - quantity of the items bought
* Test dataset
  * File `yoochoose-test.dat` contains only clicks of users over time. This is the test file used in the challenge.

The idea is to be able to associate click streams with items, so that a user entering the web shop can be recommended items that might be of interest based on the clicks performed by the user. The GNN approach taken in this example is inspired by [Hands-on Graph Neural Networks with PyTorch & PyTorch Geometric](https://towardsdatascience.com/hands-on-graph-neural-networks-with-pytorch-pytorch-geometric-359487e221a8) (code available from [here](https://github.com/khuangaf/PyTorch-Geometric-YooChoose/blob/master/YooChooseClick.ipynb)). The network uses a SageConv layer from [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216) and it is defined as follows:

$$\begin{align}
h_{N(v)}^k &\leftarrow \text{AGGREGATE}_{k}\left(\left\{h_y^{k-1}, \forall u \in N(v)\right\}\right) \\
h_v^k &\leftarrow \sigma \left(W^k \cdot \text{CONCAT}\left(h_v^{k-1}, h_{N(v)}^k\right) \right)
\end{align}$$

, where $\sigma$ is a non-linear activation function. If we use max-pooling as the aggregation method, the right-hand side of the first equation can be written as follows:

$$ \text{AGGREGATE}_{k}\left(\left\{h_y^{k-1}, \forall u \in N(v)\right\}\right) = max(\left(\left\{\sigma \left(W_\text{pool}h_{u_i}^k + b \right), \forall u_i \in N(v) \right\}\right))$$

Each neighboring node embedding is multiplied by the weight matrix, after which a bias is added, and then an activation-function is applied.

In [1]:
import os
import torch
from torch.nn import Linear, Parameter, ReLU
import torch_geometric
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, degree
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.nn import GraphConv, TopKPooling, GatedGraphConv, global_mean_pool, global_max_pool
from torch_geometric.loader import DataLoader
import opendatasets as od
import pandas as pd
from tqdm import tqdm
import numpy as np
from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
import joblib
import math

## Data Loader


`RecSys2015Dataset` class inherits the base class [InMemoryDataset](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.InMemoryDataset.html), and the functions for downloading and processing input files. Following is a description of the click- and the purchase datasets that are processed into a more appropriate format for the network by the `RecSys2015Dataset`-object.

### Click Data

Click data consists of the session-id, timestamp, item-id and category of each click stream.

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>session_id</th>
      <th>timestamp</th>
      <th>item_id</th>
      <th>category</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>1</td>
      <td>2014-04-07T10:54:09.868Z</td>
      <td>214536500</td>
      <td>0</td>
    </tr>
    <tr>
      <th>1</th>
      <td>1</td>
      <td>2014-04-07T10:54:46.998Z</td>
      <td>214536506</td>
      <td>0</td>
    </tr>
    <tr>
      <th>2</th>
      <td>1</td>
      <td>2014-04-07T10:57:00.306Z</td>
      <td>214577561</td>
      <td>0</td>
    </tr>
    <tr>
      <th>3</th>
      <td>2</td>
      <td>2014-04-07T13:56:37.614Z</td>
      <td>214662742</td>
      <td>0</td>
    </tr>
    <tr>
      <th>4</th>
      <td>2</td>
      <td>2014-04-07T13:57:19.373Z</td>
      <td>214662742</td>
      <td>0</td>
    </tr>
  </tbody>
</table>

### Purchase Data

Purchase data consists of the session-id, timestamp, item-id, price and quatity of purchase(s). Purchases can be related to the click-streams using the session-id:s.

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>session_id</th>
      <th>timestamp</th>
      <th>item_id</th>
      <th>price</th>
      <th>quantity</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>0</th>
      <td>420374</td>
      <td>2014-04-06T18:44:58.325Z</td>
      <td>214537850</td>
      <td>10471</td>
      <td>1</td>
    </tr>
    <tr>
      <th>1</th>
      <td>281626</td>
      <td>2014-04-06T09:40:13.032Z</td>
      <td>214535653</td>
      <td>1883</td>
      <td>1</td>
    </tr>
    <tr>
      <th>2</th>
      <td>420368</td>
      <td>2014-04-04T06:13:28.848Z</td>
      <td>214530572</td>
      <td>6073</td>
      <td>1</td>
    </tr>
    <tr>
      <th>3</th>
      <td>420368</td>
      <td>2014-04-04T06:13:28.858Z</td>
      <td>214835025</td>
      <td>2617</td>
      <td>1</td>
    </tr>
    <tr>
      <th>4</th>
      <td>140806</td>
      <td>2014-04-07T09:22:28.132Z</td>
      <td>214668193</td>
      <td>523</td>
      <td>1</td>
    </tr>
  </tbody>
</table>

### Converting Click-Stream into a Graph

Each of the click-streams, grouped using the session-id:s, represents a graph. Following table shows an example:

<table border="1" class="dataframe">
  <thead>
    <tr style="text-align: right;">
      <th></th>
      <th>session_id</th>
      <th>timestamp</th>
      <th>item_id</th>
      <th>category</th>
      <th>label</th>
      <th>node_index</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <th>743958</th>
      <td>239223</td>
      <td>2014-04-07T17:18:56.403Z</td>
      <td>75</td>
      <td>0</td>
      <td>False</td>
      <td>1</td>
    </tr>
    <tr>
      <th>743959</th>
      <td>239223</td>
      <td>2014-04-07T17:22:01.439Z</td>
      <td>85</td>
      <td>0</td>
      <td>False</td>
      <td>2</td>
    </tr>
    <tr>
      <th>743960</th>
      <td>239223</td>
      <td>2014-04-07T17:22:05.276Z</td>
      <td>75</td>
      <td>0</td>
      <td>False</td>
      <td>1</td>
    </tr>
    <tr>
      <th>743961</th>
      <td>239223</td>
      <td>2014-04-07T17:22:08.746Z</td>
      <td>65</td>
      <td>0</td>
      <td>False</td>
      <td>0</td>
    </tr>
    <tr>
      <th>743962</th>
      <td>239223</td>
      <td>2014-04-07T17:22:13.399Z</td>
      <td>91</td>
      <td>0</td>
      <td>False</td>
      <td>3</td>
    </tr>
    <tr>
      <th>743963</th>
      <td>239223</td>
      <td>2014-04-07T17:22:15.920Z</td>
      <td>92</td>
      <td>0</td>
      <td>False</td>
      <td>4</td>
    </tr>
  </tbody>
</table>

Node-indices are obtained simply by encoding the item-id:s into unique identifiers using an [OrdinalEncoder](https://scikit-learn.org/0.20/modules/generated/sklearn.preprocessing.OrdinalEncoder.html). In the above example, node connectivity is: 

$$1 \rightarrow 2 \rightarrow 1 \rightarrow 0 \rightarrow 3 \rightarrow 4$$

In [2]:
class RecSys2015Dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)

        # Load parameters
        self.data, self.slices = torch.load(self.processed_paths[0])

        # Load the label-encoder that was created earlier
        self.item_id_encoder: OrdinalEncoder = joblib.load(os.path.join(self.processed_dir, "item_id_encoder.joblib"))

    @property
    def raw_file_names(self):
        return ["yoochoose-buys.dat",
                "yoochoose-clicks.dat",
                "yoochoose-test.dat"]

    @property
    def processed_file_names(self):
        return ["data.pt", "item_id_encoder.joblib"]

    def download(self):
        """
        Downloads the data, unless it exists already.
        """
        # Download data and move to self.raw_dir
        od.download("https://www.kaggle.com/datasets/chadgostopp/recsys-challenge-2015/download?datasetVersionNumber=1")
        os.rename('recsys-challenge-2015', self.raw_dir)
        os.removedirs('recsys-challenge-2015')

    def get_max_item_id(self):
        """
        Get maximum item id

        Returns
        -------
        int
            Maximum item id
        """
        return self._data.x.max().item()

    def encode_item_ids(self, x):
        """
        Encode item IDs

        Parameters
        ----------
        x : array-like of shape (n_samples, n_features)
            Item IDs to be encoded

        Returns
        -------
        np.ndarray
            Encoded item IDs
        """
        return self.item_id_encoder.transform(x)

    def decode_item_ids(self, x):
        """
        Decode encoded item IDs

        Parameters
        ----------
        x : array-like of shape (n_samples, n_features)
            Encoded item IDs

        Returns
        -------
        np.ndarray
            Decode item IDs
        """
        return self.item_id_encoder.inverse_transform(x)

    def process(self):
        """
        Processes in input files and converts these into node features, connectivity information for the
        edges and target information.

        Raises
        ------
        FileNotFoundError
            Raised if the file yoochoose-buys.dat is not found
        FileNotFoundError
            Raise if the file yoochoose-clicks.dat is not found
        """
        # Read the csv files
        buy_data_path = os.path.join(self.raw_dir, 'yoochoose-buys.dat')
        click_data_path = os.path.join(self.raw_dir, 'yoochoose-clicks.dat')

        # Verify that the files exist
        if not os.path.isfile(buy_data_path):
            raise FileNotFoundError(buy_data_path)

        if not os.path.isfile(click_data_path):
            raise FileNotFoundError(click_data_path)

        # Read in the data files
        purchases = pd.read_csv(buy_data_path, dtype={2: 'int'})
        purchases.columns = ["session_id", "timestamp", "item_id", "price", "quantity"]
        
        clicks = pd.read_csv(click_data_path, dtype={2: 'int', 3: 'str'})
        clicks.columns = ["session_id", "timestamp", "item_id", "category"]

        # Print contents of purchases and clicks
        # clicks_small = clicks.head(5)
        # clicks_small.to_html('clicks.html')
        # clicks_small.to_markdown('clicks.md')
        # purchases_small = purchases.head(5)
        # purchases_small.to_html('purchases.html')
        # purchases_small.to_markdown('purchases.md')

        # Choose 1M random (unique) clicks from the data
        sampled_session_id = np.random.choice(clicks.session_id.unique(), 1000000, replace=False)
        clicks = clicks.loc[clicks.session_id.isin(sampled_session_id)]

        # Encode the item_id:s
        item_id_encoder = OrdinalEncoder()
        item_id_encoder.fit(clicks.item_id.to_numpy().reshape(-1, 1))
        clicks.item_id = item_id_encoder.transform(clicks.item_id.to_numpy().reshape(-1, 1))

        # Encode the categories
        category_encoder = OrdinalEncoder()
        category_encoder.fit(clicks.category.to_numpy().reshape(-1, 1))
        clicks.category = category_encoder.transform(clicks.category.to_numpy().reshape(-1, 1))
        
        # Filter out those click-sessions that have less than 2 clicks
        clicks = clicks[clicks['session_id'].map(clicks['session_id'].value_counts() > 2)]

        # Did the clicking session lead to a purchase event?
        clicks['label'] = clicks.session_id.isin(purchases['session_id'])

        data_list = []

        # Iterate over groups of session IDs
        for session_id, group in tqdm(clicks.groupby('session_id')):
            
            # Encode item_id of each group. Encoded item_id is the node_index
            encoder = LabelEncoder()
            group['node_index'] = encoder.fit_transform(group['item_id'])
            
            # Node features -> [node_index, item_id]
            #node_features = group.loc[group.session_id == session_id, ['node_index', 'item_id']].sort_values('node_index')
            node_features = group[['node_index', 'item_id', 'category']].reset_index(drop=True)

            # Target- and source nodes
            target_nodes = group.node_index[1:].reset_index(drop=True)
            source_nodes = group.node_index[:-1].reset_index(drop=True)

            # Edge indices [source -> target]
            edge_index = torch.tensor(np.vstack((source_nodes.to_numpy(), target_nodes.to_numpy())), dtype=torch.long)

            # 'x' consists of the node features [node_index, item_id, category]
            x = torch.IntTensor(node_features.to_numpy())
            # 'y' marks if there was a buying decision (True/False) at the end of the clicking session
            y = torch.FloatTensor([int(group.label.iloc[0])])

            data = Data(x=x, edge_index=edge_index, y=y)
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        joblib.dump(item_id_encoder, os.path.join(self.processed_dir, "item_id_encoder.joblib"))

In [3]:
class SAGEConv(MessagePassing):
    """
    Implementation of the GraphSAGE operator from https://arxiv.org/abs/1706.02216
    """
    def __init__(self, in_channels, out_channels):
        """
        SAGEConv constructor.

        Parameters
        ----------
        in_channels : int
            Size of input
        out_channels : int
            Size of output
        """
        super(SAGEConv, self).__init__(aggr='max')

        self.linear_aggregate = Linear(in_channels, out_channels)
        self.activation_aggregate = ReLU()
        self.update_linear = Linear(in_channels + out_channels, in_channels, bias=False)
        self.update_activation = ReLU()

    def forward(self, x: torch.tensor, edge_index: torch.tensor):
        """
        Runs the forward pass of the module.

        Parameters
        ----------
        x : torch.tensor
            embeddings
        edge_index : torch.tensor
            edge indices

        Returns
        -------
        _type_
            _description_
        """
        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)

    def update(self, aggr_out, x):
        """
        Updates node embeddings. Takes in the output of aggregation as first argument and any argument
        which was initially passed to propagate(). Concatenates the aggregated message and the current
        node's embeddings, and then applies linear- and activation functions on the concatenated tensors.

        Parameters
        ----------
        aggr_out : torch.tensor
            Aggregated message
        x : torch.tensor
            Current node's embedding

        Returns
        -------
        torch.tensor
            Updated node embedding
        """

        new_embedding = torch.cat([aggr_out, x], dim=1)
        new_embedding = self.update_linear(new_embedding)
        new_embedding = self.update_activation(new_embedding)

        return new_embedding
    
    def message(self, x_j):
        """
        Constructs messages from node j to node i for each edge in edge_index. Applies linear- and
        activation functions for each neighboring node's embeddings.

        Parameters
        ----------
        x_j : torch.tensor
            _description_

        Returns
        -------
        torch.tensor
            Constructed message
        """

        x_j = self.linear_aggregate(x_j)
        x_j = self.activation_aggregate(x_j)

        return x_j

In [4]:
class GNN(torch.nn.Module):
    def __init__(self, num_embeddings, embedding_dimension = 16, dropout_probability = 0.5):
        super(GNN, self).__init__()

        self.embedding_dimension = embedding_dimension
        self.num_embeddings = num_embeddings
        self.dropout_probability = dropout_probability

        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(2*self.embedding_dimension, self.embedding_dimension),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(p=self.dropout_probability),
            torch.nn.Linear(self.embedding_dimension, 8),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(p=self.dropout_probability),
            torch.nn.Linear(8, 1),
            torch.nn.Sigmoid()
        )

        self.conv1 = SAGEConv(in_channels=self.embedding_dimension, out_channels=self.embedding_dimension)
        self.pool1 = TopKPooling(in_channels=self.embedding_dimension, ratio=0.8)
        self.conv2 = SAGEConv(in_channels=self.embedding_dimension, out_channels=self.embedding_dimension)
        self.pool2 = TopKPooling(self.embedding_dimension, ratio=0.8)
        self.conv3 = SAGEConv(in_channels=self.embedding_dimension, out_channels=self.embedding_dimension)
        self.pool3 = TopKPooling(self.embedding_dimension, ratio=0.8)
        self.item_embedding = torch.nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dimension)
        self.dropout = torch.nn.Dropout(p=self.dropout_probability)

    def forward(self, data: torch_geometric.data.batch):
        """
        Inference

        Parameters
        ----------
        data : torch_geometric.data.batch
                x: torch.tensor
                    [node_index, item_id, category]
                edge_index: torch.tensor
                    Node connectivity
                batch: torch.tensor
                    Batch information

        Returns
        -------
        torch.tensor
            Probability of a buying decision [0...1]
        """

        x, edge_index, batch = data.x[:, 1], data.edge_index, data.batch

        # Calculate embeddings
        x = self.item_embedding(x)
        x = x.squeeze(1)

        # 1st message passing round
        x = self.conv1(x, edge_index)
        x = torch.nn.functional.relu(x)
        x = self.dropout(x)
        x, edge_index, _, batch, *rest = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        # 2nd message passing round
        x = self.conv2(x, edge_index)
        x = torch.nn.functional.relu(x)
        x = self.dropout(x)
        x, edge_index, _, batch, *rest = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        # 3rd message passing round
        x = self.conv3(x, edge_index)
        x = torch.nn.functional.relu(x)
        x = self.dropout(x)
        x, edge_index, _, batch, *rest = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        x = x1 + x2 + x3
        
        x = self.mlp(x)

        return torch.squeeze(x)

In [5]:
# Create / load the dataset
dataset = RecSys2015Dataset('RecSys2015_data')

# Print data from the dataset
print(f"{dataset.get(0)=}")
print(f"{dataset.get(1)=}")

dataset.get(0)=Data(x=[10, 3], edge_index=[2, 9], y=[1])
dataset.get(1)=Data(x=[3, 3], edge_index=[2, 2], y=[1])


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [6]:
# Divide the dataset into training- and validation datasets
indices = torch.randperm(len(dataset))
train_size = int(0.8*len(dataset))
train_indices = indices[:train_size]
validation_indices = indices[train_size:]

training_dataset = dataset[train_indices]
validation_dataset = dataset[validation_indices]

# Create training- and validation data 
batch_size = 32
training_loader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, num_workers=0)
total_epochs_per_epoch = math.ceil(len(training_dataset) / batch_size)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_item_id = dataset.get_max_item_id()

print(f"Device: {device}")
print(f"Number of embeddings: {max_item_id}")

model = GNN(num_embeddings = max_item_id+1, dropout_probability=0.5).to(device)
# For criteria we use binary cross entropy loss -> buy / no buy
criteria = torch.nn.BCELoss()
# AdamW optimizer -> original Adam contains a mathematical error in the weights calculation
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

Device: cuda
Number of embeddings: 35661


In [8]:
num_epochs = 10
print(f"Dataset length: {len(dataset)}")
print(f"Samples in the training dataloader: {len(training_loader)}")
print(f"Samples in the validation dataloader: {len(validation_loader)}")

Dataset length: 478925
Samples in the training dataloader: 11974
Samples in the validation dataloader: 2994


In [9]:
lowest_loss = math.inf

for epoch in range(num_epochs):
    epoch_training_loss = 0
    epoch_training_total_correct = 0
    epoch_training_total_samples = 0
    epoch_validation_loss = 0
    epoch_validation_total_correct = 0
    epoch_validation_total_samples = 0

    # Training
    model.train()
    for index, data in enumerate(training_loader):

        # Send data to device
        data = data.to(device)
        
        # Forward pass
        prediction = model(data)

        # Loss and accuracy
        loss = criteria(prediction, data.y)
        epoch_training_loss += loss.item()
        epoch_training_total_samples += len(data)
        predicted_class = torch.round(prediction)
        epoch_training_total_correct += (data.y == predicted_class).sum().item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (index) % 1000 == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}], step [{index+1}/{total_epochs_per_epoch}], training loss: {loss.item():.4f}")

    # Validation
    model.eval()
    for data in validation_loader:

        # Send data to device
        data = data.to(device)

        # Forward pass
        prediction = model(data)

        # Loass and accuracy
        loss = criteria(prediction, data.y)
        epoch_validation_loss += loss.item()
        epoch_validation_total_samples += len(data)
        predicted_class = torch.round(prediction)
        epoch_validation_total_correct += (data.y == predicted_class).sum().item()

    epoch_training_loss /= len(training_loader)
    epoch_validation_loss /= len(validation_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], training accuracy: {(epoch_training_total_correct/epoch_training_total_samples):.4f}, training loss: {epoch_training_loss:.4f}")
    print(f"Epoch [{epoch + 1}/{num_epochs}], validation accuracy: {(epoch_validation_total_correct/epoch_validation_total_samples):.4f}, validation loss: {epoch_validation_loss:.4f}")

    # Save model is validation loss is lower than before
    if epoch_validation_loss < lowest_loss:
        torch.save(model.state_dict(), 'GNN_recommendation_systemrecsys_2015.pth')
        lowest_loss = epoch_validation_loss
        print(f"Saved the model, lowest validation loss: {lowest_loss:.4f}")
    print("----------------------------")

Epoch [1/10], step [1/11974], training loss: 0.6182
Epoch [1/10], step [1001/11974], training loss: 0.3498
Epoch [1/10], step [2001/11974], training loss: 0.2901
Epoch [1/10], step [3001/11974], training loss: 0.3965
Epoch [1/10], step [4001/11974], training loss: 0.1424
Epoch [1/10], step [5001/11974], training loss: 0.4774
Epoch [1/10], step [6001/11974], training loss: 0.2735
Epoch [1/10], step [7001/11974], training loss: 0.3502
Epoch [1/10], step [8001/11974], training loss: 0.3518
Epoch [1/10], step [9001/11974], training loss: 0.3975
Epoch [1/10], step [10001/11974], training loss: 0.2031
Epoch [1/10], step [11001/11974], training loss: 0.2083
Epoch [1/10], training accuracy: 0.9150, training loss: 0.2940
Epoch [1/10], validation accuracy: 0.9143, validation loss: 0.2831
Saved the model, lowest validation loss: 0.2831
----------------------------
Epoch [2/10], step [1/11974], training loss: 0.2033
Epoch [2/10], step [1001/11974], training loss: 0.4983
Epoch [2/10], step [2001/11