<a href="https://colab.research.google.com/github/SoonerTuran/DNVA/blob/main/NC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn

In [3]:
class ALU(nn.Module):
    def __init__(self, bit_size=8, opcode_size=4, state_size=4):
        super(ALU, self).__init__()

        self.bit_size = bit_size
        self.opcode_size = opcode_size
        self.state_size = state_size

        # ALU internal state: 2*bit_size for A and B, opcode_size for opcode
        # and additional state_size for flags/status bits
        self.state = torch.zeros(2*bit_size + opcode_size + state_size, requires_grad=True)

        # Model definition
        self.main = nn.Sequential(
            nn.Linear(2*bit_size + opcode_size + state_size, 2*(2*bit_size + opcode_size + state_size)),
            nn.ReLU(),
            nn.Linear(2*(2*bit_size + opcode_size + state_size), bit_size + state_size)
        )

    def forward(self, A, B, Opcode):
        # Combine A, B, Opcode and the ALU's internal state
        combined_input = torch.cat((A, B, Opcode, self.state), dim=0)

        # Get the result from the neural network
        result = self.main(combined_input)

        # Update the internal state with the new result (excluding result bits, just the flags)
        self.state[-self.state_size:] = result[-self.state_size:].detach()

        return result

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# dataset preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
])
trainset = datasets.MNIST('dataset/', train=True, download=True, transform=transform)
testset = datasets.MNIST('dataset/', train=False, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
# defining networks
class STEFunction(autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()
    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)
class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()
    def forward(self, x):
        x = STEFunction.apply(x)
        return x
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.Conv2d(256, 4, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(4),
            nn.Tanh(),

            StraightThroughEstimator(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(4, 256, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)),
            nn.Tanh(),
        )

    def forward(self, x, encode=False, decode=False):
        if encode:
            x = self.encoder(x)
        elif decode:
            x = self.decoder(x)
        else:
            encoding = self.encoder(x)
            x = self.decoder(encoding)
        return x
net = Autoencoder().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.5, 0.999))
criterion_MSE = nn.MSELoss().to(device)
# train loop
epoch = 5
for e in range(epoch):
    print(f'Starting epoch {e} of {epoch}')
    for X, y in tqdm(trainloader):
        optimizer.zero_grad()
        X = X.to(device)
        reconstruction = net(X)
        loss = criterion_MSE(reconstruction, X)
        loss.backward()
        optimizer.step()
    print(f'Loss: {loss.item()}')
# test loop
i = 1
fig = plt.figure(figsize=(10, 10))
for X, y in testloader:
    X_in = X.to(device)
    recon = net(X_in).detach().cpu().numpy()
    if i >= 10:
      break
    fig.add_subplot(5, 2, i).set_title('Original')
    plt.imshow(X[0].reshape((28, 28)), cmap="gray")
    fig.add_subplot(5, 2, i+1).set_title('Reconstruction')
    plt.imshow(recon[0].reshape((28, 28)), cmap="gray")
    i += 2
fig.tight_layout()
plt.show()

Starting epoch 0 of 5


100%|██████████| 938/938 [00:19<00:00, 47.06it/s]


Loss: -0.0
Starting epoch 1 of 5


100%|██████████| 938/938 [00:16<00:00, 55.31it/s]


Loss: -0.0
Starting epoch 2 of 5


 72%|███████▏  | 679/938 [00:14<00:05, 46.89it/s]


KeyboardInterrupt: ignored