In [6]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim
import torchvision

In [509]:
class LateralBlock(nn.Module):
    def __init__(self,col,depth,block,out_shape, in_shapes):
        super(LateralBlock,self).__init__()
        self.col = col
        self.depth = depth 
        self.out_shape = out_shape
        self.block = block
        self.u = nn.ModuleList()
        
        
        if self.depth > 0: 
            self.u.extend([nn.Linear(in_shape,self.out_shape) for in_shape in in_shapes])


    def forward(self,inputs):
        if not isinstance(inputs, list):
            inputs = [inputs]

        cur_column_out = self.block(inputs[-1])
        prev_columns_out = [mod(x) for mod, x in zip(self.u, inputs)]
        res=  F.relu(cur_column_out + sum(prev_columns_out))
        return res
    
        

class ProgNet(nn.Module):
    def __init__(self,depth):
        super(ProgNet,self).__init__()
        
        self.columns = nn.ModuleList([])
        self.depth = depth
        
    
    def forward(self,x,task_id=-1):
        assert self.columns

        inputs = [col[0](x) for col in self.columns]
        for l in range(1,self.depth):
            out = []         
            for i,col in enumerate(self.columns):
                print("l {},i {}".format(l,i))
                out.append(col[l](inputs[:i+1]))

                inputs = out
        return out[task_id]
        
    def new_task(self,new_layers,shapes):
        assert isinstance(new_layers,nn.Sequential)
        assert(len(new_layers) == len(shapes))
        
        task_id = len(self.columns)
        idx =[i for i,layer in enumerate(new_layers) if isinstance(layer,(nn.Conv2d,nn.Linear))] + [len(new_layers)]
        new_blocks = []
        
        for k in range(len(idx) -1): 
            prev_blocks = []
            if k > 0: 
                prev_blocks = [col[k-1] for col in self.columns]
                
            new_blocks.append(LateralBlock(col = task_id,
                                           depth = k,
                                           block = new_layers[idx[k]:idx[k+1]],
                                           out_shape = shapes[idx[k+1]-1],
                                           in_shapes = self._get_out_shape_blocks(prev_blocks)
                                          ))
        
        new_column = nn.ModuleList(new_blocks)
        self.columns.append(new_column)
            
            
        
    def _get_out_shape_blocks(self,blocks):
        assert isinstance(blocks,list)
        assert all(isinstance(block,LateralBlock) for block in blocks)
        return [block.out_shape for block in blocks]
        
        
    
    def freeze_columns(self, skip=None):
        if skip == None:
            skip = []

        for i, c in enumerate(self.columns):
            if i not in skip:
                for params in c.parameters():
                    params.requires_grad = False

In [234]:
net = nn.Sequential(
            nn.Linear(1,16),
            nn.ReLU(),
            nn.Linear(16,32),
            nn.ReLU(),
            nn.Linear(32,1))
shapes = [16,16,32,32,1]
params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))

Total number of parameters is: 609


In [235]:
pg = ProgNet(5)
params = sum(p.numel() for p in pg.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))

Total number of parameters is: 0


In [236]:
pg.new_task(net,shapes)
params = sum(p.numel() for p in pg.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))

Total number of parameters is: 609


In [237]:
pg.new_task(net,shapes)
params = sum(p.numel() for p in pg.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))

Total number of parameters is: 1218


In [91]:
pg.columns[1]

ModuleList(
  (0): LateralBlock(
    (block): Sequential(
      (0): Linear(in_features=1, out_features=16, bias=True)
      (1): ReLU()
    )
    (u): ModuleList()
  )
  (1): LateralBlock(
    (block): Sequential(
      (2): Linear(in_features=16, out_features=32, bias=True)
      (3): ReLU()
    )
    (u): ModuleList(
      (0): Linear(in_features=16, out_features=32, bias=True)
    )
  )
  (2): LateralBlock(
    (block): Sequential(
      (4): Linear(in_features=32, out_features=1, bias=True)
    )
    (u): ModuleList(
      (0): Linear(in_features=32, out_features=1, bias=True)
    )
  )
)

