In [1]:
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from CellNet import *

In [2]:
dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
dataset2 = datasets.MNIST('../data', train=False,transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))

train_loader = torch.utils.data.DataLoader(dataset1, batch_size=1)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=1)

In [3]:
cellnum = 4
incellnum = 2
outcellnum = 1
epochs = 10

In [21]:
class CellNetMNIST(nn.Module):
    def __init__(self, cellnum, incellnum, outcellnum):
        super(CellNetMNIST, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)

        self.cells= []
        for _ in range(incellnum):
            self.cells.append(InCell(64, 64, cellnum-1, 9216))

        for _ in range(cellnum - incellnum - outcellnum):
            self.cells.append(Cell(64, 64, cellnum-1))

        for _ in range(outcellnum):
            self.cells.append(OutCell(64, 64, cellnum-1, 10))


    def first_step(self, train_loader):

        optimizers = []
        for cell in self.cells:
            optimizers.append(torch.optim.Adam(cell.parameters()))

        optimizer_base = torch.optim.Adam([
                {'params': self.conv1.parameters()},
                {'params': self.conv2.parameters()},
                {'params': self.dropout1.parameters()}
            ])

        initialconn = []
        for _ in range(cellnum-1):
            initialconn.append(torch.zeros(64))

        sample_img, sample_target = next(iter(train_loader))

        for cell_optim in optimizers:
            cell_optim.zero_grad()

        optimizer_base.zero_grad()

        x = self.conv1(sample_img)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        # x = torch.flatten(x, 1) # Needes to be fixed to allow for batches
        x = torch.flatten(x, 0)

        initialouts = []
        for idx, cell in enumerate(self.cells):
            if idx < incellnum:
                initialouts.append(cell(initialconn, x))
            elif idx >= incellnum and idx < cellnum - outcellnum:
                initialouts.append(cell(initialconn))
            else: #works only for outcellnum = 1
                interout, output = cell(initialconn, out=True)
                initialouts.append(interout)
        
        loss = F.nll_loss(torch.unsqueeze(output,0), sample_target) # Unsqueeze wontr be needed when batched

        loss.backward()
        optimizer_base.step()

        for cell_optim in optimizers:
            cell_optim.zero_grad()

        return optimizers, optimizer_base, initialouts

    def forward(self, x=None):
        if x is not None:
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
        output = F.log_softmax(x, dim=1)
        return output

In [39]:
testmodel = CellNetMNIST(cellnum, incellnum, outcellnum)

def get_n_params_cell(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    for cell in model.cells:
        for p in list(cell.parameters()):
            nn=1
            for s in list(p.size()):
                nn = nn*s
            pp += nn
    return pp

get_n_params_cell(testmodel)

2555530

In [None]:
for epoch in range(epochs):
    for idx, imgs in enumerate(images):

        

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)