In [1]:
import torch
import torchvision

In [2]:
n_classes = 4
batch_size = 4
n_epochs = 5
path_best_model = ".pth"

In [4]:
path_dataset = "" 
ratio_train_val = 2/3

dataset = torchvision.datasets.ImageFolder(
    path_dataset,
    torchvision.transforms.Compose([
        torchvision.transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

n_train = int(ratio_train_val * len(dataset))
n_valid = len(dataset) - n_train

train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [n_train, n_valid])

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=True,
)

In [None]:
model = torchvision.models.alexnet(pretrained=True)
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, n_classes)

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

best_accuracy = 0.0

for epoch in range(n_epochs):
    for i, (images, labels) in enumerate(train_loader) :
        optimizer.zero_grad()
        outputs = model(images)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        print('epoch %d batch %d/%d' % (epoch, (i + 1), len(train_loader)))
    test_error_count = 0.0
    for images, labels in valid_loader:
        outputs = model(images)
        test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))
    test_accuracy = 1.0 - float(test_error_count) / float(len(valid_dataset))
    print('test accuracy : %f' % test_accuracy)
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), path_best_model)
        best_accuracy = test_accuracy