In [1]:
import import_ipynb
from CustomDataset import ControlsDataset

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

importing Jupyter notebook from CustomDataset.ipynb


In [2]:
dataset = ControlsDataset()
dataloader = DataLoader(dataset, batch_size = 4, shuffle = True, num_workers = 0)

In [3]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        img_size = torch.Size([1, 3,480, 640])
        empty = torch.zeros(img_size)
        
        self.conv = nn.Sequential(nn.Conv2d(3, 16, 11, stride=3),
                                  nn.MaxPool2d(2),
                                  nn.ReLU(),
                                  nn.Conv2d(16, 32, 11, stride=3),
                                  nn.MaxPool2d(2),
                                  nn.ReLU(),
                                  nn.Conv2d(32, 64, 7, stride=3),
                                  nn.ReLU()
                                 )
        
        units = self.conv(empty).numel()
        print("units after conv", units)
        
        self.fc = nn.Sequential(nn.Linear(units, units//2),
                                nn.ReLU(),
                                nn.Linear(units//2, units//4),
                                nn.ReLU(),
                                nn.Linear(units//4, 1))
        
        print("conv parameters: ", sum(p.numel() for p in self.conv.parameters()))
        print("fc parameters: ", sum(p.numel() for p in self.fc.parameters()))
    
    def forward(self, x):
        #x: batch, channel, height, width
        batch_size = x.shape[0]
        
        out = self.conv(x)
        out = out.reshape((batch_size,-1))
        out = self.fc(out)
        return out
net = ConvNet()

units after conv 512
conv parameters:  168224
fc parameters:  164353


In [4]:
for i, batch in enumerate(dataloader):
    if i > 0:
        break
        
    imgs = batch['image'].float()
    print("input", imgs.shape)
    out = net(imgs)
    print("output", out.shape)

input torch.Size([4, 3, 480, 640])
output torch.Size([4, 1])
