<a href="https://colab.research.google.com/github/AWorldOfChaos/SoC-2024-Robust-ML/blob/main/Uday/bnn_pt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split, TensorDataset
from sklearn.model_selection import train_test_split
import numpy as np

# Define custom binarization function
class BinarizeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return torch.sign(input)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

binarize = BinarizeFunction.apply

# Define Binarized Linear layer
class BinarizedLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(BinarizedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight, 0, 0.1)
        if self.bias is not None:
            nn.init.constant_(self.bias, 0)

    def forward(self, input):
        binary_weight = binarize(self.weight)
        output = nn.functional.linear(input, binary_weight, self.bias)
        return output

# Define BNN model in PyTorch
class BNNModel(nn.Module):
    def __init__(self):
        super(BNNModel, self).__init__()
        self.flatten = nn.Flatten()
        self.dense1 = BinarizedLinear(28*28, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dense2 = BinarizedLinear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.dense3 = BinarizedLinear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.dense4 = BinarizedLinear(128, 64)
        self.bn4 = nn.BatchNorm1d(64)
        self.dense5 = BinarizedLinear(64, 32)
        self.bn5 = nn.BatchNorm1d(32)
        self.dense6 = BinarizedLinear(32, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.bn1(self.dense1(x))
        x = binarize(x)
        x = self.bn2(self.dense2(x))
        x = binarize(x)
        x = self.bn3(self.dense3(x))
        x = binarize(x)
        x = self.bn4(self.dense4(x))
        x = binarize(x)
        x = self.bn5(self.dense5(x))
        x = binarize(x)
        x = self.dense6(x)
        return nn.functional.softmax(x, dim=1)

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)

# Split train dataset into train and validation
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Instantiate the model, loss function, and optimizer
model = BNNModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Early stopping class
class EarlyStopping:
    def __init__(self, patience=6, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_state_dict = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
        elif val_loss > self.best_loss + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                model.load_state_dict(self.best_state_dict)
        else:
            self.best_loss = val_loss
            self.best_state_dict = model.state_dict()
            self.counter = 0

# Train the model with early stopping
def train(model, train_loader, val_loader, optimizer, criterion, epochs=100):
    early_stopping = EarlyStopping(patience=6)
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total * 100

        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total * 100

        print(f"Epoch [{epoch + 1}/{epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%")

        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break

# Train the model
train(model, train_loader, val_loader, optimizer, criterion, epochs=100)

# Evaluate the model
model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss /= len(test_loader)
test_acc = correct / total * 100
print(f'Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%')

# Save the PyTorch model
torch.save(model.state_dict(), 'bnn_model.pt')


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 146289813.35it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 8164433.09it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 49604017.99it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2861728.82it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch [1/100], Train Loss: 1.6908, Train Accuracy: 77.05%, Val Loss: 1.6153, Val Accuracy: 84.63%
Epoch [2/100], Train Loss: 1.6093, Train Accuracy: 85.20%, Val Loss: 1.5875, Val Accuracy: 87.33%
Epoch [3/100], Train Loss: 1.5951, Train Accuracy: 86.61%, Val Loss: 1.5835, Val Accuracy: 87.68%
Epoch [4/100], Train Loss: 1.5871, Train Accuracy: 87.40%, Val Loss: 1.5756, Val Accuracy: 88.50%
Epoch [5/100], Train Loss: 1.5810, Train Accuracy: 88.00%, Val Loss: 1.5672, Val Accuracy: 89.50%
Epoch [6/100], Train Loss: 1.5773, Train Accuracy: 88.31%, Val Loss: 1.5658, Val Accuracy: 89.63%
Epoch [7/100], Train Loss: 1.5717, Train Accuracy: 88.98%, Val Loss: 1.5594, Val Accuracy: 90.08%
Epoch [8/100], Train Loss: 1.5690, Train Accuracy: 89.22%, Val Loss: 1.5676, Val Accuracy: 89.35%
Epoch [9/100], Train Loss: 1.5664, Train Accuracy: 89.48%, Val Loss: 1.5574, Val Accuracy: 90.23%
Epoch [10/100], Train Loss: 1.5643, Train A

In [None]:
# Save the PyTorch model
torch.save(model.state_dict(), 'bnn_model.pt')

In [None]:
from google.colab import files


# Download the saved model to local machine
files.download('bnn_model.pt')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:

from torchsummary import summary

# Print the model summary
summary(model, input_size=(1, 28, 28))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1                  [-1, 784]               0
   BinarizedLinear-2                  [-1, 512]         401,920
       BatchNorm1d-3                  [-1, 512]           1,024
   BinarizedLinear-4                  [-1, 256]         131,328
       BatchNorm1d-5                  [-1, 256]             512
   BinarizedLinear-6                  [-1, 128]          32,896
       BatchNorm1d-7                  [-1, 128]             256
   BinarizedLinear-8                   [-1, 64]           8,256
       BatchNorm1d-9                   [-1, 64]             128
  BinarizedLinear-10                   [-1, 32]           2,080
      BatchNorm1d-11                   [-1, 32]              64
  BinarizedLinear-12                   [-1, 10]             330
Total params: 578,794
Trainable params: 578,794
Non-trainable params: 0
-------------------------------