In [308]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy
import os

batch_size = 32
learning_rate = 0.01
epochs = 15

transform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size)


In [309]:

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x
    
    def get_linear_layers(self) -> list[tuple[str, nn.Module]]:
        return copy.deepcopy([('fc1', self.fc1), ('fc2', self.fc2), ('fc3', self.fc3)])

model = SimpleNN()

In [310]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

model_path = "models/mnist.pth"

if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print("Model loaded from file.")
else:
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        for images, labels in train_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

    torch.save(model.state_dict(), model_path)
    print("Model saved.")
    

Model loaded from file.


  model.load_state_dict(torch.load(model_path))


In [311]:
class FieldLinear(nn.Module):
    def __init__(self, linear, p, f , padding = None):
        super().__init__()
        self.p = p
        self.f = f
        self.weight = self.encode(linear.weight)
        self.bias = self.encode(linear.bias)
        
        if padding is not None:
            self.padding = padding
            self.weight = self.weight + self.padding

    def encode(self, x):
        x_fixed = torch.round(x * self.f).long()
        return x_fixed % self.p

    def decode_from_field(self, n):
        n = torch.where(n < self.p // 2, n, n - self.p)
        return n

    def decode_from_fixed_point(self, n):
        return n.float() / self.f

    def forward(self, x):
        x_encoded = self.encode(x)        
        prod = torch.matmul(x_encoded, self.weight.t()) % self.p
        
        if self.padding is not None:
            padding_prod = torch.matmul(x_encoded, self.padding.t()) % self.p
            prod = (prod - padding_prod) % self.p
        
        prod = (prod + self.bias) % self.p
        y_field = self.decode_from_field(prod)
        y_trunc = torch.round(y_field.float() / self.f)
        y = self.decode_from_fixed_point(y_trunc)
        return y

def padding_tensor(x: torch.Tensor, p: int) -> torch.Tensor:
    return torch.randint(low=0, high=p, size=x.shape, dtype=torch.long, device=x.device)


    
linear_layers = model.get_linear_layers()

p = (1 << 61) - 1
f = 1 << 10

for name, layer in linear_layers:
    setattr(model, name, FieldLinear(layer, p=p, f=f, padding=padding_tensor(layer.weight, p)))

print('Field linear')
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
print(f'Accuracy: {correct / total * 100:.2f}%')

for name, layer in linear_layers:
    setattr(model, name, layer)

print('Real Linear')

correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {correct / total * 100:.2f}%')

Field linear
Accuracy: 95.80%
Real Linear
Accuracy: 96.27%
