In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

batch_size = 25
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.ToTensor()),
        batch_size = batch_size, shuffle=True)

from random import randint

print("kk")


##### TODO ######
# Circuit picking (all the reshapes and stuff)
# Circuit dropout
# Initial connection

kk


In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if device != 'cpu':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

num_classes=10

In [3]:
class CircuitNet(nn.Module):
    def __init__(self,circuit_dropout=0.5,circuit_ban_rate = 0.3,num_classes=12):
        super(CircuitNet, self).__init__()
        
        self.circuit_ban_rate = circuit_ban_rate
        self.circuit_dropout = circuit_dropout
        self.state_size = 100
        self.num_lstm_layers = 3
        
        #monitor use of each circuit to balance training
        self.used_circuits = np.zeros([40,2])
        for i in range(40):
            self.used_circuits[i][0] = i
        
        
        
        self.circuits_to_use = 10
        self.n_circuits = 20
        self.circuit_in_dim = 25
        
        
        #circuit picker
        self.circuit_picker = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2, stride=2),  
            nn.Conv2d(16, 8, 3, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm2d(8),
            nn.MaxPool2d(2, stride=2)
        )
        self.circuit_picker_out = nn.Linear(8*7*7, (self.n_circuits *  self.circuits_to_use))
        
        #data pipe
        self.data_pipe = nn.Sequential(
            nn.Conv2d(1, 16, 4, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2, stride=2),  
            nn.Conv2d(16, 32, 4, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2, stride=2)
        )
        self.data_pipe_out = nn.Linear(32*6*6, (self.circuits_to_use *  self.circuit_in_dim))
        
            
        self.circuits = []
        #define circuits
        for i in range(self.n_circuits):
        
            
            self.circuits.append(nn.Sequential(
                nn.Conv1d(1, 8, kernel_size=5, stride=1, padding=0),
                nn.BatchNorm1d(8),
                nn.ReLU(True),
                nn.MaxPool1d(kernel_size=2, stride=2),
            
                nn.Conv1d(8, 16, kernel_size=5, stride=1, padding=0),
                nn.BatchNorm1d(16),
                nn.ReLU(True),
                nn.MaxPool1d(kernel_size=2, stride=2)))
        
        
        #final circuit to prediction
        self.output = nn.Sequential(
            nn.Linear(16*30, 400),
            nn.Linear(400, 10))
        
    #init hidden state for lstm layer
    def init_hidden(self):
        return (torch.zeros(self.num_lstm_layers,1, self.state_size, device=device),
                torch.zeros(self.num_lstm_layers,1, self.state_size, device=device))
        
    def forward(self, x, verbose):
        x_flat = x.view(x.size(0),1, -1)
        circuit_pick = self.circuit_picker(x)
        circuit_pick = circuit_pick.view(circuit_pick.size(0), -1)
        circuit_pick = self.circuit_picker_out(circuit_pick)
        circuit_pick = circuit_pick.view(circuit_pick.size(0),self.circuits_to_use,self.n_circuits)
        circuit_pick = circuit_pick.max(2)[1]
        circuit_pick_padded = F.pad(circuit_pick,(0,18)).unsqueeze(1).unsqueeze(1)
        circuit_pick_padded = circuit_pick_padded.type(torch.cuda.FloatTensor)
        
        data_pipe_in = torch.cat((circuit_pick_padded,x),2)
        data_pipe = self.data_pipe(data_pipe_in)
        data_pipe = data_pipe.view(data_pipe.size(0), -1)
        
        data_pipe = self.data_pipe_out(data_pipe)
        data_pipe = data_pipe.view(data_pipe.size(0),self.circuits_to_use,self.circuit_in_dim)
        bans = np.random.randint(self.n_circuits,size=int(self.circuit_ban_rate*self.n_circuits))
        circuit_out = []
        for i in range(self.circuits_to_use):
            circuit_idx = circuit_pick[:,i]
            if(verbose):
                print(str(circuit_idx))
                
            batch = []
            for b in range(batch_size):
                circuit = self.circuits[circuit_idx[b]](data_pipe[b][i].unsqueeze(0).unsqueeze(0))
                
                #apply dropout
                if randint(1,100) < self.circuit_dropout*100 or circuit_idx[b] in bans:
                    circuit = torch.zeros(circuit.shape,device=device)  
                
                batch.append(circuit)
            batch = torch.cat(batch,0)
            circuit_out.append(batch)
        circuit_out = torch.cat(circuit_out,2)
        
        
        circuit_out_flat = circuit_out.view(circuit_out.size(0), -1)
        out = self.output(circuit_out_flat)
        #print(output.shape)
        #######
        #connect = self.connection(x_flat, verbose)
        #bridge = self.circuit_bridge(connect).unsqueeze(1)
        #print(bridge.shape)
        #connect1 = self.connection(bridge, verbose)
        
        ######
        
        #out = self.output(connect)
        return out
    
        

