In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
input_size = (224, 224)  # Ajusta según tus necesidades
num_classes = 10         # Número de clases de salida

model_name = 'rustic'  # Cambia a 'rustic' para probar RusticModel

In [5]:
class RusticModel(nn.Module):
    def __init__(self, num_classes):
        super(RusticModel, self).__init__()
        
        # Definición de capas convolucionales
        self.red_conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),  # Capa 1
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # Capa 2
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # Capa 3
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),  # Capa 4
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        # Capa de aplanado y fully connected
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(256 * 14 * 14, 256)  # Cambiado a 14x14
        self.dropout = nn.Dropout(0.30)
        self.fc2 = nn.Linear(256, num_classes)
    
    def forward(self, x):
        x = x / 255.0  # Normalización de entrada
        x = self.red_conv(x)
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


In [6]:
class ViTModel(nn.Module):
    def __init__(self, num_classes, embed_dim=128, num_layers=6, num_heads=8, dropout_rate=0.1):
        super(ViTModel, self).__init__()
        
        # Tamaño del parche y validación
        self.patch_size = 8
        if input_size[0] % self.patch_size != 0 or input_size[1] % self.patch_size != 0:
            raise ValueError("El tamaño de entrada debe ser divisible por el tamaño del parche.")
        
        self.embed_dim = embed_dim
        self.num_patches = (input_size[0] // self.patch_size) * (input_size[1] // self.patch_size)
        
        # Embedding de parches
        self.patch_embeddings = nn.Conv2d(1, embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
        
        # Transformer Encoder
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout_rate)
            for _ in range(num_layers)
        ])
        
        # Capa de salida
        self.fc = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        x = self.patch_embeddings(x)
        x = x.flatten(2).transpose(1, 2)
        
        for layer in self.transformer_layers:
            x = layer(x)
        
        x = x.mean(dim=1)
        return self.fc(x)


In [7]:
def get_model(model_name, num_classes):
    if model_name == 'rustic':
        return RusticModel(num_classes)
    elif model_name == 'vit':
        return ViTModel(num_classes)
    else:
        raise ValueError("Modelo desconocido: selecciona 'rustic' o 'vit'")


In [None]:
# Crear instancia del modelo
model = get_model(model_name, num_classes)

# Prueba con datos de entrada simulados
sample_data = torch.randn(4, 1, *input_size)
output = model(sample_data)
print(f"Modelo {model_name}: Output shape:", output.shape)