In [13]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
import torchvision.datasets as datasets #for mnist datasets
import torchvision.transforms as transforms #transformations for dataset
print(torch.__version__)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

training_set = datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

training_loader = DataLoader(training_set, batch_size=4, shuffle=True, num_workers=2)
validation_loader = DataLoader(validation_set, batch_size=4, shuffle=False, num_workers=2)

classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

1.10.2+cpu
Training set has 60000 instances
Validation set has 10000 instances


In [14]:
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img/2+.5
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

    dataiter = iter(training_loader)
    images, labels =dataiter.next()

    img_grid = torchvision.utils.make_grid(images)
    matplotlib_imshow(img_grid, one_channel=True)
    print(' . '.join(classes[labels[j]] for j in range(4)))


In [17]:
class GarmentClassifier(nn.Module):
        def __init__(self):
                super(GarmentClassifier, self).__init__()
                self.conv1 = nn.Conv2d(1, 6, 5)
                self.pool = nn.MaxPool2d(2, 2)
                self.conv2 = nn.Conv2d(6, 16, 5)
                self.fc1 = nn.Linear(16*4*4, 120)
                self.fc2 = nn.Linear(120, 84)
                self.fc3 = nn.Linear(84, 10)

        def forward(self, x):
                x = self.pool(F.relu(self.conv1(x)))
                x = self.pool(F.relu(self.conv2(x)))
                x = x.view(-1, 16 * 4 * 4)
                x = F.relu(self.fc1(x))
                x = F.relu(self.fc2(x))
                x = self.fc3(x)
                return x

model = GarmentClassifier()

In [18]:
loss_fn = torch.nn.CrossEntropyLoss()

dummy_outputs = torch.rand(4, 10)
dummy_labels = torch.tensor([1, 5, 3, 7])
print(dummy_outputs)
print(dummy_labels)

loss = loss_fn(dummy_outputs, dummy_labels)
print('Total loss for this batch: {}'.format(loss.item()))

tensor([[0.6366, 0.9752, 0.1671, 0.3145, 0.9011, 0.4340, 0.4700, 0.9473, 0.4764,
         0.3053],
        [0.2012, 0.7745, 0.8610, 0.7032, 0.8858, 0.2090, 0.2340, 0.9887, 0.0874,
         0.4331],
        [0.2638, 0.0517, 0.2714, 0.0062, 0.7773, 0.5595, 0.4343, 0.6807, 0.6777,
         0.0453],
        [0.6054, 0.9609, 0.3227, 0.8184, 0.8384, 0.5905, 0.3082, 0.6318, 0.4370,
         0.2165]])
tensor([1, 5, 3, 7])
Total loss for this batch: 2.3984508514404297


In [19]:
optimizer = optim.SGD(model.parameters(), lr=.001, momentum=.9)


In [20]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    for i, data in enumerate(training_loader):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000
            print(' batch {} loss: {}'.format(i+1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

        return last_loss

In [21]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

In [22]:
EPOCHS = 5
best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    model.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(validation_loader):
        vinputs, vlabels = vdata
        voutputs = model(vinputs)
        vloss = loss_fn(voutputs, vlabels)
        running_vloss += vloss
            
        avg_vloss = running_vloss / (i+1)
        print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

        writer.add_scalars('Training vs. Validation Loss', { 'Training' : avg_loss, 'Validation' : avg_vloss }, epoch_number + 1)
        writer.flush()

        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            model_path = 'model_{}_{}'.format(timestamp, epoch_number)
            torch.save(model.state_dict(), model_path)

        epoch_number += 1
        

EPOCH 1:
LOSS train 0.0 valid 2.271998405456543
LOSS train 0.0 valid 2.2838845252990723
LOSS train 0.0 valid 2.291489839553833
LOSS train 0.0 valid 2.293400764465332
LOSS train 0.0 valid 2.2986483573913574
LOSS train 0.0 valid 2.297105073928833
LOSS train 0.0 valid 2.2966995239257812
LOSS train 0.0 valid 2.30568265914917
LOSS train 0.0 valid 2.3080015182495117
LOSS train 0.0 valid 2.3066580295562744
LOSS train 0.0 valid 2.304849147796631
LOSS train 0.0 valid 2.302363634109497
LOSS train 0.0 valid 2.3022775650024414
LOSS train 0.0 valid 2.3026559352874756
LOSS train 0.0 valid 2.306633472442627
LOSS train 0.0 valid 2.3082945346832275
LOSS train 0.0 valid 2.3060450553894043
LOSS train 0.0 valid 2.3073275089263916
LOSS train 0.0 valid 2.3059489727020264
LOSS train 0.0 valid 2.3067970275878906
LOSS train 0.0 valid 2.3073151111602783
LOSS train 0.0 valid 2.306293249130249
LOSS train 0.0 valid 2.3056674003601074
LOSS train 0.0 valid 2.306182861328125
LOSS train 0.0 valid 2.3055474758148193
LO