In [2]:
import torch
import torch_directml
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import os

In [3]:
data_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# To be Improved

In [4]:
data_dir = 'data/animals10/raw-img'

image_datasets = datasets.ImageFolder(root=data_dir, transform=data_transforms)

In [8]:
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=4, shuffle=True, num_workers=8) 
dataset_sizes = len(image_datasets)
print(dataset_sizes)

class_names = image_datasets.classes
print(class_names)

26179
['cane', 'cavallo', 'elefante', 'farfalla', 'gallina', 'gatto', 'mucca', 'pecora', 'ragno', 'scoiattolo']


In [9]:
train_size = int(dataset_sizes*0.75)

train_set, val_set = torch.utils.data.random_split(image_datasets, (train_size, dataset_sizes-train_size))

In [37]:
train_set_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=8)
val_set_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, num_workers=8)

In [38]:
def set_device(dev: str):
    """
    Select the Device based on input
    Default:
        DML: DirectML
    """
    if torch_directml.device_count() > 0 and dev=="dml":
        return torch_directml.device(torch_directml.default_device())
    else:
        torch.device(dev)

In [39]:
def evaluate_model_test_set(model, val_loader):
    """
    Evaluate Model on Test Set
    Args:
        model: Model to be evaluated
        val_loader: DataLoader for Val Set
    """
    model.eval()
    predicted_correctly = 0
    total = 0
    device = set_device("dml")
    
    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)
            
            outputs = model(images)
            _, preds = torch.max(outputs.data, 1)
            
            predicted_correctly += torch.sum(preds==labels).sum()
        
    epoch_acc = 100.00 * predicted_correctly/total
    print(" -- Validating Dataset -- Got %d out of %d images correctly. (%.3f%%)" % (predicted_correctly, total, epoch_acc))

In [40]:
def train_nn(model, train_loader, val_loader, criterion, optimizer, n_epochs):
    """
    Model Training Function
    Args:
        model: Model to be trained
        train_loader: DataLoader for Training Set
        val_loader: DataLoader for Validation Set
        criterion: Loss Function
        optimizer: Optimizer
        n_epochs: Number of Epochs
    Returns:
        model: Trained Model
    """
    device = set_device("dml")

    for epoch in range(n_epochs):
        print(f"Epoch {epoch}/{n_epochs}")
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total = 0

        for data in train_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)

            optimizer.zero_grad()

            outputs = model(images)
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)

            loss.backward() # Backpropagation

            optimizer.step()

            running_loss += loss.item()
            running_corrects += torch.sum(labels==preds)

        epoch_loss = running_loss/len(train_loader)
        epoch_acc = 100.00 * running_corrects/total

        print(" -- Training Dataset -- Got %d out of %d images correctly. (%.3f%%). Epoch Loss: %.3f" % (running_corrects, total, epoch_acc, epoch_loss))

        evaluate_model_test_set(model, val_loader)
    print("Finished")
    return model

In [41]:
model = models.resnet18(pretrained=True)

num_features = model.fc.in_features
num_classes = len(class_names)
model.fc = nn.Linear(num_features, num_classes)

device = set_device("dml")
model = model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)
n_epochs = 10

In [42]:
train_nn(model, train_set_loader, val_set_loader, loss_fn, optimizer, n_epochs)

Epoch 0/10
 -- Training Dataset -- Got 14638 out of 19634 images correctly. (74.554%). Epoch Loss: 0.800
Epoch 1/10
 -- Training Dataset -- Got 16169 out of 19634 images correctly. (82.352%). Epoch Loss: 0.541
Epoch 2/10
 -- Training Dataset -- Got 16530 out of 19634 images correctly. (84.191%). Epoch Loss: 0.477
Epoch 3/10
 -- Training Dataset -- Got 16608 out of 19634 images correctly. (84.588%). Epoch Loss: 0.473
Epoch 4/10
 -- Training Dataset -- Got 16790 out of 19634 images correctly. (85.515%). Epoch Loss: 0.448
Epoch 5/10
 -- Training Dataset -- Got 16733 out of 19634 images correctly. (85.225%). Epoch Loss: 0.440
Epoch 6/10
 -- Training Dataset -- Got 16746 out of 19634 images correctly. (85.291%). Epoch Loss: 0.444
Epoch 7/10
 -- Training Dataset -- Got 16812 out of 19634 images correctly. (85.627%). Epoch Loss: 0.438
Epoch 8/10
 -- Training Dataset -- Got 16786 out of 19634 images correctly. (85.495%). Epoch Loss: 0.443
Epoch 9/10
 -- Training Dataset -- Got 16833 out of 196

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  