# Entrenamiento del modelo con las imágenes preprocesadas

In [1]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import os
from torchvision import transforms

class OCRDataset(Dataset):
    def __init__(self, csv_file, image_folder, transform=None, char2idx=None, idx2char=None):
        self.data = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform
        self.char2idx = char2idx
        self.idx2char = idx2char

    def __len__(self):
        return len(self.data)

    def encode_label(self, text):
        #return [self.char2idx[c] for c in text if c in self.char2idx]
        
        # Limpieza básica: remueve caracteres no imprimibles
        cleaned_text = "".join(char for char in str(text) if char.isprintable())
        # Mapea cada carácter, usa <unk> si no existe
        return [self.char2idx.get(char, self.char2idx['<unk>']) for char in cleaned_text]

    def decode_label(self, indices):
        return ''.join([self.idx2char[i] for i in indices if i in self.idx2char])

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.image_folder, row['Direccion'])
        
        '''
        label = row['Texto']

        image = Image.open(img_path).convert('L')
        if self.transform:
            image = self.transform(image)

        encoded = torch.tensor(self.encode_label(label), dtype=torch.long)
        return image, encoded
        '''

        try:
            label = str(row['Texto']).strip()
            if not label:  # Si está vacío o es NaN
                return None
            
            image = Image.open(img_path).convert('L')
            if self.transform:
                image = self.transform(image)
            
            encoded = torch.tensor(self.encode_label(label), dtype=torch.long)
            if len(encoded) == 0:  # Si no quedan caracteres válidos
                return None
                
            return image, encoded
        except:
            return None  # Filtra errores silenciosamente

ModuleNotFoundError: No module named 'torch'

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [3]:
import string
from collections import Counter

# Carga el CSV y extrae todos los caracteres únicos
df = pd.read_csv("../Data/ImagenTexto.csv")
all_text = "".join(df['Texto'].dropna().astype(str))
char_counts = Counter(all_text)

# Define allowed characters (you can customize this)
#all_chars = string.ascii_letters + string.digits + string.punctuation + ' '
all_chars = sorted(char_counts.keys())
#char2idx = {char: idx + 1 for idx, char in enumerate(all_chars)}  # Start from 1
char2idx = {char: idx + 2 for idx, char in enumerate(all_chars)} 
char2idx['<blank>'] = 0  # CTC requires blank token at index 0
char2idx['<unk>'] = 1    # Token para caracteres raros no vistos durante el entrenamiento

idx2char = {idx: char for char, idx in char2idx.items()}

In [4]:
from torch.utils.data import DataLoader

def custom_collate_fn(batch):
    images, labels = zip(*batch)  # images: tuple of tensors; labels: tuple of 1D tensors
    images = torch.stack(images, dim=0)
    return images, labels

dataset = OCRDataset("../Data/ImagenTexto.csv", "../Data/Anotaciones", transform=transform, char2idx=char2idx, idx2char=idx2char)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, collate_fn=custom_collate_fn)

In [5]:
import torch
import torch.nn as nn

class CRNN(nn.Module):
    def __init__(self, img_height, num_channels, num_classes, rnn_hidden_size=256):
        super(CRNN, self).__init__()

        # Feature extractor (CNN backbone)
        self.cnn = nn.Sequential(
            nn.Conv2d(num_channels, 64, 3, 1, 1),  # output: (64, H, W)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),                   # output: (64, H/2, W/2)

            nn.Conv2d(64, 128, 3, 1, 1),          # output: (128, H/2, W/2)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),                   # output: (128, H/4, W/4)

            nn.Conv2d(128, 256, 3, 1, 1),         # output: (256, H/4, W/4)
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.Conv2d(256, 256, 3, 1, 1),         # output: (256, H/4, W/4)
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1)),         # output: (256, H/8, W/4)

            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.Conv2d(512, 512, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1), (2, 1)),         # output: (512, H/16, W/4)

            nn.Conv2d(512, 512, 2, 1, 0),         # output: (512, H/16 -1, W/4 -1)
            nn.ReLU()
        )

        # RNN for sequence modeling
        '''self.rnn = nn.Sequential(
            nn.LSTM(512, rnn_hidden_size, bidirectional=True, batch_first=True),
            nn.LSTM(2 * rnn_hidden_size, rnn_hidden_size, bidirectional=True, batch_first=True)
        )'''

        self.lstm1 = nn.LSTM(512, rnn_hidden_size, bidirectional=True, batch_first=True)
        self.lstm2 = nn.LSTM(2 * rnn_hidden_size, rnn_hidden_size, bidirectional=True, batch_first=True)

        # Final classifier
        self.fc = nn.Linear(2 * rnn_hidden_size, num_classes)

    def forward(self, x):
        # x: (batch, channels, height, width)
        conv_out = self.cnn(x)  # shape: (B, C, H, W)
        b, c, h, w = conv_out.size()

        #assert h == 1 or h == 2, f"Unexpected height: {h}, check image input size and pooling"

        assert h == 1, f"Height must be 1 after CNN, got {h}"

        conv_out = conv_out.squeeze(2)  # remove height dim -> (B, C, W)
        conv_out = conv_out.permute(0, 2, 1)  # (B, W, C)

        '''rnn_out, _ = self.rnn(conv_out)  # (B, W, 2*hidden)
        out = self.fc(rnn_out)  # (B, W, num_classes)'''

        # RNN modificado
        lstm_out, _ = self.lstm1(conv_out)
        lstm_out, _ = self.lstm2(lstm_out)
        
        # Clasificación
        out = self.fc(lstm_out)  # (B, W, num_classes)

        return out.permute(1, 0, 2)  # (W, B, num_classes) for CTC loss

