In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import numpy as np
import os
import torch.optim as optim
from torch.nn import functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.model_selection import KFold
from copy import deepcopy
import logging
import json
import seaborn as sns
from pathlib import Path
import warnings
from typing import Tuple, Dict, Any, List
from torch.optim.lr_scheduler import OneCycleLR
import time

# Additional import for dynamic output in Jupyter Notebook
from IPython.display import clear_output, display

import logging

# Reset handlers if needed (important in Jupyter or scripts with imported modules)
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('training.log'),
        logging.StreamHandler()
    ]
)

# Set device and random seeds
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)


class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model = deepcopy(model.state_dict())
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.verbose:
                logging.info(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_model = deepcopy(model.state_dict())
            self.counter = 0
        return self.early_stop


class ModelCheckpoint:
    def __init__(self, save_dir: str, metric_name: str = 'val_loss', mode: str = 'min'):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.metric_name = metric_name
        self.mode = mode
        self.best_metric = float('inf') if mode == 'min' else float('-inf')

    def __call__(self, model: nn.Module, current_metric: float, epoch: int) -> None:
        improved = (self.mode == 'min' and current_metric < self.best_metric) or \
                   (self.mode == 'max' and current_metric > self.best_metric)

        if improved:
            self.best_metric = current_metric
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                f'best_{self.metric_name}': self.best_metric
            }
            torch.save(checkpoint, self.save_dir / f'best_model_{self.metric_name}.pth')
            logging.info(f'Saved new best model with {self.metric_name}: {self.best_metric:.4f}')


class ConfidenceBasedThresholding:
    def __init__(self, initial_threshold=0.95, momentum=0.999):
        self.threshold = initial_threshold
        self.momentum = momentum

    def update(self, confidence_scores):
        with torch.no_grad():
            sorted_confidence = torch.sort(confidence_scores)[0]
            new_threshold = sorted_confidence[int(0.95 * len(sorted_confidence))]
            self.threshold = self.momentum * self.threshold + (1 - self.momentum) * new_threshold

    def get_mask(self, confidence_scores):
        return confidence_scores >= self.threshold


class TrainingProgressTracker:
    def __init__(self, total_epochs, num_folds):
        self.start_time = None
        self.total_epochs = total_epochs
        self.num_folds = num_folds
        self.current_fold = 0
        self.current_epoch = 0
        self.epoch_times = []

    def start_training(self):
        self.start_time = time.time()
        logging.info(f"Training started at {time.strftime('%Y-%m-%d %H:%M:%S')}")

    def update_progress(self, fold, epoch):
        self.current_fold = fold + 1
        self.current_epoch = epoch + 1

        # Calculate time metrics
        elapsed = time.time() - self.start_time
        avg_time_per_epoch = np.mean(self.epoch_times) if self.epoch_times else 0
        remaining_epochs = (self.num_folds - self.current_fold) * self.total_epochs + \
                           (self.total_epochs - self.current_epoch)
        estimated_remaining = avg_time_per_epoch * remaining_epochs

        # Clear the output below the cell and display the updated progress in real time
        clear_output(wait=True)
        print(f"Progress: Fold {self.current_fold}/{self.num_folds} | Epoch {self.current_epoch}/{self.total_epochs}")
        print(f"Elapsed: {self.format_time(elapsed)} | Remaining: {self.format_time(estimated_remaining)}")

    def record_epoch_time(self, epoch_time):
        self.epoch_times.append(epoch_time)

    @staticmethod
    def format_time(seconds):
        if seconds < 0:
            return "--:--:--"
        hours, rem = divmod(seconds, 3600)
        minutes, seconds = divmod(rem, 60)
        return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"


