In [10]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import time

In [2]:
!unzip -q dataset.zip

In [3]:
dataset = datasets.ImageFolder(
    'dataset',
    transforms.Compose([
        transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
)

In [4]:
ptlo = 10 # percent to leave out
#train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset)*(1-ptlo/100), len(dataset)*ptlo/100]) 
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 80, 80])

In [5]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

In [6]:
model = models.alexnet(pretrained=True)

In [7]:
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)

In [8]:
device = torch.device('cuda')
model = model.to(device)

In [11]:
NUM_EPOCHS = 30
BEST_MODEL_PATH = 'best_model.pth'
best_accuracy = 0.0

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(NUM_EPOCHS):
    init_time = time.time()
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
    
    test_error_count = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        test_error_count += float(torch.sum(torch.abs(labels - outputs.argmax(1))))
    
    test_accuracy = 1.0 - float(test_error_count) / float(len(test_dataset))
    print('%d: %f' % (epoch, test_accuracy))
    if test_accuracy > best_accuracy:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_accuracy = test_accuracy
    print('Fit Time: ', time.time()-init_time)

0: 0.875000
Fit Time:  168.70462822914124
1: 0.925000
Fit Time:  63.92973041534424
2: 0.837500
Fit Time:  30.449626445770264
3: 0.962500
Fit Time:  65.13450932502747
4: 0.925000
Fit Time:  30.40734553337097
5: 0.962500
Fit Time:  27.121387481689453
6: 0.950000
Fit Time:  27.07621479034424
7: 0.925000
Fit Time:  27.085599660873413
8: 0.950000
Fit Time:  27.0825834274292
9: 0.950000
Fit Time:  27.034441709518433
10: 0.962500
Fit Time:  27.09162712097168
11: 0.962500
Fit Time:  27.042799949645996
12: 0.975000
Fit Time:  66.8077118396759
13: 0.925000
Fit Time:  30.105292558670044
14: 0.975000
Fit Time:  27.35999870300293
15: 0.975000
Fit Time:  27.065980911254883
16: 0.975000
Fit Time:  27.089840412139893
17: 0.975000
Fit Time:  27.03125023841858
18: 0.962500
Fit Time:  27.092516899108887
19: 0.950000
Fit Time:  27.043670892715454
20: 0.950000
Fit Time:  27.13512372970581
21: 0.987500
Fit Time:  49.14266061782837
22: 0.962500
Fit Time:  30.508421659469604
23: 0.950000
Fit Time:  27.1362655