In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import FashionMNIST

import random
import copy
import tqdm
import numpy as np
import matplotlib.pyplot as plt

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

In [2]:
class Args():
    device = 'cpu'
    # device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')

    input_size = 784
    output_size = 10
    num_neurons = 50
    num_committees = 90
    size_committee = 5

    beta = 1
    batch_size = 512
    lr = 1e-2
    weight_decay = 1e-5
    max_epochs = 100

In [3]:
train_data = FashionMNIST(root='./data', train=True, download=False)
test_data = FashionMNIST(root='./data', train=False, download=False)

X_train = train_data.data.unsqueeze(1).float().flatten(start_dim=1)/255
X_test = test_data.data.unsqueeze(1).float().flatten(start_dim=1)/255

y_train = F.one_hot(train_data.targets).float()
y_test = F.one_hot(test_data.targets).float()

train_data = torch.utils.data.TensorDataset(X_train, y_train)
test_data = torch.utils.data.TensorDataset(X_test, y_test)

RuntimeError: Dataset not found. You can use download=True to download it

In [None]:
class NPF(nn.Module):

    def __init__(self, input_size, output_size, num_neurons, size_committee, num_committees, args):
        super().__init__()
        self.args = args
        self.input_size = input_size
        self.output_size = output_size
        self.num_neurons = num_neurons
        self.num_committees = num_committees

        self.neurons = nn.ModuleList([nn.Linear(input_size, output_size, bias=False) for i in range(num_neurons)])
        self.npf = []
        for i in range(num_committees):
            while True:
                choice = torch.randint(0, num_neurons, size=(size_committee,))
                if len(torch.unique(choice)) == size_committee:
                    break
            self.npf.append(choice.tolist())

        self.init('gaussian')

        self.sigmoid = nn.Sigmoid()

    def init(self, type):
        if not hasattr(self, 'type'):
            self.type = type
        else:
            self.type = 'gaussian'

        if self.type == 'gaussian':
            for i in range(self.num_neurons):
                nn.init.normal_(self.neurons[i].weight, mean=0.0, std=1.0)
        else:
            raise NotImplementedError 

    def forward(self, x):
        z = torch.zeros((x.shape[0], self.num_committees, self.output_size)).to(self.args.device)
        for i in range(self.num_committees):
            z[:, i] = torch.prod(torch.stack([self.sigmoid(self.args.beta * (x @ self.neurons[self.npf[i][j]].weight.T)) for j in range(len(self.npf[i]))]), dim=0).squeeze(1)
        return z
    
class NPV(nn.Module):

    def __init__(self, input_size, output_size, num_neurons, size_committee, num_committees, args):
        super().__init__()
        self.args = args
        self.input_size = input_size
        self.output_size = output_size
        self.num_neurons = num_neurons
        self.num_committees = num_committees

        self.npv = nn.Linear(num_committees, 1, bias=False)

        self.init('gaussian')

    def init(self, type):
        if not hasattr(self, 'type'):
            self.type = type
        else:
            self.type = 'gaussian'

        if self.type == 'gaussian':
            nn.init.normal_(self.npv.weight, mean=0.0, std=1.0)
        else:
            raise NotImplementedError 

    def forward(self, z):
        z = self.npv(z.permute(0, 2, 1)) 
        return z.squeeze(-1)

In [None]:
class CommitteesTrainer():

    def __init__(self, data, models, optimzers, criterion, args):
        self.args = args
        
        self.traindata, self.testdata = data
        self.trainloader = torch.utils.data.DataLoader(self.traindata, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
        self.testloader = torch.utils.data.DataLoader(self.testdata, batch_size=self.args.batch_size, shuffle=False, drop_last=False)

        self.npf, self.npv = models
        self.npf, self.npv = self.npf.to(self.args.device), self.npv.to(self.args.device)
        self.npf_opt, self.npv_opt = optimzers
        self.criterion = criterion

        self.loss = []
        self.accuracy = []

    def train_epoch(self, step):
        train_loss = 0
        train_accuracy = 0

        for idx, (x, y) in tqdm.tqdm(enumerate(self.trainloader)):
            self.npf.train()
            self.npv.train()
            x = x.to(self.args.device)
            y = y.to(self.args.device)

            self.npf_opt.zero_grad()
            self.npv_opt.zero_grad()

            if step%5 != 0:
                z = self.npf(x)
                y_pred = self.npv(z)
                loss = self.criterion(y_pred, y)
                loss.backward()
                self.npf_opt.step()

            if step%5 == 0:
                z = self.npf(x)
                y_pred = self.npv(z)
                loss = self.criterion(y_pred, y)
                loss.backward()
                self.npv_opt.step()

            self.npf.eval()
            self.npv.eval()
            with torch.no_grad():
                z = self.npf(x)
                y_pred = self.npv(z)
                loss = self.criterion(y_pred, y)
                accuracy = (y_pred.argmax(dim=1) == y.argmax(dim=1)).float().mean()

            train_loss += loss.item()/len(self.trainloader)
            train_accuracy += accuracy.item()/len(self.trainloader)

        return train_loss, train_accuracy
    
    def test(self):
        self.npf.eval()
        self.npv.eval()
        test_loss = 0
        test_accuracy = 0
        for idx, (x, y) in enumerate(self.testloader):
            x = x.to(self.args.device)
            y = y.to(self.args.device)

            with torch.no_grad():
                z = self.npf(x)
                y_pred = self.npv(z)
                loss = self.criterion(y_pred, y)
                accuracy = (y_pred.argmax(dim=1) == y.argmax(dim=1)).float().mean()

            test_loss += loss.item()/len(self.testloader)
            test_accuracy += accuracy.item()/len(self.testloader)
        
        return loss.item(), accuracy.item()
    
    def train(self, epochs, verbose=True):
        for epoch in range(epochs):
            train_loss, train_accuracy = self.train_epoch(epoch)
            self.loss.append(train_loss)
            self.accuracy.append(train_accuracy)

            test_loss, test_accuracy = self.test()

            if verbose:
                print(f'Epoch: {epoch+1:03d}/{epochs:03d} | Train Loss: {train_loss:.3f} | Train Accuracy: {train_accuracy:.3f} | Test Loss: {test_loss:.3f} | Test Accuracy: {test_accuracy:.3f}')
        
        return test_accuracy

In [None]:
set_seed(0)
args = Args()

npf = NPF(args.input_size, args.output_size, args.num_neurons, args.size_committee, args.num_committees, args)
npv = NPV(args.input_size, args.output_size, args.num_neurons, args.size_committee, args.num_committees, args)

npf_opt = optim.SGD(npf.parameters(), lr=args.lr, weight_decay=0)
npv_opt = optim.SGD(npv.parameters(), lr=args.lr, weight_decay=args.weight_decay)

criterion = nn.CrossEntropyLoss()

trainer = CommitteesTrainer((train_data, test_data), (npf, npv), (npf_opt, npv_opt), criterion, args)
# trainer.train(args.max_epochs)

In [None]:
len(trainer.testloader)

20