In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
transform = transforms.Compose(
    [transforms.Resize((64, 64)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
train_dataset = torchvision.datasets.CIFAR10(root='.', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='.', train=False, transform=transforms.ToTensor(), download=True)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [None]:
print("dataset keys:",train_dataset.__dict__.keys())
print("dataset classes:", train_dataset.classes)
print("dataset data type:", type(train_dataset.data))
print("dataset target type:", type(train_dataset.targets))

dataset keys: dict_keys(['root', 'transform', 'target_transform', 'transforms', 'train', 'data', 'targets', 'classes', 'class_to_idx'])
dataset classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
dataset data type: <class 'numpy.ndarray'>
dataset target type: <class 'list'>


In [None]:
batch_size = 128

dataloader = {
    'train': torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True),
    'test': torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
}

In [None]:
dataloader = {
    'train': torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('../data', train=True, download=True,
                       transform=torchvision.transforms.Compose([
                            torchvision.transforms.Resize((64, 64)),
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize((0.1307,), (0.3081,))
                            ])
                      ), batch_size=64, shuffle=True, pin_memory=True),
    'test': torch.utils.data.DataLoader(torchvision.datasets.CIFAR10('../data', train=False,
                   transform=torchvision.transforms.Compose([
                        torchvision.transforms.Resize((64, 64)),
                        torchvision.transforms.ToTensor(),
                        torchvision.transforms.Normalize((0.1307,), (0.3081,))
                        ])
                     ), batch_size=64, shuffle=False, pin_memory=True)
}

Files already downloaded and verified


In [None]:
#def block(c_in, c_out, k=5, p=0, s=2, pk=2, ps=1):
def block(c_in, c_out, k=5, p=0, s=2, pk=2, ps=1):
    return torch.nn.Sequential(
        torch.nn.Conv2d(c_in, c_out, k, padding=p, stride=s),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(pk, stride=ps)
    )

def block2(c_in, c_out, k=5, p=3, s=2, pk=2, ps=1):
    return torch.nn.Sequential(
        torch.nn.Conv2d(c_in, c_out, k, padding=p, stride=s),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(pk, stride=ps)
    )

class CNN(torch.nn.Module):
  def __init__(self, n_channels=3, n_outputs=10):
    super().__init__()
    self.conv1 = block(n_channels, 784)
    self.conv2 = block(784, 392)
    self.conv3 = block(392, 196)
    self.conv4 = block2(196, 64)
    self.fc = torch.nn.Linear(64 * 2 * 2, n_outputs)

  # def forward(self, x):
  #   print("Dimensiones:")
  #   print("Entrada: ", x.shape)
  #   x = self.conv1(x)
  #   print("conv1: ", x.shape)
  #   x = self.conv2(x)
  #   print("conv2: ", x.shape)
  #   x = self.conv3(x)
  #   print("conv3: ", x.shape)
  #   x = self.conv4(x)
  #   print("conv4: ", x.shape)
  #   x = x.view(x.shape[0], -1)
  #   print("pre fc: ", x.shape)
  #   x = self.fc(x)
  #   print("Salida: ", x.shape)
  #   return x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = x.view(x.shape[0], -1)
    x = self.fc(x)
    return x

In [None]:
model = CNN()

output = model(torch.randn(64, 3, 64, 64))

In [None]:
from tqdm import tqdm
import numpy as np

def fit(model, dataloader, epochs=5):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(1, epochs+1):
        model.train()
        train_loss, train_acc = [], []
        bar = tqdm(dataloader['train'])
        for batch in bar:
            X, y = batch
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            y_hat = model(X)
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            acc = (y == torch.argmax(y_hat, axis=1)).sum().item() / len(y)
            train_acc.append(acc)
            bar.set_description(f"loss {np.mean(train_loss):.5f} acc {np.mean(train_acc):.5f}")
        bar = tqdm(dataloader['test'])
        val_loss, val_acc = [], []
        model.eval()
        with torch.no_grad():
            for batch in bar:
                X, y = batch
                X, y = X.to(device), y.to(device)
                y_hat = model(X)
                loss = criterion(y_hat, y)
                val_loss.append(loss.item())
                acc = (y == torch.argmax(y_hat, axis=1)).sum().item() / len(y)
                val_acc.append(acc)
                bar.set_description(f"val_loss {np.mean(val_loss):.5f} val_acc {np.mean(val_acc):.5f}")
        print(f"Epoch {epoch}/{epochs} loss {np.mean(train_loss):.5f} val_loss {np.mean(val_loss):.5f} acc {np.mean(train_acc):.5f} val_acc {np.mean(val_acc):.5f}")

In [None]:
model = CNN()
fit(model, dataloader, epochs=20)

loss 1.82693 acc 0.33082: 100%|██████████| 782/782 [02:58<00:00,  4.37it/s]
val_loss 1.58162 val_acc 0.42605: 100%|██████████| 157/157 [00:12<00:00, 12.77it/s]


Epoch 1/20 loss 1.82693 val_loss 1.58162 acc 0.33082 val_acc 0.42605


loss 1.50451 acc 0.45806: 100%|██████████| 782/782 [02:57<00:00,  4.41it/s]
val_loss 1.45643 val_acc 0.47572: 100%|██████████| 157/157 [00:12<00:00, 12.54it/s]


Epoch 2/20 loss 1.50451 val_loss 1.45643 acc 0.45806 val_acc 0.47572


loss 1.39083 acc 0.50416: 100%|██████████| 782/782 [02:57<00:00,  4.42it/s]
val_loss 1.38981 val_acc 0.50717: 100%|██████████| 157/157 [00:12<00:00, 12.66it/s]


Epoch 3/20 loss 1.39083 val_loss 1.38981 acc 0.50416 val_acc 0.50717


loss 1.31788 acc 0.53489: 100%|██████████| 782/782 [02:57<00:00,  4.41it/s]
val_loss 1.29482 val_acc 0.53782: 100%|██████████| 157/157 [00:12<00:00, 12.44it/s]


Epoch 4/20 loss 1.31788 val_loss 1.29482 acc 0.53489 val_acc 0.53782


loss 1.25673 acc 0.55752: 100%|██████████| 782/782 [02:58<00:00,  4.39it/s]
val_loss 1.23596 val_acc 0.56469: 100%|██████████| 157/157 [00:12<00:00, 12.27it/s]


Epoch 5/20 loss 1.25673 val_loss 1.23596 acc 0.55752 val_acc 0.56469


loss 1.20853 acc 0.57511: 100%|██████████| 782/782 [02:57<00:00,  4.40it/s]
val_loss 1.23360 val_acc 0.56141: 100%|██████████| 157/157 [00:12<00:00, 12.70it/s]


Epoch 6/20 loss 1.20853 val_loss 1.23360 acc 0.57511 val_acc 0.56141


loss 1.16404 acc 0.59433: 100%|██████████| 782/782 [02:56<00:00,  4.42it/s]
val_loss 1.16319 val_acc 0.59395: 100%|██████████| 157/157 [00:12<00:00, 12.70it/s]


Epoch 7/20 loss 1.16404 val_loss 1.16319 acc 0.59433 val_acc 0.59395


loss 1.12497 acc 0.60949: 100%|██████████| 782/782 [02:56<00:00,  4.43it/s]
val_loss 1.13866 val_acc 0.59604: 100%|██████████| 157/157 [00:12<00:00, 12.77it/s]


Epoch 8/20 loss 1.12497 val_loss 1.13866 acc 0.60949 val_acc 0.59604


loss 1.08917 acc 0.61980: 100%|██████████| 782/782 [02:56<00:00,  4.42it/s]
val_loss 1.11041 val_acc 0.61286: 100%|██████████| 157/157 [00:12<00:00, 12.74it/s]


Epoch 9/20 loss 1.08917 val_loss 1.11041 acc 0.61980 val_acc 0.61286


loss 1.05887 acc 0.63307: 100%|██████████| 782/782 [02:56<00:00,  4.42it/s]
val_loss 1.10367 val_acc 0.60798: 100%|██████████| 157/157 [00:12<00:00, 12.66it/s]


Epoch 10/20 loss 1.05887 val_loss 1.10367 acc 0.63307 val_acc 0.60798


loss 1.03013 acc 0.64266: 100%|██████████| 782/782 [02:56<00:00,  4.42it/s]
val_loss 1.09475 val_acc 0.61873: 100%|██████████| 157/157 [00:12<00:00, 12.69it/s]


Epoch 11/20 loss 1.03013 val_loss 1.09475 acc 0.64266 val_acc 0.61873


loss 1.00188 acc 0.65235: 100%|██████████| 782/782 [02:57<00:00,  4.41it/s]
val_loss 1.05370 val_acc 0.63217: 100%|██████████| 157/157 [00:12<00:00, 12.68it/s]


Epoch 12/20 loss 1.00188 val_loss 1.05370 acc 0.65235 val_acc 0.63217


loss 0.97919 acc 0.66145: 100%|██████████| 782/782 [02:57<00:00,  4.41it/s]
val_loss 1.08406 val_acc 0.61236: 100%|██████████| 157/157 [00:12<00:00, 12.42it/s]


Epoch 13/20 loss 0.97919 val_loss 1.08406 acc 0.66145 val_acc 0.61236


loss 0.95461 acc 0.67076: 100%|██████████| 782/782 [02:57<00:00,  4.42it/s]
val_loss 1.01512 val_acc 0.64441: 100%|██████████| 157/157 [00:12<00:00, 12.71it/s]


Epoch 14/20 loss 0.95461 val_loss 1.01512 acc 0.67076 val_acc 0.64441


loss 0.93446 acc 0.67757: 100%|██████████| 782/782 [02:57<00:00,  4.41it/s]
val_loss 1.06748 val_acc 0.62838: 100%|██████████| 157/157 [00:12<00:00, 12.50it/s]


Epoch 15/20 loss 0.93446 val_loss 1.06748 acc 0.67757 val_acc 0.62838


loss 0.91159 acc 0.68702: 100%|██████████| 782/782 [02:57<00:00,  4.41it/s]
val_loss 1.00568 val_acc 0.64809: 100%|██████████| 157/157 [00:12<00:00, 12.63it/s]


Epoch 16/20 loss 0.91159 val_loss 1.00568 acc 0.68702 val_acc 0.64809


loss 0.89616 acc 0.69242: 100%|██████████| 782/782 [02:56<00:00,  4.42it/s]
val_loss 0.99010 val_acc 0.65207: 100%|██████████| 157/157 [00:12<00:00, 12.58it/s]


Epoch 17/20 loss 0.89616 val_loss 0.99010 acc 0.69242 val_acc 0.65207


loss 0.87591 acc 0.69887: 100%|██████████| 782/782 [02:57<00:00,  4.41it/s]
val_loss 0.96447 val_acc 0.66521: 100%|██████████| 157/157 [00:12<00:00, 12.62it/s]


Epoch 18/20 loss 0.87591 val_loss 0.96447 acc 0.69887 val_acc 0.66521


loss 0.85680 acc 0.70520: 100%|██████████| 782/782 [02:57<00:00,  4.41it/s]
val_loss 0.95875 val_acc 0.67197: 100%|██████████| 157/157 [00:12<00:00, 12.69it/s]


Epoch 19/20 loss 0.85680 val_loss 0.95875 acc 0.70520 val_acc 0.67197


loss 0.84052 acc 0.71180: 100%|██████████| 782/782 [02:57<00:00,  4.41it/s]
val_loss 0.95139 val_acc 0.67317: 100%|██████████| 157/157 [00:12<00:00, 12.64it/s]

Epoch 20/20 loss 0.84052 val_loss 0.95139 acc 0.71180 val_acc 0.67317



