In [None]:
!pip install timm torchvision
!pip install wandb



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader, random_split
import timm
from timm.loss import LabelSmoothingCrossEntropy
from timm.models.layers import DropPath
import data_preprocessing
from wandb_logger import WandBLogger
from torch.cuda.amp import autocast, GradScaler
import torch.nn as nn
import torch.optim as optim

In [None]:
!wandb login 89e5fee022a3a1cf86f958ee0b3dff6f2aa57aad

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
PROJECT_NAME = "federated-learning-project"
BATCH_SIZE = 64
EPOCHS = 20
VAL_SPLIT = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42

In [None]:
# Reproducibility
torch.manual_seed(SEED)

<torch._C.Generator at 0x7c0e38d6baf0>

In [None]:
pipeline = data_preprocessing.CIFAR100Pipeline(val_split=VAL_SPLIT, use_augment=True)
trainset, valset, testset = pipeline.run_pipeline()

In [None]:
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(valset, batch_size=BATCH_SIZE)
testloader = DataLoader(testset, batch_size=BATCH_SIZE)

In [None]:
# Create model
def create_dino_vit_s16_for_cifar100(freezing=True):
    model = timm.create_model("vit_small_patch16_224_dino", pretrained=True, num_classes=0)

    # Replace the head with CIFAR-100 classification head
    model.head = nn.Linear(model.num_features, 100)

    if freezing:
      # Freeze all parameters except head
      for param in model.parameters():
          param.requires_grad = False

      # Unfreeze only the head
      for param in model.head.parameters():
          param.requires_grad = True

    return model

model = create_dino_vit_s16_for_cifar100(False).to(DEVICE)

  model = create_fn(


In [None]:
print(next(model.parameters()).device)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,}")

torch.backends.cudnn.benchmark = True

cuda:0
Trainable params: 21,704,164 / 21,704,164


In [None]:
from torch.optim import SGD

class SparseSGDM(SGD):
    def __init__(self, params, lr, momentum=0, weight_decay=0, masks=None):
        """
        params: Model parameters to update.
        lr: Learning rate.
        momentum: Momentum factor.
        weight_decay: Weight decay (L2 penalty).
        masks: List of binary masks matching the shape of model parameters.
        """
        super().__init__(params, lr=lr, momentum=momentum, weight_decay=weight_decay)
        self.masks = masks

    def step(self, closure=None):
        for group in self.param_groups:
            for i, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                # Apply mask if available
                if self.masks is not None:
                    mask = self.masks[i]
                    p.grad.data.mul_(mask.to(p.grad.device))
        super().step(closure)


In [None]:
"""def compute_fisher_diag(model, dataloader, criterion, device, num_batches=1):
    model.eval()
    fisher_diagonal = [torch.zeros_like(p, device=device) for p in model.head.parameters()]
    total_samples = 0

    for batch_idx, (x, y) in enumerate(dataloader):
        if batch_idx >= num_batches:
            break

        x, y = x.to(device), y.to(device)
        model.zero_grad()

        output = model(x)
        loss = criterion(output, y)
        loss.backward()

        for i, p in enumerate(model.head.parameters()):
            if p.grad is not None:
                fisher_diagonal[i] += (p.grad.data ** 2) * x.size(0)

        total_samples += x.size(0)

    for i in range(len(fisher_diagonal)):
        fisher_diagonal[i] /= total_samples

    return fisher_diagonal

def fisher_mask_from_diag(fisher_diag, sparsity=0.5):
    masks = []
    for f in fisher_diag:
        threshold = torch.quantile(f.flatten(), sparsity)
        mask = (f >= threshold).float()
        masks.append(mask)
    return masks
"""

"""# Compute Fisher diagonal and generate masks
fisher_diag = compute_fisher_diag(model, valloader, criterion, DEVICE, num_batches=3)
masks = fisher_mask_from_diag(fisher_diag, sparsity=0.2)"""

'# Compute Fisher diagonal and generate masks\nfisher_diag = compute_fisher_diag(model, valloader, criterion, DEVICE, num_batches=3)\nmasks = fisher_mask_from_diag(fisher_diag, sparsity=0.2)'

In [None]:
import torch
import torch.nn as nn