In [34]:
net = nn.Sequential(
        nn.Conv2d(1,4,3,stride = 1, padding = 1), #out: (B,4,28,28)
        nn.ReLU(), 
        nn.BatchNorm2d(4),
        nn.MaxPool2d(2,stride = 2), #out: (B,4,14,14)
        nn.Conv2d(4,16,3,stride = 1, padding = 1), #out: (B,16,14,14)
        nn.ReLU(), 
        nn.BatchNorm2d(16),
        nn.MaxPool2d(2,stride = 2), #out: (B,16,7,7)
        nn.Conv2d(16,32,3,stride = 1, padding = 1), #out: (B,32,7,7)
        nn.ReLU(), 
        nn.BatchNorm2d(32),
        nn.MaxPool2d(2,stride = 2), #out: (B,32,3,3)
        nn.Flatten(),
        nn.Linear(in_features = 32*3*3,out_features = 128),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(in_features = 128, out_features = 64), 
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(in_features = 64, out_features = 10))

shapes = [(4,28,28),
          (4,28,28),
          (4,28,28),
          (4,14,14),
          (16,14,14),
          (16,14,14),
          (16,14,14),
          (16,7,7),
          (32,7,7),
          (32,7,7),
          (32,7,7),
          (32,3,3),
          (128),
          (64),
          (10)]

In [24]:
idx = [i for i,layer in enumerate(net) if isinstance(layer,(nn.Conv2d,nn.Linear))]
for k in range(len(idx) -1): 
    print(net[idx[k]:idx[k+1]])

