In [45]:
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
from torch.utils.data import DataLoader
from torchtext.datasets import SST2

padding_idx = 1
bos_idx = 0
eos_idx = 2
max_seq_len = 256
xlmr_vocab_path = "https://download.pytorch.org/models/text/xlmr.vocab.pt"
xlmr_spm_model_path = (
    "https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
)

text_transform = T.Sequential(
    T.SentencePieceTokenizer(xlmr_spm_model_path),
    T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)),  # type: ignore
    T.Truncate(max_seq_len - 2),
    T.AddToken(token=bos_idx, begin=True),
    T.AddToken(token=eos_idx, begin=False),
)


batch_size = 3

train_datapipe = SST2(split="train")  # type: ignore
dev_datapipe = SST2(split="dev")  # type: ignore


# Transform the raw dataset using non-batched API (i.e apply transformation line by line)
def apply_transform(x):
    return text_transform(x[0]), x[1]

print("raw pip", next(iter(train_datapipe)))
train_datapipe = train_datapipe.map(apply_transform)
print("transform pip", next(iter(train_datapipe)))
train_datapipe = train_datapipe.batch(batch_size)
print("batch pip", next(iter(train_datapipe)))
train_datapipe = train_datapipe.rows2columnar(["token_ids", "target"])
print("row to clo pip", next(iter(train_datapipe)))
train_dataloader = DataLoader(train_datapipe, batch_size=None)

dev_datapipe = dev_datapipe.map(apply_transform)
dev_datapipe = dev_datapipe.batch(batch_size)
dev_datapipe = dev_datapipe.rows2columnar(["token_ids", "target"])
dev_dataloader = DataLoader(dev_datapipe, batch_size=None)

raw pip ('hide new secretions from the parental units', 0)
transform pip ([0, 1274, 112, 3525, 23410, 17514, 1295, 70, 49129, 289, 25072, 7, 2], 0)
batch pip [([0, 1274, 112, 3525, 23410, 17514, 1295, 70, 49129, 289, 25072, 7, 2], 0), ([0, 70541, 7, 110, 43198, 6, 4, 4734, 27554, 71, 914, 9405, 2], 0), ([0, 450, 5161, 7, 6863, 124850, 136, 6, 127219, 1636, 9844, 43257, 34923, 1672, 14135, 31425, 2], 1)]
row to clo pip defaultdict(<class 'list'>, {'token_ids': [[0, 1274, 112, 3525, 23410, 17514, 1295, 70, 49129, 289, 25072, 7, 2], [0, 70541, 7, 110, 43198, 6, 4, 4734, 27554, 71, 914, 9405, 2], [0, 450, 5161, 7, 6863, 124850, 136, 6, 127219, 1636, 9844, 43257, 34923, 1672, 14135, 31425, 2]], 'target': [0, 0, 1]})




In [29]:
dev_datapipe = SST2(split="dev")  # type: ignore
t1 = T.SentencePieceTokenizer(xlmr_spm_model_path)
t2 = T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path))  # type: ignore
t3 = T.Truncate(max_seq_len - 2)
t4 = T.AddToken(token=bos_idx, begin=True)
t5 = T.AddToken(token=eos_idx, begin=False)

In [30]:
txt = next(iter(dev_datapipe))[0]
print(txt)
txt = t1(txt)
print(txt)
txt = t2(txt)
print(txt)
txt = t3(txt)
print(txt)
txt = t4(txt)
print(txt)
txt = t5(txt)
print(txt)

it 's a charming and often affecting journey .
['▁it', "▁'", 's', '▁a', '▁charm', 'ing', '▁and', '▁often', '▁affect', 'ing', '▁journey', '▁', '.']
[442, 242, 7, 10, 108654, 214, 136, 27983, 52490, 214, 120696, 6, 5]
[442, 242, 7, 10, 108654, 214, 136, 27983, 52490, 214, 120696, 6, 5]
[0, 442, 242, 7, 10, 108654, 214, 136, 27983, 52490, 214, 120696, 6, 5]
[0, 442, 242, 7, 10, 108654, 214, 136, 27983, 52490, 214, 120696, 6, 5, 2]


