In [2]:
import torch
import torchvision
from torch.utils.data import DataLoader

In [3]:
print(dir(torchvision))

['__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__', '_image_backend', 'datasets', 'get_image_backend', 'models', 'set_image_backend', 'transforms', 'utils']


In [4]:
print(dir(torchvision.datasets)) # torchvision.datasets includes famous CV datasets

['CIFAR10', 'CIFAR100', 'Cityscapes', 'CocoCaptions', 'CocoDetection', 'DatasetFolder', 'EMNIST', 'FakeData', 'FashionMNIST', 'Flickr30k', 'Flickr8k', 'ImageFolder', 'KMNIST', 'LSUN', 'LSUNClass', 'MNIST', 'Omniglot', 'PhotoTour', 'SBU', 'SEMEION', 'STL10', 'SVHN', 'VOCDetection', 'VOCSegmentation', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'cifar', 'cityscapes', 'coco', 'fakedata', 'flickr', 'folder', 'lsun', 'mnist', 'omniglot', 'phototour', 'sbu', 'semeion', 'stl10', 'svhn', 'utils', 'voc']


In [27]:
print(dir(torchvision.models))  # torchvision.models includes famous CV models

['AlexNet', 'DenseNet', 'Inception3', 'ResNet', 'SqueezeNet', 'VGG', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'alexnet', 'densenet', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'inception', 'inception_v3', 'resnet', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'squeezenet', 'squeezenet1_0', 'squeezenet1_1', 'vgg', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn']


In [29]:
print(dir(torchvision.transforms)) # torchvision.transforms includes common transformations on images

['CenterCrop', 'ColorJitter', 'Compose', 'FiveCrop', 'Grayscale', 'Lambda', 'LinearTransformation', 'Normalize', 'Pad', 'RandomAffine', 'RandomApply', 'RandomChoice', 'RandomCrop', 'RandomGrayscale', 'RandomHorizontalFlip', 'RandomOrder', 'RandomResizedCrop', 'RandomRotation', 'RandomSizedCrop', 'RandomVerticalFlip', 'Resize', 'Scale', 'TenCrop', 'ToPILImage', 'ToTensor', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'functional', 'transforms']


In [5]:
mnist_transformation = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,))  # normalize: each pixel=(pixel-0.5)/0.5
    ]
)

In [12]:
mnist_dataset = torchvision.datasets.mnist
training_set = mnist_dataset.MNIST('./torchvision_data', train=True, transform=mnist_transformation)  # if not downloaded, set download=True
test_set = mnist_dataset.MNIST('./torchvision_data', train=False, transform=mnist_transformation) # apply transform to each img in the dataset
print(training_set), print(test_set)

Dataset MNIST
    Number of datapoints: 60000
    Split: train
    Root Location: ./torchvision_data
    Transforms (if any): Compose(
                             ToTensor()
                             Normalize(mean=(0.5,), std=(0.5,))
                         )
    Target Transforms (if any): None
Dataset MNIST
    Number of datapoints: 10000
    Split: test
    Root Location: ./torchvision_data
    Transforms (if any): Compose(
                             ToTensor()
                             Normalize(mean=(0.5,), std=(0.5,))
                         )
    Target Transforms (if any): None


(None, None)

In [181]:
class PrepareMNISTData:
    batch = 32
    training_data = DataLoader(
        dataset=training_set,
        batch_size=PrepareMNISTData.batch,
        shuffle=True,
    )
    test_data = DataLoader(
        dataset=test_set,
        batch_size=PrepareMNISTData.batch,
        shuffle=True
    )
    @staticmethod
    def get_training_set():
        return iter(PrepareMNISTData.training_data)
    @staticmethod
    def get_test_set():
        return iter(PrepareMNISTData.test_data)
    

In [121]:
a = next(PrepareMNISTData.get_training_set())

In [122]:
print(a[0].shape, a[1].shape)

torch.Size([32, 1, 28, 28]) torch.Size([32])