class TrainingMistakeDetector:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.previous_train_loss = None
        self.previous_val_loss = None

    def check_common_issues(self, model, train_metrics, val_metrics, epoch):
        issues = []

        # Ensure train_metrics is a valid dictionary before accessing it
        if not train_metrics or 'loss' not in train_metrics:
            issues.append("Training metrics not computed for this epoch!")
            return issues

        # Check for NaN values
        if torch.isnan(torch.tensor(train_metrics['loss'])):
            issues.append("NaN detected in training loss!")

        # Check for exploding gradients
        if self.check_exploding_gradients(model):
            issues.append("Potential exploding gradients detected!")

        # Check for validation loss divergence if metrics are available
        if epoch > 1 and val_metrics:
            if val_metrics['val_loss'] > 2 * train_metrics['loss']:
                issues.append("Validation loss significantly higher than training loss (possible overfitting)")

            if self.previous_train_loss is not None and train_metrics['loss'] > self.previous_train_loss * 1.5:
                issues.append("Training loss increased significantly from previous epoch")

        # Update previous values if available
        self.previous_train_loss = train_metrics.get('loss', None)
        self.previous_val_loss = val_metrics.get('val_loss', None) if val_metrics else None

        return issues

    @staticmethod
    def check_exploding_gradients(model, threshold=1e6):
        for name, param in model.named_parameters():
            if param.grad is not None and torch.any(torch.abs(param.grad) > threshold):
                return True
        return False



In [2]:
class TransformFix(object):
    def __init__(self):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(96, padding=4),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(96, padding=4),
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
            transforms.RandomRotation(20),
            transforms.RandAugment(num_ops=3, magnitude=15),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __call__(self, x):
        return self.weak(x), self.strong(x)


def get_stl10_dataloaders(batch_size=64):
    transform = TransformFix()
    data_dir = "/content/drive/MyDrive"
    os.makedirs(data_dir, exist_ok=True)

    try:
        labeled_dataset = torchvision.datasets.STL10(
            root=data_dir, split='train', download=True, transform=transform
        )
        unlabeled_dataset = torchvision.datasets.STL10(
            root=data_dir, split='unlabeled', download=True, transform=transform
        )
        test_dataset = torchvision.datasets.STL10(
            root=data_dir, split='test', download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        )
    except Exception as e:
        logging.error(f"Error loading datasets: {str(e)}")
        raise

    # Create train-validation split
    train_size = int(0.8 * len(labeled_dataset))
    val_size = len(labeled_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        labeled_dataset, [train_size, val_size]
    )

    # Create data loaders with error handling
    try:
        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
        )

        val_dataset.dataset.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        val_loader = DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False, num_workers=4
        )

        unlabeled_loader = DataLoader(
            unlabeled_dataset, batch_size=batch_size, shuffle=True, num_workers=4
        )

        test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False, num_workers=4
        )
    except Exception as e:
        logging.error(f"Error creating dataloaders: {str(e)}")
        raise

    return train_loader, val_loader, unlabeled_loader, test_loader


In [3]:
class ResNet18(nn.Module):
    def __init__(self, num_classes=10, dropout_rate=0.3):
        super().__init__()
        self.model = torchvision.models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.model(x)


