In [1]:
import import_ipynb
from CustomDataset import ControlsDataset

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tensorboardX import SummaryWriter

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

importing Jupyter notebook from CustomDataset.ipynb


In [2]:
if __name__ == "__main__":
    dataset = ControlsDataset()
    dataset.labels.set_type("categorical")
    dataset.images.set_grayscale(True)
    dataset.make_dataloaders()
    #dataloader = DataLoader(dataset, batch_size = 256, shuffle = True, num_workers = 0)

Total training stacks 35
Total validation stacks 9


In [None]:
class ConvNet(nn.Module):
    def __init__(self,color_channels,outputs,dataset):
        super(ConvNet, self).__init__()
        
        self.channels = color_channels
        self.report_period = 20
        self.writer = SummaryWriter()
        self.start_epoch = 0
        self.infotype = dataset.labels.infotype
        self.dataset = dataset
        self.dataloader = dataset.dataloader
        
        img_size = torch.Size([1, color_channels, 480, 640]) # [batch_size, channels, height, width]
        empty = torch.zeros(img_size)
        # Conv2d(in_channels, out_channels, kernelSize, strides)
        # strideon train masks your variable train, and=3 ==> moving Filter 3 pixels between the application of kernel size
        self.conv = nn.Sequential(nn.Conv2d(color_channels, 16, 11, stride=3),
                                  nn.MaxPool2d(2),
                                  nn.ReLU(),
                                  nn.Conv2d(16, 32, 11, stride=3),
                                  nn.MaxPool2d(2),
                                  nn.ReLU(),
                                  nn.Conv2d(32, 64, 7, stride=3),
                                  nn.ReLU())
        
        units = self.conv(empty).numel()
        print("units after conv", units)
        
        self.fc = nn.Sequential(nn.Linear(units, units//2),
                                nn.ReLU(),
                                nn.Linear(units//2, units//4),
                                nn.ReLU(),
                                nn.Linear(units//4, outputs)) # <-- Returning predictions over classes
        
        print("conv parameters: ", sum(p.numel() for p in self.conv.parameters()))
        print("fc parameters: ", sum(p.numel() for p in self.fc.parameters()))
    
    def forward(self, x):
        #x: batch, channel, height, width
        batch_size = x.shape[0]
        out = self.conv(x)
        out = out.reshape((batch_size,-1))
        out = self.fc(out)
        #print(out)
        return out
        
    def load(self,path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.start_epoch = checkpoint['epoch']
        
    def save(self,epoch,path):
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()}, 
            path)
        
    def fit(self, device, epochs, optimizer, criterion):
        self.train()
        self.optimizer = optimizer
        iter_no = 0
        for epoch in range(self.start_epoch, epochs):
            for i_batch, sampled_batch in enumerate(self.dataloader):
                #inputs and forward pass
                images = sampled_batch['image'].to(device).float()
                controls = sampled_batch['control'].to(device).long()
                controls = torch.flatten(controls)
                
                #backwards pass
                optimizer.zero_grad()
                prediction = self(images)
                print(torch.argmax(prediction,dim=1))
                #print
                print(controls)
                loss = criterion(prediction, controls)
                loss.backward()
                optimizer.step()

                iter_no += 1
                if iter_no % self.report_period == 0:
                    self.writer.add_scalar("Loss", loss.item(), iter_no)
                    print("saved to tensorboard")
                    self.save(epoch,"snapshots/{:.3f}_model.pt".format(loss.item()))

                out = "{0},{1}\tLoss:{2}\tAllocated:{3}GB\tCached:{4}GB\n"
                print(out.format(str(epoch),
                                str(iter_no),
                                round(loss.item(),5),
                                 #'na', 'na'
                                 round(torch.cuda.memory_allocated(0)/1024**3,3),
                                 round(torch.cuda.memory_allocated(0)/1024**3,3)
                                ))
                total,correct = self.score(device,self.dataset,True)
                #print(correct)
                #print(prediction)
                print("Accuracy: {}%".format(correct/total*100))
                
    def score(self,device,dataset, single_batch = False):
        self.eval()
        total = 0
        correct = 0
        for i_batch, sampled_batch in enumerate(dataset.validloader):
            images = sampled_batch['image'].to(device).float()
            controls = sampled_batch['control'].to(device).long()
            controls = torch.flatten(controls)
            prediction = self(images)

            maximum = torch.argmax(prediction,dim = 1)
            shared = maximum == controls
            shared = 1 * shared
            correct += int(torch.sum(shared))
            total += len(controls)
            if single_batch:
                return (total,correct)
        return (total,correct)
    
if __name__ == "__main__":
    if dataset.images.grayscale:
        channels = 1
    else:
        channels = 3

    net = ConvNet(channels, 21, dataset)

In [None]:
if __name__ == "__main__":
    for i, batch in enumerate(dataset.dataloader):
        if i > 0:
            break

        imgs = batch['image'].float()
        print("input", imgs.shape)
        out = net(imgs)
        print("output", out.shape)