In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import torch
from torchvision import datasets, models, transforms
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import torchvision

from PIL import Image
from tqdm import tqdm

In [2]:
path ='data/' #set here data directory

In [14]:
Image.open('data/train/alien/2.jpg').size[0]

295

In [4]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

data_transforms = {
    'train':
    transforms.Compose([
        transforms.Resize((224,224)), 
        transforms.RandomAffine(0, shear=10, scale=(0.8,1.2)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ]),
    'validation':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        normalize
    ]),
}

image_datasets = {
    'train': 
    datasets.ImageFolder(path + 'train', data_transforms['train']),
    'validation': 
    datasets.ImageFolder(path + 'validation', data_transforms['validation'])
}

dataloaders = {
    'train':
    torch.utils.data.DataLoader(image_datasets['train'],
                                batch_size=32,
                                shuffle=True), 
    'validation':
    torch.utils.data.DataLoader(image_datasets['validation'],
                                batch_size=32,
                                shuffle=False)
}

In [5]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [6]:
model = models.resnet18(weights=True).to(device)
    
for param in model.parameters():
    param.requires_grad = False   
    
model.fc = nn.Sequential(
               nn.Linear(512, 128),
               nn.ReLU(inplace=True),
               nn.Linear(128, 2)).to(device)



In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters())

## Train

In [8]:
def train_model(model, criterion, optimizer, num_epochs=3):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(image_datasets[phase])
            epoch_acc = running_corrects.float() / len(image_datasets[phase])

            print('{} loss: {:.4f}, acc: {:.4f}'.format(phase,
                                                        epoch_loss,
                                                        epoch_acc))
    return model

In [9]:
model_trained = train_model(model, criterion, optimizer, num_epochs=30)

Epoch 1/30
----------
train loss: 0.5272, acc: 0.7161
validation loss: 0.6043, acc: 0.6350
Epoch 2/30
----------
train loss: 0.3031, acc: 0.8905
validation loss: 0.2886, acc: 0.8850
Epoch 3/30
----------
train loss: 0.2623, acc: 0.8977
validation loss: 0.2583, acc: 0.9000
Epoch 4/30
----------
train loss: 0.2310, acc: 0.9121
validation loss: 0.2464, acc: 0.8950
Epoch 5/30
----------
train loss: 0.2136, acc: 0.9049
validation loss: 0.2449, acc: 0.9050
Epoch 6/30
----------
train loss: 0.1977, acc: 0.9135
validation loss: 0.2253, acc: 0.9100
Epoch 7/30
----------
train loss: 0.1721, acc: 0.9323
validation loss: 0.2134, acc: 0.9200
Epoch 8/30
----------
train loss: 0.1796, acc: 0.9236
validation loss: 0.2177, acc: 0.9050
Epoch 9/30
----------
train loss: 0.1588, acc: 0.9308
validation loss: 0.2033, acc: 0.8950
Epoch 10/30
----------
train loss: 0.1806, acc: 0.9308
validation loss: 0.2227, acc: 0.9000
Epoch 11/30
----------
train loss: 0.1530, acc: 0.9395
validation loss: 0.2190, acc: 0.89