In [1]:
# !pip install langchain arxiv pymupdf

In [2]:
from langchain.document_loaders import  ArxivLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd

import mlflow.data
from mlflow.data.pandas_dataset import PandasDataset

# Load documents

In [3]:
arxivv = "2307.08621"
dataset_source_url = f"https://arxiv.org/abs/{arxivv}"

docs = ArxivLoader(query=arxivv).load()

In [4]:
print(type(docs))

<class 'list'>


In [5]:
print(len(docs))

1


In [6]:
print(type(docs[0]))

<class 'langchain_core.documents.base.Document'>


In [9]:
print("sample: ", docs[0].page_content[0:1000])

sample:  Retentive Network: A Successor to Transformer
for Large Language Models
Yutao Sun∗†‡
Li Dong∗†
Shaohan Huang†
Shuming Ma†
Yuqing Xia†
Jilong Xue†
Jianyong Wang‡
Furu Wei†⋄
† Microsoft Research
‡ Tsinghua University
https://aka.ms/GeneralAI
Abstract
In this work, we propose Retentive Network (RETNET) as a foundation archi-
tecture for large language models, simultaneously achieving training parallelism,
low-cost inference, and good performance. We theoretically derive the connection
between recurrence and attention. Then we propose the retention mechanism for
sequence modeling, which supports three computation paradigms, i.e., parallel,
recurrent, and chunkwise recurrent. Specifically, the parallel representation allows
for training parallelism. The recurrent representation enables low-cost O(1) infer-
ence, which improves decoding throughput, latency, and GPU memory without
sacrificing performance. The chunkwise recurrent representation facilitates effi-
cient long-sequence mo

In [10]:
print('metadata', docs[0].metadata)

metadata {'Published': '2023-08-09', 'Title': 'Retentive Network: A Successor to Transformer for Large Language Models', 'Authors': 'Yutao Sun, Li Dong, Shaohan Huang, Shuming Ma, Yuqing Xia, Jilong Xue, Jianyong Wang, Furu Wei', 'Summary': 'In this work, we propose Retentive Network (RetNet) as a foundation\narchitecture for large language models, simultaneously achieving training\nparallelism, low-cost inference, and good performance. We theoretically derive\nthe connection between recurrence and attention. Then we propose the retention\nmechanism for sequence modeling, which supports three computation paradigms,\ni.e., parallel, recurrent, and chunkwise recurrent. Specifically, the parallel\nrepresentation allows for training parallelism. The recurrent representation\nenables low-cost $O(1)$ inference, which improves decoding throughput, latency,\nand GPU memory without sacrificing performance. The chunkwise recurrent\nrepresentation facilitates efficient long-sequence modeling with

# Spilit text data into chunks

In [11]:

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size = 1000,
    chunk_overlap  = 200
)

In [12]:
splits = text_splitter.split_documents(docs)

In [13]:
print("splits: ", len(splits))

splits:  51


In [14]:
print(splits[0].page_content)

Retentive Network: A Successor to Transformer
for Large Language Models
Yutao Sun∗†‡
Li Dong∗†
Shaohan Huang†
Shuming Ma†
Yuqing Xia†
Jilong Xue†
Jianyong Wang‡
Furu Wei†⋄
† Microsoft Research
‡ Tsinghua University
https://aka.ms/GeneralAI
Abstract
In this work, we propose Retentive Network (RETNET) as a foundation archi-
tecture for large language models, simultaneously achieving training parallelism,
low-cost inference, and good performance. We theoretically derive the connection
between recurrence and attention. Then we propose the retention mechanism for
sequence modeling, which supports three computation paradigms, i.e., parallel,
recurrent, and chunkwise recurrent. Specifically, the parallel representation allows
for training parallelism. The recurrent representation enables low-cost O(1) infer-
ence, which improves decoding throughput, latency, and GPU memory without
sacrificing performance. The chunkwise recurrent representation facilitates effi-


# Save dataframe