if device != 'cpu':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
model = CircuitNet(circuit_dropout=0.4, circuit_ban_rate = 0.4).to(device)
torch.set_default_tensor_type('torch.FloatTensor')

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    output = model(data,False)
    break

In [4]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()
iter_print = 25
for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    verbose = False
    if(batch_idx % iter_print == 0):
        verbose = False
    output = model(data,verbose)
    loss =  nn.CrossEntropyLoss()(output, target)
    #loss = (output-target)**2
    loss.backward(retain_graph=True)
    optimizer.step()
    
    model.hidden = model.init_hidden()
    if(batch_idx % iter_print == 0):
        print("Index : {}  Loss : {:.5f}  Real : {}  Pred : {}".format(batch_idx,loss[0],target[0:4],output[0:4].max(1)[1]))



Index : 0  Loss : 2.30596  Real : tensor([ 3,  6,  3,  3], device='cuda:0')  Pred : tensor([ 2,  6,  8,  0], device='cuda:0')
Index : 25  Loss : 2.02217  Real : tensor([ 4,  2,  3,  8], device='cuda:0')  Pred : tensor([ 7,  3,  4,  9], device='cuda:0')
Index : 50  Loss : 1.57218  Real : tensor([ 3,  3,  2,  6], device='cuda:0')  Pred : tensor([ 6,  6,  6,  6], device='cuda:0')
Index : 75  Loss : 1.48286  Real : tensor([ 2,  5,  5,  4], device='cuda:0')  Pred : tensor([ 2,  3,  5,  4], device='cuda:0')
Index : 100  Loss : 1.01163  Real : tensor([ 4,  8,  9,  6], device='cuda:0')  Pred : tensor([ 9,  6,  9,  0], device='cuda:0')
Index : 125  Loss : 0.86691  Real : tensor([ 7,  8,  3,  9], device='cuda:0')  Pred : tensor([ 7,  1,  3,  9], device='cuda:0')
Index : 150  Loss : 0.81922  Real : tensor([ 5,  3,  4,  9], device='cuda:0')  Pred : tensor([ 5,  3,  4,  9], device='cuda:0')
Index : 175  Loss : 0.75782  Real : tensor([ 3,  6,  9,  1], device='cuda:0')  Pred : tensor([ 3,  0,  9,  1]

Index : 1600  Loss : 0.24281  Real : tensor([ 1,  6,  2,  0], device='cuda:0')  Pred : tensor([ 2,  8,  2,  0], device='cuda:0')
Index : 1625  Loss : 0.14062  Real : tensor([ 2,  6,  9,  8], device='cuda:0')  Pred : tensor([ 2,  6,  9,  8], device='cuda:0')
Index : 1650  Loss : 0.19818  Real : tensor([ 5,  7,  1,  5], device='cuda:0')  Pred : tensor([ 5,  7,  1,  5], device='cuda:0')
Index : 1675  Loss : 0.13553  Real : tensor([ 3,  8,  3,  6], device='cuda:0')  Pred : tensor([ 3,  8,  3,  6], device='cuda:0')
Index : 1700  Loss : 0.22006  Real : tensor([ 1,  0,  7,  2], device='cuda:0')  Pred : tensor([ 1,  0,  7,  2], device='cuda:0')
Index : 1725  Loss : 0.27368  Real : tensor([ 1,  0,  9,  9], device='cuda:0')  Pred : tensor([ 1,  0,  9,  9], device='cuda:0')
Index : 1750  Loss : 0.05708  Real : tensor([ 8,  9,  1,  3], device='cuda:0')  Pred : tensor([ 8,  9,  1,  3], device='cuda:0')
Index : 1775  Loss : 0.09944  Real : tensor([ 3,  2,  1,  1], device='cuda:0')  Pred : tensor([ 3