class FlatMatchLoss(nn.Module):
    def __init__(self, model, rho=1e-2, threshold=0.95, num_classes=10, temperature=1.0, use_sam=True):
        super().__init__()
        self.model = model
        self.rho = rho
        self.threshold = threshold
        self.num_classes = num_classes
        self.temperature = temperature
        self.use_sam = use_sam
        self.cross_entropy = nn.CrossEntropyLoss(reduction='none')
        self.kl_div = nn.KLDivLoss(reduction='batchmean')  # KL divergence for cross-sharpness

    def forward(self, logits_w, logits_s, inputs_u, model, device):
        # Standard pseudo-labeling
        with torch.no_grad():
            probs_w = torch.softmax(logits_w / self.temperature, dim=1)
            max_probs, pseudo_labels = torch.max(probs_w, dim=1)
            mask = max_probs.ge(self.threshold)

        # Compute pseudo-label loss
        loss_pseudo = self.cross_entropy(logits_s, pseudo_labels)
        loss_pseudo = (loss_pseudo * mask.float()).mean()

        # Flat loss (KL from uniform)
        flat_targets = torch.ones_like(logits_w) / self.num_classes
        probs_s = torch.softmax(logits_s / self.temperature, dim=1)
        flat_loss = -torch.mean(torch.sum(flat_targets * torch.log(probs_s + 1e-6), dim=1))

        # Total = pseudo-label + flat loss
        total_loss = loss_pseudo + 0.1 * flat_loss

        # Cross-sharpness with SAM-style θ̃
        if self.use_sam:
            # First backward pass: get gradient
            inputs_u = inputs_u.detach()
            inputs_u.requires_grad = True
            logits_u = model(inputs_u)
            probs_u = torch.softmax(logits_u, dim=1)
            uniform = torch.full_like(probs_u, 1.0 / self.num_classes)
            kl_flat = torch.sum(probs_u * torch.log(probs_u / uniform + 1e-6), dim=1).mean()

            # Backward to get grad
            grad = torch.autograd.grad(kl_flat, model.parameters(), create_graph=True)
            grad_norm = torch.sqrt(sum([g.norm() ** 2 for g in grad])) + 1e-12

            # Perturb parameters: θ̃ = θ + ρ * g / ||g||
            eps = []
            for p, g in zip(model.parameters(), grad):
                e = self.rho * g / grad_norm
                p.data.add_(e)
                eps.append(e)

            # Forward with perturbed θ̃
            logits_tilde = model(inputs_u)
            probs_tilde = torch.log_softmax(logits_tilde, dim=1)
            probs_orig = torch.softmax(logits_u.detach(), dim=1)
            sharpness_loss = self.kl_div(probs_tilde, probs_orig)

            # Restore original weights
            for p, e in zip(model.parameters(), eps):
                p.data.sub_(e)

            total_loss += 0.5 * sharpness_loss

        return total_loss, mask.float().mean()


class DistributionAlignment(nn.Module):
    def __init__(self, num_classes, momentum=0.999):
        super().__init__()
        self.num_classes = num_classes
        self.momentum = momentum
        self.register_buffer("p_model", torch.ones(num_classes) / num_classes)

    def forward(self, probs):
        with torch.no_grad():
            p_current = probs.mean(0)
            self.p_model = self.momentum * self.p_model + (1 - self.momentum) * p_current

        qt = probs.mean(0)
        p_ratio = self.p_model / (qt + 1e-6)
        return probs * p_ratio.unsqueeze(0)


def find_learning_rate(model: nn.Module, train_loader: DataLoader,
                       criterion: nn.Module, device: torch.device,
                       start_lr: float = 1e-7, end_lr: float = 10,
                       num_iterations: int = 100) -> Tuple[list, list]:
    logging.info("Starting learning rate finder...")

    model.train()
    optimizer = optim.SGD(model.parameters(), lr=start_lr)
    scheduler = optim.lr_scheduler.ExponentialLR(
        optimizer,
        gamma=(end_lr/start_lr)**(1/num_iterations)
    )

    learning_rates = []
    losses = []

    try:
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            if batch_idx >= num_iterations:
                break

            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

            learning_rates.append(optimizer.param_groups[0]['lr'])
            losses.append(loss.item())

            scheduler.step()

            if loss.item() > 4 * min(losses) or not np.isfinite(loss.item()):
                break

    except Exception as e:
        logging.error(f"Error during learning rate finding: {str(e)}")
        raise

    return learning_rates, losses


def validate(model: nn.Module, val_loader: DataLoader,
             criterion: nn.Module, device: torch.device) -> Dict[str, float]:
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    metrics = {
        'val_loss': val_loss / len(val_loader),
        'val_acc': correct / total,
        'val_precision': precision_score(all_targets, all_preds, average='macro'),
        'val_recall': recall_score(all_targets, all_preds, average='macro'),
        'val_f1': f1_score(all_targets, all_preds, average='macro')
    }

    return metrics





