In [1]:
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from CellNet import *

In [2]:
dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
dataset2 = datasets.MNIST('../data', train=False,transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

train_loader = torch.utils.data.DataLoader(dataset1, batch_size=1)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=1)

In [3]:
cellnum = 4
incellnum = 2
outcellnum = 1

In [30]:
class CellNetMNIST(nn.Module):
    def __init__(self, cellnum, incellnum, outcellnum):
        super(CellNetMNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)

        self.cellnum = cellnum
        self.incellnum = incellnum
        self.outcellnum = outcellnum

        self.cells = []
        for _ in range(incellnum):
            self.cells.append(InCell(128, 64, cellnum-1, 9216))

        for _ in range(cellnum - incellnum - outcellnum):
            self.cells.append(Cell(128, 64, cellnum-1))

        for _ in range(outcellnum):
            self.cells.append(OutCell(128, 64, cellnum-1, 10))

    def forward(self, C1, x=None):

        input_cores_grouped = [[] for _ in C1]
        for idx, in_cores in enumerate(input_cores_grouped):
            in_cores += C1[:idx] + C1[idx+1:]

        if x is not None:
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            # x = torch.flatten(x, 1) # Needes to be fixed to allow for batches
            x = torch.flatten(x, 0)

        new_cores = []
        for idx, cell in enumerate(self.cells):
            if idx < incellnum:
                if x is None:
                    new_core = cell(input_cores_grouped[idx])
                else:
                    new_core = cell(input_cores_grouped[idx], x)
                new_cores.append(new_core)
            elif idx >= incellnum and idx < cellnum - outcellnum:
                new_core = cell(input_cores_grouped[idx])
                new_cores.append(new_core)
            else: #works only for outcellnum = 1
                new_core, output = cell(input_cores_grouped[idx], out=True)
                new_cores.append(new_core)

        output = F.log_softmax(output, dim=0) # dim=1 when batched
        
        return output, new_cores

In [31]:
model = CellNetMNIST(cellnum, incellnum, outcellnum)
def get_n_params_cell(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    for cell in model.cells:
        for p in list(cell.parameters()):
            nn=1
            for s in list(p.size()):
                nn = nn*s
            pp += nn
    return pp

get_n_params_cell(model)

2636618

In [39]:
def train(model, train_loader, optimizer, epoch, prev_cores):
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
    
        optimizer.zero_grad()

        output, new_cores = model(prev_cores, data)

        loss = F.nll_loss(torch.unsqueeze(output,0), target) # Unsqueeze wont be needed when batched
        loss.backward(retain_graph=True)

        for idx, new_core in enumerate(new_cores):
            F.mse_loss(new_core, prev_cores[idx]).backward(retain_graph=True)
        
        optimizer.step()

        if batch_idx % 6000 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        
        





        

        # return new_cores

In [40]:
model = CellNetMNIST(cellnum, incellnum, outcellnum)

all_params = []
for cell in model.cells:
    all_params.append({'params': cell.parameters()})

all_params +=[
        {'params': model.conv1.parameters()},
        {'params': model.conv2.parameters()},
        {'params': model.dropout1.parameters()}
    ]

optimizer = torch.optim.Adam(all_params)

fake_cores = []
for _ in range(cellnum):
    fake_cores.append(torch.ones(128))

_ = train(model, train_loader, optimizer, 1, fake_cores)

0