class TaLoSPruner:
    def __init__(self, model, device):
        self.model = model
        self.device = device
        self.scores = {}

        # Verificare la struttura del modello
        if hasattr(model, 'head'):
            self.head = model.head
        else:
            print("Attenzione: il modello non ha un attributo 'head'. Utilizzo l'intero modello.")
            self.head = model

    def score(self, dataloader, num_batches=1):
        self.model.eval()

        # Inizializza dizionario dei punteggi
        for p in self.head.parameters():
            self.scores[p] = torch.zeros_like(p, device=self.device)

        total_samples = 0

        for batch_idx, (x, y) in enumerate(dataloader):
            if batch_idx >= num_batches:
                break

            x, y = x.to(self.device), y.to(self.device)

            self.model.zero_grad()

            # Forward pass con gestione di diversi tipi di modelli
            if hasattr(self.model, 'forward_features'):
                with torch.no_grad():
                    features = self.model.forward_features(x)

                # Gestisci diverse strutture di output
                if features.ndim == 3:  # [B, seq_len, hidden_dim]
                    cls_token = features[:, 0, :]  # Prendi solo il CLS token
                else:
                    cls_token = features  # Già nella forma corretta

                # Forward pass solo sulla testa
                output = self.head(cls_token)
            else:
                # Se il modello non ha forward_features, usa il forward normale
                self.model.zero_grad()
                output = self.model(x)

            # Adatta l'output se necessario
            if output.ndim == 3:
                output = output.squeeze(1)

            # Calcolo della loss
            loss = nn.CrossEntropyLoss()(output, y)

            # Backward per calcolare i gradienti
            loss.backward()

            # Accumula i punteggi (Fisher diagonale approssimato)
            for p in self.head.parameters():
                if p.grad is not None:
                    self.scores[p] += (p.grad.data ** 2) * x.size(0)

            total_samples += x.size(0)

        # Normalizza i punteggi
        for p in self.head.parameters():
            if p in self.scores and total_samples > 0:
                self.scores[p] /= total_samples


    """def generate_masks(self, sparsity=0.5):
        masks = []
        for p in self.model.parameters():
            if p in self.scores:
                score = self.scores[p]
                threshold = torch.quantile(score.flatten(), sparsity)
                mask = (score >= threshold).float()
                masks.append(mask)
            else:
                # Parametri non prunati (ad esempio il backbone)
                masks.append(torch.ones_like(p, device=self.device))
        return masks"""

    def generate_masks(self, sparsity=0.5):
        masks = []

        for p in self.head.parameters():
            if p in self.scores:
                score = self.scores[p]

                # Assicurati che lo score non sia completamente zero
                if torch.all(score == 0):
                    print(f"Attenzione: tutti gli score sono zero per un parametro di forma {p.shape}!")
                    masks.append(torch.ones_like(p, device=self.device))
                    continue

                # Calcola soglia e maschera
                threshold = torch.quantile(score.flatten(), sparsity)
                mask = (score >= threshold).float()

                # Verifica che non stiamo eliminando troppi pesi
                keep_percent = mask.sum() / mask.numel()
                if keep_percent < 0.05:  # Mantieni almeno il 5% dei pesi
                    print(f"Attenzione: stai mantenendo solo {keep_percent:.2%} dei pesi! Regolando...")
                    # Riduci la sparsità per mantenere più pesi
                    top_k = max(int(0.05 * mask.numel()), 1)
                    values, _ = torch.topk(score.flatten(), top_k)
                    threshold = values.min()
                    mask = (score >= threshold).float()

                masks.append(mask)

        return masks


In [None]:
def iterative_pruning(pruner, dataloader, rounds=4, final_sparsity=0.9, num_batches=3):
    # La sparsità target finale è la frazione di pesi da rimuovere
    keep_ratio = 1.0 - final_sparsity

    for r in range(rounds):
        # Sparsità intermedia (cresce progressivamente)
        current_keep = keep_ratio ** ((r + 1) / rounds)
        current_sparsity = 1.0 - current_keep
        print(f"[Round {r+1}/{rounds}] Target sparsity: {current_sparsity:.4f}")

        # Calcolo dello score basato sui gradienti (Fisher)
        pruner.score(dataloader, num_batches=num_batches)

        # Calcolo della nuova maschera
        masks = pruner.generate_masks(sparsity=current_sparsity)

    return masks

