In [40]:
import torch, torchvision
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.autograd import Variable as var 
import torch.nn.functional as F
from torchvision import transforms
from torch.utils import data

In [41]:
# training params
n_batch = 64
learning_rate = 0.002
n_epoch = 15
n_print = 10
dropout_p = 0.75
log_interval = 1 # epochs
num_hidden_units = 50
num_classes = 10 # MNIST
decay_rate = 0.9999
max_grad_norm = 5.0

# aux training params
aux_n_batch = 64
aux_learning_rate = 0.001
aux_n_epoch = 200
aux_n_print = 10
dropout_p = 0.75
aux_log_interval = 1 # epochs
aux_train_path = "E:/#AI&CV camp/#Project/#1Protocols/Protocols/dataset/train/"
aux_test_path = "E:/#AI&CV camp/#Project/#1Protocols/Protocols/dataset/test/"

In [42]:
#### Part I : Loading Your Data 
transform = transforms.Compose([ 
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),  
])

train_data = torchvision.datasets.MNIST('mnist_data',train=True,download=True,transform=transform)
val_data = torchvision.datasets.MNIST('mnist_data',train=False,download=True,transform=transform)

train_dl = torch.utils.data.DataLoader(train_data,batch_size = n_batch)
val_dl = torch.utils.data.DataLoader(val_data,batch_size = n_batch)

aux_train_data = torchvision.datasets.ImageFolder(root=aux_train_path, transform=transform)
aux_train_dl = data.DataLoader(aux_train_data, batch_size=aux_n_batch, shuffle=True,  num_workers=4)

aux_test_data = torchvision.datasets.ImageFolder(root=aux_test_path, transform=transform)
aux_test_dl = data.DataLoader(aux_test_data, batch_size=aux_n_batch, shuffle=True, num_workers=4) 

In [43]:
#### Part II : Writing the Network
class myCNN(nn.Module):
    def __init__(self):
        super(myCNN,self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        #self.dropout_conv2 = nn.Dropout2d(dropout_p)
        self.fc1 = nn.Linear(1024, num_hidden_units)
        self.fc2 = nn.Linear(num_hidden_units, num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 1024)  # flatten
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        
        return x   

In [None]:
#### Part III : Writing the main Training loop

mycnn = myCNN().cuda()
cec = nn.CrossEntropyLoss()

optimizer = optim.Adadelta(mycnn.parameters(), lr = 1)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.3)

def validate(model, data):
  # To get validation accuracy = (correct/total)*100.
    total = 0
    correct = 0
    model.eval()
    for i, (images, labels) in enumerate(data):
        images = var(images.cuda())
        x = model(images)
        value, pred = torch.max(x, 1)
        pred = pred.data.cpu()
        total += x.size(0)
        correct += torch.sum(pred == labels)
    return correct*100./total

for e in range(n_epoch):
    for i, (images,labels) in enumerate(train_dl): 
        mycnn.train()
        images = var(images.cuda())
        labels = var(labels.cuda())
        optimizer.zero_grad()
        pred = mycnn(images)
        loss = cec(pred,labels)
        loss.backward()
        optimizer.step()

        if (i+1) % n_print == 0:
          accuracy = float(validate(mycnn, aux_test_dl))
          print('Epoch :', e+1, 'Batch :', i+1, 'Loss :', float(loss.data), 'Accuracy :', accuracy,'%')
#     scheduler.step()
    
### Auxiliary training on data extracted from protocols

optimizer = optim.Adadelta(mycnn.parameters(), lr = aux_learning_rate)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.7)

for e in range(n_epoch, n_epoch+aux_n_epoch):
    for i, (images,labels) in enumerate(aux_train_dl):
        mycnn.train()
        images = var(images.cuda())
        labels = var(labels.cuda())
        optimizer.zero_grad()
        pred = mycnn(images)
        loss = cec(pred,labels)
        loss.backward()
        optimizer.step()
        if (i+1) % n_print == 0:
          accuracy = float(validate(mycnn, aux_test_dl))
          print('Epoch :', e+1, 'Batch :', i+1, 'Loss :', float(loss.data), 'Accuracy :', accuracy,'%')
