In [2]:
import import_ipynb
from CustomDataset import ControlsDataset

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

In [3]:
dataset = ControlsDataset()
dataset.convertTOClass()
dataloader = DataLoader(dataset, batch_size = 256, shuffle = True, num_workers = 0)
print(dataset.data_frame.head(10))

   Number  Angle
0       0      9
1       1      9
2       2      9
3       3      9
4       4      9
5       5      9
6       6      9
7       7      9
8       8      9
9       9      9


In [4]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        img_size = torch.Size([1, 3, 480, 640]) # [batch_size, channels, height, width]
        empty = torch.zeros(img_size)
        
        # Conv2d(in_channels, out_channels, kernelSize, strides)
        # stride=3 ==> moving Filter 3 pixels between the application of kernel size
        self.conv = nn.Sequential(nn.Conv2d(3, 16, 11, stride=3),
                                  nn.MaxPool2d(2),
                                  nn.ReLU(),
                                  nn.Conv2d(16, 32, 11, stride=3),
                                  nn.MaxPool2d(2),
                                  nn.ReLU(),
                                  nn.Conv2d(32, 64, 7, stride=3),
                                  nn.ReLU()
                                 )
        
        units = self.conv(empty).numel()
        print("units after conv", units)
        
        self.fc = nn.Sequential(nn.Linear(units, units//2),
                                nn.ReLU(),
                                nn.Linear(units//2, units//4),
                                nn.ReLU(),
                                nn.Linear(units//4, 20)) # <-- Returning predictions over classes
        
        print("conv parameters: ", sum(p.numel() for p in self.conv.parameters()))
        print("fc parameters: ", sum(p.numel() for p in self.fc.parameters()))
    
    def forward(self, x):
        #x: batch, channel, height, width
        batch_size = x.shape[0]
        
        out = self.conv(x)
        out = out.reshape((batch_size,-1))
        out = self.fc(out)
        #print(out)
        return out
net = ConvNet()

units after conv 512
conv parameters:  168224
fc parameters:  166804


In [5]:
for i, batch in enumerate(dataloader):
    if i > 0:
        break
    
    imgs = batch['image'].float()
    print("input", imgs.shape)
    out = net(imgs)
    print("output", out.shape)

input torch.Size([256, 3, 480, 640])
tensor([[-3.5550e-01, -3.7449e-01,  5.0028e-02,  ..., -1.2383e-03,
         -6.0888e-01,  8.8046e-01],
        [-4.6440e-01, -2.6299e-01, -1.8469e-01,  ..., -1.0739e-01,
         -6.3015e-01,  1.2373e+00],
        [-5.9545e-01, -5.3449e-01, -4.1917e-01,  ..., -4.0859e-01,
         -9.5157e-01,  1.6170e+00],
        ...,
        [-4.9995e-01, -6.2073e-01, -5.0466e-01,  ..., -2.3809e-01,
         -1.2555e+00,  1.2835e+00],
        [-3.9473e-01, -2.6195e-01,  9.0336e-02,  ..., -3.1958e-01,
         -8.4353e-01,  1.1560e+00],
        [-3.7653e-01, -8.3684e-01, -1.4487e-01,  ..., -1.1076e-01,
         -1.1628e-01,  1.1601e+00]], grad_fn=<AddmmBackward>)
output torch.Size([256, 20])