def train_epoch(model, model_ema, labeled_loader, unlabeled_loader, optimizer, dist_align,
                flatmatch_loss, confidence_thresholder, device, temperature=1.0, lambda_u=1.0):
    model.train()
    dist_align.train()
    total_loss = 0
    correct = 0
    total = 0
    unlabeled_iter = iter(unlabeled_loader)

    for batch_idx, (x_l, y_l) in enumerate(labeled_loader):
        try:
            (u_w, u_s), _ = next(unlabeled_iter)
        except StopIteration:
            unlabeled_iter = iter(unlabeled_loader)
            (u_w, u_s), _ = next(unlabeled_iter)

        x_l, y_l = x_l.to(device), y_l.to(device)
        u_w, u_s = u_w.to(device), u_s.to(device)

        logits_l = model(x_l)
        logits_u_w = model(u_w)
        logits_u_s = model(u_s)

        loss_sup = F.cross_entropy(logits_l, y_l)

        _, predicted = torch.max(logits_l.data, 1)
        total += y_l.size(0)
        correct += (predicted == y_l).sum().item()

        with torch.no_grad():
            probs_u_w = torch.softmax(logits_u_w / temperature, dim=1)
            max_probs, _ = torch.max(probs_u_w, dim=1)
            confidence_thresholder.update(max_probs)
            conf_mask = confidence_thresholder.get_mask(max_probs)
            probs_u_w = dist_align(probs_u_w)

        loss_unsup, mask_mean = flatmatch_loss(
            logits_u_w, logits_u_s, u_w, model, device
        )
        loss_unsup = loss_unsup * conf_mask.float().mean()

        loss = loss_sup + lambda_u * loss_unsup

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        update_ema(model, model_ema)

        total_loss += loss.item()

    return {
        'loss': total_loss / (batch_idx + 1),
        'acc': correct / total,
        'confidence_threshold': confidence_thresholder.threshold
    }



In [5]:
def update_ema(model, model_ema, alpha=0.999):
    for ema_param, param in zip(model_ema.parameters(), model.parameters()):
        ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

