In [None]:
from model import CNN
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

CHANNELS_D = 3

img_size = 256
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "model.pth"
data_dir = "~/Documents/datasets/archive/caltech101_classification/"

classes = ["Motorcycle", "Airplane", "Schooner"]

In [None]:
def norm_transforms(dir):
    transform = transforms.Compose(
        [
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
        ]
    )
    dataset = datasets.ImageFolder(root=dir, transform=transform)
    # concate image data (CxWxH) in tensor, discard labels
    imgs = torch.stack([img_t for img_t, _ in dataset], dim=3)
    # flatten the three channels of all images and take the mean
    mean = np.array([m for m in imgs.view(3, -1).mean(dim=1)])
    std = np.array([s for s in imgs.view(3, -1).std(dim=1)])

    norm = transforms.Normalize(
        mean = mean,
        std = std
    )
    unorm = transforms.Normalize(
        mean = -(mean/std),
        std = (1 / std)
    )
    
    return norm, unorm

norm, unorm = norm_transforms(data_dir)

In [None]:
def load_data(dir):
    transform = transforms.Compose(
        [
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            norm,
        ]
    )
    dataset = datasets.ImageFolder(
        root=dir,
        transform=transform)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    return torch.utils.data.random_split(dataset, [train_size, test_size])
    
train_data, test_data = load_data("~/Documents/datasets/archive/caltech101_classification/")

In [None]:
class CNN(nn.Module):
    def __init__(self, img_d=400, features_d=6):
        super().__init__()
        self.img_d = img_d
        self.features_d = features_d
        self.seq = nn.Sequential(
            nn.Conv2d(3, self.features_d, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(self.features_d, self.features_d * 3, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(1),
            # (conv -> maxpool)(img_d) => (img_d/2)
            # (conv -> maxpool)(img_d/2) => (img_d/4)
            # (img_d/4)**2 * 3 * features_d
            self._block(((img_d // 4) ** 2) * 3 * features_d, 128),
            self._block(128, 64),
            self._block(64, 3),
        )

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        return self.seq(x)

In [None]:
def train(model, train_data, batch, epochs):
    CHANNELS_D = 3
    LEARNING_RATE = 1e-4

    loader = DataLoader(train_data, batch_size=batch, shuffle=True, num_workers=2)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

    for epoch in range(epochs):
        running_loss = 0.0
        for i, (imgs, labels) in enumerate(loader):
            inputs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f'epoch: {epoch + 1} loss: {running_loss / batch:.3f}')
        running_loss = 0.0
    
    torch.save(model.state_dict(), "model.pth")
    return model

In [None]:
model = CNN(img_d=img_size, features_d=6).to(device)
model = train(model, train_data, 64, 50)

In [None]:
def eval(model, test_data):
    with torch.no_grad():
        figure = plt.figure(figsize=(10, 8))
        cols, rows = 4, 4
        correct = 0

        for i in range(1, cols * rows + 1):
            rand = torch.randint(len(test_data), size=(1,)).item()
            img, label = test_data[rand]
            img_input = img.to(device).unsqueeze(0)

            figure.add_subplot(rows, cols, i)
            img = unorm(img)
            npimg = np.transpose(img.numpy(), (1, 2, 0))

            plt.title(f"({i}) {classes[label]}")
            plt.axis("off",)
            plt.imshow(npimg)
            pred = model(img_input).to('cpu')

            _, pred = torch.max(pred.squeeze(), 0)
    
            print(f"({i}) Prediction: {classes[pred.item()]}, Actual: {classes[label]}")
            correct += classes[pred.item()] == classes[label]
        
        print(f"\n {correct} / {cols * rows} correct -> {correct / (cols * rows)} %")
        plt.show()
        
eval(model, test_data)