In [263]:
class FourLayerLinearNeuralNetwork:
    def __init__(self, lr=5e-2, epoch=20):
        
        self.network = torch.nn.Sequential(
            torch.nn.Linear(784, 400),
            torch.nn.ReLU(),
            torch.nn.Linear(400, 200),
            torch.nn.ReLU(),
            torch.nn.Linear(200, 100),
            torch.nn.ReLU(),
            torch.nn.Linear(100, 10),
            torch.nn.ReLU(),
        )
        for layer in self.network:
            if isinstance(layer, torch.nn.Linear):
                torch.nn.init.xavier_uniform_(layer.weight)
        
        self.epoch = epoch
        
        self.optimizer = torch.optim.SGD(self.network.parameters(), lr = lr)
        self.loss_func = torch.nn.CrossEntropyLoss()
        
    def train(self, get_training_data_iter_function):
        self.network.train()
        for e in range(self.epoch):
            loss_in_epoch = 0
            batch_count = 0
            items_in_epoch = 0
            correct_items_in_epoch = 0
            
            for data_batch in get_training_data_iter_function():
                imgs, lbls = data_batch
                imgs = imgs.reshape(imgs.shape[0], -1)
                lbls = lbls.long() #torch.nn.functional.one_hot(lbls).float()
                
                out = self.network(Variable(imgs))
                loss = self.loss_func(out, Variable(lbls))
                
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                
                _,pred = out.max(1)
                correct_items_in_epoch+= (pred==lbls).sum().item()
                items_in_epoch += imgs.shape[0]
                
                loss_in_epoch += loss.data.item()
                batch_count += 1
                
            print("epoch{}/{}, training loss={}, accuracy={}".format(e, self.epoch, 
                                                                     loss_in_epoch/batch_count, 
                                                                     correct_items_in_epoch/items_in_epoch))
    
    def evaluate(self, get_test_data_iter_function):
        self.network.eval()
        loss_in_epoch = 0
        batch_count = 0
        items_in_epoch = 0
        correct_items_in_epoch = 0
        
        for data_batch in get_test_data_iter_function():
            imgs, lbls = data_batch
            imgs = imgs.reshape(imgs.shape[0], -1)
            lbls = lbls.long()
            
            out = self.network(Variable(imgs))
            loss = self.loss_func(out, Variable(lbls))
            
            _,pred = out.max(1)
            correct_items_in_epoch+= (pred==lbls).sum().item()
            items_in_epoch += imgs.shape[0]
            
            loss_in_epoch += loss.data.item()
            batch_count += 1
        
        print("eval loss={}, accuracy={}".format(loss_in_epoch/batch_count, correct_items_in_epoch/items_in_epoch))
    
    def predict(self, img_var):
        return self.network(img_var).max(1)[1]
        
            

In [264]:
net =FourLayerLinearNeuralNetwork()

In [265]:
net.train(PrepareMNISTData.get_training_set)

epoch0/20, training loss=0.917797791258494, accuracy=0.6579166666666667
epoch1/20, training loss=0.40618761148899796, accuracy=0.8505333333333334
epoch2/20, training loss=0.3237276238349577, accuracy=0.8776333333333334
epoch3/20, training loss=0.3036727732191483, accuracy=0.8824333333333333
epoch4/20, training loss=0.2891933809529679, accuracy=0.88595
epoch5/20, training loss=0.2785459199167322, accuracy=0.8885
epoch6/20, training loss=0.2705649865209125, accuracy=0.8904833333333333
epoch7/20, training loss=0.2643465009219789, accuracy=0.8920166666666667
epoch8/20, training loss=0.25939889989804166, accuracy=0.8932833333333333
epoch9/20, training loss=0.25308714145893074, accuracy=0.89505
epoch10/20, training loss=0.25129529921258026, accuracy=0.8955333333333333
epoch11/20, training loss=0.24846673242559772, accuracy=0.8961666666666667
epoch12/20, training loss=0.24554003199531385, accuracy=0.8968333333333334
epoch13/20, training loss=0.2419633799520069, accuracy=0.89775
epoch14/20, tr

In [266]:
net.evaluate(PrepareMNISTData.get_test_set)

eval loss=0.3059302164982831, accuracy=0.8844
