In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F

import torchvision
from torchvision import transforms

from torchsummary import summary

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [3]:
def check_images(PATH):
    try:
        image_path = Image.open(PATH)
        return True
    except:
        return False

In [4]:
image_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [8]:
train_data_path = './train/'
val_data_path = './val/'
test_data_path = './test/'

batch_size = 64

In [9]:
train_data = torchvision.datasets.ImageFolder(
    root = train_data_path,
    transform = image_transforms,
    is_valid_file = check_images
)

test_data = torchvision.datasets.ImageFolder(
    root = val_data_path,
    transform = image_transforms,
    is_valid_file = check_images
)

val_data = torchvision.datasets.ImageFolder(
    root = test_data_path,
    transform = image_transforms,
    is_valid_file = check_images
)

In [10]:
train_data_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size = batch_size
)

val_data_loader = torch.utils.data.DataLoader(
    val_data,
    batch_size = batch_size
)

test_data_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size = batch_size
)

In [11]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(1000, 84)
        self.fc2 = nn.Linear(84, 64)
        self.fc3 = nn.Linear(64, 2)
        
    def forward(self, x):
        x = x.view(-1, 1000)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [16]:
simplenet_model = SimpleNet()
summary(simplenet_model, (1000, 84))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                   [-1, 84]          84,084
            Linear-2                   [-1, 64]           5,440
            Linear-3                    [-1, 2]             130
Total params: 89,654
Trainable params: 89,654
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.32
Forward/backward pass size (MB): 0.00
Params size (MB): 0.34
Estimated Total Size (MB): 0.66
----------------------------------------------------------------