In [15]:
df = pd.DataFrame([split.page_content for split in splits], columns=['text'])

In [16]:
print(df.shape)

(51, 1)


In [17]:
print(df.head())

                                                text
0  Retentive Network: A Successor to Transformer\...
1  ence, which improves decoding throughput, late...
2  Figure 1: Retentive network (RetNet) achieves ...
3  els [BMR+20], which was initially proposed\nto...
4  “impossible triangle” as shown in Figure 2.\nT...


In [18]:
print("sample: ", df.iloc[1].values)

sample:  ['ence, which improves decoding throughput, latency, and GPU memory without\nsacrificing performance. The chunkwise recurrent representation facilitates effi-\ncient long-sequence modeling with linear complexity, where each chunk is encoded\nparallelly while recurrently summarizing the chunks. Experimental results on\nlanguage modeling show that RETNET achieves favorable scaling results, parallel\ntraining, low-cost deployment, and efficient inference. The intriguing properties\nmake RETNET a strong successor to Transformer for large language models. Code\nwill be available at https://aka.ms/retnet.\n0\n20\n40\n0\n150\n300\n0\n150\n300\nGPU Memory↓\n(GB)\nThroughput↑\n(wps)\nLatency↓\n(ms)\n3.4X\n15.6X\n8.4X\nInference Cost\nScaling Curve\nRetNet\nTransformer\n1\n3\n7\nLM Perplexity\nModel Size (B)\nFigure 1: Retentive network (RetNet) achieves low-cost inference (i.e., GPU memory, throughput,\nand latency), training parallelism, and favorable scaling curves compared with Tran

In [19]:
csv_data_path = f"text_data/text_train_{arxivv}.csv"
df.to_csv(csv_data_path, index=False)

# Version dataset with MLflow

In [20]:

dataset: PandasDataset = mlflow.data.from_pandas(df, source=csv_data_path)


  return _dataset_source_registry.resolve(
  return _dataset_source_registry.resolve(


In [50]:

with mlflow.start_run():
    # Log the dataset to the MLflow Run. Specify the "training" context to indicate that the
    # dataset is used for model training
    mlflow.log_input(dataset, context="text_training", tags={f'arxiv': arxivv})


# Retrieve the run, including dataset information


In [51]:
run_id = mlflow.last_active_run().info.run_id
run_id

'21c21a83b30749df8752300eec2ec4b6'

In [52]:
run = mlflow.get_run(mlflow.last_active_run().info.run_id)
dataset_info = run.inputs.dataset_inputs[0].dataset
print(f"Dataset name: {dataset_info.name}")
print(f"Dataset digest: {dataset_info.digest}")
print(f"Dataset profile: {dataset_info.profile}")
print(f"Dataset schema: {dataset_info.schema}")


Dataset name: dataset
Dataset digest: a42f2dca
Dataset profile: {"num_rows": 51, "num_elements": 51}
Dataset schema: {"mlflow_colspec": [{"type": "string", "name": "text", "required": true}]}


In [53]:
run.inputs.dataset_inputs[0].dataset

<Dataset: digest='a42f2dca', name='dataset', profile='{"num_rows": 51, "num_elements": 51}', schema='{"mlflow_colspec": [{"type": "string", "name": "text", "required": true}]}', source='{"uri": "text_data/text_train_2307.08621.csv"}', source_type='local'>

In [54]:
dataset_source = mlflow.data.get_source(dataset_info)
dataset_source.to_dict()

{'uri': 'text_data/text_train_2307.08621.csv'}

# Load data from the run_id

In [55]:
versioned_df = pd.read_csv(dataset_source.uri)

In [56]:
print(versioned_df.head())

                                                text
0  Retentive Network: A Successor to Transformer\...
1  ence, which improves decoding throughput, late...
2  Figure 1: Retentive network (RetNet) achieves ...
3  els [BMR+20], which was initially proposed\nto...
4  “impossible triangle” as shown in Figure 2.\nT...