#     scheduler.step()

Epoch : 1 Batch : 10 Loss : 2.099766492843628 Accuracy : 22.0 %
Epoch : 1 Batch : 20 Loss : 1.8508740663528442 Accuracy : 19.0 %
Epoch : 1 Batch : 30 Loss : 1.2241307497024536 Accuracy : 33.0 %
Epoch : 1 Batch : 40 Loss : 1.048595666885376 Accuracy : 40.0 %
Epoch : 1 Batch : 50 Loss : 0.6518250703811646 Accuracy : 41.0 %
Epoch : 1 Batch : 60 Loss : 0.5701473951339722 Accuracy : 43.0 %
Epoch : 1 Batch : 70 Loss : 0.6350931525230408 Accuracy : 42.0 %
Epoch : 1 Batch : 80 Loss : 0.5901870727539062 Accuracy : 44.0 %
Epoch : 1 Batch : 90 Loss : 0.387764573097229 Accuracy : 44.0 %
Epoch : 1 Batch : 100 Loss : 0.22018644213676453 Accuracy : 48.0 %
Epoch : 1 Batch : 110 Loss : 0.3818364143371582 Accuracy : 45.0 %
Epoch : 1 Batch : 120 Loss : 0.5816485285758972 Accuracy : 43.0 %
Epoch : 1 Batch : 130 Loss : 0.4417090117931366 Accuracy : 51.0 %
Epoch : 1 Batch : 140 Loss : 0.5368501543998718 Accuracy : 45.0 %
Epoch : 1 Batch : 150 Loss : 0.3706173598766327 Accuracy : 45.0 %
Epoch : 1 Batch : 160

Epoch : 2 Batch : 320 Loss : 0.053501348942518234 Accuracy : 72.0 %
Epoch : 2 Batch : 330 Loss : 0.05427974462509155 Accuracy : 68.0 %
Epoch : 2 Batch : 340 Loss : 0.1366855949163437 Accuracy : 65.0 %
Epoch : 2 Batch : 350 Loss : 0.03376862406730652 Accuracy : 57.0 %
Epoch : 2 Batch : 360 Loss : 0.06552333384752274 Accuracy : 68.0 %
Epoch : 2 Batch : 370 Loss : 0.04491547495126724 Accuracy : 71.0 %
Epoch : 2 Batch : 380 Loss : 0.08075080811977386 Accuracy : 65.0 %
Epoch : 2 Batch : 390 Loss : 0.08515806496143341 Accuracy : 71.0 %
Epoch : 2 Batch : 400 Loss : 0.1324656903743744 Accuracy : 60.0 %
Epoch : 2 Batch : 410 Loss : 0.05780044198036194 Accuracy : 60.0 %
Epoch : 2 Batch : 420 Loss : 0.1730393022298813 Accuracy : 59.0 %
Epoch : 2 Batch : 430 Loss : 0.07309067994356155 Accuracy : 65.0 %
Epoch : 2 Batch : 440 Loss : 0.0499749556183815 Accuracy : 68.0 %
Epoch : 2 Batch : 450 Loss : 0.15976735949516296 Accuracy : 61.0 %
Epoch : 2 Batch : 460 Loss : 0.07453291863203049 Accuracy : 64.0 

