<a href="https://colab.research.google.com/github/anarlavrenov/ExpressNet/blob/master/usage.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##ExpressNet Usage Example 😀

Import necessary libraries - we use only torch & torchtext

In [1]:
!git clone https://github.com/anarlavrenov/ExpressNet
%cd ExpressNet

!pip install torchdata
!pip install portalocker>=2.0.0

from ExpressNet.model import ExpressNet

import torch
import torchtext; torchtext.disable_torchtext_deprecation_warning
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import AG_NEWS

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

Cloning into 'ExpressNet'...
remote: Enumerating objects: 40, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (38/38), done.[K
remote: Total 40 (delta 12), reused 12 (delta 0), pack-reused 0[K
Receiving objects: 100% (40/40), 134.79 KiB | 5.39 MiB/s, done.
Resolving deltas: 100% (12/12), done.
/content/ExpressNet
Collecting torchdata
  Downloading torchdata-0.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2->torchdata)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2->torchdata)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2->torchdata)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collect



Download data & create simple vocab

In [2]:
train_iter = AG_NEWS(split="train")
valid_iter = AG_NEWS(split="test")

def yield_tokens(data_iter, tokenizer):
  for label, text in data_iter:
    yield tokenizer(text)

tokenizer = get_tokenizer(tokenizer="basic_english")
vocab = build_vocab_from_iterator(iterator=yield_tokens(train_iter, tokenizer), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

Create torch loader

In [3]:
def collate_fn(batch):

  labels, text = zip(*batch)

  tokens_ = []
  labels_ = []

  for label, t in zip(labels, text):
    tokens = torch.IntTensor([vocab[word.lower()] for word in t.split()])
    tokens = torch.nn.functional.pad(tokens, (0, 96-len(tokens)))

    tokens_.append(tokens)
    if label == 1:
      labels_.append(0)
    elif label == 2:
      labels_.append(1)
    elif label == 3:
      labels_.append(2)
    elif label == 4:
      labels_.append(3)

  tokens_ = torch.stack(tokens_, dim=0)
  labels_ = torch.tensor(labels_, dtype=torch.long)

  return tokens_, labels_

In [4]:
train_loader = torch.utils.data.DataLoader(
    list(train_iter),
    batch_size=64,
    shuffle=True,
    num_workers=2,
    drop_last=True,
    collate_fn=collate_fn
)

valid_loader = torch.utils.data.DataLoader(
    list(valid_iter),
    batch_size=32,
    shuffle=True,
    num_workers=2,
    drop_last=True,
    collate_fn=collate_fn
)

Initialize ExpressNet 🤖

In [8]:
model = ExpressNet(
    d_model=256,
    vocab_size=len(vocab),
    classification_type="multiclass",
    n_classes=4
).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

Run training and validation. See accuracy metric

In [9]:
def train(loader):
  model.train()
  total_loss = 0
  total_accuracy = 0

  for batch in tqdm(loader):
    optimizer.zero_grad()
    inputs, targets = batch[0].to(device), batch[1].to(device)

    outputs = model(inputs)

    loss = criterion(outputs, targets)
    total_loss += loss.item()
    accuracy = torch.sum((torch.argmax(outputs, dim=1) == targets).int()) / len(targets)
    total_accuracy += accuracy

    loss.backward()
    optimizer.step()

  return total_loss / len(loader), total_accuracy / len(loader)


def evaluation(loader):
  model.eval()
  total_loss = 0
  total_accuracy = 0

  with torch.no_grad():
    for batch in tqdm(loader):

      inputs, targets = batch[0].to(device), batch[1].to(device)

      outputs = model(inputs)

      loss = criterion(outputs, targets)
      total_loss += loss.item()
      accuracy = torch.sum((torch.argmax(outputs, dim=1) == targets).int()) / len(targets)
      total_accuracy += accuracy

  return total_loss / len(loader), total_accuracy / len(loader)

In [10]:
epochs = 3

for epoch in range(epochs):
  loss, acc = train(train_loader)
  val_loss, val_acc = evaluation(valid_loader)

  print(f"Epoch: {epoch + 1} | loss: {loss:.3f} | acc: {acc:.3f} | val_loss: {val_loss:.3f} | val_acc: {val_acc:.3f}")

100%|██████████| 1875/1875 [02:23<00:00, 13.03it/s]
100%|██████████| 237/237 [00:03<00:00, 65.32it/s]


Epoch: 1 | loss: 0.534 | acc: 0.801 | val_loss: 0.330 | val_acc: 0.889


100%|██████████| 1875/1875 [02:28<00:00, 12.60it/s]
100%|██████████| 237/237 [00:03<00:00, 65.21it/s]


Epoch: 2 | loss: 0.313 | acc: 0.892 | val_loss: 0.280 | val_acc: 0.908


100%|██████████| 1875/1875 [02:28<00:00, 12.61it/s]
100%|██████████| 237/237 [00:03<00:00, 65.89it/s]

Epoch: 3 | loss: 0.255 | acc: 0.911 | val_loss: 0.272 | val_acc: 0.912



