In [1]:
class EarlyStopping:
    def __init__(self, patience=10, delta=0):
        """
        Args:
            patience (int): How many epochs to wait after last best validation loss.
            delta (float): Minimum change to qualify as an improvement.
        """
        self.patience = patience
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [None]:
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
train_losses = []
val_losses = []
val_accuracies = []
start_epoch = 0
early_stopper = EarlyStopping(patience=10, delta=0.001)

checkpoint_path = 'checkpoint.pth'
best_checkpoint_path = 'best_model.pth'
best_val_loss = float('inf')

try:
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
    val_accuracies = checkpoint['val_accuracies']
    best_val_loss = checkpoint.get('best_val_loss', float('inf'))
    print(f"✅ Resumed from checkpoint at epoch {start_epoch}")
except FileNotFoundError:
    print("🚀 No checkpoint found, starting from scratch.")

for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0.0

    for (img1, lbp1), (img2, lbp2), label in train_loader:
        optimizer.zero_grad()

        img1, lbp1 = img1.to(device), lbp1.to(device)
        img2, lbp2 = img2.to(device), lbp2.to(device)
        label = label.to(device)
        label = torch.where(label == 1, torch.tensor(-1.0, device=device), torch.tensor(1.0, device=device))

        with autocast():
            out1, out2 = model((img1, lbp1), (img2, lbp2))
            out1 = F.normalize(out1, p=2, dim=1)
            out2 = F.normalize(out2, p=2, dim=1)
            loss = criterion(out1, out2, label)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for (img1, lbp1), (img2, lbp2), label in val_loader:
            img1, lbp1 = img1.to(device), lbp1.to(device)
            img2, lbp2 = img2.to(device), lbp2.to(device)
            label = label.to(device)
            label = torch.where(label == 1, torch.tensor(-1.0, device=device), torch.tensor(1.0, device=device))

            with autocast():
                out1, out2 = model((img1, lbp1), (img2, lbp2))
                out1 = F.normalize(out1, p=2, dim=1)
                out2 = F.normalize(out2, p=2, dim=1)
                loss = criterion(out1, out2, label)

            val_loss += loss.item()

            cos_sim = F.cosine_similarity(out1, out2)
            predictions = torch.where(cos_sim > 0.5, torch.tensor(1.0, device=device), torch.tensor(-1.0, device=device))
            correct += (predictions == label).sum().item()
            total += label.size(0)

    val_avg_loss = val_loss / len(val_loader)
    val_accuracy = (correct / total) * 100.0

    val_losses.append(val_avg_loss)
    val_accuracies.append(val_accuracy)

    print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {avg_loss:.4f} | Val Loss: {val_avg_loss:.4f} | Val Accuracy: {val_accuracy:.2f}%")



    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies,
        'best_val_loss': best_val_loss
    }, checkpoint_path)

    if val_avg_loss < best_val_loss:
        best_val_loss = val_avg_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            'val_accuracies': val_accuracies,
            'best_val_loss': best_val_loss
        }, best_checkpoint_path)
        print(f"💾 Saved new best model at epoch {epoch+1} with Val Loss: {val_avg_loss:.4f}")

    torch.cuda.empty_cache()


    early_stopper(val_avg_loss)
    if early_stopper.early_stop:
        print("🛑 Early stopping triggered at epoch", epoch + 1)
        break