Epoch : 3 Batch : 620 Loss : 0.1664947271347046 Accuracy : 68.0 %
Epoch : 3 Batch : 630 Loss : 0.05161431431770325 Accuracy : 70.0 %
Epoch : 3 Batch : 640 Loss : 0.0452851764857769 Accuracy : 68.0 %
Epoch : 3 Batch : 650 Loss : 0.10368375480175018 Accuracy : 65.0 %
Epoch : 3 Batch : 660 Loss : 0.12066226452589035 Accuracy : 65.0 %
Epoch : 3 Batch : 670 Loss : 0.09913128614425659 Accuracy : 64.0 %
Epoch : 3 Batch : 680 Loss : 0.036417894065380096 Accuracy : 65.0 %
Epoch : 3 Batch : 690 Loss : 0.026975875720381737 Accuracy : 67.0 %
Epoch : 3 Batch : 700 Loss : 0.08559785783290863 Accuracy : 59.0 %
Epoch : 3 Batch : 710 Loss : 0.06147833913564682 Accuracy : 64.0 %
Epoch : 3 Batch : 720 Loss : 0.06729133427143097 Accuracy : 72.0 %
Epoch : 3 Batch : 730 Loss : 0.037714581936597824 Accuracy : 69.0 %
Epoch : 3 Batch : 740 Loss : 0.11360634863376617 Accuracy : 68.0 %
Epoch : 3 Batch : 750 Loss : 0.03846532106399536 Accuracy : 72.0 %
Epoch : 3 Batch : 760 Loss : 0.009667742997407913 Accuracy : 

Epoch : 4 Batch : 910 Loss : 0.0012781135737895966 Accuracy : 72.0 %
Epoch : 4 Batch : 920 Loss : 0.0020286422222852707 Accuracy : 73.0 %
Epoch : 4 Batch : 930 Loss : 0.01703992486000061 Accuracy : 71.0 %
Epoch : 5 Batch : 10 Loss : 0.09115712344646454 Accuracy : 74.0 %
Epoch : 5 Batch : 20 Loss : 0.06953588128089905 Accuracy : 72.0 %
Epoch : 5 Batch : 30 Loss : 0.018771618604660034 Accuracy : 70.0 %
Epoch : 5 Batch : 40 Loss : 0.015120120719075203 Accuracy : 70.0 %
Epoch : 5 Batch : 50 Loss : 0.05607856437563896 Accuracy : 67.0 %
Epoch : 5 Batch : 60 Loss : 0.009270288050174713 Accuracy : 70.0 %
Epoch : 5 Batch : 70 Loss : 0.10752341151237488 Accuracy : 73.0 %
Epoch : 5 Batch : 80 Loss : 0.03503851592540741 Accuracy : 70.0 %
Epoch : 5 Batch : 90 Loss : 0.08593016117811203 Accuracy : 71.0 %
Epoch : 5 Batch : 100 Loss : 0.029174692928791046 Accuracy : 71.0 %
Epoch : 5 Batch : 110 Loss : 0.07202176749706268 Accuracy : 76.0 %
Epoch : 5 Batch : 120 Loss : 0.012088261544704437 Accuracy : 70

Epoch : 6 Batch : 270 Loss : 0.014555494301021099 Accuracy : 74.0 %
Epoch : 6 Batch : 280 Loss : 0.047878824174404144 Accuracy : 70.0 %
Epoch : 6 Batch : 290 Loss : 0.05064048618078232 Accuracy : 72.0 %
Epoch : 6 Batch : 300 Loss : 0.00857884343713522 Accuracy : 71.0 %
Epoch : 6 Batch : 310 Loss : 0.014366690069437027 Accuracy : 68.0 %
Epoch : 6 Batch : 320 Loss : 0.003934628330171108 Accuracy : 68.0 %
Epoch : 6 Batch : 330 Loss : 0.04615829512476921 Accuracy : 71.0 %
Epoch : 6 Batch : 340 Loss : 0.02942677028477192 Accuracy : 72.0 %
Epoch : 6 Batch : 350 Loss : 0.0038251206278800964 Accuracy : 71.0 %
Epoch : 6 Batch : 360 Loss : 0.03134482726454735 Accuracy : 71.0 %
Epoch : 6 Batch : 370 Loss : 0.00274865236133337 Accuracy : 73.0 %
Epoch : 6 Batch : 380 Loss : 0.026758993044495583 Accuracy : 71.0 %
Epoch : 6 Batch : 390 Loss : 0.15464510023593903 Accuracy : 68.0 %
Epoch : 6 Batch : 400 Loss : 0.0792866200208664 Accuracy : 71.0 %
Epoch : 6 Batch : 410 Loss : 0.005964359268546104 Accura

