In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time
from PIL import Image

import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms, models

from collections import OrderedDict

In [None]:
tr_batchsize = 32
val_test_batchsize = 16
epochs = 60
lr = 0.00005

In [None]:
# By defalt, set device to the CPU
deviceFlag = torch.device('cpu')

# Default is CPU, but as long as GPU is avaliable, then use GPU
if torch.cuda.is_available():
    print(f'Found {torch.cuda.device_count()} GPUs.')
    deviceFlag = torch.device('cuda:0') # Manually pick your cuda device. By default is 'cuda:0'

print(f'Now the deivce is set to {deviceFlag}')

# Data Loading and Transformations

In [None]:
training_transforms = transforms.Compose([
    # Randomly rotate it 90 degrees
    transforms.RandomRotation(90),
    # Randomly sharpen the image
    transforms.RandomAdjustSharpness(1.5, 0.5),
    # Randomly crop an area of the flower of size 224x224
    transforms.RandomResizedCrop(224),
    # Flip it horizontally, or don't
    transforms.RandomHorizontalFlip(),
    # Flip it vertically, or don't
    transforms.RandomVerticalFlip(),
    # Convert the image to a Tensor
    transforms.ToTensor(),
    # Normalize the Tensor values so that they're easier for the
    # model to train from
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

validation_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], # RGB mean & std estied on ImageNet
                         [0.229, 0.224, 0.225])
])

testing_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], # RGB mean & std estied on ImageNet
                         [0.229, 0.224, 0.225])
])

# Load the datasets of the Flower102 images
train_dataset = datasets.Flowers102(root = './dataset', split = 'train', transform = training_transforms, download = True)
valid_dataset = datasets.Flowers102(root = './dataset', split = 'val', transform = validation_transforms, download = True)
test_dataset = datasets.Flowers102(root = './dataset', split = 'test', transform = testing_transforms, download = True)


# Create the loaders for the datasets, to be used to train, validate and test the model
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = tr_batchsize,
                                           shuffle = True)

validate_loader = torch.utils.data.DataLoader(dataset = valid_dataset,
                                           batch_size = val_test_batchsize)


test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size = val_test_batchsize)

In [None]:
import model_flower

In [None]:
model = model_flower.FlowerModel()
model.to(deviceFlag)
model

In [None]:
# for params in model.parameters():
#     params.requries_grad = False

# Define Loss Function and Optimizer

In [None]:
# Negative Log Likelihood Loss
# criterion = nn.NLLLoss()

# Cross Entropy Loss
criterion = nn.CrossEntropyLoss()

# optimizer 1
optimizer = optim.Adam(model.parameters(), lr = lr)

# optimizer 2
# optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay = 0.005, momentum = 0.9)

In [None]:
import model_train
import model_test

In [None]:
model_train.train_classifier(model, train_loader, validate_loader, optimizer, criterion,
                             device_flag=deviceFlag, epochs=epochs,
                             validate_steps=100, validate_stepped=True, validate_epoch=False,
                             validate_end=True)

In [None]:
model_test.test_accuracy(model, test_loader, device_flag=deviceFlag)

In [None]:
import datetime

In [None]:
torch.save(model.state_dict(), "models/" + str(datetime.datetime.now()).replace(":","-")
           + f" b{tr_batchsize}-e{epochs}-lr{lr}" "-model.pt")

In [None]:
# Stop Run All here
assert False

In [None]:
# Reload imports in the case that they are changed
from importlib import reload

# If not loaded into cache yet, import them
import model_flower
import model_train
import model_test

reload(model_flower)
reload(model_train)
reload(model_test)

In [None]:
# total_step = len(train_loader)

# for epoch in range(epochs):
#     for i, (images, labels) in enumerate(train_loader):  
#         # Move tensors to the configured device
#         images = images.to(deviceFlag)
#         labels = labels.to(deviceFlag)
        
#         # Forward pass
#         outputs = model(images)
#         loss = criterion(outputs, labels)
        
#         # Backward and optimize
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#     print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
#                    .format(epoch+1, epochs, i+1, total_step, loss.item()))
            
#     # Validation
#     with torch.no_grad():
#         correct = 0
#         total = 0
#         for images, labels in validate_loader:
#             images = images.to(deviceFlag)
#             labels = labels.to(deviceFlag)
#             outputs = model(images)
#             _, predicted = torch.max(outputs.data, 1)
#             total += labels.size(0)
#             correct += (predicted == labels).sum().item()
#             del images, labels, outputs
    
#         print('Accuracy of the network on the {} validation images: {} %'.format(total, 100 * correct / total)) 


In [None]:
# with torch.no_grad():
#     correct = 0
#     total = 0
#     for images, labels in test_loader:
#         images = images.to(deviceFlag)
#         labels = labels.to(deviceFlag)
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
#         del images, labels, outputs

#     print('Accuracy of the network on the {} test images: {} %'.format(total, 100 * correct / total))   