In [1]:
from datasets import load_dataset

ds = load_dataset("ylecun/mnist")

In [2]:
ds = ds.with_format("torch")

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import time

#from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
torch.set_default_device(device)

cuda


In [5]:
data_train = torch.utils.data.DataLoader(ds['train'], batch_size=32, shuffle = False)

In [43]:
data_dev = torch.utils.data.DataLoader(ds['test'], batch_size=32, shuffle = False)

In [36]:
x_data_dev = ds['test']['image']
x_data_dev = x_data_dev.float()
y_data_dev = ds['test']['label']

In [30]:
class canNet(nn.Module):
    def __init__(self):
        super(canNet, self).__init__()

        self.pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.drop = nn.Dropout(p = 0.5)

        self.act = nn.LeakyReLU()
        
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 3, stride = 1, padding = 1, bias = True)
        self.conv2 = nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1, padding = 1, bias = True)
        self.conv3 = nn.Conv2d(in_channels = 64, out_channels = 128, kernel_size = 3, stride = 1, padding = 1, bias = True)
        #self.conv4 = nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 1, padding = 1, bias = True)
        #self.conv5 = nn.Conv2d(in_channels = 16, out_channels = 16, kernel_size = 2, stride = 1, padding = 1, bias = True)

        self.linear1 = nn.Linear(in_features = 1152, out_features = 512, bias = True)
        self.linear2 = nn.Linear(in_features = 512, out_features = 256, bias = True)
        self.linear3 = nn.Linear(in_features = 256, out_features = 10, bias = True)

        self.batch1 = nn.BatchNorm2d(32)
        self.batch2 = nn.BatchNorm2d(64)
        self.batch3 = nn.BatchNorm2d(128)

    def forward(self, x):
        out = self.conv1(x)
        out = self.batch1(out)
        out = self.act(out)
        out = self.pool(out)
        
        out = self.conv2(out)
        out = self.batch2(out)
        out = self.act(out)
        out = self.pool(out)

        out = self.conv3(out)
        out = self.batch3(out)
        out = self.act(out)
        out = self.pool(out)

        #out = self.conv4(out)
        #out = self.batch3(out)
        #out = self.act(out)
        #out = self.pool(out)

        #print(out.shape)

        out = out.view(-1, 1152)
        
        out = self.linear1(out)
        out = self.act(out)

        out = self.linear2(out)
        out = self.act(out)

        out = self.linear3(out)
        out = F.softmax(out, dim = -1)

        return out

In [31]:
model = canNet()

cross = F.cross_entropy

optimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay=1e-5)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [61]:
def eval_model():
    model.eval() 
    with torch.no_grad():
        results = []
        true_labels = []
        
        for element in data_dev:
            image = element['image']
            image = image.float()
            labels = element['label']
            
            outputs = model(image)
            _, predicted = torch.max(outputs, 1)
            
            results.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    report = classification_report(true_labels, results, zero_division=0, digits= 4)
    print(report)


In [58]:
for epoch in range(10):
    start_time = time.time()
    for element in data_train:
        image = element['image']
        image = image.float()
        label = element['label']
        #print(image.shape)
        #FORWARD 
        outputs = model(image)
        loss = cross(outputs, label)

        #BACKWARD AND OPTIM
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    scheduler.step()

    print(f'Epoch número {epoch}')
    eval_model()
    end_time = time.time()
    epoch_time = end_time - start_time

    print(f'Epoch número {epoch}, ha tardat {epoch_time:.2f}')
    

Epoch número 0
              precision    recall  f1-score   support

           0     0.9894    0.9951    0.9923      5923
           1     0.9939    0.9938    0.9938      6742
           2     0.9848    0.9916    0.9882      5958
           3     0.9886    0.9931    0.9909      6131
           4     0.9924    0.9873    0.9899      5842
           5     0.9931    0.9865    0.9898      5421
           6     0.9922    0.9910    0.9916      5918
           7     0.9874    0.9904    0.9889      6265
           8     0.9850    0.9884    0.9867      5851
           9     0.9898    0.9785    0.9841      5949

    accuracy                         0.9897     60000
   macro avg     0.9897    0.9896    0.9896     60000
weighted avg     0.9897    0.9897    0.9897     60000

Epoch número 0, ha tardat 21.75
Epoch número 1
              precision    recall  f1-score   support

           0     0.9894    0.9951    0.9923      5923
           1     0.9939    0.9938    0.9938      6742
           2    

In [60]:
torch.save(model, './model/model.9897.pth')

In [64]:
!ls model -lh

total 3,2M
-rw-r--r-- 1 mhurben guest 3,2M jul 25 21:39 model.9897.pth
