In [1]:
import torch
import os 
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
import mlflow
import mlflow.pytorch

In [2]:
class Config:
    EPOCHS = 10
    BATCH_SIZE = 32
    LR = 0.01
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    GAMMA = 0.7
    SEED = 42
    LOG_INTERVAL = 10
    TEST_BATCH_SIZE = 1000
    DRY_RUN = True

In [3]:
config = Config()

In [4]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        
        output = F.log_softmax(x, dim= 1)
        return output

In [5]:
def train(config, model, device, train_loader, optimizer, epoch):
    pass

In [6]:
def test(model, device, test_loader):
    pass

In [7]:
torch.manual_seed(config.SEED)

<torch._C.Generator at 0x19745d863d0>

In [8]:
train_kwargs = {"batch_size": config.BATCH_SIZE}
test_kwargs = {"batch_size": config.TEST_BATCH_SIZE}

In [9]:
if config.DEVICE == "cuda":
    cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)    

In [10]:
transforms = transforms.Compose(
    [transforms.ToTensor()]
)

In [11]:
train = datasets.MNIST("../data", train = True, download = True, transform = transforms)
test = datasets.MNIST("../data", train = False, transform = transforms)

train_loader = torch.utils.data.DataLoader(train, **train_kwargs)
test_loader = torch.utils.data.DataLoader(test, **test_kwargs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ../data\MNIST\raw\train-images-idx3-ubyte.gz to ../data\MNIST\raw


102.8%


Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data\MNIST\raw\train-labels-idx1-ubyte.gz



1.4%

Extracting ../data\MNIST\raw\train-labels-idx1-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ../data\MNIST\raw\t10k-images-idx3-ubyte.gz to ../data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


22.5%

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz


112.7%

Extracting ../data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../data\MNIST\raw




