In [1]:
import torch
import torch.nn as nn # Loss Functions and Neural Network Components
import torch.optim as optim # Optimizers
from torch.utils.data import DataLoader, random_split # Utilities for data loaders and batching
from torchvision import datasets, models, transforms # Image specific tasks

In [7]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [8]:
data_dir = './data'

full_dataset = datasets.ImageFolder(data_dir, transform=data_transforms)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

class_names = full_dataset.classes

print(f'Training set size: {len(train_dataset)} images')
print(f'Validation set size: {len(val_dataset)} images')

Training set size: 1086 images
Validation set size: 272 images


In [10]:
from torchvision.models import ResNet34_Weights

weights = ResNet34_Weights.IMAGENET1K_V1
model = models.resnet34(weights=weights)
num_ftrs = model.fc.in_features

model.fc = nn.Linear(num_ftrs, len(class_names))

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [12]:
dataloaders = {
    'train': train_loader,
    'val': val_loader
}

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

    print('Training complete')
    print(f'Best val Acc: {best_acc:.4f}')

    model.load_state_dict(best_model_wts)
    return model

model = train_model(model, criterion, optimizer, scheduler, num_epochs=25)

Epoch 0/24
----------
train Loss: 1.7604 Acc: 0.3720
val Loss: 1.1997 Acc: 0.6250
Epoch 1/24
----------
train Loss: 0.9451 Acc: 0.7468
val Loss: 0.8605 Acc: 0.7463
Epoch 2/24
----------
train Loss: 0.6052 Acc: 0.8435
val Loss: 0.7075 Acc: 0.7868
Epoch 3/24
----------
train Loss: 0.3767 Acc: 0.9254
val Loss: 0.6233 Acc: 0.8088
Epoch 4/24
----------
train Loss: 0.2562 Acc: 0.9448
val Loss: 0.5851 Acc: 0.8235
Epoch 5/24
----------
train Loss: 0.1610 Acc: 0.9797
val Loss: 0.5653 Acc: 0.8088
Epoch 6/24
----------
train Loss: 0.1192 Acc: 0.9816
val Loss: 0.5500 Acc: 0.8088
Epoch 7/24
----------
train Loss: 0.0871 Acc: 0.9890
val Loss: 0.5387 Acc: 0.8162
Epoch 8/24
----------
train Loss: 0.0756 Acc: 0.9963
val Loss: 0.5349 Acc: 0.8235
Epoch 9/24
----------
train Loss: 0.0788 Acc: 0.9954
val Loss: 0.5334 Acc: 0.8309
Epoch 10/24
----------
train Loss: 0.0825 Acc: 0.9917
val Loss: 0.5284 Acc: 0.8235
Epoch 11/24
----------
train Loss: 0.0829 Acc: 0.9908
val Loss: 0.5258 Acc: 0.8235
Epoch 12/24
--

In [13]:
torch.save(model.state_dict(), './model.pth')