In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")


In [58]:
transform_train = transforms.Compose([transforms.ToTensor()])
transform_valid = transforms.Compose([transforms.ToTensor()])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform_train, target_transform=lambda x: torch.Tensor([x]).float())
trainloader = DataLoader(trainset, batch_size=20,
                         shuffle=True, num_workers=0)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transform_valid, target_transform=lambda x: torch.Tensor([x]).float())
testloader = DataLoader(testset, batch_size=20,
                        shuffle=False, num_workers=0)


In [73]:
class Cnn(nn.Module):
    def __init__(self):
        super(Cnn, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, (5, 5), padding='same')
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, (5, 5), padding='same')
        self.flatten = nn.Flatten(1)
        self.fc1 = nn.Linear(7 * 7 * 64, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool(x)

        x = self.pool(F.relu(self.conv2(x)))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5)
        x = self.fc2(x)
        return x

metric_func = lambda y_pred, y_true: roc_auc_score(y_true=y_true.data.numpy(), y_score=y_pred.data.numpy())
matric_name = "auc"
def train_step(model: nn.Module, features:torch.Tensor, labels: torch.Tensor):
    features = features.to(device=device)
    labels = labels.to(device)
    
    model.train()
    optimizer.zero_grad()
    y_pred = model(features)
    loss = criterion(y_pred, labels)
    # metric = metric_func(y_pred, labels)
    metric = torch.Tensor(0)
    loss.backward()
    optimizer.step()
    return loss.item(), 0


In [74]:
from sklearn.metrics import roc_auc_score

cnn = Cnn()
cnn.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.0001)

In [75]:
import torchkeras
input_shape = (1, 28, 28)
torchkeras.summary(Cnn(), input_shape=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 28, 28]             832
         MaxPool2d-2           [-1, 32, 14, 14]               0
            Conv2d-3           [-1, 64, 14, 14]          51,264
         MaxPool2d-4             [-1, 64, 7, 7]               0
           Flatten-5                 [-1, 3136]               0
            Linear-6                 [-1, 1024]       3,212,288
            Linear-7                   [-1, 10]          10,250
Total params: 3,274,634
Trainable params: 3,274,634
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.002991
Forward/backward pass size (MB): 0.390701
Params size (MB): 12.491737
Estimated Total Size (MB): 12.885429
----------------------------------------------------------------


In [76]:
(features, labels) = next(iter(trainloader))
# print((features[0], labels[0]))
print(trainloader.batch_sampler)
print(len(trainloader))

<torch.utils.data.sampler.BatchSampler object at 0x000002D27C53A348>
3000


In [80]:
for epoch in range(20):
    loss_sum = []
    metric_sum = 0.0
    for i, data in enumerate(trainloader, 0):
        (featurs, labels) = data
        loss, metric = train_step(cnn, features=features, labels=labels.flatten().type(torch.long))
        loss_sum.append(loss)
        metric_sum += metric
        if i % 25 == 0:
            print('[%d, %5d] loss: %.3f' %
                    (epoch + 1, i + 1, torch.mean(torch.Tensor(loss_sum))))


[1,     1] loss: 2.299
[1,    26] loss: 2.301
[1,    51] loss: 2.303
[1,    76] loss: 2.302
[1,   101] loss: 2.302
[1,   126] loss: 2.302
[1,   151] loss: 2.302
[1,   176] loss: 2.301
[1,   201] loss: 2.301
[1,   226] loss: 2.301
[1,   251] loss: 2.301
[1,   276] loss: 2.300
[1,   301] loss: 2.301
[1,   326] loss: 2.301
[1,   351] loss: 2.301
[1,   376] loss: 2.301
[1,   401] loss: 2.301
[1,   426] loss: 2.301
[1,   451] loss: 2.301
[1,   476] loss: 2.301
[1,   501] loss: 2.301
[1,   526] loss: 2.301
[1,   551] loss: 2.301
[1,   576] loss: 2.301
[1,   601] loss: 2.301
[1,   626] loss: 2.301
[1,   651] loss: 2.301
[1,   676] loss: 2.301
[1,   701] loss: 2.301
[1,   726] loss: 2.301
[1,   751] loss: 2.301
[1,   776] loss: 2.301
[1,   801] loss: 2.301
[1,   826] loss: 2.301
[1,   851] loss: 2.301
[1,   876] loss: 2.301
[1,   901] loss: 2.301
[1,   926] loss: 2.301
[1,   951] loss: 2.301
[1,   976] loss: 2.301
[1,  1001] loss: 2.301
[1,  1026] loss: 2.301
[1,  1051] loss: 2.301
[1,  1076] 

KeyboardInterrupt: 