In [2]:
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import torch.nn as nn
import torch.optim as optim
from livelossplot import PlotLosses

In [5]:
train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
    
])

val_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset_train = datasets.ImageFolder(root='dataset', transform=train_transforms)
dataset_val = datasets.ImageFolder(root='dataset', transform=val_transforms)

targets = np.array([s[1] for s in dataset_train.samples])

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, val_idx = next(sss.split(np.zeros(len(targets)),targets))

train_dataset = Subset(dataset_train,train_idx)
val_dataset = Subset(dataset_val,val_idx)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


## Setup Device

In [11]:
device = torch.device('mps' if torch.cuda.is_available() else 'cpu') # en mac cambiar 'cuda' por 'mps' esta implementación se uso en una rtx 3060 

## Hiperparametros

In [21]:
BATCH_SIZE = 32
NUM_CLASSES = 20
INTIAL_EPOCHS = 5
FINE_TUNE_EPOCHS = 10
LEARNING_RATE_HEAD = 1e-3
LEARNING_RATE_FINE = 1e-4

In [None]:
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
weights = EfficientNet_V2_S_Weights.IMAGENET1K_V1
num_classes = 9
model = efficientnet_v2_s(weights=weights)
model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=NUM_CLASSES)
model = model.to(device)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=LEARNING_RATE_HEAD)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

In [25]:
liveloss = PlotLosses()
best_acc = 0 
logs = {}

In [24]:
for param in model.features.parameters():
    param.requires_grad = False

In [30]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    correct = total = train_loss = 0
    for xb,yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        outputs = model(xb)
        loss = criterion(outputs, yb)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, preds = outputs.max(1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    logs['loss'] = train_loss / len(loader)
    logs['accuracy'] = correct/total

In [31]:
def evaluate(model, loader, criterion):
    model.eval()
    val_loss = correct = total = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            outputs = model(xb)
            loss = criterion(outputs, yb)
            val_loss += loss.item()
            _, preds = outputs.max(1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)

    logs['val_loss'] = val_loss / len(loader)
    logs['val_accuracy'] = correct / total


In [None]:
print("Entrenando el head")
for epoch in range(INTIAL_EPOCHS):
    train_one_epoch(model, train_loader, optimizer, criterion)
    evaluate(model, val_loader, criterion)
    
    liveloss.update(logs)
    liveloss.draw()
    
    scheduler.step(logs['val_loss'])
    
    if logs['val_accuracy'] > best_acc:
        best_acc = logs['val_accuracy']
        torch.save(model.state_dict(), 'best_model_head.pth')
        print(f"Mejor validación: {best_acc:.4f}")

print("Descongelando el backbone")

for param in model.features[-1].parameters():
    param.requires_grad = True

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE_FINE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

print("Fine tuning")
for epoch in range(FINE_TUNE_EPOCHS):
    train_one_epoch(model, train_loader, optimizer, criterion)
    evaluate(model, val_loader, criterion)
    
    liveloss.update(logs)
    liveloss.draw()
    
    scheduler.step(logs['val_loss'])
    
    if logs['val_accuracy'] > best_acc:
        best_acc = logs['val_accuracy']
        torch.save(model.state_dict(), 'best_model_finetuned.pth')
        print(f"Mejor validación: {best_acc:.4f}")