In [None]:
logger = WandBLogger(
    project_name=PROJECT_NAME,
    run_name="CENTRALIZED MODEL EDITING (Talos calibrating)-Run-1",
    config={
        "learning_rate": 5e-5,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "optimizer": "SparseSGDM",
        "scheduler": "CosineAnnealing + Warmup",
        "weight_decay": 0.0005,
        "sparsity": 0.6,
        "calibration_rounds": 4
    }
)

0,1
learning_rate,▁▄█
train_acc,▁▂█
train_loss,█▄▁

0,1
learning_rate,3e-05
train_acc,0.04789
train_loss,4.74213


In [None]:
import json
from torch.cuda.amp import autocast, GradScaler
import torch.nn as nn
import torch.optim as optim
import torch

# Initialize components
criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
scaler = GradScaler()

# Setup pruning con parametri migliorati
pruner = TaLoSPruner(model, device=DEVICE)

# Utilizzo una sparsità meno aggressiva (50% invece di 60%)
# e più batch per una valutazione più affidabile
masks = iterative_pruning(pruner, valloader, rounds=4, final_sparsity=0.5, num_batches=10)

# Verifica delle maschere prima di utilizzarle
print(f"Numero di maschere generate: {len(masks)}")
for i, mask in enumerate(masks):
    keep_percent = mask.sum() / mask.numel()
    print(f"Maschera #{i}: shape={mask.shape}, mantiene {keep_percent:.2%} dei pesi")

# Initialize optimizer with masks
optimizer = SparseSGDM(
    model.head.parameters(),
    lr=5e-5,  # Learning rate di base invariato
    momentum=0.9,
    weight_decay=0.0005,
    masks=masks
)

# Scheduler with warmup + cosine - correzione del fattore di start
warmup_epochs = 5
cosine_epochs = EPOCHS - warmup_epochs

scheduler = optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[
        # Corretto il fattore iniziale per un warm-up più efficace
        optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=warmup_epochs),
        optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cosine_epochs)
    ],
    milestones=[warmup_epochs]
)

# Early stopping parameters
patience = 6
best_val_acc = 0.0
epochs_no_improve = 0

