In [1]:
import random
import torch.nn as nn
import torch.optim as optim
import time
import torch
from torchvision import datasets, transforms, models
from sklearn.metrics import f1_score


from torch.utils.data import DataLoader, random_split, ConcatDataset
import numpy as np

In [2]:
import matplotlib.pyplot as plt

def display_image(data_loader, index):
    images, labels = next(iter(data_loader))
    if index >= len(images):
        print(f"Index {index} is out of range. Maximum batch size is {len(images)}.")
        return

    image = images[index]
    label = labels[index].item()

    # Unnormalize the image for display
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    image = image * std + mean
    image = image.clamp(0, 1)  # Clip values to [0, 1]

    # Convert the image tensor to a numpy array
    image = image.permute(1, 2, 0).numpy()

    # Display the image and label
    plt.imshow(image)
    plt.title(f"Label: {label}")
    plt.axis('off')
    plt.show()

In [3]:
def prepare_data_loaders(size_tuple, augmented_count):
    # Define the path to the root directory containing the class folders
    root_dir = "data"

    # Define a mapping from folder names to class numbers
    class_mapping = {
        'barszcz': 1,
        'bigos': 2,
        'Kutia': 3,
        'makowiec': 4,
        'piernik': 5,
        'pierogi': 6,
        'sernik': 7,
        'grzybowa': 8
    }

    batch_size = 32
    validation_split = 0.2  # Proportion of data for validation

    # Data Augmentation and Transformations
    transform = transforms.Compose([
        transforms.Resize(size_tuple),
        transforms.RandomHorizontalFlip(),  # Random horizontal flip
        transforms.RandomCrop(size=(size_tuple[0]-20, size_tuple[1]-20), padding=20),  # Random crop
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=.3, hue=.1) ,
        transforms.RandomVerticalFlip(0.1),
        transforms.Resize(size_tuple),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
    ])

    # Load the original dataset
    original_dataset = datasets.ImageFolder(root=root_dir, transform=transforms.Compose([
        transforms.Resize(size_tuple),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
    ]))

    original_dataset.class_to_idx = {k: class_mapping[k] for k in original_dataset.class_to_idx.keys()}


    for _ in range(augmented_count):
        # Load the original dataset
        augmented_dataset = datasets.ImageFolder(root=root_dir, transform=transform)

        augmented_dataset.class_to_idx = {k: class_mapping[k] for k in augmented_dataset.class_to_idx.keys()}

        # Combine original and augmented datasets
        original_dataset = ConcatDataset([original_dataset, augmented_dataset])


    # Split into train and validation sets
    dataset_size = len(original_dataset)
    validation_size = int(validation_split * dataset_size)
    train_size = dataset_size - validation_size
    train_dataset, validation_dataset = random_split(original_dataset, [train_size, validation_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)
    return (train_loader, validation_loader)

# Training

In [4]:
def set_seed(seed_value):
    random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


# Ensemble learning

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
from sklearn.metrics import f1_score
import torchvision.models as models
import timm  # for Vision Transformer

class EnsembleTrainer:
    def __init__(self, num_classes, device, patience=10, min_delta=0.01, max_epochs=50):
        self.max_epochs = max_epochs
        self.device = device
        self.patience = patience
        self.min_delta = min_delta
        
        self.model_weights = dict()

        self.models = {
            #'swin_t': self._prepare_swin_t(num_classes),
            # 'efficientnet_b0': self._prepare_efficientnet_b0(num_classes),
            # 'resnet_18': self._prepare_resnet_18(num_classes),
            'mobilenet_v3_large': self._prepare_mobilenet_v3_large(num_classes),
            'vit_tiny': self._prepare_vit_tiny(num_classes)
        }
        
        # Optimizers for each model
        self.optimizers = {
            name: optim.AdamW(model.parameters()) 
            for name, model in self.models.items()
        }
        
        self.criterion = nn.CrossEntropyLoss()
        
        # Tracking metrics
        self.results = {
            name: {
                'train_losses': [],
                'val_losses': [],
                'f1_scores': [],
                'best_val_f1': 0,
                'patience_counter': 0
            } for name in self.models.keys()
        }


    def _prepare_efficientnet_b0(self, num_classes):
        model = models.efficientnet_b0(weights='DEFAULT')
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        return model.to(self.device)
    
    def _prepare_resnet_18(self, num_classes):
        model = models.resnet18(weights='DEFAULT')
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model.to(self.device)

    def _prepare_mobilenet_v3_large(self, num_classes):
        model = models.mobilenet_v3_large(weights='IMAGENET1K_V2')
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
        return model.to(self.device)
    
    def _prepare_vit_tiny(self, num_classes):
        model = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=num_classes)
        return model.to(self.device)

    def _prepare_swin_t(self, num_classes):
        model = models.swin_t(weights='DEFAULT')
        if hasattr(model, 'head'):
            model.head = nn.Linear(model.head.in_features, num_classes)
        else:
            model.fc = nn.Linear(model.num_features, num_classes)
        return model.to(self.device)
        

    def train_single_model(self, model, optimizer, train_loader, validation_loader, model_name):
        for epoch in range(1, self.max_epochs + 1):  # Your original epoch range
            current_time = time.time()
            
            # Training phase
            model.train()
            epoch_loss_train = 0
            running_loss = 0
            
            for i, (inputs, labels) in enumerate(train_loader, 0):
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                epoch_loss_train += loss.item()
                
                if i % 20 == 19:
                    print(f"{model_name} - Current time: {round(time.time() - current_time, 0)} s, "
                          f"epoch: {epoch}/50, minibatch: {i + 1:5d}/{len(train_loader)}, "
                          f"running loss: {running_loss / 500:.3f}")
                    running_loss = 0

            # Validation phase
            model.eval()
            epoch_loss_val = 0
            all_preds = []
            all_labels = []
            
            with torch.no_grad():
                for images, labels in validation_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = model(images)
                    loss = self.criterion(outputs, labels)
                    epoch_loss_val += loss.item()
                    
                    _, preds = torch.max(outputs, 1)
                    all_preds.extend(preds.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())

            # Compute metrics
            avg_train_loss = epoch_loss_train / len(train_loader)
            avg_val_loss = epoch_loss_val / len(validation_loader)
            val_f1 = f1_score(all_labels, all_preds, average='macro')

            # Update results
            results = self.results[model_name]
            results['train_losses'].append(avg_train_loss)
            results['val_losses'].append(avg_val_loss)
            results['f1_scores'].append(val_f1)

            print(f"{model_name} - Epoch {epoch}: "
                  f"Train Loss = {avg_train_loss:.4f}, "
                  f"Val Loss = {avg_val_loss:.4f}, "
                  f"Val F1 = {val_f1:.4f}")

            # Early stopping
            if val_f1 > results['best_val_f1'] + self.min_delta:
                results['best_val_f1'] = val_f1
                results['patience_counter'] = 0
                
                # Save best model
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': results['train_losses'],
                    'val_loss': results['val_losses'],
                    'f1_metric_val': results['f1_scores'],
                    'best_val_f1': results['best_val_f1'],
                }, f'models/{model_name}_best_f1.pth')
                self.model_weights[model_name] = val_f1
            else:
                results['patience_counter'] += 1

            if results['patience_counter'] >= self.patience:
                print(f"{model_name} - Early stopping triggered at epoch {epoch}, best f1 score {results['f1_scores'][-11]}")
                break

    def train_ensemble(self, train_loader, validation_loader):
        # Train each model
        for name, model in self.models.items():
            print(f"\nTraining {name}")
            self.train_single_model(model, self.optimizers[name], 
                                    train_loader, validation_loader, name)

    def ensemble_predict(self, dataloader):
        # Weighted voting ensemble prediction
        all_ensemble_preds = []
        all_labels = []

        # Normalizacja wag, jeśli nie sumują się do 1
        total_weight = sum(self.model_weights.values())
        model_weights_normalized = {model: weight / total_weight for model, weight in self.model_weights.items()}

        for images, labels in dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)

            # Zbieramy predykcje każdego modelu pomnożone przez ich wagi
            weighted_preds = []
            for model_name, model in self.models.items():
                probs = torch.softmax(model(images), dim=1)
                weight = model_weights_normalized.get(model_name, 0.0)  # Domyślna waga 0, jeśli model nie ma wagi
                weighted_preds.append(probs * weight)

            # Suma ważonych predykcji
            ensemble_probs = torch.sum(torch.stack(weighted_preds), dim=0)

            # Ostateczne predykcje
            _, ensemble_preds = torch.max(ensemble_probs, 1)

            all_ensemble_preds.extend(ensemble_preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        # Obliczenie F1-score dla ostatecznych predykcji
        ensemble_f1 = f1_score(all_labels, all_ensemble_preds, average='macro')
        print(f"Weighted Ensemble F1 Score: {ensemble_f1:.4f}")

        return ensemble_f1

In [9]:
# Assuming you have train_loader and validation_loader defined
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

set_seed(23)

train_loader, validation_loader = prepare_data_loaders((224,224), 1)

ensemble_trainer = EnsembleTrainer(num_classes=8, device=device, max_epochs = 50)

# Train all models
ensemble_trainer.train_ensemble(train_loader, validation_loader)

# Perform ensemble prediction
ensemble_f1 = ensemble_trainer.ensemble_predict(validation_loader)

Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth" to C:\Users\dimpi/.cache\torch\hub\checkpoints\mobilenet_v3_large-5c1a4163.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 26.9MB/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development



Training mobilenet_v3_large
mobilenet_v3_large - Current time: 54.0 s, epoch: 1/50, minibatch:    20/400, running loss: 0.045
mobilenet_v3_large - Current time: 127.0 s, epoch: 1/50, minibatch:    40/400, running loss: 0.028
mobilenet_v3_large - Current time: 197.0 s, epoch: 1/50, minibatch:    60/400, running loss: 0.029
mobilenet_v3_large - Current time: 270.0 s, epoch: 1/50, minibatch:    80/400, running loss: 0.026
mobilenet_v3_large - Current time: 349.0 s, epoch: 1/50, minibatch:   100/400, running loss: 0.024
mobilenet_v3_large - Current time: 400.0 s, epoch: 1/50, minibatch:   120/400, running loss: 0.024
mobilenet_v3_large - Current time: 460.0 s, epoch: 1/50, minibatch:   140/400, running loss: 0.023
mobilenet_v3_large - Current time: 515.0 s, epoch: 1/50, minibatch:   160/400, running loss: 0.020
mobilenet_v3_large - Current time: 564.0 s, epoch: 1/50, minibatch:   180/400, running loss: 0.020
mobilenet_v3_large - Current time: 636.0 s, epoch: 1/50, minibatch:   200/400, ru

KeyboardInterrupt: 

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
from sklearn.metrics import f1_score
import torchvision.models as models
import timm  # for Vision Transformer

class EnsembleTrainer:
    def __init__(self, num_classes, device, patience=10, min_delta=0.01, max_epochs=50):
        self.max_epochs = max_epochs
        self.device = device
        self.patience = patience
        self.min_delta = min_delta
        
        self.model_weights = dict()


        self.models = {
            'mnasnet1_3': self._prepare_mnasnet1_3(num_classes)
        }
        
        # Optimizers for each model
        self.optimizers = {
            name: optim.AdamW(model.parameters()) 
            for name, model in self.models.items()
        }
        
        self.criterion = nn.CrossEntropyLoss()
        
        # Tracking metrics
        self.results = {
            name: {
                'train_losses': [],
                'val_losses': [],
                'f1_scores': [],
                'best_val_f1': 0,
                'patience_counter': 0
            } for name in self.models.keys()
        }

    def _prepare_mnasnet1_3(self, num_classes):
        model = models.mnasnet1_3(weights="DEFAULT")
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        return model.to(self.device)
    
    def _prepare_efficientnet_b0(self, num_classes):
        model = models.efficientnet_b0(weights='DEFAULT')
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
        return model.to(self.device)
    
    def _prepare_resnet_18(self, num_classes):
        model = models.resnet18(weights='DEFAULT')
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model.to(self.device)

    def _prepare_mobilenet_v3_large(self, num_classes):
        model = models.mobilenet_v3_large(weights='IMAGENET1K_V2')
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
        return model.to(self.device)
    
    def _prepare_vit_tiny(self, num_classes):
        model = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=num_classes)
        return model.to(self.device)

    def _prepare_swin_t(self, num_classes):
        model = models.swin_t(weights='DEFAULT')
        if hasattr(model, 'head'):
            model.head = nn.Linear(model.head.in_features, num_classes)
        else:
            model.fc = nn.Linear(model.num_features, num_classes)
        return model.to(self.device)
        

    def train_single_model(self, model, optimizer, train_loader, validation_loader, model_name):
        for epoch in range(1, self.max_epochs + 1):  # Your original epoch range
            current_time = time.time()
            
            # Training phase
            model.train()
            epoch_loss_train = 0
            running_loss = 0
            
            for i, (inputs, labels) in enumerate(train_loader, 0):
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                epoch_loss_train += loss.item()
                
                if i % 20 == 19:
                    print(f"{model_name} - Current time: {round(time.time() - current_time, 0)} s, "
                          f"epoch: {epoch}/50, minibatch: {i + 1:5d}/{len(train_loader)}, "
                          f"running loss: {running_loss / 500:.3f}")
                    running_loss = 0

            # Validation phase
            model.eval()
            epoch_loss_val = 0
            all_preds = []
            all_labels = []
            
            with torch.no_grad():
                for images, labels in validation_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = model(images)
                    loss = self.criterion(outputs, labels)
                    epoch_loss_val += loss.item()
                    
                    _, preds = torch.max(outputs, 1)
                    all_preds.extend(preds.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())

            # Compute metrics
            avg_train_loss = epoch_loss_train / len(train_loader)
            avg_val_loss = epoch_loss_val / len(validation_loader)
            val_f1 = f1_score(all_labels, all_preds, average='macro')

            # Update results
            results = self.results[model_name]
            results['train_losses'].append(avg_train_loss)
            results['val_losses'].append(avg_val_loss)
            results['f1_scores'].append(val_f1)

            print(f"{model_name} - Epoch {epoch}: "
                  f"Train Loss = {avg_train_loss:.4f}, "
                  f"Val Loss = {avg_val_loss:.4f}, "
                  f"Val F1 = {val_f1:.4f}")

            # Early stopping
            if val_f1 > results['best_val_f1'] + self.min_delta:
                results['best_val_f1'] = val_f1
                results['patience_counter'] = 0
                
                # Save best model
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': results['train_losses'],
                    'val_loss': results['val_losses'],
                    'f1_metric_val': results['f1_scores'],
                    'best_val_f1': results['best_val_f1'],
                }, f'models/{model_name}_best_f1.pth')
                self.model_weights[model_name] = val_f1
            else:
                results['patience_counter'] += 1

            if results['patience_counter'] >= self.patience:
                print(f"{model_name} - Early stopping triggered at epoch {epoch}, best f1 score {results['f1_scores'][-11]}")
                break

    def train_ensemble(self, train_loader, validation_loader):
        # Train each model
        for name, model in self.models.items():
            print(f"\nTraining {name}")
            self.train_single_model(model, self.optimizers[name], 
                                    train_loader, validation_loader, name)

    def ensemble_predict(self, dataloader):
        # Weighted voting ensemble prediction
        all_ensemble_preds = []
        all_labels = []

        # Normalizacja wag, jeśli nie sumują się do 1
        total_weight = sum(self.model_weights.values())
        model_weights_normalized = {model: weight / total_weight for model, weight in self.model_weights.items()}

        for images, labels in dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)

            # Zbieramy predykcje każdego modelu pomnożone przez ich wagi
            weighted_preds = []
            for model_name, model in self.models.items():
                probs = torch.softmax(model(images), dim=1)
                weight = model_weights_normalized.get(model_name, 0.0)  # Domyślna waga 0, jeśli model nie ma wagi
                weighted_preds.append(probs * weight)

            # Suma ważonych predykcji
            ensemble_probs = torch.sum(torch.stack(weighted_preds), dim=0)

            # Ostateczne predykcje
            _, ensemble_preds = torch.max(ensemble_probs, 1)

            all_ensemble_preds.extend(ensemble_preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        # Obliczenie F1-score dla ostatecznych predykcji
        ensemble_f1 = f1_score(all_labels, all_ensemble_preds, average='macro')
        print(f"Weighted Ensemble F1 Score: {ensemble_f1:.4f}")

        return ensemble_f1

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
# Assuming you have train_loader and validation_loader defined
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

set_seed(23)

train_loader, validation_loader = prepare_data_loaders((224,224), 1)

ensemble_trainer = EnsembleTrainer(num_classes=8, device=device, max_epochs = 50)

# Train all models
ensemble_trainer.train_ensemble(train_loader, validation_loader)

# Perform ensemble prediction
ensemble_f1 = ensemble_trainer.ensemble_predict(validation_loader)

Downloading: "https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth" to C:\Users\dimpi/.cache\torch\hub\checkpoints\mnasnet1_3-a4c69d6f.pth
100%|██████████| 24.2M/24.2M [00:03<00:00, 7.93MB/s]



Training mnasnet1_3
mnasnet1_3 - Current time: 12.0 s, epoch: 1/50, minibatch:    20/400, running loss: 0.051
mnasnet1_3 - Current time: 22.0 s, epoch: 1/50, minibatch:    40/400, running loss: 0.031
mnasnet1_3 - Current time: 33.0 s, epoch: 1/50, minibatch:    60/400, running loss: 0.030
mnasnet1_3 - Current time: 44.0 s, epoch: 1/50, minibatch:    80/400, running loss: 0.028
mnasnet1_3 - Current time: 54.0 s, epoch: 1/50, minibatch:   100/400, running loss: 0.028
mnasnet1_3 - Current time: 64.0 s, epoch: 1/50, minibatch:   120/400, running loss: 0.023
mnasnet1_3 - Current time: 75.0 s, epoch: 1/50, minibatch:   140/400, running loss: 0.022
mnasnet1_3 - Current time: 85.0 s, epoch: 1/50, minibatch:   160/400, running loss: 0.025
mnasnet1_3 - Current time: 96.0 s, epoch: 1/50, minibatch:   180/400, running loss: 0.020
mnasnet1_3 - Current time: 106.0 s, epoch: 1/50, minibatch:   200/400, running loss: 0.022
mnasnet1_3 - Current time: 116.0 s, epoch: 1/50, minibatch:   220/400, running

KeyboardInterrupt: 

In [8]:
import torch
import torch.nn as nn

class StackingMetaLearner(nn.Module):
    def __init__(self, num_input_models, num_classes):
        """
        Initialize the meta-learner neural network for stacking.
        
        Args:
        - num_input_models (int): Number of base models feeding into this meta-learner
        - num_classes (int): Number of output classes to predict
        """
        super(StackingMetaLearner, self).__init__()
        
        # Define the network architecture
        self.layers = nn.Sequential(
            # Input layer: each input is the prediction probability from a base model
            nn.Linear(num_input_models * num_classes, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            
            # Hidden layers
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            
            # Dropout for regularization
            nn.Dropout(0.3),
            
            # Output layer
            nn.Linear(32, num_classes)
        )
        
    def forward(self, x):
        """
        Forward pass of the meta-learner
        
        Args:
        x (torch.Tensor): Concatenated predictions from base models
                           Shape: [batch_size, num_models * num_classes]
        
        Returns:
        torch.Tensor: Final predictions
        """
        return self.layers(x)

In [None]:
def prepare_stacking_input(base_model_predictions):
    """
    Prepare input for the meta-learner by concatenating predictions from base models
    
    Args:
    base_model_predictions (list of torch.Tensor): 
        List of prediction tensors from each base model
        Each tensor should have shape [batch_size, num_classes]
    
    Returns:
    torch.Tensor: Concatenated predictions ready for meta-learner input
    """
    # Ensure all prediction tensors are on the same device
    base_model_predictions = [pred.cpu() for pred in base_model_predictions]
    
    # Concatenate predictions along the last dimension
    stacked_input = torch.cat(base_model_predictions, dim=1)
    
    return stacked_input

In [None]:
class EnsembleTrainer:
    def __init__(self, num_classes, device):
        self.device = device
        self.stacking_model = StackingMetaLearner(6, 8)
        self.stacking_model = self.stacking_model.to(device)
        self.model_weights = dict()
        self.models = {
            #'swin_t': self._prepare_swin_t(num_classes, 'swin_t'),
            'efficientnet_b0': self._prepare_efficientnet_b0(num_classes, 'efficientnet_b0'),
            'resnet_18': self._prepare_resnet_18(num_classes, 'resnet_18'),
            'mobilenet_v3_large': self._prepare_mobilenet_v3_large(num_classes, 'mobilenet_v3_large'),
            #'vit_tiny': self._prepare_vit_tiny(num_classes, 'vit_tiny')
        }
        
        
        # Optimizers for each model
        self.stacking_optimizers = optim.AdamW(self.stacking_model.parameters())
        
        self.criterion = nn.CrossEntropyLoss()
        
        # Tracking metrics
        self.stacking_results = {
                'train_losses': [],
                'val_losses': [],
                'f1_scores': [],
                'best_val_f1': 0,
                'patience_counter': 0
        }


    def _prepare_efficientnet_b0(self, num_classes, model_name):
        model = models.efficientnet_b0(weights='DEFAULT')
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

        # Load the saved checkpoint
        checkpoint = torch.load(f'models/{model_name}_best_f1.pth')

        # Load the model weights
        model.load_state_dict(checkpoint['model_state_dict'])

        # Optional: Set the model to evaluation mode
        model.eval()
        self.model_weights[model_name] = checkpoint['best_val_f1']
        return model.to(self.device)
    
    def _prepare_resnet_18(self, num_classes, model_name):
        model = models.resnet18(weights='DEFAULT')
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        
        # Load the saved checkpoint
        checkpoint = torch.load(f'models/{model_name}_best_f1.pth')

        # Load the model weights
        model.load_state_dict(checkpoint['model_state_dict'])

        # Optional: Set the model to evaluation mode
        model.eval()
        self.model_weights[model_name] = checkpoint['best_val_f1']
        
        return model.to(self.device)

    def _prepare_mobilenet_v3_large(self, num_classes, model_name):
        model = models.mobilenet_v3_large(weights='IMAGENET1K_V2')
        model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
        
        # Load the saved checkpoint
        checkpoint = torch.load(f'models/{model_name}_best_f1.pth')

        # Load the model weights
        model.load_state_dict(checkpoint['model_state_dict'])

        # Optional: Set the model to evaluation mode
        model.eval()
        self.model_weights[model_name] = checkpoint['best_val_f1']
        
        
        return model.to(self.device)
    
    def _prepare_vit_tiny(self, num_classes, model_name):
        model = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=num_classes)
        # Load the saved checkpoint
        checkpoint = torch.load(f'models/{model_name}_best_f1.pth')

        # Load the model weights
        model.load_state_dict(checkpoint['model_state_dict'])

        # Optional: Set the model to evaluation mode
        model.eval()
        self.model_weights[model_name] = checkpoint['best_val_f1']
        
        return model.to(self.device)

    def _prepare_swin_t(self, num_classes, model_name):
        model = models.swin_t(weights='DEFAULT')
        if hasattr(model, 'head'):
            model.head = nn.Linear(model.head.in_features, num_classes)
        else:
            model.fc = nn.Linear(model.num_features, num_classes)
            
        # Load the saved checkpoint
        checkpoint = torch.load(f'models/{model_name}_best_f1.pth')

        # Load the model weights
        model.load_state_dict(checkpoint['model_state_dict'])

        # Optional: Set the model to evaluation mode
        model.eval()
        self.model_weights[model_name] = checkpoint['best_val_f1']
        return model.to(self.device)
    
    def train_single_model(self, train_loader, validation_loader, model_name):
        for epoch in range(1, 50 + 1):  # Your original epoch range
            current_time = time.time()
            
            # Training phase
            self.stacking_model.train()
            epoch_loss_train = 0
            running_loss = 0
            weighted_preds = []
            
            for i, (inputs, labels) in enumerate(train_loader, 0):
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                weighted_preds = []
                for model_name, model in self.models.items():
                    probs = torch.softmax(model(inputs), dim=1)
                    weighted_preds.append(probs.cpu())
                
                
                stacked_input = torch.cat(weighted_preds, dim=1)
                self.stacking_optimizers.zero_grad()
                outputs = self.stacking_model(stacked_input)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.stacking_optimizers.step()

                running_loss += loss.item()
                epoch_loss_train += loss.item()
                
                if i % 20 == 19:
                    print(f"{model_name} - Current time: {round(time.time() - current_time, 0)} s, "
                          f"epoch: {epoch}/50, minibatch: {i + 1:5d}/{len(train_loader)}, "
                          f"running loss: {running_loss / 500:.3f}")
                    running_loss = 0

            # Validation phase
            self.stacking_model.eval()
            epoch_loss_val = 0
            all_preds = []
            all_labels = []
            
            with torch.no_grad():
                for images, labels in validation_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    weighted_preds = []
                    for model_name, model in self.models.items():
                        probs = torch.softmax(model(images), dim=1)
                        weighted_preds.append(probs.cpu())
                    
                    
                    stacked_input = torch.cat(weighted_preds, dim=1)
                    outputs = self.stacking_model(images)
                    loss = self.criterion(outputs, labels)
                    epoch_loss_val += loss.item()
                    
                    _, preds = torch.max(outputs, 1)
                    all_preds.extend(preds.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())

            # Compute metrics
            avg_train_loss = epoch_loss_train / len(train_loader)
            avg_val_loss = epoch_loss_val / len(validation_loader)
            val_f1 = f1_score(all_labels, all_preds, average='macro')

            # Update results
            self.stacking_results['train_losses'].append(avg_train_loss)
            self.stacking_results['val_losses'].append(avg_val_loss)
            self.stacking_results['f1_scores'].append(val_f1)

            print(f"{model_name} - Epoch {epoch}: "
                  f"Train Loss = {avg_train_loss:.4f}, "
                  f"Val Loss = {avg_val_loss:.4f}, "
                  f"Val F1 = {val_f1:.4f}")

            # Early stopping
            if val_f1 > self.stacking_results['best_val_f1'] + self.min_delta:
                self.stacking_results['best_val_f1'] = val_f1
                self.stacking_results['patience_counter'] = 0
                
                # Save best model
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.stacking_model.state_dict(),
                    'optimizer_state_dict': self.stacking_optimizers.state_dict(),
                    'train_loss': self.stacking_results['train_losses'],
                    'val_loss': self.stacking_results['val_losses'],
                    'f1_metric_val': self.stacking_results['f1_scores'],
                    'best_val_f1': self.stacking_results['best_val_f1'],
                }, f'models/{model_name}_best_f1.pth')
                self.model_weights[model_name] = val_f1
            else:
                self.stacking_results['patience_counter'] += 1

            if self.stacking_results['patience_counter'] >= self.patience:
                print(f"{model_name} - Early stopping triggered at epoch {epoch}, best f1 score {self.stacking_results['f1_scores'][-11]}")
                break
        

    def ensemble_predict(self, dataloader):
        # Weighted voting ensemble prediction
        all_ensemble_preds = []
        all_labels = []

        # Normalizacja wag, jeśli nie sumują się do 1
        total_weight = sum(self.model_weights.values())
        model_weights_normalized = {model: weight / total_weight for model, weight in self.model_weights.items()}

        for images, labels in dataloader:
            print("Percent of complition:", len(images)/len(dataloader))
            images = images.to(self.device)
            labels = labels.to(self.device)

            # Zbieramy predykcje każdego modelu pomnożone przez ich wagi
            weighted_preds = []
            for model_name, model in self.models.items():
                probs = torch.softmax(model(images), dim=1)
                weight = model_weights_normalized.get(model_name, 0.0)  # Domyślna waga 0, jeśli model nie ma wagi
                weighted_preds.append(probs * weight)

            # Suma ważonych predykcji
            ensemble_probs = torch.sum(torch.stack(weighted_preds), dim=0)

            # Ostateczne predykcje
            _, ensemble_preds = torch.max(ensemble_probs, 1)

            all_ensemble_preds.extend(ensemble_preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

        # Obliczenie F1-score dla ostatecznych predykcji
        ensemble_f1 = f1_score(all_labels, all_ensemble_preds, average='macro')
        print(f"Weighted Ensemble F1 Score: {ensemble_f1}")

        return ensemble_f1

In [None]:
# Assuming you have train_loader and validation_loader defined
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

set_seed(23)

train_loader, validation_loader = prepare_data_loaders((224,224), 1)

ensemble_trainer = EnsembleTrainer(num_classes=8, device=device, max_epochs = 50)

# Train all models
ensemble_trainer.train_single_model(train_loader, validation_loader, 'stacking_model')
