In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import numpy as np
from random import randint

import matplotlib.pyplot as plt
%matplotlib inline

batch_size = 15
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.ToTensor()),
        batch_size = 10000, shuffle=True)

from random import randint

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if device != 'cpu':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

num_classes=10

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=3, dropout=0):
        super(EncoderRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout = dropout
        self.gru = nn.GRU(input_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True)
        self.hidden = None
        
    def forward(self, input_seqs, input_lengths):
       # input_seqs = input_seqs.type(torch.FloatTensor)
        packed = torch.nn.utils.rnn.pack_padded_sequence(input_seqs, input_lengths)
        outputs, self.hidden = self.gru(packed, self.hidden)
        outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs 
        return outputs

class DecoderRNN(nn.Module):
    def __init__(self,output_size,hidden_size,enc_hidden_size,n_layers=3):
        super(DecoderRNN, self).__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.enc_hidden_size = enc_hidden_size

        self.gru = nn.GRU(enc_hidden_size, hidden_size, n_layers, bidirectional=True)
        self.hidden = None
       
        self.concat = nn.Linear(hidden_size*2,hidden_size)
        self.out = nn.Linear(hidden_size,output_size)
        
        
    def forward(self, encoder_outputs):
        rnn_in = encoder_outputs[-1].unsqueeze(0)
           
        outputs, self.hidden = self.gru(rnn_in, self.hidden)
        outputs = self.concat(outputs)
        out = self.out(outputs)
        
        return F.softmax(out)

In [3]:
class CircuitNet(nn.Module):
    def __init__(self,num_classes=12,num_lstm_layers=3):
        super(CircuitNet, self).__init__()
        
        self.num_lstm_layers = num_lstm_layers

        self.circuits_to_use = 3
        self.n_circuits = 10
        self.circuit_in_dim = 4
        
        self.circuit_out_shape = 16*3
        self.init_layer =  nn.Linear(28*28, self.circuit_out_shape*5)
        
        #circuit picker able to send instructions to data pipe
        self.hidden_data_num = 10
        
        hidden_size = 3
        self.encoder = EncoderRNN(self.circuit_out_shape,hidden_size,n_layers = 3)
        self.decoder = DecoderRNN(self.n_circuits+self.hidden_data_num,hidden_size,hidden_size,n_layers = 3)
        
        self.circuit_picker = nn.Sequential(
            nn.Conv1d(1, 16, 3, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(16),
            nn.MaxPool1d(2, stride=2),  
            nn.Conv1d(16, 8, 3, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(8),
            nn.MaxPool1d(2, stride=2)
        )
        
        self.circuit_picker_out = nn.LSTM(1568,(self.n_circuits *  self.circuits_to_use), self.num_lstm_layers)
        self.hidden = None
        
        #data pipe
        self.data_pipe = nn.Sequential(
            nn.Conv1d(1, 4, 4, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(4),
            nn.Conv1d(4, 8, 4, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(8),
            nn.MaxPool1d(2, stride=2), 
            
            nn.Conv1d(8, 4, 4, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(4),
            nn.Conv1d(4, 8, 4, stride=1, padding=1),  
            nn.ReLU(True),
            nn.BatchNorm1d(8),
            nn.MaxPool1d(2, stride=2)
        )
        
        #self.data_pipe_out = nn.Linear(245, (self.circuits_to_use * self.circuit_in_dim))
        self.data_pipe_out = nn.Linear(250, self.circuit_in_dim)
        
        #define circuits
        self.circuits = []
        for i in range(self.n_circuits):
            self.circuits.append(nn.Sequential(
                nn.Conv1d(1, 8, kernel_size=2, stride=1, padding=1),
                nn.BatchNorm1d(8),
                nn.ReLU(True),
                #nn.MaxPool2d(kernel_size=2, stride=2)))
            
                nn.Conv1d(8, 16, kernel_size=2, stride=1, padding=1),
                nn.BatchNorm1d(16),
                nn.ReLU(True),
                
                nn.MaxPool1d(kernel_size=2, stride=2)))
        
        #connect circuits (this is temporary)
        self.connect = nn.Sequential(
            nn.Linear(1440, 28*28))
        
        #final circuit to prediction
        self.output = nn.Sequential(
            nn.Linear(144, num_classes))
        
    #connection input shape [batch_size,num_circuits,circuit_out_shape]
    def connection(self,x,verbose,circuit_dropout,circuit_ban_rate):
        x_flat = x.view(x.size(0),1,-1)
        ###pick circuits##################
        #circuit_pick = self.circuit_picker(x_flat)
        #circuit_pick = circuit_pick.view(circuit_pick.size(0),1, -1)
        
        #circuit_pick, self.hidden = self.circuit_picker_out(circuit_pick,self.hidden)
        
        
        #circuit_pick = circuit_pick.view(circuit_pick.size(0),self.circuits_to_use,self.n_circuits)
        #circuit_pick = circuit_pick.max(2)[1]
        
        enc_in = x.transpose(0,1)
        enc_out = self.encoder(enc_in,[5]*batch_size)
        all_dec_out = []
        hidden_instructions = []
        for i in range(self.circuits_to_use):
            dec_out = self.decoder(enc_out)
            all_dec_out.append(dec_out[:,:,0:self.n_circuits])
            hidden_instructions.append(dec_out[:,:,self.n_circuits:])
            
        all_dec_out = torch.cat(all_dec_out,0).transpose(0,1)
        hidden_instructions = torch.cat(hidden_instructions,0).transpose(0,1)

        circuit_pick = all_dec_out
        circuit_pick = circuit_pick.max(2)[1]
        data_pipe_in = torch.cat((circuit_pick.unsqueeze(1).type(torch.cuda.FloatTensor),x_flat),2)
        
        #print(all_dec_out.shape)
        ###################################

        #data_pipe = self.data_pipe(data_pipe_in)
        #data_pipe = data_pipe.view(data_pipe.size(0), -1)
        #data_pipe = data_pipe_in.squeeze(1)
        #data_pipe = self.data_pipe_out(data_pipe)
        #data_pipe = data_pipe.view(data_pipe.size(0),self.circuits_to_use,self.circuit_in_dim)
        bans = np.random.randint(self.n_circuits,size=int(circuit_ban_rate*self.n_circuits))
        circuit_out = []
        for i in range(self.circuits_to_use):
            circuit_idx = circuit_pick[:,i]
            if(verbose):
                print(str(circuit_idx.cpu().numpy()))
                
            batch = []
            for b in range(batch_size):
                
                hid_instruct = hidden_instructions[b,i]
                #data = data_pipe[b][i]
                data_in = torch.cat((x_flat[b],hidden_instructions[b,i].unsqueeze(0)),1)
                data_pipe = self.data_pipe_out(data_in)
                #data = data.unsqueeze(0).unsqueeze(0)
                circuit = self.circuits[circuit_idx[b]](data_pipe.unsqueeze(0))
                #apply dropout
                if randint(1,100) < circuit_dropout*100 or circuit_idx[b] in bans:
                    circuit = torch.zeros(1,16,3,device=device)  
                
                batch.append(circuit)
            batch = torch.cat(batch,0)
            circuit_out.append(batch)
        
        circuit_out = torch.cat(circuit_out,1)
        
        return circuit_out
    
    def forward(self, x, verbose,circuit_dropout=0.5,circuit_ban_rate = 0.5):
        x_flat = x.view(x.size(0),-1)
        init = self.init_layer(x_flat)
        
        init = init.view(init.size(0),5,init.size(1)/5)
        circuit_out1 = self.connection(init,verbose,circuit_dropout,circuit_ban_rate)
        circuit_out_flat1 = circuit_out1.view(circuit_out1.size(0), -1)
        
        #connect = self.connect(circuit_out_flat1)
        #connect = connect.view(connect.size(0),1,28,28)
        
       # circuit_out2 = self.connection(connect,verbose)
        #circuit_out_flat2 = circuit_out2.view(circuit_out2.size(0), -1)
        out = self.output(circuit_out_flat1)
        
        return out
    
        

if device != 'cpu':
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
model = CircuitNet().to(device)
torch.set_default_tensor_type('torch.FloatTensor')

for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device)[0:batch_size], target.to(device)[0:batch_size]
    output = model(data,False)
    break



In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer2 = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()
iter_print = 50

for batch_idx, (data_, target_) in enumerate(train_loader):
    #print(target)
   
    
    
    #train on digits 1-6, then train on 7-9 to check adaptability
    for i in range(1000):
        try:
            rand = randint(0, 3000)
            indices = np.where(target_<6)[0][rand:rand+batch_size]
            #indices = 
            data = np.take(data_,indices,axis = 0)
            target = np.take(target_,indices)

           # print(np.where(target_<6)[0][0:batch_size])
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            verbose = False
            if(i % iter_print == 0):
                verbose = True
            output = model(data,verbose,circuit_dropout=0.7,circuit_ban_rate = 0.5)
            loss =  nn.CrossEntropyLoss()(output, target)
            #loss = (output-target)**2
            loss.backward()
            optimizer.step()

            model.hidden = None
            model.encoder.hidden = None
            model.decoder.hidden = None
            if(i % iter_print == 0):
                print("Index : {}  Loss : {:.5f}  Real : {}  Pred : {}".format(i,
                                                                               loss[0],target[0:4].cpu().numpy(),
                                                                               output[0:4].max(1)[1].cpu().numpy()))
        except:
            continue
    print("switch")
    for i in range(500):
        try:
            rand = randint(0, 3000)
            indices = np.where(target_>6)[0][rand:rand+batch_size]
            #indices = 
            data = np.take(data_,indices,axis = 0)
            target = np.take(target_,indices)

           # print(np.where(target_<6)[0][0:batch_size])
            data, target = data.to(device), target.to(device)

            optimizer2.zero_grad()
            verbose = False
            if(batch_idx % iter_print == 0):
                verbose = False
            output = model(data,verbose)
            loss =  nn.CrossEntropyLoss()(output, target)
            #loss = (output-target)**2
            loss.backward()
            optimizer2.step()

            model.hidden = None
            model.encoder.hidden = None
            model.decoder.hidden = None
            if(i % 1 == 0):
                print("Index : {}  Loss : {:.5f}  Real : {}  Pred : {}".format(i,
                                                                               loss[0],target[0:4].cpu().numpy(),
                                                                               output[0:4].max(1)[1].cpu().numpy()))
        except:
            continue



[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 0  Loss : 2.48238  Real : [5 0 3 2]  Pred : [2 2 2 2]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 50  Loss : 1.39711  Real : [2 5 3 3]  Pred : [2 0 3 2]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 100  Loss : 2.09663  Real : [1 5 2 1]  Pred : [1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 150  Loss : 1.94991  Real : [2 0 0 2]  Pred : [1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 200  Loss : 1.14499  Real : [1 0 0 4]  Pred : [1 1 0 4]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 250  Loss : 0.88265  Real : [4 4 2 5]  Pred : [4 4 1 3]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 



Index : 3  Loss : 5.72577  Real : [7 7 7 7]  Pred : [1 1 1 1]
Index : 4  Loss : 11.02013  Real : [7 7 7 8]  Pred : [5 5 1 5]
Index : 5  Loss : 11.62502  Real : [8 8 9 7]  Pred : [2 0 2 5]
Index : 6  Loss : 5.68566  Real : [8 9 9 8]  Pred : [1 1 1 1]
Index : 7  Loss : 9.22644  Real : [9 7 8 9]  Pred : [4 0 1 4]
Index : 8  Loss : 11.04920  Real : [7 9 7 7]  Pred : [0 2 5 0]
Index : 9  Loss : 8.55365  Real : [9 9 8 7]  Pred : [2 5 0 4]
Index : 10  Loss : 9.34093  Real : [9 7 9 8]  Pred : [5 0 5 2]
Index : 11  Loss : 7.68800  Real : [9 9 7 8]  Pred : [0 0 2 2]
Index : 12  Loss : 7.83788  Real : [8 8 8 7]  Pred : [2 1 3 5]
Index : 13  Loss : 8.00094  Real : [8 8 9 9]  Pred : [0 2 5 1]
Index : 14  Loss : 5.52204  Real : [7 7 7 7]  Pred : [1 1 1 1]
Index : 15  Loss : 6.75434  Real : [8 8 7 9]  Pred : [2 2 2 4]
Index : 16  Loss : 5.50901  Real : [8 9 9 9]  Pred : [5 0 2 2]
Index : 17  Loss : 5.18859  Real : [8 7 9 8]  Pred : [5 4 5 2]
Index : 18  Loss : 4.58444  Real : [9 8 8 8]  Pred : [4 2 2

Index : 134  Loss : 0.88085  Real : [9 8 7 8]  Pred : [7 7 7 8]
Index : 135  Loss : 4.07781  Real : [8 8 7 9]  Pred : [2 2 2 2]
Index : 136  Loss : 0.52408  Real : [7 8 8 8]  Pred : [7 8 7 8]
Index : 137  Loss : 4.04365  Real : [8 8 9 7]  Pred : [2 2 2 2]
Index : 139  Loss : 1.48989  Real : [7 7 7 9]  Pred : [9 7 2 7]
Index : 140  Loss : 4.01397  Real : [8 9 8 7]  Pred : [2 2 2 2]
Index : 141  Loss : 3.99288  Real : [9 8 9 7]  Pred : [2 2 2 2]
Index : 142  Loss : 3.98286  Real : [9 7 8 7]  Pred : [2 2 2 2]
Index : 143  Loss : 3.95444  Real : [8 7 9 8]  Pred : [2 2 2 2]
Index : 144  Loss : 1.51564  Real : [8 7 9 7]  Pred : [2 9 2 9]
Index : 145  Loss : 3.92396  Real : [8 9 9 7]  Pred : [2 2 2 2]
Index : 146  Loss : 0.94581  Real : [8 8 8 7]  Pred : [8 8 8 7]
Index : 147  Loss : 0.54491  Real : [9 8 9 7]  Pred : [9 8 7 7]
Index : 148  Loss : 3.86765  Real : [7 7 7 9]  Pred : [2 2 2 2]
Index : 149  Loss : 0.97934  Real : [9 8 8 8]  Pred : [9 7 8 8]
Index : 150  Loss : 1.21888  Real : [7 7

Index : 272  Loss : 2.63571  Real : [7 7 9 9]  Pred : [2 2 2 2]
Index : 273  Loss : 2.62619  Real : [8 9 8 9]  Pred : [2 2 2 2]
Index : 274  Loss : 0.90198  Real : [8 7 9 7]  Pred : [8 7 8 9]
Index : 275  Loss : 2.59226  Real : [7 7 7 8]  Pred : [2 2 2 2]
Index : 276  Loss : 0.59561  Real : [9 7 9 8]  Pred : [9 7 9 2]
Index : 277  Loss : 0.88151  Real : [7 8 8 9]  Pred : [7 8 8 9]
Index : 278  Loss : 0.46737  Real : [9 7 8 8]  Pred : [9 7 8 8]
Index : 279  Loss : 0.73482  Real : [9 9 9 8]  Pred : [9 9 9 8]
Index : 281  Loss : 0.54682  Real : [8 9 9 7]  Pred : [8 9 9 7]
Index : 282  Loss : 2.54426  Real : [8 7 9 7]  Pred : [2 2 2 2]
Index : 283  Loss : 0.70749  Real : [8 8 8 9]  Pred : [8 8 8 7]
Index : 284  Loss : 0.73131  Real : [9 7 7 8]  Pred : [9 7 7 8]
Index : 285  Loss : 2.51809  Real : [7 7 7 9]  Pred : [2 2 2 2]
Index : 286  Loss : 0.51377  Real : [8 9 8 8]  Pred : [8 9 8 8]
Index : 287  Loss : 0.44758  Real : [9 7 7 9]  Pred : [9 7 9 9]
Index : 288  Loss : 2.49860  Real : [9 7

Index : 408  Loss : 0.35967  Real : [9 7 9 8]  Pred : [9 7 9 8]
Index : 409  Loss : 0.19912  Real : [8 9 8 8]  Pred : [8 9 8 8]
Index : 410  Loss : 0.46132  Real : [9 8 9 8]  Pred : [9 8 9 8]
Index : 411  Loss : 1.74813  Real : [8 9 8 7]  Pred : [7 7 7 7]
Index : 412  Loss : 0.33227  Real : [7 9 8 8]  Pred : [7 9 8 8]
Index : 413  Loss : 1.70977  Real : [9 7 8 7]  Pred : [7 7 7 7]
Index : 414  Loss : 0.21484  Real : [9 8 8 9]  Pred : [9 8 8 9]
Index : 415  Loss : 1.71860  Real : [9 7 7 9]  Pred : [7 7 7 7]
Index : 416  Loss : 1.73236  Real : [8 9 8 8]  Pred : [7 7 7 7]
Index : 417  Loss : 1.72049  Real : [7 9 9 8]  Pred : [7 7 7 7]
Index : 418  Loss : 0.15885  Real : [9 8 8 9]  Pred : [9 8 8 7]
Index : 419  Loss : 0.15554  Real : [7 7 7 8]  Pred : [7 7 7 8]
Index : 420  Loss : 0.31697  Real : [7 8 7 8]  Pred : [7 8 7 8]
Index : 421  Loss : 0.53848  Real : [7 9 7 9]  Pred : [7 9 7 9]
Index : 422  Loss : 0.59548  Real : [9 7 7 9]  Pred : [7 7 7 9]
Index : 423  Loss : 1.68204  Real : [8 9

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 800  Loss : 1.81500  Real : [5 1 1 1]  Pred : [1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 850  Loss : 0.93007  Real : [5 3 5 4]  Pred : [1 4 5 4]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 900  Loss : 1.17456  Real : [0 1 2 1]  Pred : [0 1 2 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 950  Loss : 1.80345  Real : [5 1 1 1]  Pred : [1 1 1 1]
switch
Index : 0  Loss : 5.36106  Real : [7 9 7 8]  Pred : [1 1 1 1]
Index : 1  Loss : 5.35783  Real : [8 7 8 9]  Pred : [1 1 1 1]
Index : 2  Loss : 5.35257  Real : [7 7 8 7]  Pred : [1 1 1 1]
Index : 3  Loss : 22.48224  Real : [8 9 9 9]  Pred : [1 4 4 1]
Index : 4  Loss : 5.26686  Real : [7 8 7 7]  Pred : [1 1 1 1]
Index : 5  Loss : 19.89157  Real : [8 7 7 

Index : 123  Loss : 0.46207  Real : [7 8 8 9]  Pred : [7 8 8 9]
Index : 124  Loss : 3.19601  Real : [9 7 9 7]  Pred : [2 2 2 2]
Index : 125  Loss : 0.68891  Real : [8 7 9 9]  Pred : [8 7 9 8]
Index : 126  Loss : 0.57535  Real : [8 9 9 9]  Pred : [8 2 9 9]
Index : 127  Loss : 0.17102  Real : [7 8 7 7]  Pred : [7 8 7 7]
Index : 128  Loss : 3.17397  Real : [9 7 7 9]  Pred : [2 2 2 2]
Index : 129  Loss : 0.26855  Real : [8 9 7 7]  Pred : [8 9 2 7]
Index : 130  Loss : 3.09410  Real : [8 7 9 8]  Pred : [2 2 2 2]
Index : 131  Loss : 3.12542  Real : [8 8 9 9]  Pred : [2 2 2 2]
Index : 132  Loss : 3.13489  Real : [8 9 9 8]  Pred : [2 2 2 2]
Index : 133  Loss : 0.69096  Real : [8 7 8 8]  Pred : [8 7 8 8]
Index : 134  Loss : 1.74983  Real : [7 7 8 8]  Pred : [2 9 2 8]
Index : 135  Loss : 0.72786  Real : [9 7 8 7]  Pred : [9 7 8 9]
Index : 136  Loss : 0.59433  Real : [8 7 8 8]  Pred : [8 7 8 8]
Index : 137  Loss : 0.48925  Real : [8 9 9 9]  Pred : [8 9 9 9]
Index : 138  Loss : 0.75846  Real : [9 8

Index : 256  Loss : 0.69619  Real : [9 9 9 7]  Pred : [9 9 9 7]
Index : 257  Loss : 0.66350  Real : [8 8 9 8]  Pred : [8 8 9 8]
Index : 258  Loss : 0.32397  Real : [7 8 8 7]  Pred : [7 8 8 7]
Index : 259  Loss : 1.96694  Real : [8 7 9 7]  Pred : [8 8 8 8]
Index : 260  Loss : 0.42018  Real : [8 9 8 8]  Pred : [8 9 8 8]
Index : 261  Loss : 0.58373  Real : [9 8 7 7]  Pred : [9 8 8 8]
Index : 262  Loss : 1.93487  Real : [7 9 7 9]  Pred : [8 8 8 8]
Index : 263  Loss : 0.31060  Real : [8 7 9 9]  Pred : [8 7 9 9]
Index : 264  Loss : 1.95927  Real : [7 9 7 8]  Pred : [8 8 8 8]
Index : 265  Loss : 0.81086  Real : [7 9 7 7]  Pred : [7 9 8 8]
Index : 266  Loss : 1.91503  Real : [8 8 9 9]  Pred : [8 8 8 8]
Index : 267  Loss : 1.95338  Real : [7 7 7 9]  Pred : [8 8 8 8]
Index : 268  Loss : 0.90599  Real : [9 7 7 8]  Pred : [9 7 7 8]
Index : 269  Loss : 1.93211  Real : [8 7 9 9]  Pred : [8 8 8 8]
Index : 270  Loss : 1.92592  Real : [9 9 7 7]  Pred : [8 8 8 8]
Index : 271  Loss : 0.47060  Real : [7 7

Index : 385  Loss : 1.53223  Real : [7 9 9 8]  Pred : [8 8 8 8]
Index : 386  Loss : 1.52035  Real : [9 8 8 9]  Pred : [8 8 8 8]
Index : 387  Loss : 0.85007  Real : [7 8 9 7]  Pred : [7 8 9 8]
Index : 388  Loss : 1.50703  Real : [8 8 7 8]  Pred : [8 8 8 8]
Index : 389  Loss : 0.26542  Real : [9 8 8 9]  Pred : [9 8 8 9]
Index : 390  Loss : 0.42285  Real : [7 9 7 9]  Pred : [7 9 7 9]
Index : 391  Loss : 1.51615  Real : [8 9 7 8]  Pred : [8 8 8 8]
Index : 392  Loss : 0.80069  Real : [9 7 7 9]  Pred : [9 9 7 8]
Index : 393  Loss : 1.53898  Real : [8 9 8 7]  Pred : [8 8 8 8]
Index : 394  Loss : 0.79891  Real : [9 9 7 8]  Pred : [9 8 7 8]
Index : 395  Loss : 0.27961  Real : [9 9 9 7]  Pred : [9 9 9 9]
Index : 396  Loss : 1.54851  Real : [7 7 9 7]  Pred : [8 8 8 8]
Index : 397  Loss : 1.52978  Real : [8 9 8 9]  Pred : [8 8 8 8]
Index : 398  Loss : 1.52330  Real : [8 7 7 7]  Pred : [8 8 8 8]
Index : 399  Loss : 1.49880  Real : [8 9 7 8]  Pred : [8 8 8 8]
Index : 400  Loss : 0.33817  Real : [9 8

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 300  Loss : 0.63850  Real : [4 0 0 1]  Pred : [5 0 0 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 350  Loss : 0.90354  Real : [0 0 1 4]  Pred : [3 0 1 3]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 400  Loss : 0.50808  Real : [3 5 5 3]  Pred : [3 5 5 3]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 450  Loss : 1.83795  Real : [5 4 1 0]  Pred : [0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 500  Loss : 1.81151  Real : [1 3 1 5]  Pred : [3 3 3 3]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 550  Loss : 1.76534  Real : [3 3 1 1]  Pred : [1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0

Index : 97  Loss : 0.37701  Real : [9 8 9 7]  Pred : [9 8 9 7]
Index : 98  Loss : 0.88387  Real : [8 9 9 7]  Pred : [1 9 9 7]
Index : 99  Loss : 0.62569  Real : [7 7 7 7]  Pred : [7 7 7 7]
Index : 100  Loss : 3.11705  Real : [7 8 7 8]  Pred : [1 1 1 1]
Index : 101  Loss : 3.25260  Real : [8 7 8 9]  Pred : [1 1 1 1]
Index : 102  Loss : 3.17938  Real : [8 8 7 8]  Pred : [1 1 1 1]
Index : 103  Loss : 3.13288  Real : [8 8 9 9]  Pred : [1 1 1 1]
Index : 104  Loss : 0.83351  Real : [7 9 9 7]  Pred : [1 9 9 7]
Index : 105  Loss : 0.42244  Real : [9 8 8 8]  Pred : [1 8 8 8]
Index : 106  Loss : 3.11533  Real : [9 9 8 8]  Pred : [1 1 1 1]
Index : 107  Loss : 0.75024  Real : [7 8 7 9]  Pred : [7 8 7 9]
Index : 108  Loss : 3.11635  Real : [9 9 7 7]  Pred : [1 1 1 1]
Index : 109  Loss : 3.00664  Real : [7 8 8 7]  Pred : [1 1 1 1]
Index : 110  Loss : 2.98963  Real : [7 8 7 9]  Pred : [1 1 1 1]
Index : 111  Loss : 0.74049  Real : [8 9 9 9]  Pred : [1 9 9 9]
Index : 112  Loss : 1.42668  Real : [8 9 7 

Index : 227  Loss : 0.44673  Real : [7 7 7 9]  Pred : [7 7 8 9]
Index : 228  Loss : 0.77826  Real : [7 8 7 9]  Pred : [8 8 7 8]
Index : 229  Loss : 0.44425  Real : [7 8 8 9]  Pred : [7 8 8 9]
Index : 230  Loss : 0.36271  Real : [9 9 8 7]  Pred : [9 9 8 7]
Index : 231  Loss : 0.57718  Real : [9 7 9 7]  Pred : [9 7 9 7]
Index : 232  Loss : 0.25299  Real : [8 8 8 8]  Pred : [8 8 8 8]
Index : 233  Loss : 2.05561  Real : [9 9 7 7]  Pred : [8 8 8 8]
Index : 234  Loss : 2.00106  Real : [9 9 7 9]  Pred : [8 8 8 8]
Index : 235  Loss : 0.60105  Real : [7 9 8 8]  Pred : [7 7 8 8]
Index : 236  Loss : 1.94163  Real : [8 8 7 9]  Pred : [8 8 8 8]
Index : 237  Loss : 0.45044  Real : [7 8 8 9]  Pred : [7 8 8 9]
Index : 238  Loss : 2.00142  Real : [9 9 8 7]  Pred : [8 8 8 8]
Index : 239  Loss : 0.07134  Real : [7 8 7 7]  Pred : [7 8 7 7]
Index : 240  Loss : 0.51601  Real : [8 7 9 7]  Pred : [8 7 9 7]
Index : 241  Loss : 1.86914  Real : [7 9 8 7]  Pred : [8 8 8 8]
Index : 242  Loss : 0.65893  Real : [8 7

Index : 356  Loss : 1.51972  Real : [8 9 9 7]  Pred : [8 8 8 8]
Index : 357  Loss : 1.52097  Real : [7 7 8 9]  Pred : [8 8 8 8]
Index : 358  Loss : 0.14167  Real : [9 8 7 8]  Pred : [9 8 7 8]
Index : 359  Loss : 1.54445  Real : [9 9 8 7]  Pred : [8 8 8 8]
Index : 360  Loss : 1.51984  Real : [7 7 9 8]  Pred : [8 8 8 8]
Index : 361  Loss : 0.31572  Real : [7 8 7 7]  Pred : [7 8 7 7]
Index : 362  Loss : 1.55477  Real : [9 9 7 9]  Pred : [8 8 8 8]
Index : 363  Loss : 1.53643  Real : [9 7 7 7]  Pred : [8 8 8 8]
Index : 364  Loss : 0.31601  Real : [7 8 8 7]  Pred : [7 8 8 7]
Index : 365  Loss : 0.33404  Real : [9 8 9 8]  Pred : [9 8 7 8]
Index : 366  Loss : 0.23010  Real : [7 8 7 8]  Pred : [8 8 7 8]
Index : 367  Loss : 1.52694  Real : [9 7 9 7]  Pred : [8 8 8 8]
Index : 368  Loss : 1.52691  Real : [9 8 9 9]  Pred : [8 8 8 8]
Index : 369  Loss : 0.46213  Real : [9 8 7 9]  Pred : [9 8 7 9]
Index : 370  Loss : 0.14429  Real : [9 8 8 9]  Pred : [9 8 8 9]
Index : 371  Loss : 0.21856  Real : [8 7

Index : 487  Loss : 0.23055  Real : [8 7 9 8]  Pred : [8 7 9 8]
Index : 488  Loss : 0.36613  Real : [8 8 8 7]  Pred : [8 8 8 7]
Index : 489  Loss : 0.20589  Real : [8 9 7 7]  Pred : [8 9 7 7]
Index : 490  Loss : 0.19896  Real : [7 7 7 7]  Pred : [7 7 7 7]
Index : 491  Loss : 1.33280  Real : [9 8 7 9]  Pred : [8 8 8 8]
Index : 492  Loss : 1.32647  Real : [7 9 7 9]  Pred : [8 8 8 8]
Index : 493  Loss : 0.17210  Real : [8 7 8 9]  Pred : [8 7 8 9]
Index : 494  Loss : 1.32870  Real : [7 7 8 9]  Pred : [8 8 8 8]
Index : 495  Loss : 0.30332  Real : [9 7 8 7]  Pred : [9 9 8 7]
Index : 496  Loss : 0.39655  Real : [7 8 9 8]  Pred : [7 8 9 8]
Index : 497  Loss : 1.31325  Real : [8 7 9 7]  Pred : [8 8 8 8]
Index : 498  Loss : 0.24211  Real : [8 8 7 7]  Pred : [8 8 7 7]
Index : 499  Loss : 0.22211  Real : [9 9 7 7]  Pred : [8 9 9 7]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
Index : 0  Loss : 3.41379  Real : [4 1 5 0]  Pred : [8 8 8 8]
[0 0 0 0 0

Index : 69  Loss : 0.64815  Real : [9 9 8 8]  Pred : [9 9 8 8]
Index : 70  Loss : 0.06988  Real : [8 7 9 7]  Pred : [8 7 9 7]
Index : 71  Loss : 2.13248  Real : [7 7 9 9]  Pred : [8 9 9 7]
Index : 72  Loss : 0.62115  Real : [8 7 7 9]  Pred : [8 9 7 9]
Index : 73  Loss : 1.03083  Real : [9 8 9 7]  Pred : [1 8 9 7]
Index : 74  Loss : 0.58310  Real : [8 8 7 8]  Pred : [1 8 7 8]
Index : 75  Loss : 3.74276  Real : [9 7 7 9]  Pred : [1 1 1 1]
Index : 76  Loss : 3.69508  Real : [7 9 7 7]  Pred : [1 1 1 1]
Index : 78  Loss : 1.08612  Real : [9 8 9 8]  Pred : [9 9 9 8]
Index : 79  Loss : 0.41054  Real : [8 9 9 7]  Pred : [8 7 9 7]
Index : 80  Loss : 0.29520  Real : [7 9 7 8]  Pred : [7 9 7 8]
Index : 81  Loss : 0.29076  Real : [7 8 7 9]  Pred : [7 8 7 9]
Index : 82  Loss : 3.66992  Real : [9 9 9 9]  Pred : [1 1 1 1]
Index : 83  Loss : 0.08962  Real : [7 7 8 9]  Pred : [7 7 8 9]
Index : 84  Loss : 0.30706  Real : [9 7 7 9]  Pred : [9 7 7 9]
Index : 85  Loss : 0.70255  Real : [7 7 7 8]  Pred : [1

Index : 200  Loss : 0.34930  Real : [8 8 8 9]  Pred : [8 8 8 9]
Index : 201  Loss : 2.16908  Real : [9 9 8 8]  Pred : [8 8 8 8]
Index : 202  Loss : 0.63933  Real : [8 7 8 7]  Pred : [7 7 8 7]
Index : 203  Loss : 2.28009  Real : [7 7 9 7]  Pred : [8 8 8 8]
Index : 204  Loss : 0.21659  Real : [7 7 8 7]  Pred : [7 7 8 7]
Index : 205  Loss : 0.56669  Real : [8 8 7 7]  Pred : [8 8 7 7]
Index : 206  Loss : 2.16452  Real : [9 8 9 9]  Pred : [8 8 8 8]
Index : 207  Loss : 2.16213  Real : [8 9 8 9]  Pred : [8 8 8 8]
Index : 208  Loss : 2.11758  Real : [8 8 7 8]  Pred : [8 8 8 8]
Index : 209  Loss : 2.11668  Real : [8 7 8 9]  Pred : [8 8 8 8]
Index : 210  Loss : 0.38354  Real : [8 8 7 7]  Pred : [8 8 8 7]
Index : 211  Loss : 0.62983  Real : [8 9 7 9]  Pred : [8 9 8 9]
Index : 212  Loss : 2.13657  Real : [8 7 9 9]  Pred : [8 8 8 8]
Index : 213  Loss : 2.16035  Real : [7 9 9 9]  Pred : [8 8 8 8]
Index : 215  Loss : 2.03552  Real : [9 9 7 8]  Pred : [8 8 8 8]
Index : 216  Loss : 0.69992  Real : [8 7

Index : 337  Loss : 0.23920  Real : [8 7 9 8]  Pred : [8 7 9 8]
Index : 338  Loss : 0.39870  Real : [7 7 9 7]  Pred : [7 7 9 7]
Index : 339  Loss : 0.25578  Real : [8 8 7 7]  Pred : [8 8 7 8]
Index : 340  Loss : 0.33484  Real : [8 8 9 9]  Pred : [8 8 9 9]
Index : 341  Loss : 1.50736  Real : [7 7 7 7]  Pred : [8 8 8 8]
Index : 342  Loss : 0.16673  Real : [7 8 9 8]  Pred : [7 8 9 8]
Index : 343  Loss : 0.23819  Real : [9 8 8 7]  Pred : [9 8 8 7]
Index : 344  Loss : 1.34242  Real : [8 8 8 9]  Pred : [9 8 8 9]
Index : 345  Loss : 0.41245  Real : [9 9 7 8]  Pred : [9 9 7 8]
Index : 346  Loss : 1.50709  Real : [7 9 7 8]  Pred : [8 8 8 8]
Index : 347  Loss : 0.39064  Real : [9 9 8 9]  Pred : [7 9 8 9]
Index : 348  Loss : 0.64743  Real : [8 9 9 8]  Pred : [8 8 9 8]
Index : 349  Loss : 0.37522  Real : [8 7 7 8]  Pred : [8 7 7 8]
Index : 350  Loss : 1.49184  Real : [7 9 8 9]  Pred : [8 8 8 8]
Index : 351  Loss : 1.45891  Real : [8 9 8 9]  Pred : [8 8 8 8]
Index : 352  Loss : 0.16192  Real : [8 7

Index : 471  Loss : 0.33971  Real : [7 7 7 9]  Pred : [7 7 7 9]
Index : 472  Loss : 1.32264  Real : [8 7 9 7]  Pred : [8 8 8 8]
Index : 473  Loss : 1.32150  Real : [7 7 9 8]  Pred : [8 8 8 8]
Index : 474  Loss : 0.24692  Real : [8 7 9 7]  Pred : [8 7 9 7]
Index : 475  Loss : 1.32056  Real : [9 7 9 9]  Pred : [8 8 8 8]
Index : 476  Loss : 1.31839  Real : [7 8 8 9]  Pred : [8 8 8 8]
Index : 477  Loss : 0.31488  Real : [9 8 7 7]  Pred : [9 8 8 7]
Index : 479  Loss : 0.38334  Real : [8 8 9 7]  Pred : [8 8 9 7]
Index : 480  Loss : 0.12988  Real : [9 8 7 9]  Pred : [9 8 9 9]
Index : 481  Loss : 0.68955  Real : [8 7 8 9]  Pred : [8 8 8 9]
Index : 482  Loss : 0.39755  Real : [8 7 9 7]  Pred : [8 8 9 7]
Index : 483  Loss : 1.31488  Real : [9 9 9 9]  Pred : [8 8 8 8]
Index : 484  Loss : 1.31192  Real : [7 7 9 9]  Pred : [8 8 8 8]
Index : 485  Loss : 0.45003  Real : [9 7 8 8]  Pred : [9 7 8 8]
Index : 486  Loss : 1.31204  Real : [7 8 7 8]  Pred : [8 8 8 8]
Index : 487  Loss : 0.35404  Real : [7 7

In [None]:
print("fsdg")