Pytorch CNN

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(128 * 4 * 4, 100)  # Ajusta según el tamaño de entrada
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(100, 2)  # 2 clases: ave y no ave

    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        x = F.relu(self.conv3(x))
        x = self.pool3(x)

        x = F.relu(self.conv4(x))
        x = self.pool4(x)

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

In [2]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Transformaciones: ajustar tamaño y normalizar
transform = transforms.Compose(
    [
        transforms.Resize((64, 64)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.3),
        transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

val_transform = transforms.Compose(
    [
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)
val_dataset = ImageFolder(root="dataset_split/val", transform=val_transform)


train_dataset = ImageFolder(root="dataset_split/train", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
net = CustomCNN().to(device)



In [11]:
num_epochs = 30

def train(num_epochs,model):
    criterion = nn.CrossEntropyLoss()  # Para clasificación multiclase
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", factor=0.5, patience=3
    )
    for epoch in range(num_epochs):
    # --- Entrenamiento ---
        model.train()
        running_loss = 0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        train_acc = 100 * correct / total

        # --- Validación ---
        model.eval()
        val_correct = 0
        val_total = 0           
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        val_acc = 100 * val_correct / val_total

        print(
            f"Epoch {epoch+1}/{num_epochs} - "
            f"Loss: {running_loss/len(train_loader):.4f} - "
            f"Train Acc: {train_acc:.2f}% - "
            f"Val Acc: {val_acc:.2f}%"
        )

        scheduler.step(val_acc) 



In [12]:
import os
from pathlib import Path
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print("Size(KB):", os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'tinyCNN.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load (MODEL_FILENAME ) )
    print( ' Loaded model from disk')
else:
    train(num_epochs, net)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)


Epoch 1/30 - Loss: 0.6905 - Train Acc: 52.88% - Val Acc: 71.64%
Epoch 2/30 - Loss: 0.6590 - Train Acc: 66.25% - Val Acc: 69.15%
Epoch 3/30 - Loss: 0.6516 - Train Acc: 61.62% - Val Acc: 68.66%
Epoch 4/30 - Loss: 0.5928 - Train Acc: 68.50% - Val Acc: 76.12%
Epoch 5/30 - Loss: 0.5627 - Train Acc: 73.00% - Val Acc: 80.10%
Epoch 6/30 - Loss: 0.5681 - Train Acc: 71.62% - Val Acc: 70.65%
Epoch 7/30 - Loss: 0.5450 - Train Acc: 74.00% - Val Acc: 79.10%
Epoch 8/30 - Loss: 0.5177 - Train Acc: 75.50% - Val Acc: 79.60%
Epoch 9/30 - Loss: 0.5046 - Train Acc: 75.25% - Val Acc: 75.62%
Epoch 10/30 - Loss: 0.4979 - Train Acc: 75.88% - Val Acc: 79.10%
Epoch 11/30 - Loss: 0.4950 - Train Acc: 76.25% - Val Acc: 81.09%
Epoch 12/30 - Loss: 0.4736 - Train Acc: 78.50% - Val Acc: 78.11%
Epoch 13/30 - Loss: 0.4871 - Train Acc: 77.12% - Val Acc: 79.60%
Epoch 14/30 - Loss: 0.4614 - Train Acc: 79.25% - Val Acc: 80.10%
Epoch 15/30 - Loss: 0.4667 - Train Acc: 79.00% - Val Acc: 81.09%
Epoch 16/30 - Loss: 0.4586 - Train

In [16]:
print('Weight for quantization')
print(net.conv1.weight)
print(net.conv1.weight.dtype)

Weight for quantization
Parameter containing:
tensor([[[[-1.6515e-01,  2.1807e-01, -1.7042e-02],
          [-1.8199e-02,  8.6198e-02,  1.9746e-03],
          [-7.7824e-02,  2.1910e-02, -2.6301e-02]],

         [[-4.4202e-02, -5.2773e-02, -1.6734e-01],
          [-2.2692e-02, -1.7671e-01, -1.1619e-01],
          [ 1.0103e-01,  5.2116e-02, -8.0960e-02]],

         [[-1.6551e-01,  1.9344e-01, -2.7482e-02],
          [-1.7917e-01,  1.7208e-01, -1.5584e-01],
          [-8.8777e-02,  9.3557e-02,  4.1366e-02]]],


        [[[-1.3175e-01,  5.3676e-03,  1.1814e-01],
          [ 1.2859e-01, -1.3982e-01,  3.0599e-02],
          [ 7.9608e-02, -2.1340e-03,  1.5894e-01]],

         [[-1.5243e-01, -8.2030e-03,  1.2523e-01],
          [ 7.2440e-02, -4.3358e-02,  8.2299e-02],
          [ 2.7266e-02,  1.0743e-01, -1.3667e-01]],

         [[ 8.1704e-02,  1.4884e-02,  1.0262e-01],
          [ 4.3082e-02, -6.7385e-02,  1.2704e-01],
          [ 2.4735e-02, -4.9994e-02,  1.5812e-02]]],


        [[[ 2.5023e-

In [53]:
print_size_of_model(net)

Size(KB): 1788.469


In [90]:
class Quant_simple_CustomCNN(nn.Module):
    def __init__(self):
        super(Quant_simple_CustomCNN, self).__init__()

        self.quant = torch.quantization.QuantStub()
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(128 * 4 * 4, 100)  # Ajusta según el tamaño de entrada
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(100, 2)  # 2 clases: ave y no ave

        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self, x):

        x = self.quant(x)

        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        x = F.relu(self.conv3(x))
        x = self.pool3(x)

        x = F.relu(self.conv4(x))
        x = self.pool4(x)

        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        x = self.dequant(x)
    
        return x

In [91]:
net_quant = Quant_simple_CustomCNN().to(device)

net_quant.load_state_dict(net.state_dict())
net_quant.eval()

net_quant.qconfig = torch.ao.quantization.default_qconfig
net_quant = torch.ao.quantization.prepare(net_quant) 
net_quant

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  net_quant = torch.ao.quantization.prepare(net_quant)


Quant_simple_CustomCNN(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (conv1): Conv2d(
    3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(
    32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(
    64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(
    128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): MinMaxObserver(min_val=inf, max

In [92]:
import os
from PIL import Image
import torch


def test(model, folder="Images_test"):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for class_name in os.listdir(folder):
            class_folder = os.path.join(folder, class_name)
            if not os.path.isdir(class_folder):
                continue

            for filename in os.listdir(class_folder):
                if filename.endswith((".jpg", ".png", ".jpeg")):
                    img_path = os.path.join(class_folder, filename)
                    img = Image.open(img_path)
                    img = transform(img).unsqueeze(0).to(device)

                    output = model(img)
                    pred_class = output.argmax(dim=1).item()

                    # Obtenemos el índice de la clase real según train_dataset.classes
                    true_class = train_dataset.classes.index(class_name)

                    if pred_class == true_class:
                        correct += 1
                    total += 1


    accuracy = 100 * correct / total if total > 0 else 0
    print(f"Precisión total: {accuracy:.2f}%")

In [93]:
test(net)

Precisión total: 78.26%


In [96]:
test(net_quant)

Precisión total: 78.26%


In [97]:
print(f'Check stadistics')
net_quant

Check stadistics


Quant_simple_CustomCNN(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-1.0, max_val=1.0)
  )
  (conv1): Conv2d(
    3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): MinMaxObserver(min_val=-1.575791597366333, max_val=1.4939677715301514)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(
    32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): MinMaxObserver(min_val=-1.544890284538269, max_val=1.4271245002746582)
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(
    64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (activation_post_process): MinMaxObserver(min_val=-0.953361988067627, max_val=1.1487722396850586)
  )
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(
    128, 128, kernel_size=(3, 3), strid

In [98]:
net_quant = torch.ao.quantization.convert(net_quant)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  net_quant = torch.ao.quantization.convert(net_quant)


In [99]:
print(f"Check Stadistics")
net_quant

Check Stadistics


Quant_simple_CustomCNN(
  (quant): Quantize(scale=tensor([0.0157]), zero_point=tensor([64]), dtype=torch.quint8)
  (conv1): QuantizedConv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.024171333760023117, zero_point=65, padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): QuantizedConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.023401692509651184, zero_point=66, padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.016552237793803215, zero_point=58, padding=(1, 1))
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.023344218730926514, zero_point=76, padding=(1, 1))
  (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): QuantizedLinear(

In [100]:
# Weights after quantization
print("Weihts quantization")
print(torch.int_repr(net_quant.conv1.weight()))

Weihts quantization
tensor([[[[ -81,  108,   -8],
          [  -9,   43,    1],
          [ -38,   11,  -13]],

         [[ -22,  -26,  -83],
          [ -11,  -87,  -57],
          [  50,   26,  -40]],

         [[ -82,   95,  -14],
          [ -88,   85,  -77],
          [ -44,   46,   20]]],


        [[[ -65,    3,   58],
          [  63,  -69,   15],
          [  39,   -1,   78]],

         [[ -75,   -4,   62],
          [  36,  -21,   41],
          [  13,   53,  -67]],

         [[  40,    7,   51],
          [  21,  -33,   63],
          [  12,  -25,    8]]],


        [[[  12,   76,   98],
          [ -61,  -52,  -57],
          [  -2,    9,  -31]],

         [[ -68,  -64,   70],
          [  26,  -14,    9],
          [  39,   35,    2]],

         [[ -76,  -33,   25],
          [  40,    7,  -57],
          [  39,  -78,  -85]]],


        [[[ -44,   -9,  -24],
          [ -79,  -25,    0],
          [ -62,  -57,  -25]],

         [[ -89,   37,   37],
          [  62,   86,  

In [101]:
# Compare Weights
print('Original')
print(net.conv1.weight)
print('')
print(f'Dequantize')
print(torch.dequantize(net_quant.conv1.weight()))
print('')

Original
Parameter containing:
tensor([[[[-1.6515e-01,  2.1807e-01, -1.7042e-02],
          [-1.8199e-02,  8.6198e-02,  1.9746e-03],
          [-7.7824e-02,  2.1910e-02, -2.6301e-02]],

         [[-4.4202e-02, -5.2773e-02, -1.6734e-01],
          [-2.2692e-02, -1.7671e-01, -1.1619e-01],
          [ 1.0103e-01,  5.2116e-02, -8.0960e-02]],

         [[-1.6551e-01,  1.9344e-01, -2.7482e-02],
          [-1.7917e-01,  1.7208e-01, -1.5584e-01],
          [-8.8777e-02,  9.3557e-02,  4.1366e-02]]],


        [[[-1.3175e-01,  5.3676e-03,  1.1814e-01],
          [ 1.2859e-01, -1.3982e-01,  3.0599e-02],
          [ 7.9608e-02, -2.1340e-03,  1.5894e-01]],

         [[-1.5243e-01, -8.2030e-03,  1.2523e-01],
          [ 7.2440e-02, -4.3358e-02,  8.2299e-02],
          [ 2.7266e-02,  1.0743e-01, -1.3667e-01]],

         [[ 8.1704e-02,  1.4884e-02,  1.0262e-01],
          [ 4.3082e-02, -6.7385e-02,  1.2704e-01],
          [ 2.4735e-02, -4.9994e-02,  1.5812e-02]]],


        [[[ 2.5023e-02,  1.5390e-01

In [102]:
print("Size after quantization")
print_size_of_model(net_quant)

Size after quantization
Size(KB): 456.117


In [103]:
print('Testing after quantization')
test(net_quant)

Testing after quantization
Precisión total: 73.91%


Brevita

In [None]:
from torch import nn
import torch.nn.functional as F

import brevitas.nn as qnn
from brevitas import Int8WeightPerTensorFloat

class QuantTinyCNN(nn.Module):
    def __init__(self):
        super(QuantTinyCNN,self).__init__()
        quantinfo = {'weight_quant': Int8WeightPerTensorFloat, 'weight_bit_width': 8 }

        self.conv1 = qnn.QuantConv2d(3,32,kernel_size=3,padding=1,**quantinfo)
        self.relu1 = qnn.QuantReLU(bit_width = 8)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = qnn.QuantConv2d(32, 64, kernel_size=3, padding=1, **quantinfo)
        self.relu2 = qnn.QuantReLU(bit_width=8)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = qnn.QuantConv2d(64, 128, kernel_size=3, padding=1, **quantinfo)
        self.relu3 = qnn.QuantReLU(bit_width=8)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = qnn.QuantConv2d(128, 128, kernel_size=3, padding=1, **quantinfo)
        self.relu4 = qnn.QuantReLU(bit_width=8)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = qnn.QuantLinear(128*4*4,100,**quantinfo)
        self.relu_fc1 = qnn.QuantReLU(bit_width = 8)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = qnn.QuantLinear(100,2,**quantinfo)
    
    def forward(self,x):
        
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.pool3(self.relu3(self.conv3(x)))
        x = self.pool4(self.relu4(self.conv4(x)))
        
        x = x.view(x.size(0),-1)
        x = self.relu_fc1(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x