# Base Image Classifier

> This notebook will attempt to classify the pneumonia images using a basic CNN, implemented using PyTorch.

In [None]:
import torch
from torchvision import transforms
import os

import numpy as np
import matplotlib.pyplot as plt

import XRay_utils

> Define some global variables.

In [None]:
DATA_DIR = '../Data/'

# Directories for data
TRAIN_DATA = DATA_DIR + 'train/'
TEST_DATA = DATA_DIR + 'test/'
VAL_DATA = DATA_DIR + 'val/'

RESCALE_FACTOR = 0.1
MAX_IMAGE_SIZE = tuple((np.array((3000,2800)) * RESCALE_FACTOR).astype(int))

> Load in the data.

In [None]:
transf = transforms.Compose([XRay_utils.Rescale(RESCALE_FACTOR), XRay_utils.Pad(MAX_IMAGE_SIZE, fill=0)])

train_dataset = XRay_utils.XRayDataset(TRAIN_DATA, transform=transf)
test_dataset = XRay_utils.XRayDataset(TEST_DATA)
val_dataset = XRay_utils.XRayDataset(VAL_DATA)

print('Training Samples: {}'.format(len(train_dataset)))
print('Testing Samples: {}'.format(len(test_dataset)))
print('Valiidation Samples: {}'.format(len(val_dataset)))

> View some of the data

In [None]:
fig = plt.figure()

for i in range(len(train_dataset)):
    
    sample = train_dataset[i]
    print(i, np.asarray(sample['image']).shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('{} #{}'.format(sample['class'], i))
    ax.axis('off')
    plt.imshow(sample['image'], cmap='gray')

    if i == 3:
        plt.show()
        break