In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
from random import randint

import matplotlib.pyplot as plt
%matplotlib inline

batch_size = 15
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.ToTensor()),
        batch_size = 10000, shuffle=True)

from random import randint

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if device != 'cpu':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

num_classes=10

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=3, dropout=0):
        super(EncoderRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout
        self.gru = nn.GRU(input_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True)
        self.hidden = None
        #self.init_weights()
        
    def forward(self, input_seqs, input_lengths):
       # input_seqs = input_seqs.type(torch.FloatTensor)
        #packed = torch.nn.utils.rnn.pack_padded_sequence(input_seqs, input_lengths)
        
        outputs, self.hidden = self.gru(input_seqs, self.hidden)
        #outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs 
        return outputs
    
    def init_weights(self):
        for name, param in self.gru.named_parameters():
            if 'bias' in name:
                nn.init.constant(param, 0.0)
            elif 'weight' in name:
                nn.init.xavier_normal(param)
class DecoderRNN(nn.Module):
    def __init__(self,output_size,hidden_size,enc_hidden_size,n_layers=3):
        super(DecoderRNN, self).__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.enc_hidden_size = enc_hidden_size

        self.gru = nn.GRU(enc_hidden_size, hidden_size, n_layers, bidirectional=True)
        self.hidden = None
       # self.init_weights()
        
        self.concat = nn.Linear(hidden_size*2,hidden_size)
        self.out = nn.Linear(hidden_size,output_size)
        
        
    def forward(self, encoder_outputs):
        rnn_in = encoder_outputs[-1].unsqueeze(0)
           
        outputs, self.hidden = self.gru(rnn_in, self.hidden)
        outputs = self.concat(outputs)
        out = self.out(outputs)
        
        return out
    
    def init_weights(self):
        for name, param in self.gru.named_parameters():
            if 'bias' in name:
                nn.init.constant(param, 0.0)
            elif 'weight' in name:
                nn.init.xavier_normal(param)

In [3]:
class CircuitNet(nn.Module):
    def __init__(self,num_classes=12,num_lstm_layers=3):
        super(CircuitNet, self).__init__()
        
        self.num_lstm_layers = num_lstm_layers

        self.circuits_to_use = 3
        self.max_circuits = 6
        self.n_circuits = 10
        self.circuit_in_dim = 4
        
        self.circuit_out_shape = 16*3
        self.init_layer =  nn.Linear(28*28, self.circuit_out_shape*5)
        
        #circuit picker able to send instructions to data pipe
        self.hidden_data_num = 10
        
        hidden_size = 3
        self.encoder = EncoderRNN(self.circuit_out_shape,hidden_size,n_layers = 3)
        self.decoder = DecoderRNN(self.n_circuits+self.hidden_data_num,hidden_size,hidden_size,n_layers = 3)
        
        self.circuit_picker = nn.Sequential(
            nn.Conv1d(1, 16, 3, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(16),
            nn.MaxPool1d(2, stride=2),  
            nn.Conv1d(16, 8, 3, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(8),
            nn.MaxPool1d(2, stride=2)
        )
        
        self.circuit_picker_out = nn.LSTM(1568,(self.n_circuits *  self.circuits_to_use), self.num_lstm_layers)
        self.hidden = None
        
        #data pipe
        self.data_pipe = nn.Sequential(
            nn.Conv1d(1, 4, 4, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(4),
            nn.Conv1d(4, 8, 4, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(8),
            nn.MaxPool1d(2, stride=2), 
            
            nn.Conv1d(8, 4, 4, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(4),
            nn.Conv1d(4, 8, 4, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(8),
            nn.MaxPool1d(2, stride=2)
        )
        
        #self.data_pipe_out = nn.Linear(245, (self.circuits_to_use * self.circuit_in_dim))
        self.data_pipe_out = nn.Linear(298, self.circuit_in_dim)
        
        #define circuits
        self.circuits = []
        for i in range(self.n_circuits):
            self.circuits.append(nn.Sequential(
                nn.Conv1d(1, 8, kernel_size=2, stride=1, padding=1),
                nn.BatchNorm1d(8),
                nn.ReLU(True),
                #nn.MaxPool2d(kernel_size=2, stride=2)))
            
                nn.Conv1d(8, 16, kernel_size=2, stride=1, padding=1),
                nn.BatchNorm1d(16),
                nn.ReLU(True),
                
                nn.MaxPool1d(kernel_size=2, stride=2)))
        
        #connect circuits (this is temporary)
        self.connect = nn.Sequential(
            nn.Linear(1440, 28*28))
        
        #final circuit to prediction
        self.output = nn.Sequential(
            nn.Linear(144, num_classes))
        
    #connection input shape [batch_size,num_circuits,circuit_out_shape]
    def connection(self,x,verbose,circuit_dropout,circuit_ban_rate):
        
        x_pad = F.pad(x.transpose(1,2),(0,self.max_circuits-x.size(1))).transpose(1,2)
        x_flat = x_pad.contiguous().view(x_pad.size(0),1,-1)
        
        ###pick circuits##################
        #circuit_pick = self.circuit_picker(x_flat)
        #circuit_pick = circuit_pick.view(circuit_pick.size(0),1, -1)
        
        #circuit_pick, self.hidden = self.circuit_picker_out(circuit_pick,self.hidden)
        
        
        #circuit_pick = circuit_pick.view(circuit_pick.size(0),self.circuits_to_use,self.n_circuits)
        #circuit_pick = circuit_pick.max(2)[1]
        
        enc_in = x.transpose(0,1)
        enc_out = self.encoder(enc_in,[5]*batch_size)
        all_dec_out = []
        hidden_instructions = []
        for i in range(self.circuits_to_use):
            dec_out = self.decoder(enc_out)
            all_dec_out.append(dec_out[:,:,0:self.n_circuits])
            hidden_instructions.append(dec_out[:,:,self.n_circuits:])
            
        all_dec_out = torch.cat(all_dec_out,0).transpose(0,1)
        hidden_instructions = torch.cat(hidden_instructions,0).transpose(0,1)

        circuit_pick = all_dec_out
        circuit_pick = circuit_pick.max(2)[1]
        data_pipe_in = torch.cat((circuit_pick.unsqueeze(1).type(torch.cuda.FloatTensor),x_flat),2)
        
        #print(all_dec_out.shape)
        ###################################

        #data_pipe = self.data_pipe(data_pipe_in)
        #data_pipe = data_pipe.view(data_pipe.size(0), -1)
        #data_pipe = data_pipe_in.squeeze(1)
        #data_pipe = self.data_pipe_out(data_pipe)
        #data_pipe = data_pipe.view(data_pipe.size(0),self.circuits_to_use,self.circuit_in_dim)
        bans = np.random.randint(self.n_circuits,size=int(circuit_ban_rate*self.n_circuits))
        circuit_out = []
        for i in range(self.circuits_to_use):
            circuit_idx = circuit_pick[:,i]
            if(verbose):
                print(str(circuit_idx.cpu().numpy()))
                
            batch = []
            for b in range(batch_size):
                
                hid_instruct = hidden_instructions[b,i]
                #data = data_pipe[b][i]
                data_in = torch.cat((x_flat[b],hidden_instructions[b,i].unsqueeze(0)),1)
                data_pipe = self.data_pipe_out(data_in)
                #data = data.unsqueeze(0).unsqueeze(0)
                circuit = self.circuits[circuit_idx[b]](data_pipe.unsqueeze(0))
                #apply dropout
                if randint(1,100) < circuit_dropout*100 or circuit_idx[b] in bans:
                    circuit = torch.zeros(1,16,3,device=device)  
                
                batch.append(circuit)
            batch = torch.cat(batch,0)
            circuit_out.append(batch)
        if(verbose):
            print(" ")
        circuit_out = torch.cat(circuit_out,1).transpose(1,2)
        return circuit_out
    
    def forward(self, x, verbose,circuit_dropout=0.5,circuit_ban_rate = 0.5):
        #print(self.encoder.gru.named_parameters())
        x_flat = x.view(x.size(0),-1)
        init = self.init_layer(x_flat)
        
        init = init.view(init.size(0),5,init.size(1)/5)
        circuit_out1 = self.connection(init,verbose,circuit_dropout,circuit_ban_rate)
        #circuit_out_flat1 = circuit_out1.contiguous().view(circuit_out1.size(0), -1)
        self.encoder.hidden = None
        self.decoder.hidden = None
        
        circuit_out2 = self.connection(circuit_out1,verbose,circuit_dropout,circuit_ban_rate)
        circuit_out_flat2 = circuit_out2.contiguous().view(circuit_out2.size(0), -1)
        #print(circuit_out1.shape)
        #connect = self.connect(circuit_out_flat1)
        #connect = connect.view(connect.size(0),1,28,28)
        
       # circuit_out2 = self.connection(connect,verbose)
        #circuit_out_flat2 = circuit_out2.view(circuit_out2.size(0), -1)
        out = self.output(circuit_out_flat2)
        
        return out
    
        

if device != 'cpu':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
model = CircuitNet().to(device)
torch.set_default_tensor_type('torch.FloatTensor')

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device)[0:batch_size], target.to(device)[0:batch_size]
    output = model(data,False)
    break

In [4]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer2 = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()
iter_print = 50

for batch_idx, (data_, target_) in enumerate(train_loader):
    #print(target)
   
    
    
    #train on digits 1-6, then train on 7-9 to check adaptability
    for i in range(1000):
        try:
            
            rand = randint(0, 3000)
            indices = np.where(target_<6)[0][rand:rand+batch_size]
            #indices = 
            data = np.take(data_,indices,axis = 0)
            target = np.take(target_,indices)

           # print(np.where(target_<6)[0][0:batch_size])
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            verbose = False
            if(i % iter_print == 0):
                verbose = True
            output = model(data,verbose,circuit_dropout=0.7,circuit_ban_rate = 0.5)
            loss =  nn.CrossEntropyLoss()(output, target)
            #loss = (output-target)**2
            loss.backward()
            optimizer.step()

            model.hidden = None
            model.encoder.hidden = None
            model.decoder.hidden = None
            if(i % iter_print == 0):
                print("Index : {}  Loss : {:.5f}  Real : {}  Pred : {}".format(i,
                                                                               loss[0],target[0:4].cpu().numpy(),
                                                                               output[0:4].max(1)[1].cpu().numpy()))
        except:
            continue
    print("switch")
    for i in range(500):
        try:
            rand = randint(0, 3000)
            indices = np.where(target_>6)[0][rand:rand+batch_size]
            #indices = 
            data = np.take(data_,indices,axis = 0)
            target = np.take(target_,indices)

           # print(np.where(target_<6)[0][0:batch_size])
            datcircuit_out_flat2a, target = data.to(device), target.to(device)

            optimizer2.zero_grad()
            verbose = False
            if(batch_idx % iter_print == 0):
                verbose = False
            output = model(data,verbose)
            loss =  nn.CrossEntropyLoss()(output, target)
            #loss = (output-target)**2
            loss.backward()
            optimizer2.step()

            model.hidden = None
            model.encoder.hidden = None
            model.decoder.hidden = None
            if(i % 1 == 0):
                print("Index : {}  Loss : {:.5f}  Real : {}  Pred : {}".format(i,
                                                                               loss[0],target[0:4].cpu().numpy(),
                                                                               output[0:4].max(1)[1].cpu().numpy()))
        except:
            continue

[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 0  Loss : 2.56511  Real : [1 3 1 1]  Pred : [9 7 7 7]




[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 50  Loss : 2.23239  Real : [0 4 1 4]  Pred : [1 1 1 1]




[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 100  Loss : 1.76903  Real : [1 1 5 2]  Pred : [1 2 4 1]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 150  Loss : 2.06402  Real : [2 5 1 5]  Pred : [1 1 1 1]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 200  Loss : 1.82606  Real : [5 0 1 4]  Pred : [3 4 1 0]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 250  Loss : 1.67

[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 700  Loss : 1.77208  Real : [5 1 1 2]  Pred : [1 5 0 1]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 750  Loss : 1.79972  Real : [3 5 2 0]  Pred : [1 5 5 3]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 800  Loss : 1.79719  Real : [1 3 5 3]  Pred : [1 1 1 1]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 850  Loss : 1.46

[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 300  Loss : 2.14194  Real : [4 5 3 2]  Pred : [1 3 4 4]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 350  Loss : 1.41543  Real : [5 1 4 5]  Pred : [3 3 4 0]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 400  Loss : 1.80660  Real : [0 5 1 3]  Pred : [5 5 5 0]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 450  Loss : 1.80

[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 900  Loss : 1.79718  Real : [3 1 4 5]  Pred : [1 1 1 1]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 950  Loss : 1.52956  Real : [1 4 0 2]  Pred : [1 1 0 3]
switch
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 0  Loss : 1.78133  Real : [0 0 1 1]  Pred : [1 1 1 1]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
[8 8 8 8 8 8 8 8 8 8 8 8 8 8 8]
 
Index : 50  Loss : 

In [5]:
print("fsdg")

fsdg


In [None]:
+