In [9]:
import torch
import torch_directml
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import os

In [10]:
data_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# To be Improved

In [11]:
data_dir = 'data/animals10/raw-img'

image_datasets = datasets.ImageFolder(root=data_dir, transform=data_transforms)

In [12]:
#dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=4, shuffle=True, num_workers=8) 
dataset_sizes = len(image_datasets)
print(dataset_sizes)

class_names = image_datasets.classes
print(class_names)

26179
['cane', 'cavallo', 'elefante', 'farfalla', 'gallina', 'gatto', 'mucca', 'pecora', 'ragno', 'scoiattolo']


In [13]:
train_size = int(dataset_sizes*0.75)

train_set, val_set = torch.utils.data.random_split(image_datasets, (train_size, dataset_sizes-train_size))

In [14]:
train_set_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=8)
val_set_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, num_workers=8)

In [15]:
def set_device(dev: str):
    """
    Select the Device based on input
    Default:
        DML: DirectML
    """
    if torch_directml.device_count() > 0 and dev=="dml":
        return torch_directml.device(torch_directml.default_device())
    else:
        return torch.device(dev)

In [16]:
def evaluate_model_test_set(model, val_loader):
    """
    Evaluate Model on Test Set
    Args:
        model: Model to be evaluated
        val_loader: DataLoader for Val Set
    Returns:
        epoch_acc: Accuracy of Model on Test Set
    """
    model.eval()
    predicted_correctly = 0
    total = 0
    device = set_device("dml")
    
    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)
            
            outputs = model(images)
            _, preds = torch.max(outputs.data, 1)
            
            predicted_correctly += torch.sum(preds==labels).sum()
        
    epoch_acc = 100.00 * predicted_correctly/total
    print(" -- Validating Dataset -- Got %d out of %d images correctly. (%.3f%%)" % (predicted_correctly, total, epoch_acc))
    return epoch_acc

In [17]:
def save_checkpoint(model, epoch, optimizer, best_acc):
    """
    Function to Save Model Checkpoint
    Args:
        model: Model to be saved
        epoch: Epoch Number
        optimizer: Optimizer State
        best_acc: Best Accuracy
    """
    state = {
        'epoch': epoch + 1,
        'model': model,
        'best_accuracy': best_acc,
        'optimizer': optimizer.state_dict()
    }
    torch.save(state, 'best_checkpoint.pth.tar')

In [18]:
def train_nn(model, train_loader, val_loader, criterion, optimizer, n_epochs):
    """
    Model Training Function
    Args:
        model: Model to be trained
        train_loader: DataLoader for Training Set
        val_loader: DataLoader for Validation Set
        criterion: Loss Function
        optimizer: Optimizer
        n_epochs: Number of Epochs
    Returns:
        model: Trained Model
    """
    device = set_device("dml")
    best_acc = 0

    for epoch in range(n_epochs):
        print(f"Epoch {epoch}/{n_epochs}")
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total = 0

        for data in train_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)

            optimizer.zero_grad()

            outputs = model(images)
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)

            loss.backward() # Backpropagation

            optimizer.step()

            running_loss += loss.item()
            running_corrects += torch.sum(labels==preds)

        epoch_loss = running_loss/len(train_loader)
        epoch_acc = 100.00 * running_corrects/total

        print(" -- Training Dataset -- Got %d out of %d images correctly. (%.3f%%). Epoch Loss: %.3f" % (running_corrects, total, epoch_acc, epoch_loss))

        test_data_acc = evaluate_model_test_set(model, val_loader)

        if test_data_acc > best_acc:
            best_acc = test_data_acc
            save_checkpoint(model, epoch, optimizer, best_acc)

    print("Finished")
    return model

In [19]:
model = models.resnet18(pretrained=True)

num_features = model.fc.in_features
num_classes = len(class_names)
model.fc = nn.Linear(num_features, num_classes)

device = set_device("dml")
model = model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)
n_epochs = 10



In [20]:
train_nn(model, train_set_loader, val_set_loader, loss_fn, optimizer, n_epochs)

Epoch 0/10


  device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)


In [10]:
chk = torch.load('best_checkpoint.pth.tar')

In [13]:
chk['model'].state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-4.1524e-03,  1.5612e-03, -5.3124e-03,  ...,  1.8387e-02,
                          2.9245e-05, -1.3898e-02],
                        [ 1.6930e-02,  2.7938e-02, -6.5564e-02,  ..., -1.9489e-01,
                         -8.3636e-02,  1.9225e-02],
                        [-1.0117e-03,  5.5736e-02,  2.3927e-01,  ...,  4.1408e-01,
                          2.1463e-01,  6.8278e-02],
                        ...,
                        [-4.0008e-02, -7.9718e-03,  2.1279e-02,  ..., -2.9247e-01,
                         -3.4861e-01, -2.2325e-01],
                        [ 1.2861e-02,  2.7938e-02,  4.9594e-02,  ...,  3.1317e-01,
                          2.9616e-01,  1.1854e-01],
                        [-1.1667e-02,  1.9270e-03, -1.1782e-02,  ..., -1.0650e-01,
                         -5.8413e-02, -7.7553e-03]],
              
                       [[-5.3923e-04, -7.8047e-03, -2.2752e-02,  ...,  8.0206e-03,
                         -8.3953

In [14]:
saved_model = models.resnet18()
num_features = saved_model.fc.in_features
num_classes = len(class_names)
saved_model.fc = nn.Linear(num_features, num_classes)
saved_model.load_state_dict(chk['model'].state_dict())

torch.save(saved_model, 'best_model.pth')


In [16]:
chk['epoch']

5

In [19]:
from PIL import Image

In [15]:
class_names

['cane',
 'cavallo',
 'elefante',
 'farfalla',
 'gallina',
 'gatto',
 'mucca',
 'pecora',
 'ragno',
 'scoiattolo']

In [16]:
model = torch.load('best_model.pth')

In [20]:
def classify(model, img_tranforms, img_path, classes):
    model = model.eval()

    image = Image.open(img_path)
    image = img_tranforms(image)
    image = image.unsqueeze(0)

    output = model(image)
    _, pred = torch.max(output.data, 1)

    print(pred)
    print(classes[pred])

In [25]:
classify(model, data_transforms, 'dog.jpg', class_names)

tensor([0])
cane
