In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import Tensor
import math


In [21]:
mnist_transform = transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])

data_loader = DataLoader(MNIST('data', train=True, download=True, transform=mnist_transform), batch_size=64, shuffle=True)

In [27]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)]
        return x


In [124]:
class FF_block(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.fc2 = nn.Linear(d_model, d_model)
    
    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = self.fc1(x)
        x = F.relu(x)
        x = self.norm(x)
        x = self.fc2(x)
        return x

class CrossAttentionModel(nn.Module):
    def __init__(self, d_model: int, n_numbers: int, num_heads = 4, n_classes=10):
        super().__init__()
        self.embedder = nn.Sequential(
            nn.Linear(1, d_model), 
            PositionalEncoding(d_model)
        )  
        queries = torch.randn(1, n_numbers, d_model)
        self.register_parameter('queries', nn.Parameter(queries))
        self.mh_attention = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.key = FF_block(d_model)
        self.value = FF_block(d_model)

        self.cls_head = nn.Linear(d_model, n_classes)
    
    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = self.embedder(x)
        key = self.key(x)
        value = self.value(x)

        queries = self.queries.repeat(x.shape[0], 1, 1)
        x, _ = self.mh_attention(queries, key, value)
        x = self.cls_head(x)
        x = F.softmax(x, dim=-1)
        return x


In [125]:
    
d_model = 64
n_numbers = 4
model = CrossAttentionModel(d_model, n_numbers)
#model = nn.Linear(n_numbers * 28 * 28, n_numbers * 10)
optimizer = optim.SGD(model.parameters(), lr=0.1)
loss_fn = nn.CrossEntropyLoss()

In [126]:
for epoch in range(10):
    correct = 0
    for batch_idx, (data, target) in enumerate(data_loader):
        optimizer.zero_grad()
        data = data.view(-1, n_numbers * 28 * 28, 1)
        output = model(data)
        output = output.view(-1, 10)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            #print(output[0].tolist(), target[0])
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(data_loader.dataset),
                100. * batch_idx / len(data_loader), loss.item()))
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

    print('\nAccuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))


Accuracy: 6726/60000 (11%)


Accuracy: 6743/60000 (11%)


Accuracy: 7453/60000 (12%)


Accuracy: 8846/60000 (15%)


Accuracy: 9008/60000 (15%)


Accuracy: 9398/60000 (16%)

