In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, dropout_rate=0.1):
        super(TransformerEncoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = nn.Embedding(input_vocab_size, d_model)
        self.pos_encoding = self.positional_encoding(maximum_position_encoding, self.d_model)

        self.enc_layers = nn.ModuleList([self.encoder_layer(d_model, num_heads, dff, dropout_rate) for _ in range(num_layers)])

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask=None):
        seq_len = x.size(1)
        x = self.embedding(x)
        x *= torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x, mask)

        return x

    def encoder_layer(self, d_model, num_heads, dff, dropout_rate):
        return nn.Sequential(
            nn.MultiheadAttention(d_model, num_heads, dropout=dropout_rate),
            nn.LayerNorm(d_model),
            nn.Linear(d_model, dff),
            nn.ReLU(),
            nn.Linear(dff, d_model),
            nn.Dropout(dropout_rate),
            nn.LayerNorm(d_model)
        )

    def positional_encoding(self, position, d_model):
        angle_rads = self.get_angles(torch.arange(position).unsqueeze(1),
                                     torch.arange(d_model).unsqueeze(0),
                                     d_model)

        # apply sin to even indices in the array; 2i
        angle_rads[:, 0::2] = torch.sin(angle_rads[:, 0::2])

        # apply cos to odd indices in the array; 2i+1
        angle_rads[:, 1::2] = torch.cos(angle_rads[:, 1::2])

        pos_encoding = angle_rads.unsqueeze(0)

        return pos_encoding

    def get_angles(self, pos, i, d_model):
        angle_rates = 1 / torch.pow(10000, (2 * (i // 2)) / torch.tensor(d_model, dtype=torch.float32))
        return pos * angle_rates

In [7]:
class Trainer:
    def __init__(self, model, loss_fun, optimizer, device='cpu'):
        self.model = model
        self.loss_fun = loss_fun
        self.optimizer = optimizer
        self.device = device

    def train(self, train_loader, num_epochs):
        self.model.train()
        self.model.to(self.device)

        for epoch in range(num_epochs):
            running_loss = 0.0
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                self.optimizer.zero_grad()

                outputs = self.model(inputs)
                loss = self.loss_fun(outputs, labels)
                loss.backward()
                self.optimizer.step()

                running_loss += loss.item() * inputs.size(0)

            epoch_loss = running_loss / len(train_loader.dataset)
            print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}')

    def evaluate(self, data_loader):
        self.model.eval()
        self.model.to(self.device)

        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in data_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                outputs = self.model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total
        print(f'Accuracy: {accuracy:.4f}')

        return accuracy