In [None]:
# === PRUEBAS RÁPIDAS (CELDA APARTE) ===
from torch.utils.data import Subset
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn
import torch
import os


# 1. Define un subset pequeño (ej: 10% de los datos)
subset_size = 5000  # Ajusta este número
#indices = torch.randperm(len(dataset))[:subset_size]  # Muestras aleatorias
#train_subset = Subset(dataset, indices)
indices = torch.randperm(len(dataset))[:subset_size].cpu().numpy()  # Convertir a array numpy
train_subset = Subset(dataset, indices)  # Ahora usa índices compatibles

# 2. Crea un DataLoader temporal para pruebas
test_dataloader = DataLoader(
    train_subset,
    batch_size=128,
    shuffle=True,
    num_workers=0,
    collate_fn = custom_collate_fn
    #pin_memory=True
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CRNN(img_height=32, num_channels=1, num_classes=len(char2idx)).to(device)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# 3. Usa este dataloader en tu bucle de entrenamiento de prueba
num_test_epochs = 2  # Épocas para probar
for epoch in range(num_test_epochs):
    # Inicializa la barra de progreso para CADA ÉPOCA
    loop = tqdm(test_dataloader, desc=f"Epoch [{epoch + 1}/{num_test_epochs}]", leave=True)
    total_loss = 0  # Reinicia la pérdida por época

    # Verifica que el DataLoader funcione
    try:
        test_batch = next(iter(test_dataloader))
        print(f"Batch de prueba - Imágenes: {test_batch[0].shape}, Labels: {len(test_batch[1])}")
    except Exception as e:
        print(f"Error al cargar el batch de prueba: {e}")
    
    for images, labels in test_dataloader:
        images = images.to(device)
        label_lengths = torch.tensor([len(t) for t in labels], dtype=torch.long)
        targets = torch.cat([t for t in labels]).to(device)

        outputs = model(images)
        log_probs = outputs.log_softmax(2)
        input_lengths = torch.full(
            size=(images.size(0),), 
            fill_value=outputs.size(0),
            dtype=torch.long
        ).to(device)

        # Forward pass (opcional, solo para medir velocidad)
        #with torch.no_grad():  # Desactiva gradientes para solo medir carga/transferencia
            #outputs = model(images)

        outputs = model(images)
        log_probs = outputs.log_softmax(2)
        
        # Monitorea VRAM y temperatura (opcional)
        used_vram = torch.cuda.memory_allocated() / 1024**3
        loop.set_postfix(vram=f"{used_vram:.2f}GB")

        loss = ctc_loss(log_probs, targets, input_lengths, label_lengths)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(
            loss=f"{loss.item():.4f}",
            avg_loss=f"{total_loss / (loop.n + 1):.4f}",
            vram=f"{torch.cuda.memory_allocated() / 1024**3:.2f}GB"
        )
    print(f"Época {epoch + 1} - Loss: {loss.item():.4f}")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import os

# --- 1. Funciones para checkpoints (AGREGAR ESTO AL INICIO) ---
def save_checkpoint(epoch, model, optimizer, loss, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, path)

def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']

# --- 2. Configuración inicial (MODIFICAR ESTA PARTE) ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CRNN(img_height=32, num_channels=1, num_classes=len(char2idx)).to(device)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Crear carpeta para checkpoints
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# --- 3. Cargar checkpoint previo si existe (NUEVO) ---
start_epoch = 0
checkpoint_path = os.path.join(checkpoint_dir, "last_checkpoint.pth")
if os.path.exists(checkpoint_path):
    start_epoch, _ = load_checkpoint(model, optimizer, checkpoint_path)
    print(f"Reanudando entrenamiento desde epoch {start_epoch + 1}")

# --- 4. Bucle de entrenamiento (MODIFICADO) ---
num_epochs = 10
for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0

    loop = tqdm(dataloader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)
    
    for images, labels in loop:
        images = images.to(device)
        label_lengths = torch.tensor([len(t) for t in labels], dtype=torch.long)
        targets = torch.cat([t for t in labels]).to(device)

        outputs = model(images)
        log_probs = outputs.log_softmax(2)
        input_lengths = torch.full(
            size=(images.size(0),), 
            fill_value=outputs.size(0),
            dtype=torch.long
        ).to(device)

        loss = ctc_loss(log_probs, targets, input_lengths, label_lengths)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}: Avg Loss = {avg_loss:.4f}")

    # --- 5. Guardar checkpoint (NUEVO) ---
    save_checkpoint(
        epoch + 1,  # Guardamos el siguiente epoch a entrenar
        model,
        optimizer,
        avg_loss,
        checkpoint_path  # Sobrescribe el último checkpoint
    )
    # Opcional: Guardar también un checkpoint por epoch
    epoch_checkpoint = os.path.join(checkpoint_dir, f"epoch_{epoch+1}.pth")
    save_checkpoint(epoch + 1, model, optimizer, avg_loss, epoch_checkpoint)

Epoch [1/10]:   0%|                                                       | 1/1925 [00:18<9:47:12, 18.31s/it, loss=118]

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118