In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from tqdm import tqdm
import convolution_kan
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score


In [2]:
def data_to_binary(mnist_data, binary=False):
    """
    Just keep the 0 and 1 classes
    """
    if binary:
        mnist_data.data = mnist_data.data[(mnist_data.targets == 0) | (mnist_data.targets == 1)]
        mnist_data.targets = mnist_data.targets[(mnist_data.targets == 0) | (mnist_data.targets == 1)]
    return mnist_data



# Transformaciones
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Cargar MNIST y filtrar por dos clases
all_mnist_train = MNIST(root='./data', train=True, download=True, transform=transform)
# mnist_train = data_to_binary(all_mnist_train, binary=True)
mnist_train = data_to_binary(all_mnist_train, binary=False)

all_mnist_test = MNIST(root='./data', train=False, download=True, transform=transform)
# mnist_test = data_to_binary(all_mnist_test, binary=True)
mnist_test = data_to_binary(all_mnist_test, binary=False)

# DataLoader
train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=64, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:04<00:00, 2341950.41it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 164768.32it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1125673.86it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1088352.88it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [10]:
def train(model, device, train_loader, optimizer, epoch, criterion):
    # Set the model to training mode
    model.to(device)
    model.train()
    train_loss = 0
    print("Epoch:", epoch)
    # Process the images in batches
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        # Use the CPU or GPU as appropriate
        # Recall that GPU is optimized for the operations we are dealing with
        data, target = data.to(device), target.to(device)
        
        # Reset the optimizer
        optimizer.zero_grad()
        
        # Push the data forward through the model layers
        output = model(data)
        
        # Get the loss
        loss = criterion(output, target)

        # Keep a running total
        train_loss += loss.item()
        
        # Backpropagate
        loss.backward()
        optimizer.step()
        
        # Print metrics so we see some progress
        # print('\tTraining batch {} Loss: {:.6f}'.format(batch_idx + 1, loss.item()))
            
    # return average loss for the epoch
    avg_loss = train_loss / (batch_idx+1)
    print('Training set: Average loss: {:.6f}'.format(avg_loss))
    return avg_loss

def test(model, device, test_loader, criterion):
    # Switch the model to evaluation mode
    model.eval()
    test_loss = 0
    correct = 0
    all_targets = []
    all_predictions = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            
            # Get the predicted classes for this batch
            output = model(data)
            
            # Calculate the loss for this batch
            test_loss += criterion(output, target).item()
            
            # Calculate the accuracy for this batch
            _, predicted = torch.max(output.data, 1)
            correct += (target == predicted).sum().item()

            # Collect all targets and predictions for metric calculations
            all_targets.extend(target.view_as(predicted).cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())

    # Calculate overall metrics
    precision = precision_score(all_targets, all_predictions, average='macro')
    recall = recall_score(all_targets, all_predictions, average='macro')
    f1 = f1_score(all_targets, all_predictions, average='macro')

    # Normalize test loss
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%), Precision: {:.2f}, Recall: {:.2f}, F1 Score: {:.2f}\n'.format(
        test_loss, correct, len(test_loader.dataset), accuracy, precision, recall, f1))

    return test_loss, accuracy, precision, recall, f1

In [11]:
model = convolution_kan.CNN_KAN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use an "Adam" optimizer to adjust weights
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Define loss
criterion = nn.CrossEntropyLoss()

# Track metrics in these arrays
epoch_nums = []
training_loss = []
validation_loss = []

# Train over 10 epochs (We restrict to 10 for time issues)
epochs = 10
print('Training on', device)
for epoch in range(1, epochs + 1):
        train_loss = train(model, device, train_loader, optimizer, epoch, criterion)
        test_loss, accuracy, precision, recall, f1 = test(model, device, test_loader, criterion)
        epoch_nums.append(epoch)
        training_loss.append(train_loss)
        validation_loss.append(test_loss)
        scheduler.step()
        print('')
        print("lr: ", optimizer.param_groups[0]['lr'])
        print("test loss: ", test_loss)
        print("accuracy: ", accuracy)
        print("precision: ", precision)
        print("recall: ", recall)
        print("f1: ", f1)
        print('')


Training on cuda
Epoch: 1


100%|██████████| 938/938 [00:21<00:00, 43.45it/s]


Training set: Average loss: 0.189711

Test set: Average loss: 0.0009, Accuracy: 9828/10000 (98%), Precision: 0.98, Recall: 0.98, F1 Score: 0.98

