In [1]:
import torch
from longformer.longformer import Longformer, LongformerConfig
from longformer.sliding_chunks import pad_to_window_size
from transformers import RobertaTokenizer

In [2]:
config = LongformerConfig.from_pretrained('longformer-base-4096/') 
# choose the attention mode 'n2', 'tvm' or 'sliding_chunks'
# 'n2': for regular n2 attantion
# 'tvm': a custom CUDA kernel implementation of our sliding window attention
# 'sliding_chunks': a PyTorch implementation of our sliding window attention
config.attention_mode = 'sliding_chunks'

In [3]:
model = Longformer.from_pretrained('longformer-base-4096/', config=config)

In [4]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
tokenizer.model_max_length = model.config.max_position_embeddings

In [5]:
SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000)  # long input document

In [6]:
input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0)  # batch of size 1

In [7]:
# model = model.cuda(); input_ids = input_ids.cuda()

In [8]:
# Attention mask values -- 0: no attention, 1: local attention, 2: global attention
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
attention_mask[:, [1, 4, 21,]] =  2  # Set global attention based on the task. For example,
                                     # classification: the <s> token
                                     # QA: question tokens

In [9]:
# padding seqlen to the nearest multiple of 512. Needed for the 'sliding_chunks' attention
input_ids, attention_mask = pad_to_window_size(
        input_ids, attention_mask, config.attention_window[0], tokenizer.pad_token_id)

In [10]:
output = model(input_ids, attention_mask=attention_mask)[0]

In [11]:
output

tensor([[[-0.0473, -0.0016,  0.0404,  ..., -0.0328, -0.0960, -0.0306],
         [-0.2415,  0.2863,  0.1896,  ...,  0.0980,  0.0838,  0.1481],
         [-0.0653,  0.0544,  0.1262,  ..., -0.3759, -0.1101,  0.3571],
         ...,
         [-0.0124,  0.0684, -0.0095,  ..., -0.1089, -0.0324, -0.0683],
         [-0.0124,  0.0684, -0.0095,  ..., -0.1089, -0.0324, -0.0683],
         [-0.0124,  0.0684, -0.0095,  ..., -0.1089, -0.0324, -0.0683]]],
       grad_fn=<NativeLayerNormBackward>)