Sequential(
  (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU()
  (2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
Sequential(
  (4): Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): ReLU()
  (6): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
Sequential(
  (8): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU()
  (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (12): Flatten()
)
Sequential(
  (13): Linear(in_features=288, out_features=128, bias=True)
  (14): ReLU()
  (15): Dropout(p=0.5, inplace=False)
)
Sequential(
  (16): Linear(in_features=128, ou

In [35]:
type(net[0:4])

torch.nn.modules.container.Sequential

# MNIST Classification

In [15]:
#data loading
batch_size_train = 64
batch_size_test = 1000
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [16]:
data = next(iter(train_loader))

In [18]:
img,_ = data
net(img).shape

torch.Size([64, 10])

In [362]:
#model definition
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        
        self.conv_layers = nn.Sequential(
        nn.Conv2d(1,4,3,stride = 1, padding = 1), #out: (B,4,28,28)
        nn.ReLU(), 
        nn.BatchNorm2d(4),
        nn.MaxPool2d(2,stride = 2), #out: (B,4,14,14)
        nn.Conv2d(4,16,3,stride = 1, padding = 1), #out: (B,16,14,14)
        nn.ReLU(), 
        nn.BatchNorm2d(16),
        nn.MaxPool2d(2,stride = 2), #out: (B,16,7,7)
        nn.Conv2d(16,32,3,stride = 1, padding = 1), #out: (B,32,7,7)
        nn.ReLU(), 
        nn.BatchNorm2d(32),
        nn.MaxPool2d(2,stride = 2)) #out: (B,32,3,3)
        
        self.classifier = nn.Sequential(
        nn.Linear(in_features = 32*3*3,out_features = 128),
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(in_features = 128, out_features = 64), 
        nn.ReLU(),
        nn.Dropout(),
        nn.Linear(in_features = 64, out_features = 10))
        
    def forward(self,x):
        x = self.conv_layers(x)
        x = x.view(x.shape[0],-1)
        x = self.classifier(x)
        return x

In [77]:
model = MyNet()
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))
print(model)

Total number of parameters is: 51274
MyNet(
  (conv_layers): Sequential(
    (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(4, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
    (6): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=288, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inp

In [78]:
#Training
learning_rate = 0.001
num_epochs = 10

optimizer = optim.Adam(model.parameters(),
                       lr = learning_rate)
criterion = F.cross_entropy


In [79]:
train_losses = []
model.train()
for epoch in range(num_epochs):
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        img, y = data
        optimizer.zero_grad()
        # forward
        y_pred = model(img)
        loss = criterion(y_pred, y,reduction = 'mean')
        # backward
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch + 1, num_epochs, train_loss / len(train_loader)))

    train_losses.append(train_loss/ len(train_loader))




epoch [1/10], loss:0.3411
epoch [2/10], loss:0.1091
epoch [3/10], loss:0.0876
epoch [4/10], loss:0.0727
epoch [5/10], loss:0.0657
epoch [6/10], loss:0.0600
epoch [7/10], loss:0.0523
epoch [8/10], loss:0.0496
epoch [9/10], loss:0.0469
epoch [10/10], loss:0.0454


## Classification with ProgNet: fully connected network

In [497]:
def make_column(): 
    return nn.Sequential(nn.Linear(28*28*1,256),
                      nn.ReLU(),
                      nn.Linear(256,128),
                      nn.ReLU(), 
                      nn.Linear(128,64),
                      nn.ReLU(),
                      nn.Linear(64,10))

shapes = [256,256,128,128,64,64,10]



In [498]:
#data loading
batch_size_train = 64
batch_size_test = 1000

dataset = torchvision.datasets.MNIST('./data', train=True, download=True,
                                     transform = torchvision.transforms.Compose([
                                                 torchvision.transforms.ToTensor(),
                                                 torchvision.transforms.Normalize(
                                                 (0.1307,), (0.3081,))]))

data = dataset.data
targets = dataset.targets

In [499]:
split_idx = [(dataset.targets ==i).tolist() for i in range(10)]

### Training phase  1 
Column 0, digits 0 to 4

In [500]:
mask_digits_0_4 = [any(tupl) for tupl in zip(split_idx[0],split_idx[1],split_idx[2],split_idx[3],split_idx[4])]

dataset.data = data[mask_digits_0_4]
dataset.targets = targets[mask_digits_0_4]

train_loader = torch.utils.data.DataLoader(
  dataset,
  batch_size=batch_size_train, shuffle=True)

In [501]:
prog_net = ProgNet(depth = 4)
prog_net.new_task(make_column(),shapes)
dummy_net = make_column() 

params = sum(p.numel() for p in prog_net.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))
print(prog_net)

Total number of parameters is: 242762
ProgNet(
  (columns): ModuleList(
    (0): ModuleList(
      (0): LateralBlock(
        (block): Sequential(
          (0): Linear(in_features=784, out_features=256, bias=True)
          (1): ReLU()
        )
        (u): ModuleList()
      )
      (1): LateralBlock(
        (block): Sequential(
          (2): Linear(in_features=256, out_features=128, bias=True)
          (3): ReLU()
        )
        (u): ModuleList()
      )
      (2): LateralBlock(
        (block): Sequential(
          (4): Linear(in_features=128, out_features=64, bias=True)
          (5): ReLU()
        )
        (u): ModuleList()
      )
      (3): LateralBlock(
        (block): Sequential(
          (6): Linear(in_features=64, out_features=10, bias=True)
        )
        (u): ModuleList()
      )
    )
  )
)


In [502]:
#Training
learning_rate = 0.001
num_epochs = 10

optimizer = optim.Adam(prog_net.parameters(),
                       lr = learning_rate)
optimizer_d = optim.Adam(dummy_net.parameters(),
                       lr = learning_rate)
criterion = F.cross_entropy

In [503]:
train_losses = []
train_losses_d = []
prog_net.train()
dummy_net.train()
for epoch in range(num_epochs):
    train_loss = 0
    train_loss_d = 0
    for batch_idx, data in enumerate(train_loader):
        img, y = data
        img = img.view(img.shape[0],-1)
        optimizer.zero_grad()
        optimizer_d.zero_grad()
        # forward
        y_pred = prog_net(img)
        y_pred_d = dummy_net(img)
        loss = criterion(y_pred, y,reduction = 'mean')
        loss_d = criterion(y_pred_d, y,reduction = 'mean')
        # backward
        loss.backward()
        loss_d.backward()
        train_loss += loss.item()
        train_loss_d += loss_d.item()
        optimizer.step()
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch + 1, num_epochs, train_loss / len(train_loader)))

    train_losses.append(train_loss/ len(train_loader))
    train_losses_d.append(train_loss_d/ len(train_loader))

epoch [1/10], loss:0.3564
epoch [2/10], loss:0.0477
epoch [3/10], loss:0.0314
epoch [4/10], loss:0.0184
epoch [5/10], loss:0.0159
epoch [6/10], loss:0.0147
epoch [7/10], loss:0.0151
epoch [8/10], loss:0.0080
epoch [9/10], loss:0.0118
epoch [10/10], loss:0.0070


### Training Phase 2
Column 1, digits 5 to 9

In [504]:
#data loading
batch_size_train = 64
batch_size_test = 1000

dataset = torchvision.datasets.MNIST('./data', train=True, download=True,
                                     transform = torchvision.transforms.Compose([
                                                 torchvision.transforms.ToTensor(),
                                                 torchvision.transforms.Normalize(
                                                 (0.1307,), (0.3081,))]))

data = dataset.data
targets = dataset.targets

In [505]:
mask_digits_5_9 = [any(tupl) for tupl in zip(split_idx[5],split_idx[6],split_idx[7],split_idx[8],split_idx[9])]

dataset.data = data[mask_digits_5_9]
dataset.targets = targets[mask_digits_5_9]

train_loader = torch.utils.data.DataLoader(
  dataset,
  batch_size=batch_size_train, shuffle=True)

In [506]:
prog_net.freeze_columns()
prog_net.new_task(make_column(),shapes)

params = sum(p.numel() for p in prog_net.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))
print(prog_net)

Total number of parameters is: 284564
ProgNet(
  (columns): ModuleList(
    (0): ModuleList(
      (0): LateralBlock(
        (block): Sequential(
          (0): Linear(in_features=784, out_features=256, bias=True)
          (1): ReLU()
        )
        (u): ModuleList()
      )
      (1): LateralBlock(
        (block): Sequential(
          (2): Linear(in_features=256, out_features=128, bias=True)
          (3): ReLU()
        )
        (u): ModuleList()
      )
      (2): LateralBlock(
        (block): Sequential(
          (4): Linear(in_features=128, out_features=64, bias=True)
          (5): ReLU()
        )
        (u): ModuleList()
      )
      (3): LateralBlock(
        (block): Sequential(
          (6): Linear(in_features=64, out_features=10, bias=True)
        )
        (u): ModuleList()
      )
    )
    (1): ModuleList(
      (0): LateralBlock(
        (block): Sequential(
          (0): Linear(in_features=784, out_features=256, bias=True)
          (1): ReLU()
        )

In [507]:
#Training
learning_rate = 0.001
num_epochs = 10

optimizer = optim.Adam(prog_net.parameters(),
                       lr = learning_rate)
optimizer_d = optim.Adam(dummy_net.parameters(),
                       lr = learning_rate)
criterion = F.cross_entropy

In [508]:
prog_net.train()
dummy_net.train()
for epoch in range(num_epochs):
    train_loss = 0
    train_loss_d = 0
    for batch_idx, data in enumerate(train_loader):
        img, y = data
        img = img.view(img.shape[0],-1)
        optimizer.zero_grad()
        optimizer_d.zero_grad()
        # forward
        y_pred = prog_net(img)
        y_pred_d = dummy_net(img)
        loss = criterion(y_pred, y,reduction = 'mean')
        loss_d = criterion(y_pred_d, y,reduction = 'mean')
        # backward
        loss.backward()
        loss_d.backward()
        train_loss += loss.item()
        train_loss_d += loss_d.item()
        optimizer.step()
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch + 1, num_epochs, train_loss / len(train_loader)))

    train_losses.append(train_loss/ len(train_loader))
    train_losses_d.append(train_loss_d/ len(train_loader))

RuntimeError: size mismatch, m1: [64 x 128], m2: [256 x 128] at /tmp/pip-req-build-p5q91txh/aten/src/TH/generic/THTensorMath.cpp:752