In [1]:
import time

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

from torch.optim import lr_scheduler
from torchvision import transforms, models
from torch.utils.data import DataLoader
from src.products_dataset import ProductsDataset
from src.transforms import Resize, CenterCrop, Normalize, ToTensor
from src.multitask_model import MultitaskModel

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
dataset = ProductsDataset(xlsx_filepath='./data/products.xlsx',
                          root_dir='./data/images',
                          transform=transforms.Compose([
                              Resize(224),
                              CenterCrop(224),
                              ToTensor(),
                              Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
                          ]))

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
model = MultitaskModel()

In [4]:
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [5]:
dataset_size = len(dataset)

def train_model(model, criterion, optimizer, num_epochs=10):
    tic = time.time()

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Set model to training mode
        model.train()

        running_loss = 0.0
        cat_accuracy = 0
        cond_accuracy = 0

        for batch in dataloader:            
            inputs = batch['image'].to(device)
            gt_categories = batch['category'].to(device)
            gt_conditions = batch['condition'].to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                out_categories, out_conditions = model(inputs)

                _, cat_predictions = torch.max(out_categories, 1)
                _, cond_predictions = torch.max(out_conditions, 1)

                loss_category = criterion(out_categories, gt_categories)
                loss_condition = criterion(out_conditions, gt_conditions)
                loss = loss_category + loss_condition

                loss.backward()
                optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            cat_accuracy += torch.sum(cat_predictions == gt_categories.data)
            cond_accuracy += torch.sum(cond_predictions == gt_conditions.data)

        epoch_loss = running_loss / dataset_size
        epoch_cat_accuracy = cat_accuracy.double() / dataset_size
        epoch_cond_accuracy = cond_accuracy.double() / dataset_size

        print('Loss: {:.4f} Categories acc: {:.4f}, conditions acc: {:.4f}'.format(
            epoch_loss, epoch_cat_accuracy, epoch_cond_accuracy))

    time_elapsed = time.time() - tic
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    return model

In [6]:
train_model(model, criterion, optimizer)

Epoch 0/9
----------
Loss: 2.7703 Categories acc: 0.5140, conditions acc: 0.4436
Epoch 1/9
----------
Loss: 2.1830 Categories acc: 0.6533, conditions acc: 0.5271
Epoch 2/9
----------
Loss: 1.8472 Categories acc: 0.7216, conditions acc: 0.5875
Epoch 3/9
----------
Loss: 1.5005 Categories acc: 0.7919, conditions acc: 0.6715
Epoch 4/9
----------
Loss: 1.1045 Categories acc: 0.8559, conditions acc: 0.7689
Epoch 5/9
----------
Loss: 0.6948 Categories acc: 0.9153, conditions acc: 0.8825
Epoch 6/9
----------
Loss: 0.4165 Categories acc: 0.9551, conditions acc: 0.9419
Epoch 7/9
----------
Loss: 0.2669 Categories acc: 0.9749, conditions acc: 0.9662
Epoch 8/9
----------
Loss: 0.1716 Categories acc: 0.9862, conditions acc: 0.9826
Epoch 9/9
----------
Loss: 0.1207 Categories acc: 0.9912, conditions acc: 0.9890
Training complete in 4m 5s


MultitaskModel(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stat