In [1]:
import torch
from torch import nn
import numpy as np
from defconv.models import ConvNet, DeformConvNet
from argus import Model
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

In [2]:
transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
               ])

In [3]:
mnist_trainset = DataLoader(datasets.MNIST(root='./data', train=True, download=True, transform=transform),
                           batch_size=64,shuffle=True)
mnist_valset = DataLoader(datasets.MNIST(root='./data', train=False, download=True, transform=transform),
                         batch_size=64, shuffle=True)

In [4]:
# in_channels = 1
# channels = [32,32,64,64, 128, 128, 128 ,128]
# n_classes = 10
# convnet = ConvNet(in_channels, channels, n_classes, padding=1)
# defconvnet = DeformConvNet(in_channels, channels, n_classes, padding=1)

# print(convnet(torch.rand((1,1,28,28))).shape)

# print(defconvnet(torch.rand(64,1,28,28)).shape)

In [5]:
class ConvNetArgus(Model):
    nn_module = ConvNet
    loss = 'CrossEntropyLoss'

class DeformConvNetArgus(Model):
    nn_module = DeformConvNet
    loss = 'CrossEntropyLoss'

In [9]:
PARAMS = {'nn_module': {
                        'in_channels': 1,
                         'channels': [32,32,64,64, 128, 128, 128 ,128],
                         'n_classes': 10
                        },
          
           'optimizer': ('Adam', {'lr':1e-3}),
           'device':'cuda'
         }

In [10]:
model2 = DeformConvNetArgus(PARAMS)

In [11]:
model2.fit(mnist_trainset, val_loader=mnist_valset, max_epochs=20)

2019-08-10 20:35:11,557 INFO Validation, val_loss: 2.302618
2019-08-10 20:35:52,717 INFO Train - Epoch: 1, LR: 0.001, train_loss: 1.577252
2019-08-10 20:35:56,072 INFO Validation - Epoch: 1, val_loss: 1.505638
2019-08-10 20:36:37,522 INFO Train - Epoch: 2, LR: 0.001, train_loss: 1.504117
2019-08-10 20:36:40,925 INFO Validation - Epoch: 2, val_loss: 1.494042
2019-08-10 20:37:23,469 INFO Train - Epoch: 3, LR: 0.001, train_loss: 1.501095
2019-08-10 20:37:26,864 INFO Validation - Epoch: 3, val_loss: 1.494587
2019-08-10 20:38:07,781 INFO Train - Epoch: 4, LR: 0.001, train_loss: 1.501452
2019-08-10 20:38:11,149 INFO Validation - Epoch: 4, val_loss: 1.489388
2019-08-10 20:38:51,611 INFO Train - Epoch: 5, LR: 0.001, train_loss: 1.493299
2019-08-10 20:38:55,013 INFO Validation - Epoch: 5, val_loss: 1.485181
2019-08-10 20:39:37,036 INFO Train - Epoch: 6, LR: 0.001, train_loss: 1.492843
2019-08-10 20:39:40,433 INFO Validation - Epoch: 6, val_loss: 1.488932
2019-08-10 20:40:21,905 INFO Train - Epo

In [24]:
from dpipe.torch.utils import to_var, to_np
targets = []
prediction = []

for batch in mnist_valset:
    input, target = batch
    preds = model2.predict(input.cuda())
    prediction.extend(to_np(preds))
    targets.extend(to_np(target))

In [25]:
prediction = np.array(prediction)
targets = np.array(targets)

In [26]:
prediction.shape

(10000, 10)

In [None]:
prediction.argmax(1)

In [9]:
model1 = ConvNetArgus(PARAMS)

In [10]:
model1.fit(mnist_trainset, val_loader=mnist_valset, max_epochs=20)

2019-08-10 20:12:02,247 INFO Validation, val_loss: 2.302584
2019-08-10 20:12:25,823 INFO Train - Epoch: 1, LR: 0.01, train_loss: 1.550499
2019-08-10 20:12:27,996 INFO Validation - Epoch: 1, val_loss: 1.496974
2019-08-10 20:12:51,671 INFO Train - Epoch: 2, LR: 0.01, train_loss: 1.490642
2019-08-10 20:12:53,868 INFO Validation - Epoch: 2, val_loss: 1.492389
2019-08-10 20:13:18,197 INFO Train - Epoch: 3, LR: 0.01, train_loss: 1.485896
2019-08-10 20:13:20,428 INFO Validation - Epoch: 3, val_loss: 1.481224
2019-08-10 20:13:44,058 INFO Train - Epoch: 4, LR: 0.01, train_loss: 1.483933
2019-08-10 20:13:46,316 INFO Validation - Epoch: 4, val_loss: 1.49637
2019-08-10 20:14:10,377 INFO Train - Epoch: 5, LR: 0.01, train_loss: 1.48008
2019-08-10 20:14:12,560 INFO Validation - Epoch: 5, val_loss: 1.476188
2019-08-10 20:14:36,708 INFO Train - Epoch: 6, LR: 0.01, train_loss: 1.479211
2019-08-10 20:14:38,941 INFO Validation - Epoch: 6, val_loss: 1.474259
2019-08-10 20:15:02,936 INFO Train - Epoch: 7, L