In [1]:
import import_ipynb
from CustomDataset import ControlsDataset

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tensorboardX import SummaryWriter

# Ignore warnings
import warnings
import time
warnings.filterwarnings("ignore")


from tqdm.notebook import tqdm_notebook

plt.ion()   # interactive mode

importing Jupyter notebook from CustomDataset.ipynb


In [3]:
if __name__ == "__main__":
    dataset = ControlsDataset()
    #dataloader = DataLoader(dataset, batch_size = 256, shuffle = True, num_workers = 0)

Total training stacks 138
Total validation stacks 35


In [6]:
class ConvNet(nn.Module):
    def __init__(self,outputs,image_shape):
        super(ConvNet, self).__init__()
        img_size = list(image_shape)
        img_size = torch.Size([1] + img_size)
        empty = torch.zeros(img_size)
        # Conv2d(in_channels, out_channels, kernelSize, strides)
        
        channels1,channels2,channels3 = 16,32,64
        kernel1, kernel2, kernel3 = 11, 5, 3
        padding1, padding2, padding3 = (kernel1-1)//2,(kernel2-1)//2,(kernel3-1)//2
        stride1, stride2, stride3 = 2,2,1
        
        self.conv1 = nn.Sequential(nn.Conv2d(image_shape[0], channels1, kernel1, stride1, padding1),
                                  nn.BatchNorm2d(channels1),
                                  nn.Tanh(),
                                  nn.MaxPool2d(2))
        
        self.conv2 = nn.Sequential(nn.Conv2d(channels1, channels2, kernel2, stride2, padding2),
                                  nn.BatchNorm2d(channels2),
                                  nn.Tanh(),
                                  nn.MaxPool2d(2))
        
        self.conv3 = nn.Sequential(nn.Conv2d(channels2, channels3, kernel3, stride3, padding3),
                                  nn.BatchNorm2d(channels3),
                                  nn.Tanh())
        
        out = self.conv1(empty)
        out = self.conv2(out)
        units = self.conv3(out).numel()
        
        print("units after conv", units)
        self.fc = nn.Sequential(nn.Linear(units, units//40),
                                nn.BatchNorm1d(units//40),
                                nn.Tanh(),
                                nn.Linear(units//40, outputs)) # <-- Returning predictions over classes
        
        print("conv parameters: ", sum(p.numel() for p in self.conv1.parameters())+
                                   sum(p.numel() for p in self.conv2.parameters())+
                                   sum(p.numel() for p in self.conv3.parameters()))
        print("fc parameters: ",sum(p.numel() for p in self.fc.parameters()))
    
    def forward(self, x):
        #x: batch, channel, height, width
        batch_size = x.shape[0]
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = out.reshape((batch_size,-1))
        out = self.fc(out)
        #print(out)
        return out
        
    def load_weights(self,path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.start_epoch = checkpoint['epoch']
        
    def save_weights(self,optimizer,epoch,path):
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()}, 
            path)
       
    def current_snapshot_name():
        from time import gmtime, strftime
        import socket

        hostname = socket.gethostname()

        date = strftime("%b%d_", gmtime())
        clock = strftime("%X", gmtime())
        now = clock.split(":")
        now = date+'-'.join(now)

        name = now+"_"+hostname
        return name
    
if __name__ == "__main__":
    net = ConvNet(3, 1, dataset)

units after conv 1008
conv parameters:  24896
fc parameters:  25301


In [7]:
if __name__ == "__main__":
    for i, batch in enumerate(dataset.dataloader):
        if i > 0:
            break

        imgs = batch['image'].float()
        print("input", imgs.shape)
        out = net(imgs)
        print("output", out.shape)

input torch.Size([64, 3, 480, 640])
output torch.Size([64, 1])