def train_with_kfold_improved(model, labeled_dataset, unlabeled_loader, num_folds=5, num_epochs=200, device=device):
    # Initialize progress tracker and mistake detector only once
    progress_tracker = TrainingProgressTracker(num_epochs, num_folds)
    progress_tracker.start_training()
    mistake_detector = TrainingMistakeDetector(num_classes=10)

    kfold = KFold(n_splits=num_folds, shuffle=True, random_state=42)
    fold_results = []
    best_val_acc = 0
    best_model = None

    # Create checkpoint handler
    checkpoint_handler = ModelCheckpoint(save_dir='checkpoints', metric_name='val_acc', mode='max')

    for fold, (train_ids, val_ids) in enumerate(kfold.split(labeled_dataset)):
        logging.info(f'\nStarting Fold {fold+1}/{num_folds}')

        # Create data loaders for the current fold
        train_loader = DataLoader(
            labeled_dataset,
            batch_size=64,
            sampler=SubsetRandomSampler(train_ids)
        )
        val_loader = DataLoader(
            labeled_dataset,
            batch_size=64,
            sampler=SubsetRandomSampler(val_ids)
        )

        # Initialize model and components for the current fold
        fold_model = ResNet18().to(device)
        model_ema = deepcopy(fold_model)
        for param in model_ema.parameters():
          param.requires_grad = False
        confidence_thresholder = ConfidenceBasedThresholding(initial_threshold=0.95)
        dist_align = DistributionAlignment(num_classes=10).to(device)
        flatmatch_loss = FlatMatchLoss(model=fold_model,threshold=0.95, num_classes=10, temperature=1.0).to(device)

        # Find optimal learning rate
        lrs, losses = find_learning_rate(fold_model, train_loader, nn.CrossEntropyLoss(), device)
        optimal_lr = lrs[np.argmin(losses)]
        logging.info(f"Optimal learning rate found: {optimal_lr:.6f}")

        # Initialize optimizer and scheduler
        optimizer = optim.SGD([
            {'params': fold_model.model.fc.parameters(), 'lr': optimal_lr},
            {'params': [p for n, p in fold_model.model.named_parameters() if not n.startswith('fc')],
             'lr': optimal_lr/10}
        ], momentum=0.9, weight_decay=1e-4)

        scheduler = OneCycleLR(optimizer, max_lr=optimal_lr, epochs=num_epochs,
                                steps_per_epoch=len(train_loader))
        early_stopping = EarlyStopping(patience=10, verbose=True)

        # Training loop for the current fold
        fold_train_losses = []
        fold_val_losses = []
        fold_train_accs = []
        fold_val_accs = []

        for epoch in range(num_epochs):
            epoch_start_time = time.time()
            try:
                train_metrics = train_epoch(
                    fold_model, model_ema, train_loader, unlabeled_loader,
                    optimizer, dist_align, flatmatch_loss,
                    confidence_thresholder, device
                    )


                # Validation phase
                val_metrics = validate(fold_model, val_loader, nn.CrossEntropyLoss(), device)
                epoch_time = time.time() - epoch_start_time
                progress_tracker.record_epoch_time(epoch_time)
                progress_tracker.update_progress(fold, epoch)

                issues = mistake_detector.check_common_issues(
                    fold_model, train_metrics, val_metrics, epoch
                )

                if issues:
                    logging.warning("Potential issues detected:")
                    for issue in issues:
                        logging.warning(f"  - {issue}")

                # Store metrics
                fold_train_losses.append(train_metrics['loss'])
                fold_train_accs.append(train_metrics['acc'])
                fold_val_losses.append(val_metrics['val_loss'])
                fold_val_accs.append(val_metrics['val_acc'])

                # Log metrics
                logging.info(f'Epoch {epoch+1}/{num_epochs}:')
                logging.info(f'Train Loss: {train_metrics["loss"]:.4f}, Train Acc: {train_metrics["acc"]:.4f}')
                logging.info(f'Confidence Threshold: {train_metrics["confidence_threshold"]:.4f}')
                logging.info(f'Val Loss: {val_metrics["val_loss"]:.4f}, Val Acc: {val_metrics["val_acc"]:.4f}')

                # Checkpoint saving
                checkpoint_handler(fold_model, val_metrics["val_acc"], epoch)

                # Early stopping check
                if early_stopping(val_metrics["val_loss"], fold_model):
                    logging.info("Early stopping triggered!")
                    fold_model.load_state_dict(early_stopping.best_model)
                    break

                scheduler.step()

            except Exception as e:
                logging.error(f"Error during training: {str(e)}")
                logging.error(f"Error occurred at Fold {fold+1}, Epoch {epoch+1}")
                logging.error(f"Current learning rate: {optimizer.param_groups[0]['lr']}")
                raise

        # Store fold results
        fold_results.append({
            'model': fold_model,
            'val_metrics': val_metrics,
            'train_losses': fold_train_losses,
            'val_losses': fold_val_losses,
            'train_accs': fold_train_accs,
            'val_accs': fold_val_accs
        })

        # Update best model
        if val_metrics['val_acc'] > best_val_acc:
            best_val_acc = val_metrics['val_acc']
            best_model = deepcopy(fold_model)

    return best_model, fold_results


