In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import mylib.data_loaders as data_loaders
import mylib.data_transformers as data_transformers
import mylib.models_repo as models_repo
import mylib.optimizer_repo as optimizer_repo
import mylib.scheduler_repo as scheduler_repo
import mylib.trainer as trainer

In [None]:
from torchvision.datasets.folder import ImageFolder

In [None]:
%matplotlib inline

In [None]:
data_path = "/home/as/datasets/fastai.dogscats"
num_classes = 2       # Cats & Dogs
img_size  = 224       # H and W are expected to be atleast 224 for PyTorch model zoo models
scale_img_size = 300  # During data augmentation, we first scale the image to this value, 
                      # then we take a Random Crop of size (img_size x img_size) from within that image
batch_size = 256      # Set as per your GPU RAM

Let's get the transformers

In [None]:
norm = data_transformers.pytorch_zoo_normaliser
trans = data_transformers.get_transformer(img_size, scale_img_size, norm, False)
trans_aug = data_transformers.get_transformer(img_size, scale_img_size, norm, True)
trans_valid = data_transformers.get_test_valid_transformer(img_size, scale_img_size, norm)

Let's create the datasets with the given transformers. Note that ImageFolder() is a utility class in torchvision which can read images which are segregated into class folders.

In [None]:
# Vanilla dataset
train_images = ImageFolder(f'{data_path}/train', transform=trans)

# Augmented dataset
train_images_aug = ImageFolder(f'{data_path}/train', transform=trans_aug)

# For valid and test datasets
valid_images = ImageFolder(f'{data_path}/valid', transform=trans_valid)

In [None]:
print('Number of train instances', len(train_images))
print('Number of valid instances', len(valid_images))

In [None]:
print('Classes', train_images.classes)
print('Class index', train_images.class_to_idx)

Let's create the loaders. We will iterate these during training. They will give us our batches.

In [None]:
train_loader     = data_loaders.get_data_loader(train_images, batch_size)
train_loader_aug = data_loaders.get_data_loader(train_images_aug, batch_size)
valid_loader     = data_loaders.get_data_loader(valid_images, batch_size)

### Let's try with a vanilla pretrained ResNet, with no augmentation

We just replace the last FC layer to account for the num_classes, that's all

In [None]:
if False:
    # Get the model
    model = models_repo.model_resnet_vanilla(num_classes)
    # Get the optimizer and loss function
    criteria, optimizer = optimizer_repo.sgd(model, 0.01, model.fc.parameters(), momentum=0.9, weight_decay=1e-4)
    # Get the scheduler
    scheduler = scheduler_repo.step_lr(optimizer, step_size=5, gamma=0.1)

    print('Layers in the model')
    for p in model.children():
        print(type(p))
    print('Training starts')

    # Train!
    best_model = trainer.train(model, criteria, optimizer, scheduler, train_loader, valid_loader, 5)

### Let's try with a vanilla pretrained ResNet, with no data augmentation

In [None]:
if False:
    # Get the model
    model = models_repo.model_resnet_vanilla(num_classes)
    # Get the optimizer and loss function
    criteria, optimizer = optimizer_repo.sgd(model, 0.01, model.fc.parameters(), momentum=0.9, weight_decay=1e-4)
    # Get the scheduler
    scheduler = scheduler_repo.step_lr(optimizer, step_size=5, gamma=0.1)

    print('Layers in the model')
    for p in model.children():
        print(type(p))
    print('Training starts')

    # Train!
    best_model = trainer.train(model, criteria, optimizer, scheduler, train_loader_aug, valid_loader, 5)

### Let's try with a pretrained ResNet with extra FC layers, with no data augmentation

In [None]:
# model, params_to_optimize = models_repo.resnet34_extra_layers(num_classes, top_layers_to_freeze=6, debug=True)

In [None]:
model, params_to_optimize = models_repo.resnet34_extra_layers(num_classes, top_layers_to_freeze=6,)
criteria, optimizer = optimizer_repo.sgd(model, 0.01, params_to_optimize=params_to_optimize)
scheduler = scheduler_repo.step_lr(optimizer)

best_model = trainer.train(model, criteria, optimizer, scheduler, train_loader_aug, valid_loader, 20)