Epoch : 7 Batch : 560 Loss : 0.00023956684162840247 Accuracy : 76.0 %
Epoch : 7 Batch : 570 Loss : 0.14308315515518188 Accuracy : 71.0 %
Epoch : 7 Batch : 580 Loss : 0.025577982887625694 Accuracy : 74.0 %
Epoch : 7 Batch : 590 Loss : 0.05170514062047005 Accuracy : 73.0 %
Epoch : 7 Batch : 600 Loss : 0.09068860113620758 Accuracy : 71.0 %
Epoch : 7 Batch : 610 Loss : 0.032439909875392914 Accuracy : 74.0 %
Epoch : 7 Batch : 620 Loss : 0.008350266143679619 Accuracy : 76.0 %
Epoch : 7 Batch : 630 Loss : 0.0250201728194952 Accuracy : 75.0 %
Epoch : 7 Batch : 640 Loss : 0.014293119311332703 Accuracy : 77.0 %
Epoch : 7 Batch : 650 Loss : 0.08851881325244904 Accuracy : 73.0 %
Epoch : 7 Batch : 660 Loss : 0.020484674721956253 Accuracy : 68.0 %
Epoch : 7 Batch : 670 Loss : 0.050484344363212585 Accuracy : 71.0 %
Epoch : 7 Batch : 680 Loss : 0.038467809557914734 Accuracy : 67.0 %
Epoch : 7 Batch : 690 Loss : 0.02384745329618454 Accuracy : 70.0 %
Epoch : 7 Batch : 700 Loss : 0.018245283514261246 Acc

Epoch : 8 Batch : 850 Loss : 0.00845804437994957 Accuracy : 71.0 %
Epoch : 8 Batch : 860 Loss : 0.00021242909133434296 Accuracy : 74.0 %
Epoch : 8 Batch : 870 Loss : 0.007442634552717209 Accuracy : 77.0 %
Epoch : 8 Batch : 880 Loss : 0.00017025135457515717 Accuracy : 75.0 %
Epoch : 8 Batch : 890 Loss : 0.010076884180307388 Accuracy : 76.0 %
Epoch : 8 Batch : 900 Loss : 0.005333541892468929 Accuracy : 77.0 %
Epoch : 8 Batch : 910 Loss : 0.06756992638111115 Accuracy : 73.0 %
Epoch : 8 Batch : 920 Loss : 0.0001927679404616356 Accuracy : 76.0 %
Epoch : 8 Batch : 930 Loss : 0.01730284094810486 Accuracy : 72.0 %
Epoch : 9 Batch : 10 Loss : 0.08740653097629547 Accuracy : 71.0 %
Epoch : 9 Batch : 20 Loss : 0.046170447021722794 Accuracy : 71.0 %
Epoch : 9 Batch : 30 Loss : 0.00104800914414227 Accuracy : 74.0 %
Epoch : 9 Batch : 40 Loss : 0.0557820200920105 Accuracy : 76.0 %
Epoch : 9 Batch : 50 Loss : 0.06609217822551727 Accuracy : 77.0 %
Epoch : 9 Batch : 60 Loss : 0.04505664110183716 Accuracy

