In [None]:
import os
import glob
import random
from PIL import Image
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models, datasets
from torchvision.models import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights
import wandb

In [None]:
class WignerFolderDataset(Dataset):
    
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        class_dirs = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        
        for class_idx, class_name in enumerate(class_dirs):
            class_path = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_path):
                if img_name.endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(class_path, img_name))
                    self.labels.append(class_idx)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('L')
        target_size = (64, 64)
        image = image.resize(target_size, Image.LANCZOS)
        image = np.array(image, dtype=np.float32)
        image = image / 255.0

        image = torch.FloatTensor(image).unsqueeze(0)
        image = image.repeat(3, 1, 1)
        
        if self.transform:
            image = self.transform(image)
        
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        image = normalize(image)
        
        label = self.labels[idx]
        return image, label

    def get_class_name(self, label_idx):
        return self.class_names[label_idx]
    
    def get_transforms(train=True):
        if train:
            return transforms.Compose([
                transforms.RandomRotation(10),
                transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
                transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
            ])
        else:
            return None


In [None]:
class WignerTransferLearningModel(nn.Module):
    def __init__(self, num_classes=10, model_name='resnet18', freeze_backbone=True):
        super().__init__()
        
        if model_name == 'resnet18':
            self.backbone = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
            num_features = 512
        elif model_name == 'resnet34':
            self.backbone = models.resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
            num_features = 512
        elif model_name == 'resnet50':
            self.backbone = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
            num_features = 2048
        else:
            raise ValueError(f"Modelo {model_name} no soportado")
        
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
        
        self.features = nn.Sequential(*list(self.backbone.children())[:-1])
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
    
    def unfreeze_last_layers(self, num_layers=1):
        layers = [self.backbone.layer4, self.backbone.layer3, 
                  self.backbone.layer2, self.backbone.layer1]
        for i in range(min(num_layers, len(layers))):
            for param in layers[i].parameters():
                param.requires_grad = True


In [None]:
class ProgressiveTrainer:
    
    def __init__(self, model, device="cuda", use_wandb=False):
        self.device = device if torch.cuda.is_available() else "cpu"
        print(f"Usando dispositivo: {self.device}")
        self.use_wandb = use_wandb
        self.model = model.to(self.device)
        self.history = {
            "train_loss": [], "train_acc": [],
            "val_loss": [], "val_acc": []
        }
    
    def phase1_feature_extraction(self, train_loader, val_loader, epochs=6):
        print("\n" + "="*50)
        print("FASE 1: Feature Extraction (solo clasificador)")
        print("="*50 + "\n")
        
        optimizer = optim.Adam(
            self.model.classifier.parameters(),
            lr=0.001,
            weight_decay=0.0001
        )
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(epochs):
            train_loss, train_acc = self._train_epoch(
                train_loader, criterion, optimizer
            )
            val_loss, val_acc = self._validate(val_loader, criterion)
            
            print(f"Epoch {epoch+1}/{epochs}: "
                  f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}% | "
                  f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%")
            
            self.history["train_loss"].append(train_loss)
            self.history["train_acc"].append(train_acc)
            self.history["val_loss"].append(val_loss)
            self.history["val_acc"].append(val_acc)
            if self.use_wandb:
                wandb.log({
                    'epoch': epoch + 1,
                    'train_loss': train_loss,
                    'train_acc': train_acc,
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                    'learning_rate': optimizer.param_groups[0]['lr']
                })

    def phase2_fine_tuning(self, train_loader, val_loader, epochs=6, num_layers_unfreeze=1):
        print("\n" + "="*50)
        print(f"FASE 2: Fine-Tuning ({num_layers_unfreeze} capa(s) del backbone)")
        print("="*50 + "\n")
        
        self.model.unfreeze_last_layers(num_layers=num_layers_unfreeze)
        
        optimizer = optim.Adam([
            {'params': self.model.backbone.layer4.parameters(), 'lr': 1e-4},
            {'params': self.model.classifier.parameters(), 'lr': 1e-3}
        ], weight_decay=0.0001)
        
        criterion = nn.CrossEntropyLoss()
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.5, patience=3
        )
        
        best_val_acc = 0
        patience_counter = 0
        patience = 5
        
        for epoch in range(epochs):
            train_loss, train_acc = self._train_epoch(
                train_loader, criterion, optimizer
            )
            val_loss, val_acc = self._validate(val_loader, criterion)
            scheduler.step(val_loss)
            
            print(f"Epoch {epoch+1}/{epochs}: "
                  f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}% | "
                  f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%")
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                torch.save(self.model.state_dict(), "best_model.pth")
                print(f"  → Mejor modelo guardado (Val Acc: {val_acc:.2f}%)")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print("\n⚠️  Early stopping triggered")
                    break
            
            self.history["train_loss"].append(train_loss)
            self.history["train_acc"].append(train_acc)
            self.history["val_loss"].append(val_loss)
            self.history["val_acc"].append(val_acc)
            if self.use_wandb:
                wandb.log({
                    'epoch': epoch + 1,
                    'train_loss': train_loss,
                    'train_acc': train_acc,
                    'val_loss': val_loss,
                    'val_acc': val_acc,
                    'learning_rate': optimizer.param_groups[0]['lr']
                })

        print(f"\n✓ Mejor validación accuracy: {best_val_acc:.2f}%")
    
    def _train_epoch(self, dataloader, criterion, optimizer):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(dataloader, desc='Training')
        for images, labels in pbar:
            images, labels = images.to(self.device), labels.to(self.device)
            
            optimizer.zero_grad()
            outputs = self.model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/total:.2f}%'
            })
        
        epoch_loss = running_loss / total
        epoch_acc = 100. * correct / total
        return epoch_loss, epoch_acc
    
    def _validate(self, dataloader, criterion):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            pbar = tqdm(dataloader, desc='Validation')
            for images, labels in pbar:
                images, labels = images.to(self.device), labels.to(self.device)
                
                outputs = self.model(images)
                loss = criterion(outputs, labels)
                
                running_loss += loss.item() * images.size(0)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{100.*correct/total:.2f}%'
                })
        
        epoch_loss = running_loss / total
        epoch_acc = 100. * correct / total
        return epoch_loss, epoch_acc
    
    def plot_history(self, save_path='training_history.png'):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
        
        ax1.plot(self.history['train_loss'], label='Train Loss', marker='o')
        ax1.plot(self.history['val_loss'], label='Val Loss', marker='s')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Training and Validation Loss')
        ax1.legend()
        ax1.grid(True)
        
        ax2.plot(self.history['train_acc'], label='Train Acc', marker='o')
        ax2.plot(self.history['val_acc'], label='Val Acc', marker='s')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Accuracy (%)')
        ax2.set_title('Training and Validation Accuracy')
        ax2.legend()
        ax2.grid(True)
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=150)
        print(f"Gráficas guardadas en: {save_path}")
        plt.show()


