In [None]:
# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from skimage import io, transform
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms, utils
import matplotlib.pyplot as plt
import time
import os
import copy
from torch.utils.data import Dataset, DataLoader
from skimage.color import rgba2rgb

plt.ion()   # interactive mode

In [None]:
dataset_root = "/Users/shawon/Codes/Python/V6"
dataset_path = os.path.join(dataset_root, "train")

# load all image names and their labels
dataset_info = [] # (label, file_name) tuple

class_labels = os.listdir(dataset_path)
class_labels.pop(class_labels.index('.DS_Store'))

for label in class_labels:
    p = os.path.join(dataset_path, label)
    images = os.listdir(p)

    for img in images:
        dataset_info.append((label, os.path.join(p, img)))

In [None]:

# a function to show an image

def show_image(image_path):
    plt.imshow(io.imread(image_path))

In [None]:

# create torch dataset
class FruitImageDataset(Dataset):
    def __init__(self, dataset_info, transform=None):
        self.dataset_info = dataset_info
        self.transform = transform

    def __len__(self):
        return len(self.dataset_info)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        label = self.dataset_info[idx][0]
        img_file_name = self.dataset_info[idx][1]

        image = io.imread(img_file_name)

        sample = { "label": label, "image": image }

        if self.transform:
            sample = self.transform(sample)

        return sample

In [None]:
# transorms

class Rescale(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample["image"], sample["label"]

        img = transform.resize(image, self.output_size)
        
        # for png images
        if img.shape[2] == 4:
            img = rgba2rgb(img)

        return { "label": label, "image": img }

In [None]:
class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        return {'image': image, 'label': label}


In [None]:
# class labels to ints
#target = torch.randint(0, 10, (10,))
#one_hot = torch.nn.functional.one_hot(target)

class_idx = dict()

i = 0
for label in class_labels:
    class_idx[label] = i
    i += 1

In [None]:
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        image = image.astype(np.float32)

        # convert label to tensor as well
        label = class_idx[label]

        return {'image': torch.from_numpy(image),
                'label': label}

In [None]:
dataset = FruitImageDataset(dataset_info=dataset_info, 
                                        transform=transforms.Compose([Rescale((256, 256)),
                                                                        RandomCrop((224, 224)), ToTensor()]))

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

In [None]:
model_ft = models.resnet18(pretrained=True)

for param in model_ft.parameters():
    param.requires_grad = False

num_ftrs = model_ft.fc.in_features
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, len(class_labels))

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_ft.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        model.train()

        running_loss = 0.0
        running_corrects = 0


        for batch in dataloader:
            inputs = batch["image"].to(device)
            labels = batch["label"]

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward                
            # # track history if only in train
            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # backward + optimize only if in training phase
                loss.backward()
                optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
        scheduler.step()

        epoch_loss = running_loss / 50
        epoch_acc = running_corrects.double() / 50

        print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))

    print()
    


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


In [None]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)