##Entrenamiento del modelo

In [1]:
# Libaries
import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
import torch.optim as opt
from torch.utils.data import DataLoader
from torch import nn as nn
from MODEL import REDCN1

##Activacionde la gpu(si hay una disponible)

In [2]:
# GPU activation
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

##Carga de los set de datos

In [3]:
# data load
PATH = "./DATA/DATA.pth"
train = torch.load(PATH)
print(train)

{'train_dataloader': <torch.utils.data.dataloader.DataLoader object at 0x000001A07E4934D0>, 'test_dataloader': <torch.utils.data.dataloader.DataLoader object at 0x000001A07E4A4650>, 'validation_dataloader': <torch.utils.data.dataloader.DataLoader object at 0x000001A07E4A4B50>, 'classes': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']}


##Carga de la RED

In [4]:
# Model load
RED = REDCN1()
print(RED)
criterion = nn.CrossEntropyLoss()
optimizer = opt.SGD(RED.parameters(), lr=0.001, momentum=0.9)

REDCN1(
  (relu_conv): Sequential(
    (0): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(20, 30, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): Conv2d(30, 40, kernel_size=(5, 5), stride=(1, 1))
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Flatten(start_dim=1, end_dim=-1)
  )
  (relu_linear): Sequential(
    (0): Linear(in_features=360, out_features=700, bias=True)
    (1): SELU()
    (2): Linear(in_features=700, out_features=300, bias=True)
    (3): SELU()
    (4): Linear(in_features=300, out_features=150, bias=True)
    (5): SELU()
    (6): Linear(in_features=150, out_features=80, bias=True)
    (7): SELU()
    (8): Linear(in_features=80, out_features=20, bias=True)
  )
)


##Ciclo de entrenamiento

In [5]:
# training cycle
graphic1 = []
graphic2 = []
EPOCH = []
for epoch in range(10):
    train_loss = 0.0
    valid_loss = 0.0
    for inputs, labels in train["train_dataloader"]:
        optimizer.zero_grad()
        outputs = RED(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    RED.eval()
    for inputs, labels in train["validation_dataloader"]:
        outputs = RED(inputs)
        loss = criterion(outputs, labels)
        valid_loss += loss.item()

    train_loss = train_loss / len(train["train_dataloader"])
    print("[Para la epoca", epoch + 1, "] loss:", train_loss)
    graphic1.append(train_loss)

    valid_loss = valid_loss / len(train["validation_dataloader"])
    print("[Para la epoca", epoch + 1, "] val_loss:", valid_loss)
    graphic2.append(valid_loss)

    EPOCH.append(float(epoch))

[Para la epoca 1 ] loss: 1.642210458651185
[Para la epoca 1 ] val_loss: 1.322207680940628
[Para la epoca 2 ] loss: 1.2438485040061176
[Para la epoca 2 ] val_loss: 1.0966651159003378
[Para la epoca 3 ] loss: 1.0793774187099188
[Para la epoca 3 ] val_loss: 0.9709988304153084
[Para la epoca 4 ] loss: 0.9756159084467217
[Para la epoca 4 ] val_loss: 0.8481776649467647
[Para la epoca 5 ] loss: 0.8915466266402975
[Para la epoca 5 ] val_loss: 0.7657714611440897
[Para la epoca 6 ] loss: 0.8222235562787857
[Para la epoca 6 ] val_loss: 0.738847901029978
[Para la epoca 7 ] loss: 0.7613086444426095
[Para la epoca 7 ] val_loss: 0.6477920900890604
[Para la epoca 8 ] loss: 0.7079946303341887
[Para la epoca 8 ] val_loss: 0.6707857513458002
[Para la epoca 9 ] loss: 0.6515025270202197
[Para la epoca 9 ] val_loss: 0.5344303504514973
[Para la epoca 10 ] loss: 0.6039271916635814
[Para la epoca 10 ] val_loss: 0.49589015344984366


##Guardado de la Red entrenada

In [6]:
# Saved Model
PATH = "./trained_model/RED_entrenada.pth"
torch.save(
    {
        "epoch": EPOCH,
        "model_state_dict": RED.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
        "loss_epoch": graphic1,
        "loss_epoch_validation": graphic2,
    },
    PATH,
)