In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Simple_CustomCNN(nn.Module):
    def __init__(self):
        super(Simple_CustomCNN, self).__init__()

        self.quant = torch.quantization.QuantStub()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(128 * 4 * 4, 100)  # Ajusta según el tamaño de entrada
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(100, 2)  # 2 clases: ave y no ave

        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):

        x = self.quant(x)

        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        x = F.relu(self.conv3(x))
        x = self.pool3(x)

        x = F.relu(self.conv4(x))
        x = self.pool4(x)

        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        x = self.dequant(x)

        return x
    
net = Simple_CustomCNN().to(device='cpu')

In [6]:
net.qconfig = torch.ao.quantization.default_qconfig
net.train()
net_quant = torch.ao.quantization.prepare_qat(net)
net_quant

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  net_quant = torch.ao.quantization.prepare_qat(net)


Simple_CustomCNN(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (conv1): Conv2d(
    3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(
    32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(
    64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (pool3): MaxPool2d(kernel_size=2, 

In [7]:
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Transformaciones: ajustar tamaño y normalizar
transform = transforms.Compose(
    [
        transforms.Resize((64, 64)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.3, saturation=0.3),
        transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

val_transform = transforms.Compose(
    [
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)
val_dataset = ImageFolder(root="dataset_split/val", transform=val_transform)


train_dataset = ImageFolder(root="dataset_split/train", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

num_epochs = 30


def train(num_epochs, model):
    criterion = nn.CrossEntropyLoss()  # Para clasificación multiclase
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", factor=0.5, patience=3
    )
    for epoch in range(num_epochs):
        # --- Entrenamiento ---
        model.train()
        running_loss = 0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        train_acc = 100 * correct / total

        # --- Validación ---
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        val_acc = 100 * val_correct / val_total

        print(
            f"Epoch {epoch+1}/{num_epochs} - "
            f"Loss: {running_loss/len(train_loader):.4f} - "
            f"Train Acc: {train_acc:.2f}% - "
            f"Val Acc: {val_acc:.2f}%"
        )

        scheduler.step(val_acc)

In [9]:
train(num_epochs,net_quant)

Epoch 1/30 - Loss: 0.6681 - Train Acc: 59.25% - Val Acc: 68.16%
Epoch 2/30 - Loss: 0.6655 - Train Acc: 60.50% - Val Acc: 71.14%
Epoch 3/30 - Loss: 0.6320 - Train Acc: 66.12% - Val Acc: 74.63%
Epoch 4/30 - Loss: 0.6028 - Train Acc: 67.00% - Val Acc: 76.62%
Epoch 5/30 - Loss: 0.5982 - Train Acc: 66.38% - Val Acc: 78.61%
Epoch 6/30 - Loss: 0.5747 - Train Acc: 70.12% - Val Acc: 81.59%
Epoch 7/30 - Loss: 0.5335 - Train Acc: 73.75% - Val Acc: 75.12%
Epoch 8/30 - Loss: 0.5374 - Train Acc: 74.75% - Val Acc: 76.12%
Epoch 9/30 - Loss: 0.5336 - Train Acc: 73.88% - Val Acc: 82.09%
Epoch 10/30 - Loss: 0.5051 - Train Acc: 76.88% - Val Acc: 80.10%
Epoch 11/30 - Loss: 0.4995 - Train Acc: 76.38% - Val Acc: 83.08%
Epoch 12/30 - Loss: 0.5122 - Train Acc: 76.12% - Val Acc: 83.08%
Epoch 13/30 - Loss: 0.4999 - Train Acc: 76.50% - Val Acc: 78.11%
Epoch 14/30 - Loss: 0.5128 - Train Acc: 74.62% - Val Acc: 78.11%
Epoch 15/30 - Loss: 0.4453 - Train Acc: 80.75% - Val Acc: 78.61%
Epoch 16/30 - Loss: 0.4386 - Train

In [10]:
import os
from PIL import Image
import torch


def test(model, folder="Images_test"):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for class_name in os.listdir(folder):
            class_folder = os.path.join(folder, class_name)
            if not os.path.isdir(class_folder):
                continue

            for filename in os.listdir(class_folder):
                if filename.endswith((".jpg", ".png", ".jpeg")):
                    img_path = os.path.join(class_folder, filename)
                    img = Image.open(img_path)
                    img = transform(img).unsqueeze(0).to(device)

                    output = model(img)
                    pred_class = output.argmax(dim=1).item()

                    # Obtenemos el índice de la clase real según train_dataset.classes
                    true_class = train_dataset.classes.index(class_name)

                    if pred_class == true_class:
                        correct += 1
                    total += 1

    accuracy = 100 * correct / total if total > 0 else 0
    print(f"Precisión total: {accuracy:.2f}%")

In [11]:
net_quant

Simple_CustomCNN(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-1.0, max_val=1.0)
  )
  (conv1): Conv2d(
    3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=-0.2487485557794571, max_val=0.2315288484096527)
    (activation_post_process): MinMaxObserver(min_val=-1.8321338891983032, max_val=1.8092478513717651)
  )
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(
    32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=-0.14295434951782227, max_val=0.14749263226985931)
    (activation_post_process): MinMaxObserver(min_val=-2.678623676300049, max_val=2.7000174522399902)
  )
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(
    64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (weight_fake_quant): MinMaxObserver(min_val=-0.16141840

In [12]:
net_quant.eval()
net_quant = torch.ao.quantization.convert(net_quant)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  net_quant = torch.ao.quantization.convert(net_quant)


In [13]:
print(f"Check statistics of the various layers")
net_quant

Check statistics of the various layers


Simple_CustomCNN(
  (quant): Quantize(scale=tensor([0.0157]), zero_point=tensor([64]), dtype=torch.quint8)
  (conv1): QuantizedConv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.028672296553850174, zero_point=64, padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): QuantizedConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.042351506650447845, zero_point=63, padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.07993029803037643, zero_point=60, padding=(1, 1))
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.15773199498653412, zero_point=78, padding=(1, 1))
  (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): QuantizedLinear(in_featu

In [14]:
# Print the weights matrix of the model before quantization
print("Weights before quantization")
print(torch.int_repr(net_quant.conv1.weight()))

Weights before quantization
tensor([[[[ -13,   53,  -56],
          [ -81,   17,   29],
          [ -64,   99,   50]],

         [[  44,   11,  -70],
          [   2,  -48,   14],
          [ -23,  -42,  -45]],

         [[  19,   78,    0],
          [  18,  -84,   31],
          [  52,  -61,   40]]],


        [[[  45,  -34,   38],
          [ -29,  -16,   22],
          [ -53,  -53,   31]],

         [[  55,  -78,   90],
          [ -84,  106,  -89],
          [  80,   14,  -46]],

         [[ -61,   26,  -19],
          [ -43,  -37,   65],
          [  24,   86,  -45]]],


        [[[  79,   60,  -36],
          [  75,  -21,   74],
          [  77,  -12,   44]],

         [[ -93,  -32,  -69],
          [  13,   -7,  -82],
          [ -81,  -93,  -93]],

         [[  30,   41,   -4],
          [  62,    1,  -62],
          [  73,   -7,   35]]],


        [[[  68,   24,   14],
          [  55,  -67,  -86],
          [ -48,  -86,   49]],

         [[  85,   93,   61],
          [ -54,

In [18]:
test(net_quant)

Precisión total: 82.61%
