In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
device

device(type='cuda', index=0)

In [3]:
class LeNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2, stride=1)
        self.max1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=0, stride=1)
        self.max2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)

    def forward(self, x):
        x = torch.sigmoid(self.conv1(x))
        x = self.max1(x)
        x = torch.sigmoid(self.conv2(x))
        x = self.max2(x)
        x = x.view(x.shape[0], -1)
        x = torch.sigmoid(self.fc1(x)) 
        x = torch.sigmoid(self.fc2(x))
        x = self.out(x)

        return x



In [4]:
model = LeNet()
model.to(device)

LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (max1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (max2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (out): Linear(in_features=84, out_features=10, bias=True)
)

In [5]:
# class LeNet(nn.Module):

#     def __init__(self):
#         super().__init__()
#         self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2, stride=1)
#         self.max1 = nn.AvgPool2d(kernel_size=2, stride=2)
#         self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=0, stride=1)
#         self.max2 = nn.AvgPool2d(kernel_size=2, stride=2)
#         self.fc1 = nn.Linear(400, 120)
#         self.fc2 = nn.Linear(120, 84)
#         self.out = nn.Linear(84, 10)

#     def forward(self, x):
#         x = nn.Sigmoid()(self.conv1(x))
#         x = self.max1(x)
#         x = nn.Sigmoid()(self.conv2(x))
#         x = self.max2(x)
#         # x = torch.view(-1, 1)
#         x = nn.Flatten()(x)
#         x = nn.Sigmoid()(self.fc1(x)) 
#         x = nn.Sigmoid()(self.fc2(x))
#         x = self.out(x)

#         return x


In [6]:
train_data = datasets.MNIST(root='.', download=True, train=True, transform=transforms.ToTensor())
valid_data = datasets.MNIST(root='.', download=True, train=False, transform=transforms.ToTensor())

In [7]:
train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=32, shuffle=True)

In [8]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [9]:
criterion = nn.CrossEntropyLoss()

In [10]:
epochs = 25
num_train_data = len(train_data)
num_valid_data = len(valid_data)

In [11]:
for epoch in range (0, epochs):
    model.train()

    if epoch % 2 == 0:
        checkpoint = {
            'model_state' : model.state_dict(),
            'optim_state' : optimizer.state_dict(),
            'epoch' : epoch
        }

    torch.save(checkpoint, 'chechpoint.pth')

    correct_train = 0

    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device)
       
        yhat = model(x)
        _, train_label = torch.max(yhat, 1)

        correct_train += (train_label == y).sum()

        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            loss = criterion(yhat, y)
            loss.backward()
            optimizer.step()
           

    train_acc = correct_train / num_train_data
    
    print(f'train_acc epoch: {epoch}: {train_acc}')

    model.eval()
    correct_valid = 0

    for x_val, y_val in valid_loader:
        x_val = x_val.to(device)
        y_val = y_val.to(device)
        
        yhat_val = model(x_val)
        _, yhat_label = torch.max(yhat_val, 1)
        
        correct_valid += (yhat_label == y_val).sum()

    valid_acc = correct_valid / num_valid_data

    print(f'valid_acc epoch: {epoch}: {valid_acc}')


train_acc epoch: 0: 0.7311333417892456
valid_acc epoch: 0: 0.9228999614715576
train_acc epoch: 1: 0.9394500255584717
valid_acc epoch: 1: 0.953499972820282
train_acc epoch: 2: 0.9601666927337646
valid_acc epoch: 2: 0.9705999493598938
train_acc epoch: 3: 0.9714833498001099
valid_acc epoch: 3: 0.9713999629020691
train_acc epoch: 4: 0.9763000011444092
valid_acc epoch: 4: 0.979699969291687
train_acc epoch: 5: 0.9798166751861572
valid_acc epoch: 5: 0.9812999963760376
train_acc epoch: 6: 0.9828166961669922
valid_acc epoch: 6: 0.9840999841690063
train_acc epoch: 7: 0.9851666688919067
valid_acc epoch: 7: 0.983299970626831
train_acc epoch: 8: 0.9866333603858948
valid_acc epoch: 8: 0.9846999645233154
train_acc epoch: 9: 0.9881666898727417
valid_acc epoch: 9: 0.9835000038146973
train_acc epoch: 10: 0.9891499876976013
valid_acc epoch: 10: 0.9854999780654907
train_acc epoch: 11: 0.9901000261306763
valid_acc epoch: 11: 0.9850999712944031
train_acc epoch: 12: 0.990766704082489
valid_acc epoch: 12: 0.9