In [1]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [2]:
torch.randn(5).cuda()

tensor([-0.0968,  2.1419, -0.9588,  1.0083,  0.0039], device='cuda:0')

In [3]:
model = nn.Sequential(
    nn.Linear(28*28,4096),
    nn.ReLU(),
    nn.Linear(4096,10)
)

In [44]:
# more flexbible model

In [40]:
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28*28,4096)
        #self.l2 = nn.Linear(64,64)
        #self.l3 = nn.Linear(64,64)
        self.l4 = nn.Linear(4096,10)
        self.do = nn.Dropout(0)
    
    def forward(self,x):
        h1 = nn.functional.relu(self.l1(x))
        #h2 = nn.functional.relu(self.l2(h1))
        #h3 = nn.functional.relu(self.l3(h2))
        do = self.do(h1)
        logits = self.l4(do)
        return logits

model = ResNet().cuda()

In [41]:
params = model.parameters()
optimizer = optim.SGD(params, lr = 1e-2)

In [42]:
loss = nn.CrossEntropyLoss()

In [43]:
train_data = datasets.MNIST('data',train = True, download = True, transform = transforms.ToTensor())
train,val = random_split(train_data, [55000,5000])
train_loader = DataLoader(train,batch_size = 32)
val_loader = DataLoader(val,batch_size = 32)

In [44]:
n_epochs = 80
for epoch in range(n_epochs):
    losses = list()
    accuracies = list()
    model.train()
    for batch in train_loader:
        
        x,y, = batch
        b = x.size(0)

        #flatten image
        x = x.view(b,-1).cuda()
        
        l = model(x)
        J = loss(l,y.cuda())
        
        model.zero_grad()
        
        J.backward()
        
        optimizer.step()
        
        losses.append(J.item())
        accuracies.append(y.eq(l.detach().argmax(dim=1).cpu()).float().mean())
        
    print(f'Epoch {epoch + 1}', end = ',')
    print(f'train loss:{torch.tensor(losses).mean():.2f}', end = ',')
    print(f'train accuracy:{torch.tensor(accuracies).mean():.2f}')
        
    losses = list()
    accuracies = list()
    model.eval()
    for batch in val_loader:
        
        x,y, = batch
        b = x.size(0)
        #flatten image
        x = x.view(b,-1).cuda()
        
        with torch.no_grad():
            l = model(x)
        J = loss(l,y.cuda())
        
        losses.append(J.item())
        accuracies.append(y.eq(l.detach().argmax(dim=1).cpu()).float().mean())
        
    print(f'Epoch {epoch + 1}',end= ',')
    print(f'validation loss:{torch.tensor(losses).mean():.2f}',end = ',')
    print(f'validation accuracy: {torch.tensor(accuracies).mean():.2f}')

Epoch 1,train loss:0.68,train accuracy:0.85
Epoch 1,validation loss:0.38,validation accuracy: 0.90
Epoch 2,train loss:0.34,train accuracy:0.91
Epoch 2,validation loss:0.31,validation accuracy: 0.91
Epoch 3,train loss:0.29,train accuracy:0.92
Epoch 3,validation loss:0.27,validation accuracy: 0.92
Epoch 4,train loss:0.26,train accuracy:0.93
Epoch 4,validation loss:0.25,validation accuracy: 0.93
Epoch 5,train loss:0.24,train accuracy:0.93
Epoch 5,validation loss:0.23,validation accuracy: 0.94
Epoch 6,train loss:0.22,train accuracy:0.94
Epoch 6,validation loss:0.21,validation accuracy: 0.94
Epoch 7,train loss:0.20,train accuracy:0.94
Epoch 7,validation loss:0.20,validation accuracy: 0.94
Epoch 8,train loss:0.19,train accuracy:0.95
Epoch 8,validation loss:0.19,validation accuracy: 0.95
Epoch 9,train loss:0.18,train accuracy:0.95
Epoch 9,validation loss:0.18,validation accuracy: 0.95
Epoch 10,train loss:0.16,train accuracy:0.95
Epoch 10,validation loss:0.17,validation accuracy: 0.95
Epoch 11

In [45]:
test_data = datasets.MNIST('data',train = False, download = True, transform = transforms.ToTensor())
test_loader = DataLoader(test_data,batch_size = 100)

In [46]:
losses = list()
accuracies = list()
model.eval()
for batch in test_loader:
        
    x,y, = batch
    b = x.size(0)
    #flatten image
    x = x.view(b,-1).cuda()
        
    with torch.no_grad():
        l = model(x)
    J = loss(l,y.cuda())
        
    losses.append(J.item())
    accuracies.append(y.eq(l.detach().argmax(dim=1).cpu()).float().mean())
        
print(f'test loss:{torch.tensor(losses).mean():.2f}',end = ',')
print(f'test accuracy: {torch.tensor(accuracies).mean():.2f}')

test loss:0.06,test accuracy: 0.98