Epoch : 10 Batch : 200 Loss : 0.014746776781976223 Accuracy : 68.0 %
Epoch : 10 Batch : 210 Loss : 0.007855905219912529 Accuracy : 68.0 %
Epoch : 10 Batch : 220 Loss : 0.06161900609731674 Accuracy : 68.0 %
Epoch : 10 Batch : 230 Loss : 0.05787210538983345 Accuracy : 72.0 %
Epoch : 10 Batch : 240 Loss : 0.013518786057829857 Accuracy : 72.0 %
Epoch : 10 Batch : 250 Loss : 0.059828732162714005 Accuracy : 74.0 %
Epoch : 10 Batch : 260 Loss : 0.0033005960285663605 Accuracy : 73.0 %
Epoch : 10 Batch : 270 Loss : 0.014368044212460518 Accuracy : 78.0 %
Epoch : 10 Batch : 280 Loss : 0.02097972109913826 Accuracy : 74.0 %
Epoch : 10 Batch : 290 Loss : 0.08303897827863693 Accuracy : 75.0 %
Epoch : 10 Batch : 300 Loss : 0.00253827846609056 Accuracy : 78.0 %
Epoch : 10 Batch : 310 Loss : 0.008530429564416409 Accuracy : 78.0 %
Epoch : 10 Batch : 320 Loss : 0.0019215866923332214 Accuracy : 76.0 %
Epoch : 10 Batch : 330 Loss : 0.00438348576426506 Accuracy : 70.0 %
Epoch : 10 Batch : 340 Loss : 0.002281

Epoch : 11 Batch : 470 Loss : 0.0741688534617424 Accuracy : 62.0 %
Epoch : 11 Batch : 480 Loss : 0.002241319976747036 Accuracy : 68.0 %
Epoch : 11 Batch : 490 Loss : 0.016419192776083946 Accuracy : 69.0 %
Epoch : 11 Batch : 500 Loss : 0.12048196792602539 Accuracy : 66.0 %
Epoch : 11 Batch : 510 Loss : 0.0009904573671519756 Accuracy : 72.0 %
Epoch : 11 Batch : 520 Loss : 0.005746121052652597 Accuracy : 72.0 %
Epoch : 11 Batch : 530 Loss : 0.013818386010825634 Accuracy : 75.0 %
Epoch : 11 Batch : 540 Loss : 0.025255190208554268 Accuracy : 71.0 %
Epoch : 11 Batch : 550 Loss : 0.05585009604692459 Accuracy : 70.0 %
Epoch : 11 Batch : 560 Loss : 0.0021756934002041817 Accuracy : 66.0 %
Epoch : 11 Batch : 570 Loss : 0.1973111629486084 Accuracy : 68.0 %
Epoch : 11 Batch : 580 Loss : 0.0982196182012558 Accuracy : 72.0 %
Epoch : 11 Batch : 590 Loss : 0.06743358820676804 Accuracy : 67.0 %
Epoch : 11 Batch : 600 Loss : 0.13507580757141113 Accuracy : 72.0 %
Epoch : 11 Batch : 610 Loss : 0.0541611164

Epoch : 12 Batch : 740 Loss : 0.07665864378213882 Accuracy : 62.0 %
Epoch : 12 Batch : 750 Loss : 0.0021662074141204357 Accuracy : 67.0 %
Epoch : 12 Batch : 760 Loss : 0.013604462146759033 Accuracy : 77.0 %
Epoch : 12 Batch : 770 Loss : 0.003914410714060068 Accuracy : 73.0 %
Epoch : 12 Batch : 780 Loss : 0.14827747642993927 Accuracy : 72.0 %
Epoch : 12 Batch : 790 Loss : 0.0026986906304955482 Accuracy : 74.0 %
Epoch : 12 Batch : 800 Loss : 0.03475883975625038 Accuracy : 67.0 %
Epoch : 12 Batch : 810 Loss : 0.0001885984092950821 Accuracy : 68.0 %
Epoch : 12 Batch : 820 Loss : 0.007938742637634277 Accuracy : 73.0 %
Epoch : 12 Batch : 830 Loss : 0.07678468525409698 Accuracy : 68.0 %
Epoch : 12 Batch : 840 Loss : 0.013829343020915985 Accuracy : 75.0 %
Epoch : 12 Batch : 850 Loss : 0.0003105385694652796 Accuracy : 71.0 %
Epoch : 12 Batch : 860 Loss : 0.012835677713155746 Accuracy : 75.0 %
Epoch : 12 Batch : 870 Loss : 0.047731976956129074 Accuracy : 78.0 %
Epoch : 12 Batch : 880 Loss : 0.09

