In [1]:
import torchvision
import torchvision.transforms as transforms
import os
from torch.utils.data import DataLoader
import torch
from torch.utils.data import random_split
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [2]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Loading FMNIST dataset

In [3]:
fmnist_train = torchvision.datasets.FashionMNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /content/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:02<00:00, 9435415.34it/s] 


Extracting /content/FashionMNIST/raw/train-images-idx3-ubyte.gz to /content/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /content/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 169839.12it/s]


Extracting /content/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /content/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /content/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 3100223.70it/s]


Extracting /content/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /content/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /content/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 14901502.41it/s]

Extracting /content/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /content/FashionMNIST/raw






In [4]:
loader = DataLoader(fmnist_train,
                         batch_size=10,
                         num_workers=0,
                         shuffle=True)

channels_sum, channels_squared_sum, num_batches = 0, 0, 0

for data, _ in loader:
    channels_sum += torch.mean(data, dim=[0, 2, 3])
    channels_squared_sum += torch.mean(data**2, dim=[0, 2, 3])
    num_batches += 1

mean = channels_sum/num_batches
std = (channels_squared_sum/num_batches - mean**2)**0.5
print(mean, std)

tensor([0.2860]) tensor([0.3530])


In [5]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2860, ), (0.3530, ))])

In [6]:
fmnist_train = torchvision.datasets.FashionMNIST(os.getcwd(), train=True, download=True, transform=transform)
fmnist_test = torchvision.datasets.FashionMNIST(os.getcwd(), train=False, download=True, transform=transform)

In [7]:
fmnist_train, fmnist_val = random_split(fmnist_train, [55000, 5000])

In [8]:
fmnist_train = DataLoader(fmnist_train, batch_size=64, num_workers=4, pin_memory=True, shuffle=True)
fmnist_val = DataLoader(fmnist_val, batch_size=64, num_workers=4, pin_memory=True, shuffle=True)
fmnist_test = DataLoader(fmnist_test, batch_size=64, num_workers=4, pin_memory=True, shuffle=True)



# Defining network

In [9]:
class Net(nn.Module):
  def __init__(self):
      super().__init__()
      self.conv1 = nn.Conv2d(1, 64, 3, 1)
      self.conv2 = nn.Conv2d(64, 128, 3, 1)
      self.dropout = nn.Dropout(p=0.5)
      self.fc1 = nn.Linear(3200, 128)
      self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
      x = F.relu(self.conv1(x))
      x = F.max_pool2d(x, 2, 2)
      x = F.relu(self.dropout(self.conv2(x)))
      x = F.max_pool2d(x, 2, 2)
      x = x.view(-1, 3200)
      x = F.relu(self.fc1(x))
      x = self.fc2(x)
      return x

In [10]:
model = Net()
model = model.to(DEVICE)

In [11]:
optimizer = torch.optim.RMSprop(model.parameters(), lr = 1e-3)

In [12]:
from torchsummary import summary
summary(model, input_size=(1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 26, 26]             640
            Conv2d-2          [-1, 128, 11, 11]          73,856
           Dropout-3          [-1, 128, 11, 11]               0
            Linear-4                  [-1, 128]         409,728
            Linear-5                   [-1, 10]           1,290
Total params: 485,514
Trainable params: 485,514
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.57
Params size (MB): 1.85
Estimated Total Size (MB): 2.42
----------------------------------------------------------------


# Helper functions

In [13]:
def common_compute(model, batch):
    x, y = batch
    x, y = x.to(DEVICE), y.to(DEVICE)
    logits = model(x)
    loss = F.cross_entropy(logits, y)
    return logits, loss, y

In [14]:
def train_batch(model, optimizer, batch):
    logits, loss, _ = common_compute(model, batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return loss

In [15]:
def validate_batch(model, batch):
    logits, loss, _ = common_compute(model, batch)
    return loss

In [16]:
def test_batch(model, batch):
    logits, loss, y = common_compute(model, batch)
    _, predicted = torch.max(logits.data, 1)
    return y.size(0), (predicted == y).sum().item(), loss

# Training loop

In [17]:
num_epochs = 15
for epoch in range(num_epochs):
    model.train()
    train_loss = []
    bar = tqdm(fmnist_train, position=0, leave=False, desc='epoch %d'%epoch)
    for batch in bar:
        loss = train_batch(model, optimizer, batch)
        train_loss.append(loss)
    avg_train_loss = torch.stack(train_loss).mean()
    print('train_loss', avg_train_loss.item())

    model.eval()
    with torch.no_grad():
        val_loss = []
        for batch in fmnist_val:
            loss = validate_batch(model, batch)
            val_loss.append(loss)
        avg_val_loss = torch.stack(val_loss).mean()
        print('val_loss', avg_val_loss.item())

                                                          

train_loss 0.5687328577041626




val_loss 0.38607266545295715


                                                          

train_loss 0.3502223789691925




val_loss 0.359352707862854


                                                          

train_loss 0.3020690083503723




val_loss 0.30992591381073


                                                          

train_loss 0.275563508272171




val_loss 0.3165651857852936


                                                          

train_loss 0.25584250688552856




val_loss 0.2598448693752289


                                                          

train_loss 0.2413945347070694




val_loss 0.27394595742225647


                                                          

train_loss 0.22929425537586212




val_loss 0.24525509774684906




train_loss 0.21992388367652893
val_loss 0.23827221989631653


                                                          

train_loss 0.2107696533203125




val_loss 0.2359306961297989


                                                          

train_loss 0.201670840382576




val_loss 0.2358035147190094


                                                           

train_loss 0.19287459552288055




val_loss 0.23590590059757233


                                                           

train_loss 0.18770372867584229




val_loss 0.22306309640407562


                                                           

train_loss 0.181253582239151




val_loss 0.2175786942243576


                                                           

train_loss 0.1767544001340866




val_loss 0.22451826930046082


                                                           

train_loss 0.16948078572750092




val_loss 0.2429293692111969


# Testing

In [18]:
bar = tqdm(fmnist_test, position=0, leave=False, desc='test')
test_loss = []
correct = 0
total = 0
with torch.no_grad():
    for batch in bar:
        batch_size, batch_correct, loss = test_batch(model, batch)
        total += batch_size
        correct += batch_correct
        test_loss.append(loss)
    avg_test_loss = torch.stack(test_loss).mean()
    print('test_loss', avg_test_loss.item())
    print('Accuracy %.2f%%' % (100 * float(correct) / total))

                                                       

test_loss 0.25632283091545105
Accuracy 91.00%




In [19]:
torch.save(model, 'model_fmnist.pth')