In [206]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import copy
import os

batch_size = 32
learning_rate = 0.01
epochs = 50

transform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size)


In [207]:

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(28*28, 10)

    def forward(self, x):
        x = self.flatten(x)
        return self.linear(x)

model = SimpleNN()

In [None]:
p = 1 << 31
f = 1 << 8

def encode_to_fixed_point(n):
    return int(round(n * f))   

def encode_to_field(n):
    if n > 0:
        return n % p
    else:
        return (p + n) % p
    
def encode(x):
    return encode_to_field(encode_to_fixed_point(x))

def field_scalar_vec_dot(m, x):
    sum = 0
    for e in m:
        sum += (e * x) % p 
    return sum % p

def decode_from_field(n):
    if n < p/2:
        return n
    else:
        return n - p
    
def decode_from_fixed_point(n):
    return n / f

In [209]:

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

model_path = "models/mnist.pth"

if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print("Model loaded from file.")
else:
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        for images, labels in train_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

    torch.save(model.state_dict(), model_path)
    print("Model saved.")
    

Model loaded from file.


  model.load_state_dict(torch.load(model_path))


In [210]:
class CustomLinear(nn.Module):
    def __init__(self, linear):
        super().__init__()
        self.weight = linear.weight
        self.bias = linear.bias
        
    def forward(self, x):
        out = []
        for sample in x:
            sample_out = []
            for w, b in zip(self.weight, self.bias):  
                dot = (sample * w).sum() + b  
                sample_out.append(dot)
            out.append(torch.stack(sample_out))
        return torch.stack(out)
    
class FieldLinear(nn.Module):
    def __init__(self, linear):
        super().__init__()
        encoded_weight = torch.tensor([[encode(x.item()) for x in row] for row in linear.weight])
        encoded_bias = torch.tensor([encode(x.item()) for x in linear.bias])
        self.weight = encoded_weight
        self.bias = encoded_bias
        
    def forward(self, x):
        out = []
        for sample in x:
            encoded_sample = torch.tensor([encode(x.item()) for x in sample])
            
            sample_out = []
            for w, b in zip(self.weight, self.bias):  
                y_field = (torch.sum((w * encoded_sample) % p) + b) % p  
                y_trunc = torch.round(decode_from_field(y_field) / f)
                y = decode_from_fixed_point(y_trunc)
                sample_out.append(y)
            out.append(torch.stack(sample_out))
        return torch.stack(out)
    
real_linear = copy.deepcopy(model.linear)
model.linear = FieldLinear(real_linear)
print('FieldLinear')
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
print(f'Accuracy: {correct / total * 100:.2f}%')

model.linear = CustomLinear(real_linear)
print('CustomLinear')

correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {correct / total * 100:.2f}%')

FieldLinear
Accuracy: 87.80%
CustomLinear
Accuracy: 91.55%