0.0008

test loss:  0.0008964011472126003
accuracy:  98.28
precision:  0.9829414455822884
recall:  0.9828111975147109
f1:  0.9828097059883113

Epoch: 2


100%|██████████| 938/938 [00:21<00:00, 43.72it/s]


Training set: Average loss: 0.052503

Test set: Average loss: 0.0008, Accuracy: 9833/10000 (98%), Precision: 0.98, Recall: 0.98, F1 Score: 0.98

0.00064

test loss:  0.0007634898854827043
accuracy:  98.33
precision:  0.983500784038584
recall:  0.9831711178594276
f1:  0.983259410196518

Epoch: 3


100%|██████████| 938/938 [00:21<00:00, 42.78it/s]


Training set: Average loss: 0.036364

Test set: Average loss: 0.0006, Accuracy: 9863/10000 (99%), Precision: 0.99, Recall: 0.99, F1 Score: 0.99

0.0005120000000000001

test loss:  0.0006414072981890058
accuracy:  98.63
precision:  0.9864387269455748
recall:  0.9860939503410225
f1:  0.9862037974155525

Epoch: 4


100%|██████████| 938/938 [00:21<00:00, 44.16it/s]


Training set: Average loss: 0.028853

Test set: Average loss: 0.0006, Accuracy: 9864/10000 (99%), Precision: 0.99, Recall: 0.99, F1 Score: 0.99

0.0004096000000000001

test loss:  0.0005961154789940338
accuracy:  98.64
precision:  0.9865372288985071
recall:  0.9863145687225611
f1:  0.9863758921560375

Epoch: 5


100%|██████████| 938/938 [00:20<00:00, 46.83it/s]


Training set: Average loss: 0.022136

Test set: Average loss: 0.0005, Accuracy: 9885/10000 (99%), Precision: 0.99, Recall: 0.99, F1 Score: 0.99

0.0003276800000000001

test loss:  0.000535092187745613
accuracy:  98.85
precision:  0.9885502918566502
recall:  0.9883177898636661
f1:  0.9884085188409906

Epoch: 6


100%|██████████| 938/938 [00:20<00:00, 46.83it/s]


Training set: Average loss: 0.017281

Test set: Average loss: 0.0006, Accuracy: 9882/10000 (99%), Precision: 0.99, Recall: 0.99, F1 Score: 0.99

0.0002621440000000001

test loss:  0.0005533403412839107
accuracy:  98.82
precision:  0.9882341866156341
recall:  0.9879748901358913
f1:  0.9880831568938216

Epoch: 7


100%|██████████| 938/938 [00:21<00:00, 44.04it/s]


Training set: Average loss: 0.014065

Test set: Average loss: 0.0006, Accuracy: 9875/10000 (99%), Precision: 0.99, Recall: 0.99, F1 Score: 0.99

0.00020971520000000012

test loss:  0.0005969544560754002
accuracy:  98.75
precision:  0.9875609517632575
recall:  0.9873672779812905
f1:  0.9873923674249194

Epoch: 8


100%|██████████| 938/938 [00:21<00:00, 43.40it/s]


Training set: Average loss: 0.011720

Test set: Average loss: 0.0005, Accuracy: 9891/10000 (99%), Precision: 0.99, Recall: 0.99, F1 Score: 0.99

0.0001677721600000001

test loss:  0.0004944570142113661
accuracy:  98.91
precision:  0.9890585239513726
recall:  0.9889400124572802
f1:  0.9889889771953084

Epoch: 9


100%|██████████| 938/938 [00:21<00:00, 43.11it/s]


Training set: Average loss: 0.009827

Test set: Average loss: 0.0005, Accuracy: 9884/10000 (99%), Precision: 0.99, Recall: 0.99, F1 Score: 0.99

0.00013421772800000008

test loss:  0.000518110146923209
accuracy:  98.84
precision:  0.9884108243235536
recall:  0.9882427870655677
f1:  0.9883082111400767

Epoch: 10


100%|██████████| 938/938 [00:20<00:00, 45.10it/s]


Training set: Average loss: 0.008326

Test set: Average loss: 0.0005, Accuracy: 9895/10000 (99%), Precision: 0.99, Recall: 0.99, F1 Score: 0.99

0.00010737418240000007

test loss:  0.0004997915303645641
accuracy:  98.95
precision:  0.9895057764742223
recall:  0.9893265931720059
f1:  0.9894043251317456

