In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import glob
import torch
from torchvision import datasets, transforms
import utils

# load images

In [None]:
# define the tranforms

train_transform = transforms.Compose([
    transforms.Resize(64),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

valid_test_transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

In [None]:
# Load the data from the specified dirs

train_data = datasets.ImageFolder("../data/train", transform=train_transform)
valid_data = datasets.ImageFolder("../data/valid", transform=valid_test_transform)
test_data = datasets.ImageFolder("../data/test", transform=valid_test_transform)

In [None]:
# specify some parameters
# num_workers = 0
batch_size = 10

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=batch_size,
    # num_workers=num_workers,
    shuffle=True
)
valid_loader = torch.utils.data.DataLoader(
    train_data, 
    batch_size=5, #because there's only 5 images in the validation dir
    # num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
    test_data, 
    batch_size=5, #because there's only 5 images in the test dir
    # num_workers=num_workers,
)


In [None]:
classes = utils.listdir_ignore_hidden("../data/train")
classes = sorted([c for c in classes])

In [None]:
utils.show_batch(train_loader, batch_size, classes)

In [None]:
dataiter = iter(train_loader)
images, labels = next(dataiter)
images[0].shape

In [None]:
# Calculate dims of first fc layer

image_height_width = images[0].shape[-1]
n_pool_layers = 3
pool_size = 2
n_filters_final_layer = 64

fc_layer_units = (
    n_filters_final_layer 
    * (image_height_width / (pool_size ** n_pool_layers))
    * (image_height_width / (pool_size ** n_pool_layers))
)

print(fc_layer_units)

# Define the network

In [None]:
import torch.nn as nn
import torch.nn.functional as F

## TODO: Define the NN architecture
class cnn(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=16,
            kernel_size=3,
            padding=1,
        )
        self.conv2 = nn.Conv2d(
            in_channels=16,
            out_channels=32,
            kernel_size=5,
            padding=2,
        )
        self.conv3 = nn.Conv2d(
            in_channels=32,
            out_channels=64,
            kernel_size=7,
            padding=3,
        )
        
        self.maxpool = nn.MaxPool2d(2,2)
        
        self.fc1 = nn.Linear(64 * 8 * 8, 500)
        self.fc2 = nn.Linear(500, 10)
        
        self.dropout = nn.Dropout(0.25)
            
    def forward(self, x):
        
        x = self.maxpool(F.relu(self.conv1(x)))
        x = self.maxpool(F.relu(self.conv2(x)))
        x = self.maxpool(F.relu(self.conv3(x)))
        
        x = x.view(-1, 64 * 8 * 8)
        
        x = self.dropout(F.relu(self.fc1(x)))
        x = F.log_softmax(self.fc2(x), dim=1)
        
        return x

# initialize the NN
model = cnn()
print(model)