Epoch : 14 Batch : 80 Loss : 0.0071789976209402084 Accuracy : 77.0 %
Epoch : 14 Batch : 90 Loss : 0.0032338458113372326 Accuracy : 75.0 %
Epoch : 14 Batch : 100 Loss : 0.0021546771749854088 Accuracy : 66.0 %
Epoch : 14 Batch : 110 Loss : 0.006392132490873337 Accuracy : 80.0 %
Epoch : 14 Batch : 120 Loss : 0.023358361795544624 Accuracy : 74.0 %
Epoch : 14 Batch : 130 Loss : 0.0011906777508556843 Accuracy : 74.0 %
Epoch : 14 Batch : 140 Loss : 0.006179208867251873 Accuracy : 75.0 %
Epoch : 14 Batch : 150 Loss : 0.01461968943476677 Accuracy : 72.0 %
Epoch : 14 Batch : 160 Loss : 0.08070934563875198 Accuracy : 75.0 %
Epoch : 14 Batch : 170 Loss : 0.00029580551199615 Accuracy : 73.0 %
Epoch : 14 Batch : 180 Loss : 0.027075210586190224 Accuracy : 70.0 %
Epoch : 14 Batch : 190 Loss : 0.005861657205969095 Accuracy : 69.0 %
Epoch : 14 Batch : 200 Loss : 0.008169567212462425 Accuracy : 71.0 %
Epoch : 14 Batch : 210 Loss : 0.0031743464060127735 Accuracy : 73.0 %
Epoch : 14 Batch : 220 Loss : 0.00

Epoch : 15 Batch : 340 Loss : 0.07018724083900452 Accuracy : 71.0 %
Epoch : 15 Batch : 350 Loss : 0.001749723101966083 Accuracy : 75.0 %
Epoch : 15 Batch : 360 Loss : 0.0012819916009902954 Accuracy : 75.0 %
Epoch : 15 Batch : 370 Loss : 0.0023094923235476017 Accuracy : 76.0 %
Epoch : 15 Batch : 380 Loss : 2.404465340077877e-05 Accuracy : 71.0 %
Epoch : 15 Batch : 390 Loss : 0.06958498060703278 Accuracy : 64.0 %
Epoch : 15 Batch : 400 Loss : 0.012746118940412998 Accuracy : 71.0 %
Epoch : 15 Batch : 410 Loss : 0.0013242216082289815 Accuracy : 66.0 %
Epoch : 15 Batch : 420 Loss : 0.02887227199971676 Accuracy : 65.0 %
Epoch : 15 Batch : 430 Loss : 0.09462690353393555 Accuracy : 66.0 %
Epoch : 15 Batch : 440 Loss : 0.07492117583751678 Accuracy : 71.0 %
Epoch : 15 Batch : 450 Loss : 0.013774299994111061 Accuracy : 70.0 %
Epoch : 15 Batch : 460 Loss : 0.0002717655152082443 Accuracy : 70.0 %
Epoch : 15 Batch : 470 Loss : 0.04651201516389847 Accuracy : 67.0 %
Epoch : 15 Batch : 480 Loss : 0.001

In [21]:
### Save model
torch.save(mycnn.state_dict(), 
           'E:/#AI&CV camp/#Project/#1Protocols/mnist_digits.pt')

In [23]:
### Load model
# rec_cnn = myCNN().cuda()
# model.load_state_dict(torch.load('E:/#AI&CV camp/#Project/#1Protocols/mnist_digits.pt'))
# model.eval()

<All keys matched successfully>

In [10]:
# from PIL import Image
# img = Image.open("test_6_3.jpg")

# img_t = transform(img).to(device='cuda')
# batch_t = torch.unsqueeze(img_t, 0)

In [11]:
# out = F.softmax(model(batch_t), dim=1)
# result_vec = out[0].cpu().detach().numpy()
# print (np.argmax(result_vec))