# Paper
## Imports
https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=10863523

bib reference -> 10863523

In [4]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from tensorflow.keras.preprocessing.image import ImageDataGenerator # type: ignore

from src.config import *
from src.data import *
from src.models.efficientnet import EfficientNetB5Custom
from src.utils import *
from src.data import OriginalOAIDataset
from src.train import train, train_model
from src.trainers.classification import Classification


# Class balance

In [None]:
NEW_OAI_DATASET = 'dataset/mendeleyOAI_dataset/augmented_dataset_1'


In [None]:
# Original dataset
data = explorar_split_data(MENDELEY_OAI_224_SPLIT_PATH)

In [None]:
ORIGINAL_TRAIN_PATH = os.path.join(MENDELEY_OAI_224_SPLIT_PATH, 'train')
classes = [d for d in os.listdir(ORIGINAL_TRAIN_PATH) if os.path.isdir(os.path.join(ORIGINAL_TRAIN_PATH, d))]
data_gen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)
if not os.path.exists(NEW_OAI_DATASET):
    os.makedirs(NEW_OAI_DATASET)

augmentation_classes = [1, 2, 5, 10, 20]
for class_name in classes:
    class_dir = os.path.join(ORIGINAL_TRAIN_PATH, class_name)
    print(int(class_name))
    num_augmentations = augmentation_classes[int(class_name)]
    print(f"Generando imágenes aumentadas para la clase {class_name}...")
    print(f"Se generarán {num_augmentations} imágenes")
    print(f"Directorio original de la clase: {class_dir}")
    imagenes_generadas = 0
    for img_name in os.listdir(class_dir):
        
        
        img_path = os.path.join(class_dir, img_name)
        img = cv2.imread(img_path)
        
        # Verificar si la imagen fue leída correctamente
        if img is None:
            print(f"Error al leer la imagen {img_path}. Puede que no sea una imagen válida o esté dañada.")
            continue
        
        # Convertir la imagen a un numpy array
        img_array = np.array(img)
        img_array = img_array.reshape((1,) + img_array.shape)  # Añadir dimensión batch
        # Generar imágenes aumentadas
        """
        for i in range(num_augmentations):
            for batch in data_gen.flow(img_array, batch_size=1, save_to_dir=NEW_OAI_DATASET, save_prefix='aug', save_format='png'):
                break
        """
        imagenes_generadas += num_augmentations

    print(f"Se han generado {imagenes_generadas} imágenes aumentadas para la clase {class_name}\n-----------------------------------\n")

# Train Model

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
])
BATCH_SIZE = 20
LEARNING_RATE = 0.001
FACTOR = 0.001
L1 = 0.001
L2 = 0.001
PATIENCE = 5
BETAS=(0.9, 0.999)
# Regularización L1 y L2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = OriginalOAIDataset('train', batch_size=BATCH_SIZE, transform=transform, local=True)
val_dataset = OriginalOAIDataset('val', batch_size=BATCH_SIZE, transform=transform, local=True)
model = EfficientNetB5Custom(num_classes = 5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trainer = Classification(model, device, L1=L1, L2=L2, lr=LEARNING_RATE, factor=FACTOR, patience=PATIENCE, betas=BETAS)

LOCAL MODE ENABLED
LOCAL MODE ENABLED




In [8]:
model.to(device)
print(device)
train_model(model, trainer, train_dataset, val_dataset, epochs=2, device=device)



cpu


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

Training Epoch [1/2]: 100%|█| 25/25 [07:04<00:00, 16.96s/it, curr_train_loss=1116.19480675, val_loss
Validation Epoch [1/2]: 100%|██████████████████| 22/22 [01:43<00:00,  4.71s/it, val_loss=1.61081486]
Training Epoch [2/2]: 100%|█| 25/25 [07:15<00:00, 17.41s/it, curr_train_loss=738.76972840, val_loss=
Validation Epoch [2/2]: 100%|██████████████████| 22/22 [01:44<00:00,  4.74s/it, val_loss=1.61107863]