def plot_training_curves(fold_results, save_dir):
    plt.figure(figsize=(15, 10))

    # Plot losses
    plt.subplot(2, 2, 1)
    for fold_idx, fold_data in enumerate(fold_results):
        plt.plot(fold_data['train_losses'], label=f'Fold {fold_idx+1} Train')
        plt.plot(fold_data['val_losses'], label=f'Fold {fold_idx+1} Val')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot accuracies
    plt.subplot(2, 2, 2)
    for fold_idx, fold_data in enumerate(fold_results):
        plt.plot(fold_data['train_accs'], label=f'Fold {fold_idx+1} Train')
        plt.plot(fold_data['val_accs'], label=f'Fold {fold_idx+1} Val')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'training_curves.png'))
    plt.close()


def plot_confusion_matrix(model, test_loader, device, save_dir):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(os.path.join(save_dir, 'confusion_matrix.png'))
    plt.close()


def main():
    # Set save directory
    save_dir = "/content/drive/MyDrive/flatmatch-stl-10-results-trial"
    os.makedirs(save_dir, exist_ok=True)

    try:
        # Load data
        train_loader, val_loader, unlabeled_loader, test_loader = get_stl10_dataloaders()
        logging.info("Data loaded successfully")

        # Train model with k-fold cross-validation
        best_model, fold_results = train_with_kfold_improved(
            model=None,
            labeled_dataset=train_loader.dataset,
            unlabeled_loader=unlabeled_loader,
            num_folds=5,
            num_epochs=120,
            device=device
        )

        # Evaluate on test set
        test_metrics = validate(best_model, test_loader, nn.CrossEntropyLoss(), device)

        # Plot results
        plot_training_curves(fold_results, save_dir)
        plot_confusion_matrix(best_model, test_loader, device, save_dir)

        # Save model and results
        torch.save({
            'model_state_dict': best_model.state_dict(),
            'test_metrics': test_metrics
        }, os.path.join(save_dir, 'best_model_flatmatch_stl-10.pth'))

        # Log final results
        logging.info("\nFinal Test Metrics:")
        logging.info(f"Accuracy: {test_metrics['val_acc']:.4f}")
        logging.info(f"Precision: {test_metrics['val_precision']:.4f}")
        logging.info(f"Recall: {test_metrics['val_recall']:.4f}")
        logging.info(f"F1 Score: {test_metrics['val_f1']:.4f}")

        # Log final results using print
        print("\nFinal Test Metrics:")
        print(f"Accuracy: {test_metrics['val_acc']:.4f}")
        print(f"Precision: {test_metrics['val_precision']:.4f}")
        print(f"Recall: {test_metrics['val_recall']:.4f}")
        print(f"F1 Score: {test_metrics['val_f1']:.4f}")


        return best_model, test_metrics

    except Exception as e:
        logging.error(f"Error in main pipeline: {str(e)}")
        raise


if __name__ == "__main__":
    warnings.filterwarnings('ignore')
    model, metrics = main()


2025-04-14 23:31:39,302 - INFO - Epoch 21/120:
2025-04-14 23:31:39,303 - INFO - Train Loss: 2.3663, Train Acc: 0.1084
2025-04-14 23:31:39,305 - INFO - Confidence Threshold: 0.5028
2025-04-14 23:31:39,306 - INFO - Val Loss: 2.3119, Val Acc: 0.1050
2025-04-14 23:31:39,306 - INFO - EarlyStopping counter: 10 out of 10
2025-04-14 23:31:39,307 - INFO - Early stopping triggered!


Progress: Fold 5/5 | Epoch 21/120
Elapsed: 01:21:41 | Remaining: 00:28:05


2025-04-14 23:31:51,806 - INFO - 
Final Test Metrics:
2025-04-14 23:31:51,807 - INFO - Accuracy: 0.5499
2025-04-14 23:31:51,809 - INFO - Precision: 0.5735
2025-04-14 23:31:51,809 - INFO - Recall: 0.5499
2025-04-14 23:31:51,810 - INFO - F1 Score: 0.5294



Final Test Metrics:
Accuracy: 0.5499
Precision: 0.5735
Recall: 0.5499
F1 Score: 0.5294