In [None]:
def create_dataloaders(data_dir, batch_size=32, val_split=0.2, seed=42, image_size=224, augment=False):

    base_transform = [
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((image_size, image_size)),
    ]

    if augment:
        train_transform = transforms.Compose(base_transform + [
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        train_transform = transforms.Compose(base_transform + [
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    val_transform = transforms.Compose(base_transform + [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    full_dataset = datasets.ImageFolder(root=data_dir)
    
    dataset_size = len(full_dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(val_split * dataset_size))
    
    random.seed(seed)
    random.shuffle(indices)
    train_idx, val_idx = indices[split:], indices[:split]
    
    train_dataset = datasets.ImageFolder(root=data_dir, transform=train_transform)
    val_dataset = datasets.ImageFolder(root=data_dir, transform=val_transform)
    
    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler, num_workers=4, pin_memory=True)
    
    print(f"Dataset cargado: {dataset_size} imágenes totales")
    print(f"  - Train: {len(train_idx)} imágenes")
    print(f"  - Val: {len(val_idx)} imágenes")
    print(f"  - Clases: {len(full_dataset.classes)}")
    
    return train_loader, val_loader, full_dataset


if __name__ == "__main__":

    wandb.init(
        project="wigner-classifier",
        config={
            "batch_size": 32,
            "epochs": 25,
            "learning_rate": 0.001,
            "val_split": 0.2,
            "patience": 15
        }
    )
    
    wandb.define_metric("epoch")
    wandb.define_metric("train_loss", step_metric="epoch")
    wandb.define_metric("train_acc", step_metric="epoch")
    wandb.define_metric("val_loss", step_metric="epoch")
    wandb.define_metric("val_acc", step_metric="epoch")
    wandb.define_metric("learning_rate", step_metric="epoch")

    DATA_DIR = "quantum_images_png"
    NUM_CLASSES = 10
    BATCH_SIZE = 32
    MODEL_NAME = 'resnet18'
    
    train_loader, val_loader, dataset = create_dataloaders(
        DATA_DIR,
        batch_size=wandb.config.batch_size,
        val_split=wandb.config.val_split,
        augment=False
    )
    
    model = WignerTransferLearningModel(
        num_classes=NUM_CLASSES,
        model_name=MODEL_NAME,
        freeze_backbone=True
    )

    wandb.watch(model, log="all", log_freq=100)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    trainer = ProgressiveTrainer(model, device=device)
    
    trainer.phase1_feature_extraction(train_loader, val_loader, epochs=6)
    
    trainer.phase2_fine_tuning(train_loader, val_loader, epochs=6, num_layers_unfreeze=1)

    model.load_state_dict(torch.load("best_model.pth"))
    print("\n✓ Modelo entrenado y cargado con los mejores pesos")
    
    trainer.plot_history()
    preds, labels, probs = trainer.predict(val_loader)
    final_acc = 100 * (preds == labels).mean()

    wandb.log({"final_val_accuracy": final_acc})

    wandb.finish()