In [1]:
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from CellNet import *

In [2]:
dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
dataset2 = datasets.MNIST('../data', train=False,transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

train_loader = torch.utils.data.DataLoader(dataset1, batch_size=1)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=1)

In [3]:
cellnum = 4
incellnum = 2
outcellnum = 1
epochs = 10

In [8]:
class CellNetMNIST(nn.Module):
    def __init__(self, cellnum, incellnum, outcellnum):
        super(CellNetMNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)

        self.cellnum = cellnum
        self.incellnum = incellnum
        self.outcellnum = outcellnum

        self.cells = []
        for _ in range(incellnum):
            self.cells.append(InCell(64, 64, cellnum-1, 9216))

        for _ in range(cellnum - incellnum - outcellnum):
            self.cells.append(Cell(64, 64, cellnum-1))

        for _ in range(outcellnum):
            self.cells.append(OutCell(64, 64, cellnum-1, 10))


    def first_step(self, train_loader):

        optimizers = []
        for cell in self.cells:
            optimizers.append(torch.optim.Adam(cell.parameters()))

        optimizer_base = torch.optim.Adam([
                {'params': self.conv1.parameters()},
                {'params': self.conv2.parameters()},
                {'params': self.dropout1.parameters()}
            ])

        initialconn = []
        for _ in range(cellnum-1):
            initialconn.append(torch.zeros(64))

        sample_img, sample_target = next(iter(train_loader))

        for cell_optim in optimizers:
            cell_optim.zero_grad()

        optimizer_base.zero_grad()

        x = self.conv1(sample_img)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        # x = torch.flatten(x, 1) # Needes to be fixed to allow for batches
        x = torch.flatten(x, 0)

        initialouts = []
        new_cores = []
        for idx, cell in enumerate(self.cells):
            if idx < incellnum:
                initialout, new_core = cell(initialconn, x)
                initialouts.append(initialout)
                new_cores.append(new_core)
            elif idx >= incellnum and idx < cellnum - outcellnum:
                initialout, new_core = cell(initialconn)
                initialouts.append(initialout)
                new_cores.append(new_core)
            else: #works only for outcellnum = 1
                initialout, new_core, output = cell(initialconn, out=True)
                initialouts.append(initialout)
                new_cores.append(new_core)
        
        output = F.log_softmax(output, dim=0) # dim=1 when batched
        loss = F.nll_loss(torch.unsqueeze(output,0), sample_target) # Unsqueeze wont be needed when batched

        print(new_cores)

        print(self.cells[0].core)

        loss.backward()
        optimizer_base.step()

        print(self.cells[0].core)

        for cell_optim in optimizers:
            cell_optim.zero_grad()

        # Group new cores by which cell they will be input into
        # TODO: Optimise this
        initialouts_grouped = [[] for _ in initialouts]
        for idx, iniout in enumerate(initialouts):
            for grpidx, i in enumerate(iniout):
                if grpidx < idx:
                    initialouts_grouped[grpidx].append(i)
                elif grpidx >= idx:
                    initialouts_grouped[grpidx+1].append(i)

        return optimizers, optimizer_base, initialouts_grouped, new_cores

    def forward(self, prev_conn, prev_cores, x=None):
        if x is not None:
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            # x = torch.flatten(x, 1) # Needes to be fixed to allow for batches
            x = torch.flatten(x, 0)

        new_conns = []
        new_cores = []
        for idx, cell in enumerate(self.cells):
            if idx < incellnum:
                new_conn, new_core = cell(prev_conn[idx], x)
                new_conns.append(new_conn)
                new_cores.append(new_core)
            elif idx >= incellnum and idx < cellnum - outcellnum:
                new_conn, new_core = cell(prev_conn[idx])
                new_conns.append(new_conn)
                new_cores.append(new_core)
            else: #works only for outcellnum = 1
                new_conn, new_core, output = cell(prev_conn[idx], out=True)
                new_conns.append(new_conn)
                new_cores.append(new_core)

        output = F.log_softmax(output, dim=0) # dim=1 when batched

        new_conns_grouped = [[] for _ in initialouts]
        for idx, iniout in enumerate(initialouts):
            for grpidx, i in enumerate(iniout):
                if grpidx < idx:
                    new_conns_grouped[grpidx].append(i)
                elif grpidx >= idx:
                    new_conns_grouped[grpidx+1].append(i)
        
        return output, new_conns_grouped, new_cores

In [10]:
model = CellNetMNIST(cellnum, incellnum, outcellnum)

_ = model.first_step(train_loader)

[[tensor([-0.1430, -0.1388, -0.2437, -0.0997,  0.1904,  0.1540,  0.0597,  0.0656,
         0.1246,  0.1975,  0.2180, -0.1246, -0.3351,  0.0435, -0.1144,  0.2654,
         0.0874,  0.0527,  0.1728,  0.0604, -0.0550,  0.1212,  0.0264, -0.2357,
         0.3122,  0.0952,  0.3136, -0.1442,  0.0408, -0.1833,  0.1451, -0.1762,
        -0.0483, -0.1619,  0.0973, -0.0690, -0.1009, -0.3599,  0.0120, -0.3232,
        -0.1335,  0.3842,  0.0113, -0.0067,  0.0260,  0.1846, -0.5403, -0.2982,
         0.0487,  0.2373,  0.1381,  0.0897,  0.0401, -0.2979,  0.1204,  0.0130,
        -0.2373,  0.2124, -0.2652,  0.1871,  0.0500, -0.1070,  0.0321,  0.3206],
       grad_fn=<ViewBackward0>), tensor([-0.1454, -0.1430, -0.2474, -0.0959,  0.1872,  0.1526,  0.0575,  0.0641,
         0.1254,  0.1974,  0.2202, -0.1270, -0.3399,  0.0431, -0.1130,  0.2668,
         0.0875,  0.0534,  0.1767,  0.0585, -0.0534,  0.1201,  0.0272, -0.2325,
         0.3144,  0.0970,  0.3103, -0.1502,  0.0379, -0.1784,  0.1441, -0.1785,
    

In [11]:
model = CellNetMNIST(cellnum, incellnum, outcellnum)
def get_n_params_cell(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    for cell in model.cells:
        for p in list(cell.parameters()):
            nn=1
            for s in list(p.size()):
                nn = nn*s
            pp += nn
    return pp

get_n_params_cell(model)

2555530

In [None]:
model = CellNetMNIST(cellnum, incellnum, outcellnum)

optimizers, optimizer_base, initialouts = model.first_step(train_loader)



for epoch in range(epochs):
    for idx, imgs in enumerate(images):

        

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)