In [31]:
from torchtext.datasets import SST2

batch_size = 16

train_datapipe = SST2(split="train")
dev_datapipe = SST2(split="dev")


# Transform the raw dataset using non-batched API (i.e apply transformation line by line)
def apply_transform(x):
    return text_transform(x[0]), x[1]


train_datapipe = train_datapipe.map(apply_transform)
train_datapipe = train_datapipe.batch(batch_size)
train_datapipe = train_datapipe.rows2columnar(["token_ids", "target"])
train_dataloader = DataLoader(train_datapipe, batch_size=None)

dev_datapipe = dev_datapipe.map(apply_transform)
dev_datapipe = dev_datapipe.batch(batch_size)
dev_datapipe = dev_datapipe.rows2columnar(["token_ids", "target"])
dev_dataloader = DataLoader(dev_datapipe, batch_size=None)


In [32]:
num_classes = 2
input_dim = 768
DEVICE = "cpu"

from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER

classifier_head = RobertaClassificationHead(
    num_classes=num_classes, input_dim=input_dim
)
model = XLMR_BASE_ENCODER.get_model(head=classifier_head)
model.to(DEVICE)

Downloading: "https://download.pytorch.org/models/text/xlmr.base.encoder.pt" to /Users/cg/.cache/torch/hub/checkpoints/xlmr.base.encoder.pt
100%|██████████| 1.03G/1.03G [01:26<00:00, 12.8MB/s]


RobertaModel(
  (encoder): RobertaEncoder(
    (transformer): TransformerEncoder(
      (token_embedding): Embedding(250002, 768, padding_idx=1)
      (layers): TransformerEncoder(
        (layers): ModuleList(
          (0-11): 12 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (linear1): Linear(in_features=768, out_features=3072, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear2): Linear(in_features=3072, out_features=768, bias=True)
            (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout1): Dropout(p=0.1, inplace=False)
            (dropout2): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (positional_embedding): PositionalEmbedding(
        (embedding): Embedding(5

In [33]:
import torchtext.functional as F
from torch.optim import AdamW
from torch import nn
import torch

learning_rate = 1e-5
optim = AdamW(model.parameters(), lr=learning_rate)
criteria = nn.CrossEntropyLoss()


def train_step(input, target):
    output = model(input)
    loss = criteria(output, target)
    optim.zero_grad()
    loss.backward()
    optim.step()


def eval_step(input, target):
    output = model(input)
    loss = criteria(output, target).item()
    return float(loss), (output.argmax(1) == target).type(torch.float).sum().item()


def evaluate():
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0
    counter = 0
    with torch.no_grad():
        for batch in dev_dataloader:
            input = F.to_tensor(batch["token_ids"], padding_value=padding_idx).to(
                DEVICE
            )
            target = torch.tensor(batch["target"]).to(DEVICE)
            loss, predictions = eval_step(input, target)
            total_loss += loss
            correct_predictions += predictions
            total_predictions += len(target)
            counter += 1

    return total_loss / counter, correct_predictions / total_predictions

In [None]:
num_epochs = 1

for e in range(num_epochs):
    for batch in train_dataloader:
        input = F.to_tensor(batch["token_ids"], padding_value=padding_idx).to(DEVICE)
        target = torch.tensor(batch["target"]).to(DEVICE)
        train_step(input, target)

    loss, accuracy = evaluate()
    print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy))

In [39]:
from torchinfo import summary

summary(model, input=(batch["token_ids"],), device="cpu")

Layer (type:depth-idx)                                                      Param #
RobertaModel                                                                --
├─RobertaEncoder: 1-1                                                       --
│    └─TransformerEncoder: 2-1                                              --
│    │    └─Embedding: 3-1                                                  192,001,536
│    │    └─TransformerEncoder: 3-2                                         85,054,464
│    │    └─PositionalEmbedding: 3-3                                        394,752
│    │    └─LayerNorm: 3-4                                                  1,536
│    │    └─Dropout: 3-5                                                    --
├─RobertaClassificationHead: 1-2                                            --
│    └─Linear: 2-2                                                          590,592
│    └─Dropout: 2-3                                                         --
│    └─Linear: 2-