In [2]:
import torch
import torch.nn as nn
from lightning import LightningModule, Trainer
from lightning.pytorch.loggers import CSVLogger
from torchvision import models
import os
from lightning.pytorch.callbacks import EarlyStopping

def get_in_features(base_model):
    base_model.eval()
    with torch.no_grad():
        dummy = torch.randn(1, 3, 224, 224).to(next(base_model.parameters()).device)
        out = base_model(dummy)
    return out.shape[1]

class LitClassifier(LightningModule):
    def __init__(self, base_model_fn, weights, freeze_layer_name, num_classes=3, lr=3e-5):
        super().__init__()
        self.save_hyperparameters()
        self.training_outputs = []
        self.validation_outputs = []

        base_model = base_model_fn(weights=weights)
        unfreeze = False
        for name, param in base_model.named_parameters():
            if freeze_layer_name in name:
                unfreeze = True
            param.requires_grad = unfreeze

        if hasattr(base_model, "fc"):
            base_model.fc = nn.Identity()
            in_features = get_in_features(base_model)
        elif hasattr(base_model, "classifier"):
            base_model.classifier = nn.Identity()
            in_features = get_in_features(base_model)

        self.model = nn.Sequential(
            base_model,
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        self.loss_fn = nn.CrossEntropyLoss()
        self.lr = lr

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, _):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        self.training_outputs.append(loss.detach())
        return loss  

    def on_train_epoch_end(self):
        avg_loss = torch.stack(self.training_outputs).mean()
        self.log("train_loss", avg_loss, prog_bar=True, on_step=False, on_epoch=True)
        self.training_outputs.clear()

    def validation_step(self, batch, _):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        self.validation_outputs.append({"val_loss": loss.detach(), "val_acc": acc.detach()})
        return None  

    def on_validation_epoch_end(self):
        avg_loss = torch.stack([x["val_loss"] for x in self.validation_outputs]).mean()
        avg_acc = torch.stack([x["val_acc"] for x in self.validation_outputs]).mean()
        self.log("val_loss", avg_loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("val_acc", avg_acc, prog_bar=True, on_step=False, on_epoch=True)
        self.validation_outputs.clear()
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


def train_and_test_models(train_loader, val_loader, test_loader):
    model_defs = [
        (models.resnet18, models.ResNet18_Weights.DEFAULT, "fc"),
        (models.efficientnet_b0, models.EfficientNet_B0_Weights.DEFAULT, "features.8"),
        (models.mobilenet_v3_large, models.MobileNet_V3_Large_Weights.DEFAULT, "features.8"),
        (models.convnext_tiny, models.ConvNeXt_Tiny_Weights.DEFAULT, "features.7"),
    ]

    learning_rates = [1e-4, 3e-5, 1e-5]

    os.makedirs("checkpoints", exist_ok=True)
    results = {}
    history = {}

    for lr in learning_rates:
        for fn, w, freeze_name in model_defs:
            name = fn.__name__
            print(f"Entrenando {name} con lr={lr}")
            logger = CSVLogger("lightning_logs", name=f"{name}_lr{lr}")
            early_stop_callback = EarlyStopping(
                monitor="val_loss",
                min_delta=0.001,
                patience=5,
                verbose=True,
                mode="min"
            )
                
            model = LitClassifier(fn, w, freeze_name, num_classes=3, lr=lr)
            trainer = Trainer(
                max_epochs=30,
                accelerator="gpu" if torch.cuda.is_available() else "cpu",
                logger=logger,
                callbacks=[early_stop_callback]
            )
            trainer.fit(model, train_loader, val_loader)

            checkpoint_path = f"checkpoints/{name}_lr{lr}_final.ckpt"
            trainer.save_checkpoint(checkpoint_path)

            metrics_df = logger.experiment.metrics
            history[f"{name}_lr{lr}"] = metrics_df

            model.eval()
            all_preds, all_labels = [], []
            for x, y in test_loader:
                x = x.to(model.device)
                with torch.no_grad():
                    preds = model(x).argmax(1)
                    all_preds.append(preds.cpu())
                    all_labels.append(y)
            results[f"{name}_lr{lr}"] = {
                "preds": torch.cat(all_preds),
                "labels": torch.cat(all_labels)
            }

    return results, history

In [27]:
from torchinfo import summary
# 1. ResNet18 descongelando solo la fc
resnet_model = LitClassifier(models.resnet18, models.ResNet18_Weights.DEFAULT, "fc", num_classes=3)
total_params = sum(p.numel() for p in resnet_model.model.parameters())
trainable_params = sum(p.numel() for p in resnet_model.model.parameters() if p.requires_grad)
print(f"ResNet18 total params: {total_params:,}")
print(f"ResNet18 trainable params: {trainable_params:,}")
#summary(resnet_model.model, input_size=(1,3,224,224))


ResNet18 total params: 11,308,611
ResNet18 trainable params: 132,099


In [26]:
# 2. EfficientNet-B0 descongelando a partir de features.8
efficient_model = LitClassifier(models.efficientnet_b0, models.EfficientNet_B0_Weights.DEFAULT, "features.8", num_classes=3)
total_params = sum(p.numel() for p in efficient_model.model.parameters())
trainable_params = sum(p.numel() for p in efficient_model.model.parameters() if p.requires_grad)
print(f"EfficientNet-B0 total params: {total_params:,}")
print(f"EfficientNet-B0 trainable params: {trainable_params:,}")
#summary(efficient_model.model, input_size=(1,3,224,224))


EfficientNet-B0 total params: 4,336,255
EfficientNet-B0 trainable params: 740,867


In [33]:
# 3. MobileNetV3-Large descongelando a partir de features.8
mobilenet_model = LitClassifier(models.mobilenet_v3_large, models.MobileNet_V3_Large_Weights.DEFAULT, "features.8", num_classes=3)
total_params = sum(p.numel() for p in mobilenet_model.model.parameters())
trainable_params = sum(p.numel() for p in mobilenet_model.model.parameters() if p.requires_grad)
print(f"MobileNetV3-Large total params: {total_params:,}")
print(f"MobileNetV3-Large trainable params: {trainable_params:,}")
summary(mobilenet_model.model, input_size=(1,3,224,224))

MobileNetV3-Large total params: 3,218,739
MobileNetV3-Large trainable params: 3,125,539


Layer (type:depth-idx)                                  Output Shape              Param #
Sequential                                              [1, 3]                    --
├─MobileNetV3: 1-1                                      [1, 960]                  --
│    └─Sequential: 2-1                                  [1, 960, 7, 7]            --
│    │    └─Conv2dNormActivation: 3-1                   [1, 16, 112, 112]         (464)
│    │    └─InvertedResidual: 3-2                       [1, 16, 112, 112]         (464)
│    │    └─InvertedResidual: 3-3                       [1, 24, 56, 56]           (3,440)
│    │    └─InvertedResidual: 3-4                       [1, 24, 56, 56]           (4,440)
│    │    └─InvertedResidual: 3-5                       [1, 40, 28, 28]           (10,328)
│    │    └─InvertedResidual: 3-6                       [1, 40, 28, 28]           (20,992)
│    │    └─InvertedResidual: 3-7                       [1, 40, 28, 28]           (20,992)
│    │    └─InvertedResidu

In [32]:
# 4. ConvNeXt-Tiny descongelando a partir de features.7
convnext_model = LitClassifier(models.convnext_tiny, models.ConvNeXt_Tiny_Weights.DEFAULT, "features.7", num_classes=3)
total_params = sum(p.numel() for p in convnext_model.model.parameters())
trainable_params = sum(p.numel() for p in convnext_model.model.parameters() if p.requires_grad)
print(f"ConvNeXt-Tiny total params: {total_params:,}")
print(f"ConvNeXt-Tiny trainable params: {trainable_params:,}")
summary(convnext_model.model, input_size=(1,3,224,224))

ConvNeXt-Tiny total params: 28,016,227
ConvNeXt-Tiny trainable params: 14,487,043


Layer (type:depth-idx)                             Output Shape              Param #
Sequential                                         [1, 3]                    --
├─ConvNeXt: 1-1                                    [1, 768, 1, 1]            --
│    └─Sequential: 2-1                             [1, 768, 7, 7]            --
│    │    └─Conv2dNormActivation: 3-1              [1, 96, 56, 56]           (4,896)
│    │    └─Sequential: 3-2                        [1, 96, 56, 56]           (237,888)
│    │    └─Sequential: 3-3                        [1, 192, 28, 28]          (74,112)
│    │    └─Sequential: 3-4                        [1, 192, 28, 28]          (918,144)
│    │    └─Sequential: 3-5                        [1, 384, 14, 14]          (295,680)
│    │    └─Sequential: 3-6                        [1, 384, 14, 14]          (10,817,280)
│    │    └─Sequential: 3-7                        [1, 768, 7, 7]            (1,181,184)
│    │    └─Sequential: 3-8                        [1, 768, 7, 7

In [34]:
def train_and_test_new_models(train_loader, val_loader, test_loader):
    model_defs = [
        (models.resnet18, models.ResNet18_Weights.DEFAULT, "fc"),
        (models.resnet18, models.ResNet18_Weights.DEFAULT, "layer3"),
        (models.efficientnet_b3, models.EfficientNet_B3_Weights.DEFAULT, "features.7"),
        (models.mobilenet_v3_large, models.MobileNet_V3_Large_Weights.DEFAULT, "features.6"),
        (models.efficientnet_b0, models.EfficientNet_B0_Weights.DEFAULT, "features.7"),
        (models.efficientnet_b0, models.EfficientNet_B0_Weights.DEFAULT, "features.6"),
    ]

    learning_rates = [1e-4]

    os.makedirs("checkpoints", exist_ok=True)
    results = {}
    history = {}

    for lr in learning_rates:
        for fn, w, freeze_name in model_defs:
            #name = fn.__name__
            name = f"{fn.__name__}_{freeze_name.replace('.', '_')}"
            print(f"Entrenando {name} con lr={lr}")
            #logger = CSVLogger("lightning_logs", name=f"{name}_lr{lr}")
            logger = CSVLogger("lightning_logs", name=f"{name}_lr{lr}")
            early_stop_callback = EarlyStopping(
                monitor="val_loss",
                min_delta=0.001,
                patience=5,
                verbose=True,
                mode="min"
            )
                
            model = LitClassifier(fn, w, freeze_name, num_classes=3, lr=lr)
            trainer = Trainer(
                max_epochs=40,
                accelerator="gpu" if torch.cuda.is_available() else "cpu",
                logger=logger,
                callbacks=[early_stop_callback]
            )
            trainer.fit(model, train_loader, val_loader)

            checkpoint_path = f"checkpoints/{name}_lr{lr}_final.ckpt"
            trainer.save_checkpoint(checkpoint_path)

            metrics_df = logger.experiment.metrics
            history[f"{name}_lr{lr}"] = metrics_df

            model.eval()
            all_preds, all_labels = [], []
            for x, y in test_loader:
                x = x.to(model.device)
                with torch.no_grad():
                    preds = model(x).argmax(1)
                    all_preds.append(preds.cpu())
                    all_labels.append(y)
            results[f"{name}_lr{lr}"] = {
                "preds": torch.cat(all_preds),
                "labels": torch.cat(all_labels)
            }

    return results, history

In [35]:
# 5. ResNet18 descongelando a partir de layer3
resnet18_layer3_model = LitClassifier(models.resnet18, models.ResNet18_Weights.DEFAULT, "layer3", num_classes=3)
total_params = sum(p.numel() for p in resnet18_layer3_model.model.parameters())
trainable_params = sum(p.numel() for p in resnet18_layer3_model.model.parameters() if p.requires_grad)
print(f"ResNet18 (layer3) total params: {total_params:,}")
print(f"ResNet18 (layer3) trainable params: {trainable_params:,}")
summary(resnet18_layer3_model.model, input_size=(1,3,224,224))

ResNet18 (layer3) total params: 11,308,611
ResNet18 (layer3) trainable params: 10,625,539


Layer (type:depth-idx)                        Output Shape              Param #
Sequential                                    [1, 3]                    --
├─ResNet: 1-1                                 [1, 512]                  --
│    └─Conv2d: 2-1                            [1, 64, 112, 112]         (9,408)
│    └─BatchNorm2d: 2-2                       [1, 64, 112, 112]         (128)
│    └─ReLU: 2-3                              [1, 64, 112, 112]         --
│    └─MaxPool2d: 2-4                         [1, 64, 56, 56]           --
│    └─Sequential: 2-5                        [1, 64, 56, 56]           --
│    │    └─BasicBlock: 3-1                   [1, 64, 56, 56]           (73,984)
│    │    └─BasicBlock: 3-2                   [1, 64, 56, 56]           (73,984)
│    └─Sequential: 2-6                        [1, 128, 28, 28]          --
│    │    └─BasicBlock: 3-3                   [1, 128, 28, 28]          (230,144)
│    │    └─BasicBlock: 3-4                   [1, 128, 28, 28]      

In [36]:
# 6. EfficientNet-B3 descongelando a partir de features.7
efficientnet_b3_model = LitClassifier(models.efficientnet_b3, models.EfficientNet_B3_Weights.DEFAULT, "features.7", num_classes=3)
total_params = sum(p.numel() for p in efficientnet_b3_model.model.parameters())
trainable_params = sum(p.numel() for p in efficientnet_b3_model.model.parameters() if p.requires_grad)
print(f"EfficientNet-B3 (features.7) total params: {total_params:,}")
print(f"EfficientNet-B3 (features.7) trainable params: {trainable_params:,}")
summary(efficientnet_b3_model.model, input_size=(1,3,224,224))

EfficientNet-B3 (features.7) total params: 11,090,475
EfficientNet-B3 (features.7) trainable params: 4,271,357


Layer (type:depth-idx)                                       Output Shape              Param #
Sequential                                                   [1, 3]                    --
├─EfficientNet: 1-1                                          [1, 1536]                 --
│    └─Sequential: 2-1                                       [1, 1536, 7, 7]           --
│    │    └─Conv2dNormActivation: 3-1                        [1, 40, 112, 112]         (1,160)
│    │    └─Sequential: 3-2                                  [1, 24, 112, 112]         (3,504)
│    │    └─Sequential: 3-3                                  [1, 32, 56, 56]           (48,118)
│    │    └─Sequential: 3-4                                  [1, 48, 28, 28]           (110,912)
│    │    └─Sequential: 3-5                                  [1, 96, 14, 14]           (638,700)
│    │    └─Sequential: 3-6                                  [1, 136, 14, 14]          (1,387,760)
│    │    └─Sequential: 3-7                             

In [37]:
# 7. MobileNet-V3 Large descongelando a partir de features.6
mobilenet_v3_model = LitClassifier(models.mobilenet_v3_large, models.MobileNet_V3_Large_Weights.DEFAULT, "features.6", num_classes=3)
total_params = sum(p.numel() for p in mobilenet_v3_model.model.parameters())
trainable_params = sum(p.numel() for p in mobilenet_v3_model.model.parameters() if p.requires_grad)
print(f"MobileNet-V3 Large (features.6) total params: {total_params:,}")
print(f"MobileNet-V3 Large (features.6) trainable params: {trainable_params:,}")
summary(mobilenet_v3_model.model, input_size=(1,3,224,224))

MobileNet-V3 Large (features.6) total params: 3,218,739
MobileNet-V3 Large (features.6) trainable params: 3,178,611


Layer (type:depth-idx)                                  Output Shape              Param #
Sequential                                              [1, 3]                    --
├─MobileNetV3: 1-1                                      [1, 960]                  --
│    └─Sequential: 2-1                                  [1, 960, 7, 7]            --
│    │    └─Conv2dNormActivation: 3-1                   [1, 16, 112, 112]         (464)
│    │    └─InvertedResidual: 3-2                       [1, 16, 112, 112]         (464)
│    │    └─InvertedResidual: 3-3                       [1, 24, 56, 56]           (3,440)
│    │    └─InvertedResidual: 3-4                       [1, 24, 56, 56]           (4,440)
│    │    └─InvertedResidual: 3-5                       [1, 40, 28, 28]           (10,328)
│    │    └─InvertedResidual: 3-6                       [1, 40, 28, 28]           (20,992)
│    │    └─InvertedResidual: 3-7                       [1, 40, 28, 28]           20,992
│    │    └─InvertedResidual

In [38]:
# 8. EfficientNet-B0 descongelando a partir de features.7
efficientnet_b0_model = LitClassifier(models.efficientnet_b0, models.EfficientNet_B0_Weights.DEFAULT, "features.7", num_classes=3)
total_params = sum(p.numel() for p in efficientnet_b0_model.model.parameters())
trainable_params = sum(p.numel() for p in efficientnet_b0_model.model.parameters() if p.requires_grad)
print(f"EfficientNet-B0 (features.7) total params: {total_params:,}")
print(f"EfficientNet-B0 (features.7) trainable params: {trainable_params:,}")
summary(efficientnet_b0_model.model, input_size=(1,3,224,224))

EfficientNet-B0 (features.7) total params: 4,336,255
EfficientNet-B0 (features.7) trainable params: 1,458,099


Layer (type:depth-idx)                                       Output Shape              Param #
Sequential                                                   [1, 3]                    --
├─EfficientNet: 1-1                                          [1, 1280]                 --
│    └─Sequential: 2-1                                       [1, 1280, 7, 7]           --
│    │    └─Conv2dNormActivation: 3-1                        [1, 32, 112, 112]         (928)
│    │    └─Sequential: 3-2                                  [1, 16, 112, 112]         (1,448)
│    │    └─Sequential: 3-3                                  [1, 24, 56, 56]           (16,714)
│    │    └─Sequential: 3-4                                  [1, 40, 28, 28]           (46,640)
│    │    └─Sequential: 3-5                                  [1, 80, 14, 14]           (242,930)
│    │    └─Sequential: 3-6                                  [1, 112, 14, 14]          (543,148)
│    │    └─Sequential: 3-7                                  

In [39]:
# 9. EfficientNet-B0 descongelando a partir de features.6
efficientnet_b0_model_6 = LitClassifier(models.efficientnet_b0, models.EfficientNet_B0_Weights.DEFAULT, "features.6", num_classes=3)
total_params = sum(p.numel() for p in efficientnet_b0_model_6.model.parameters())
trainable_params = sum(p.numel() for p in efficientnet_b0_model_6.model.parameters() if p.requires_grad)
print(f"EfficientNet-B0 (features.6) total params: {total_params:,}")
print(f"EfficientNet-B0 (features.6) trainable params: {trainable_params:,}")
summary(efficientnet_b0_model_6.model, input_size=(1,3,224,224))

EfficientNet-B0 (features.6) total params: 4,336,255
EfficientNet-B0 (features.6) trainable params: 3,484,447


Layer (type:depth-idx)                                       Output Shape              Param #
Sequential                                                   [1, 3]                    --
├─EfficientNet: 1-1                                          [1, 1280]                 --
│    └─Sequential: 2-1                                       [1, 1280, 7, 7]           --
│    │    └─Conv2dNormActivation: 3-1                        [1, 32, 112, 112]         (928)
│    │    └─Sequential: 3-2                                  [1, 16, 112, 112]         (1,448)
│    │    └─Sequential: 3-3                                  [1, 24, 56, 56]           (16,714)
│    │    └─Sequential: 3-4                                  [1, 40, 28, 28]           (46,640)
│    │    └─Sequential: 3-5                                  [1, 80, 14, 14]           (242,930)
│    │    └─Sequential: 3-6                                  [1, 112, 14, 14]          (543,148)
│    │    └─Sequential: 3-7                                  