# Training loop
for epoch in range(EPOCHS):
    model.train()
    correct, total, train_loss = 0, 0, 0.0

    for x, y in trainloader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()

        with autocast():
            outputs = model(x)
            loss = criterion(outputs, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Applicazione maschere con controllo migliore della forma
        for i, (p, mask) in enumerate(zip(model.head.parameters(), masks)):
            # Verifica compatibilità di forma prima di applicare
            if p.shape == mask.shape:
                p.data.mul_(mask.to(p.device))  # Zero out pruned weights
            else:
                # Tenta l'adattamento della maschera se possibile
                try:
                    # Se la maschera ha dimensioni extra, ridimensionala
                    reshaped_mask = mask
                    while reshaped_mask.dim() > p.dim() and reshaped_mask.shape[0] == 1:
                        reshaped_mask = reshaped_mask.squeeze(0)

                    # Se serve un'espansione
                    if reshaped_mask.dim() < p.dim():
                        for _ in range(p.dim() - reshaped_mask.dim()):
                            reshaped_mask = reshaped_mask.unsqueeze(0)

                    # Verifica finale
                    if p.shape == reshaped_mask.shape:
                        p.data.mul_(reshaped_mask.to(p.device))
                    else:
                        print(f"Warning [Epoch {epoch+1}]: Impossibile applicare maschera #{i}. "
                              f"Param shape {p.shape}, Mask shape {mask.shape}")
                except Exception as e:
                    print(f"Errore applicazione maschera #{i}: {str(e)}")

        train_loss += loss.item() * y.size(0)
        _, pred = torch.max(outputs, 1)
        correct += (pred == y).sum().item()
        total += y.size(0)

    scheduler.step()

    train_acc = correct / total
    train_loss /= total


     # Validation
    model.eval()
    correct, total, val_loss = 0, 0, 0.0
    with torch.no_grad():
        for x, y in valloader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            outputs = model(x)
            loss = criterion(outputs, y)

            val_loss += loss.item() * y.size(0)
            _, pred = torch.max(outputs, 1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    val_acc = correct / total
    val_loss /= total

    # Aggiorna il log con i risultati della validazione
    logger.log_metrics({
        "val_loss": val_loss,
        "val_acc": val_acc
    }, step=epoch)

    print(f"Epoch {epoch+1:02d}/{EPOCHS} — Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | LR: {scheduler.get_last_lr()[0]:.6f}")

    # Early stopping logic con dettagli aggiuntivi
    if val_acc > best_val_acc:
        improvement = val_acc - best_val_acc
        best_val_acc = val_acc
        epochs_no_improve = 0
        best_model_state = model.state_dict()  # Save best model
        logger.log_model(model, path="best_model.pth")
        print(f"Miglioramento dell'accuratezza: +{improvement:.4f}. Modello salvato.")
    else:
        epochs_no_improve += 1
        print(f"Nessun miglioramento per {epochs_no_improve}/{patience} epoche.")
        if epochs_no_improve >= patience:
            print(f"Early stopping triggeerato all'epoca {epoch+1}. Migliore accuratezza: {best_val_acc:.4f}")
            break

# Ripristina il miglior modello alla fine del training
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"Training completato. Ripristinato il miglior modello con accuratezza: {best_val_acc:.4f}")

# Stampa riassunto finale
print("\n==== RIASSUNTO TRAINING ====")
print(f"Migliore accuratezza validazione: {best_val_acc:.4f}")
print(f"Sparsità finale della testa: {1.0 - sum(m.sum() for m in masks) / sum(m.numel() for m in masks):.2%}")
print("============================")

  scaler = GradScaler()


[Round 1/4] Target sparsity: 0.1591
[Round 2/4] Target sparsity: 0.2929
[Round 3/4] Target sparsity: 0.4054
[Round 4/4] Target sparsity: 0.5000
Numero di maschere generate: 2
Maschera #0: shape=torch.Size([100, 384]), mantiene 50.00% dei pesi
Maschera #1: shape=torch.Size([100]), mantiene 50.00% dei pesi


  with autocast():


Epoch 01/20 — Train Acc: 0.0091 | Val Acc: 0.0144 | LR: 0.000014
Miglioramento dell'accuratezza: +0.0144. Modello salvato.
Epoch 02/20 — Train Acc: 0.0297 | Val Acc: 0.0494 | LR: 0.000023
Miglioramento dell'accuratezza: +0.0350. Modello salvato.
Epoch 03/20 — Train Acc: 0.0910 | Val Acc: 0.1420 | LR: 0.000032
Miglioramento dell'accuratezza: +0.0926. Modello salvato.
Epoch 04/20 — Train Acc: 0.2104 | Val Acc: 0.2726 | LR: 0.000041
Miglioramento dell'accuratezza: +0.1306. Modello salvato.




Epoch 05/20 — Train Acc: 0.3306 | Val Acc: 0.3762 | LR: 0.000050
Miglioramento dell'accuratezza: +0.1036. Modello salvato.
Epoch 06/20 — Train Acc: 0.4233 | Val Acc: 0.4534 | LR: 0.000049
Miglioramento dell'accuratezza: +0.0772. Modello salvato.
Epoch 07/20 — Train Acc: 0.4833 | Val Acc: 0.5006 | LR: 0.000048
Miglioramento dell'accuratezza: +0.0472. Modello salvato.
Epoch 08/20 — Train Acc: 0.5225 | Val Acc: 0.5328 | LR: 0.000045
Miglioramento dell'accuratezza: +0.0322. Modello salvato.
Epoch 09/20 — Train Acc: 0.5520 | Val Acc: 0.5520 | LR: 0.000042
Miglioramento dell'accuratezza: +0.0192. Modello salvato.
Epoch 10/20 — Train Acc: 0.5727 | Val Acc: 0.5664 | LR: 0.000038
Miglioramento dell'accuratezza: +0.0144. Modello salvato.
Epoch 11/20 — Train Acc: 0.5863 | Val Acc: 0.5768 | LR: 0.000033
Miglioramento dell'accuratezza: +0.0104. Modello salvato.
Epoch 12/20 — Train Acc: 0.5974 | Val Acc: 0.5856 | LR: 0.000028
Miglioramento dell'accuratezza: +0.0088. Modello salvato.
Epoch 13/20 — Tr

In [None]:
# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

# Test evaluation
model.eval()
correct, total, test_loss = 0, 0, 0.0
with torch.no_grad():
    for x, y in testloader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        outputs = model(x)
        loss = criterion(outputs, y)

        test_loss += loss.item() * y.size(0)
        _, pred = torch.max(outputs, 1)
        correct += (pred == y).sum().item()
        total += y.size(0)

test_acc = correct / total
test_loss /= total

logger.log_metrics({
    "test_loss": test_loss,
    "test_acc": test_acc
})

logger.finish()

print(f"\nFinal Test Accuracy: {test_acc:.4f} | Test Loss: {test_loss:.4f}")

0,1
test_acc,▁
test_loss,▁
val_acc,▁▁▃▄▅▆▇▇▇███████████
val_loss,█▆▅▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test_acc,0.6052
test_loss,2.08908
val_acc,0.608
val_loss,2.06817



Final Test Accuracy: 0.6052 | Test Loss: 2.0891


In [None]:
"""#!pip install ray[tune]
from ray import tune

search_space = {
    "lr": tune.grid_search([0.01, 0.1]),
    "momentum": 0.9,
    "weight_decay": 5e-4,
    "batch_size": 128,
    "sparsity": tune.grid_search([0.2, 0.5, 0.8])  # proporzione di pesi NON aggiornati
}"""

'#!pip install ray[tune]\nfrom ray import tune\n\nsearch_space = {\n    "lr": tune.grid_search([0.01, 0.1]),\n    "momentum": 0.9,\n    "weight_decay": 5e-4,\n    "batch_size": 128,\n    "sparsity": tune.grid_search([0.2, 0.5, 0.8])  # proporzione di pesi NON aggiornati\n}'

In [None]:
"""from torchvision import transforms

def get_cifar_transform() -> transforms.Compose:
    return transforms.Compose([
        transforms.Resize((224, 224)),  # Resize CIFAR images to 224x224
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],  # ImageNet means
            std=[0.229, 0.224, 0.225]    # ImageNet stds
        )
    ])
"""

'from torchvision import transforms\n\ndef get_cifar_transform() -> transforms.Compose:\n    return transforms.Compose([\n        transforms.Resize((224, 224)),  # Resize CIFAR images to 224x224\n        transforms.ToTensor(),\n        transforms.Normalize(\n            mean=[0.485, 0.456, 0.406],  # ImageNet means\n            std=[0.229, 0.224, 0.225]    # ImageNet stds\n        )\n    ])\n'

In [None]:
"""from ray.train import Checkpoint
import json
import os
import torch
from ray.train import report
from torch.cuda.amp import autocast, GradScaler
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import shutil

def train_vit(config):
    # 1. Crea il modello
    model = create_dino_vit_s16_for_cifar100().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    scaler = GradScaler()

    # 2. Prepara i dati
    pipeline = data_preprocessing.CIFAR100Pipeline(val_split=VAL_SPLIT, use_augment=True)
    trainset, valset, testset = pipeline.run_pipeline()
    trainloader = DataLoader(trainset, batch_size=config["batch_size"], shuffle=True)
    valloader = DataLoader(valset, batch_size=config["batch_size"])
    testloader = DataLoader(testset, batch_size=config["batch_size"])

    # 3. Calcola Fisher diagonal e maschera
    fisher_diag = compute_fisher_diag(model, valloader, criterion, DEVICE, num_batches=3)
    masks = fisher_mask_from_diag(fisher_diag, sparsity=config["sparsity"])

    # --- CALCOLO SCORE & MASCHERE CON IL TUO TALOSPRUNER ---
    pruner = TaLoSPruner(model.head, device=DEVICE)
    pruner.score(valloader, num_batches=3)  # Fisher score stimato sui dati di validazione
    masks = pruner.generate_masks(sparsity=0.2)

    # 4. Ottimizzatore custom con maschera
    optimizer = SparseSGDM(
        model.head.parameters(),
        lr=config["lr"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
        masks=masks
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

    # Struttura per salvare i risultati
    results = {
        'train_acc': [],
        'train_loss': [],
        'val_acc': [],
        'val_loss': [],
        'best_epoch': 0,
        'best_val_acc': 0.0,
        'test_acc': 0.0,
        'test_loss': 0.0,
        'config': config
    }

    best_val_acc = 0.0
    best_model_state = None

    for epoch in range(20):
        # Training loop (come prima)
        # ...

        # Dopo ogni epoca, salva le metriche
        results['train_acc'].append(train_acc)
        results['train_loss'].append(train_loss)
        results['val_acc'].append(val_acc)
        results['val_loss'].append(val_loss)

        # Aggiorna il miglior modello
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            results['best_val_acc'] = best_val_acc
            results['best_epoch'] = epoch
            best_model_state = model.state_dict()

        # Report a Ray Tune (per monitorare durante il training)
        report({
            "val_accuracy": val_acc,
            "train_accuracy": train_acc,
            "sparsity": config["sparsity"],
            "epoch": epoch
        })

    # Dopo il training, valuta sul test set con il miglior modello
    model.load_state_dict(best_model_state)
    model.eval()
    correct, total, test_loss = 0, 0, 0.0
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            outputs = model(x)
            loss = criterion(outputs, y)
            test_loss += loss.item() * y.size(0)
            _, pred = outputs.max(1)
            correct += (pred == y).sum().item()
            total += y.size(0)

    test_acc = correct / total
    results['test_acc'] = test_acc
    results['test_loss'] = test_loss / total

    # Crea una checkpoint directory temporanea
    checkpoint_dir = os.path.join(os.getcwd(), "checkpoint")
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Salva il modello e i risultati
    torch.save(best_model_state, os.path.join(checkpoint_dir, "model.pth"))
    with open(os.path.join(checkpoint_dir, "results.json"), "w") as f:
        json.dump(results, f, indent=4)

    # Crea un checkpoint per Ray Tune
    checkpoint = Checkpoint.from_directory(checkpoint_dir)

    # Report finale con il checkpoint
    report({
        "val_accuracy": best_val_acc,
        "train_accuracy": train_acc,
        "test_accuracy": test_acc,
        "sparsity": config["sparsity"],
        "checkpoint": checkpoint
    })

    # Pulisci la directory temporanea
    try:
        shutil.rmtree(checkpoint_dir)
    except:
        pass"""

'from ray.train import Checkpoint\nimport json\nimport os\nimport torch\nfrom ray.train import report\nfrom torch.cuda.amp import autocast, GradScaler\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\nimport shutil\n\ndef train_vit(config):\n    # 1. Crea il modello\n    model = create_dino_vit_s16_for_cifar100().to(DEVICE)\n    criterion = nn.CrossEntropyLoss()\n    scaler = GradScaler()\n\n    # 2. Prepara i dati\n    pipeline = data_preprocessing.CIFAR100Pipeline(val_split=VAL_SPLIT, use_augment=True)\n    trainset, valset, testset = pipeline.run_pipeline()\n    trainloader = DataLoader(trainset, batch_size=config["batch_size"], shuffle=True)\n    valloader = DataLoader(valset, batch_size=config["batch_size"])\n    testloader = DataLoader(testset, batch_size=config["batch_size"])\n\n    # 3. Calcola Fisher diagonal e maschera\n    fisher_diag = compute_fisher_diag(model, valloader, criterion, DEVICE, num_batches=3)\n    masks = fisher_mask

In [None]:
"""from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.basic_variant import BasicVariantGenerator
from ray.tune import CLIReporter
import os

storage_uri = f"file://ray_results"

reporter = CLIReporter(
    metric_columns=["val_accuracy", "train_accuracy", "test_accuracy", "sparsity", "training_iteration"]
)

analysis = tune.run(
    train_vit,
    config=search_space,
    storage_path=storage_uri,
    search_alg=BasicVariantGenerator(),
    num_samples=10,
    resources_per_trial={"cpu": 2, "gpu": 1},
    scheduler=ASHAScheduler(metric="val_accuracy", mode="max"),
    name="vit_hyperparam_search",
    progress_reporter=reporter
)

SyntaxError: incomplete input (<ipython-input-127-d99af327ec26>, line 1)

In [None]:
"""# Ottieni il miglior trial
best_trial = analysis.get_best_trial("val_accuracy", mode="max", scope="all")

# Percorso del checkpoint
best_checkpoint = best_trial.checkpoint.value

# Carica i risultati
with open(os.path.join(best_checkpoint, "results.json"), "r") as f:
    best_results = json.load(f)

# Carica il modello
best_model = create_dino_vit_s16_for_cifar100().to(DEVICE)
best_model.load_state_dict(torch.load(os.path.join(best_checkpoint, "model.pth")))"""