In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets

In [2]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cude" if USE_CUDA else "cpu")

In [57]:
EPOCHS = 40
BATCH_SIZE = 60

In [4]:
train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST('./.data',train = True, download = True, 
                                                                 transform=transforms.Compose([transforms.ToTensor(),
                                                                                              transforms.Normalize((0.1307,),(0.3081))])),
                                          batch_size = BATCH_SIZE, shuffle = True)

test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST('./.data',train=False,
                                                               transform=transforms.Compose([
                                                                   transforms.ToTensor(),
                                                                   transforms.Normalize((0.1307,),(0.3081,))
                                                               ])),
                                         batch_size = BATCH_SIZE, shuffle = True)

In [48]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Conv2d(1,10,kernel_size=5)
        self.conv2 = nn.Conv2d(10,20,kernel_size=5)
        self.conv3 = nn.Conv2d(20,40,kernel_size=3)
        self.conv3_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320,50)
        self.fc2 = nn.Linear(50,10)
        
    def forward(self,x):
        x = F.relu(F.max_pool2d(self.conv1(x),2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))
        x = x.view(-1,320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training = self.training)
        x = self.fc2(x)
        return x

In [49]:
model = CNN().to(DEVICE)
optimizer = optim.SGD(model.parameters(),lr=0.01, momentum = 0.5)

In [50]:
def train(model, train_loader, optimizer, epoch):
    model.train()
    
    for batch_idx, (data,target) in enumerate(train_loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output,target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 200 == 0:
            print('Train Epoch : {} [{}/{} ({:.0f}%)]\tLoss:{:.6f}'.format(epoch,batch_idx*len(data),
                                                                          len(train_loader.dataset),
                                                                          100.*batch_idx / len(train_loader),
                                                                          loss.item()))

In [55]:
def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            
            pred = output.max(1,keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    test_loss /=len(test_loader.dataset)
    test_accuracy = 100 * correct / len(test_loader.dataset)
    return test_loss, test_accuracy

In [58]:
for epoch in range(1,EPOCHS + 1):
    train(model,train_loader,optimizer,epoch)
    test_loss, test_accuracy = evaluate(model, test_loader)
    
    print('[{}] Test Loss : {:.4f}, Accuracy: {:.2f}%'.format(epoch,test_loss, test_accuracy))

[1] Test Loss : 0.3166, Accuracy: 88.83%
[2] Test Loss : 0.3186, Accuracy: 88.49%
[3] Test Loss : 0.3175, Accuracy: 88.26%
[4] Test Loss : 0.3148, Accuracy: 88.51%
[5] Test Loss : 0.3124, Accuracy: 88.84%
[6] Test Loss : 0.3161, Accuracy: 88.80%
[7] Test Loss : 0.3126, Accuracy: 88.80%
[8] Test Loss : 0.3213, Accuracy: 88.33%
[9] Test Loss : 0.3127, Accuracy: 88.77%
[10] Test Loss : 0.3059, Accuracy: 89.19%
[11] Test Loss : 0.3093, Accuracy: 88.81%
[12] Test Loss : 0.3070, Accuracy: 89.06%
[13] Test Loss : 0.3085, Accuracy: 88.90%
[14] Test Loss : 0.3117, Accuracy: 88.84%
[15] Test Loss : 0.3058, Accuracy: 89.07%
[16] Test Loss : 0.3111, Accuracy: 88.84%
[17] Test Loss : 0.3048, Accuracy: 88.84%
[18] Test Loss : 0.3061, Accuracy: 89.30%
[19] Test Loss : 0.3220, Accuracy: 88.36%
[20] Test Loss : 0.3067, Accuracy: 89.01%
[21] Test Loss : 0.3115, Accuracy: 88.67%
[22] Test Loss : 0.3079, Accuracy: 88.98%
[23] Test Loss : 0.3081, Accuracy: 88.70%
[24] Test Loss : 0.3031, Accuracy: 89.35%
[

[29] Test Loss : 0.3028, Accuracy: 89.42%
[30] Test Loss : 0.3042, Accuracy: 89.12%
[31] Test Loss : 0.2993, Accuracy: 89.29%
[32] Test Loss : 0.3041, Accuracy: 89.08%
[33] Test Loss : 0.3038, Accuracy: 89.51%
[34] Test Loss : 0.3016, Accuracy: 89.27%
[35] Test Loss : 0.3013, Accuracy: 89.35%
[36] Test Loss : 0.3058, Accuracy: 88.81%
[37] Test Loss : 0.2981, Accuracy: 89.26%
[38] Test Loss : 0.3064, Accuracy: 88.99%
[39] Test Loss : 0.3021, Accuracy: 89.34%
[40] Test Loss : 0.2988, Accuracy: 89.09%
[41] Test Loss : 0.3028, Accuracy: 88.95%
[42] Test Loss : 0.3007, Accuracy: 89.48%
[43] Test Loss : 0.3011, Accuracy: 89.37%
[44] Test Loss : 0.3024, Accuracy: 89.28%
[45] Test Loss : 0.2969, Accuracy: 89.54%
[46] Test Loss : 0.2995, Accuracy: 89.18%
[47] Test Loss : 0.2987, Accuracy: 89.37%
[48] Test Loss : 0.2984, Accuracy: 89.26%
[49] Test Loss : 0.2984, Accuracy: 89.34%
[50] Test Loss : 0.2983, Accuracy: 89.38%
[51] Test Loss : 0.2981, Accuracy: 89.13%
[52] Test Loss : 0.2961, Accuracy:

[57] Test Loss : 0.2977, Accuracy: 89.37%
[58] Test Loss : 0.3030, Accuracy: 89.10%
[59] Test Loss : 0.3008, Accuracy: 89.34%
[60] Test Loss : 0.3055, Accuracy: 89.21%
[61] Test Loss : 0.3015, Accuracy: 89.20%
[62] Test Loss : 0.2984, Accuracy: 89.45%
[63] Test Loss : 0.3006, Accuracy: 89.04%
[64] Test Loss : 0.2964, Accuracy: 89.46%
[65] Test Loss : 0.2982, Accuracy: 89.21%
[66] Test Loss : 0.2991, Accuracy: 89.17%
[67] Test Loss : 0.2972, Accuracy: 89.14%
[68] Test Loss : 0.2950, Accuracy: 89.56%
[69] Test Loss : 0.2951, Accuracy: 89.32%
[70] Test Loss : 0.2997, Accuracy: 89.40%
[71] Test Loss : 0.3028, Accuracy: 88.98%
[72] Test Loss : 0.2983, Accuracy: 89.34%
[73] Test Loss : 0.2986, Accuracy: 89.51%
[74] Test Loss : 0.3015, Accuracy: 89.09%
[75] Test Loss : 0.2951, Accuracy: 89.24%
[76] Test Loss : 0.2970, Accuracy: 89.23%
[77] Test Loss : 0.2940, Accuracy: 89.62%
[78] Test Loss : 0.2920, Accuracy: 89.53%
[79] Test Loss : 0.2923, Accuracy: 89.50%
[80] Test Loss : 0.2961, Accuracy: