## Define network

In [1]:
import torch
from torch import nn
from src.transformers import ViT

torch.manual_seed(0)

class Net(nn.Module):
    def __init__(self, n_classes=10, embed_dim=8):
        super(Net, self).__init__()

        self.encoder = ViT(embed_dim=embed_dim, n_blocks=2, n_heads=2)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, n_classes),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.mlp(x[:, 0, :])

        return x

## Define dataset anda dataloader

In [2]:
from src.datasets import PolyMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

train = PolyMNIST(csv_file="mnist/train/polygon-mnist.csv",
                  transform=ToTensor(), return_poly=False)

test = PolyMNIST(csv_file="mnist/test/polygon-mnist.csv",
                  transform=ToTensor(), return_poly=False)

train_loader = DataLoader(train, batch_size=60, shuffle=True)
test_loader = DataLoader(test, batch_size=60, shuffle=False)

## Train

In [3]:
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tqdm.notebook import trange, tqdm

N_EPOCHS = 5
LR = 0.005

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_name = torch.cuda.get_device_name(
    device) if torch.cuda.is_available() else "cpu"
print("Using device: ", device, f"({device_name})")

model = Net().to(device)
optimizer = Adam(model.parameters(), lr=LR)
criterion = CrossEntropyLoss()

model.train()
train_loss = 0.0
for epoch in trange(N_EPOCHS, desc="Training"):

    desc = f"Last loss: {train_loss}\nEpoch {epoch + 1} in training"

    train_loss = 0.0

    for batch in tqdm(train_loader, desc=desc, leave=False):
        image, label = batch

        image = image.to(device)
        label = label.to(device)

        pred = model(image)
        loss = criterion(pred, label)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_loss += loss.item() / len(train_loader)

    print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

Using device:  cuda (Quadro M5000)


Training:   0%|          | 0/5 [00:00<?, ?it/s]

Last loss: 0.0
Epoch 1 in training:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 1/5 loss: 1.84


Last loss: 1.8391025855541234
Epoch 2 in training:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 2/5 loss: 1.71


Last loss: 1.7120792683362964
Epoch 3 in training:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 3/5 loss: 1.64


Last loss: 1.6396469342708586
Epoch 4 in training:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 4/5 loss: 1.62


Last loss: 1.6168654135465603
Epoch 5 in training:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 5/5 loss: 1.61


In [4]:
with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_loader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_loader)

        correct += torch.sum(torch.argmax(y_hat, dim=1)
                             == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

Testing:   0%|          | 0/167 [00:00<?, ?it/s]

Test loss: 1.59
Test accuracy: 86.65%


## Save

In [5]:
path = "checkpoints/classification_001.pth"
torch.save(model.state_dict(), path)