In [None]:
# Все импорты

import torch
import sys
import os
import pandas as pd
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from math import ceil
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, CosineAnnealingWarmRestarts
from torch.optim import AdamW
from torch_optimizer import Lamb
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from tqdm import tqdm
from IPython.display import clear_output
from torchvision.ops import StochasticDepth
from torch.cuda.amp import autocast, GradScaler
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch
import torchvision.transforms.functional as TF  
import random
import numpy as np
from torch.optim import SGD
from torchvision.models import EfficientNet_B4_Weights, ResNet50_Weights
from sklearn.metrics import confusion_matrix
import seaborn as sns
from torchvision import models

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Используемое устройство: {device}")

In [None]:
class MoleDataset(Dataset):
    def __init__(self, csv_file: str, img_dir: str, img_size: int = 256, is_train: bool = False):
        self.annotations = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.img_size = img_size
        self.is_train = is_train
        self.samples = [(self._find_image_path(row['isic_id']), row['benign_malignant']) 
                       for _, row in self.annotations.iterrows()]
        self.image_cache = {} 

    def _find_image_path(self, img_id):
        for ext in ('jpg', 'png', 'jpeg'):
            path = os.path.join(self.img_dir, f"{img_id}.{ext}")
            if os.path.exists(path):
                return path
        return None

    def __len__(self):
        return len(self.samples)

    def _apply_augmentations(self, image: torch.Tensor) -> torch.Tensor:
        if not self.is_train:
            return image
        
        original_size = (self.img_size, self.img_size)
        
        # Геометрические ауг с сохранением размера
        if random.random() > 0.5:
            shift_x = random.uniform(-0.03, 0.03) * self.img_size
            shift_y = random.uniform(-0.03, 0.03) * self.img_size
            image = TF.affine(
                image, 
                angle=0, 
                translate=(shift_x, shift_y), 
                scale=1.0, 
                shear=0
            )
            image = TF.resize(image, original_size)  
        
        if random.random() > 0.5:
            angle = random.uniform(-10, 10)
            image = TF.rotate(image, angle, fill=0)
            image = TF.resize(image, original_size)
        
        # Зум
        if random.random() > 0.5:
            scale_factor = random.uniform(0.95, 1.05)
            new_size = int(self.img_size * scale_factor)
            image = TF.resize(image, (new_size, new_size))
            image = TF.center_crop(image, original_size)  # Фиксируем размер

        # Слегка меняют цвет
        if random.random() > 0.5:
            brightness_factor = random.uniform(0.95, 1.05)
            contrast_factor = random.uniform(0.95, 1.05)
            image = TF.adjust_brightness(image, brightness_factor)
            image = TF.adjust_contrast(image, contrast_factor)

        # Эффекты освещения
        if random.random() > 0.7:
            kernel_size = int(0.05 * self.img_size)
            if kernel_size % 2 == 0:
                kernel_size += 1
            image = TF.gaussian_blur(image, kernel_size=[kernel_size, kernel_size], sigma=(0.1, 0.2))

        # Cutput на фоне
        if random.random() > 0.8:
            h, w = image.shape[1], image.shape[2]
            mask_size_h = int(random.uniform(0.02, 0.1) * h)
            mask_size_w = int(random.uniform(0.02, 0.1) * w)
            y = random.randint(0, h - mask_size_h)
            x = random.randint(0, w - mask_size_w)
            
            center_x, center_y = w // 2, h // 2
            if not (abs(x - center_x) < w//4 and abs(y - center_y) < h//4):
                image[:, y:y+mask_size_h, x:x+mask_size_w] = 0

        return image

    def __getitem__(self, idx):
        img_path, target = self.samples[idx]
        
        if img_path not in self.image_cache:
            img = Image.open(img_path).convert('RGB')
            # Используем torchvision.transforms.functional.resize вместо F.resize
            img = TF.resize(img, (self.img_size, self.img_size))
            img = TF.to_tensor(img)
            self.image_cache[img_path] = img
            
        image = self.image_cache[img_path].clone()
        
        if self.is_train:
            image = self._apply_augmentations(image)
        
        # Нормализация (после всех аугментаций)
        image = TF.normalize(image, 
                          mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
        
        return image, torch.tensor(target, dtype=torch.float32)

In [None]:
BATCH_SIZE = 64

img_size = 224


train_ds = MoleDataset(
    csv_file='dataset_split/train_labels.csv',
    img_dir='dataset_split/train', 
    img_size=img_size,
    is_train=True
)

val_ds = MoleDataset(
    csv_file='dataset_split/val_labels.csv', 
    img_dir='dataset_split/val', 
    img_size=img_size,
    is_train=False
)

test_ds = MoleDataset(
    csv_file='dataset_split/test_labels.csv',
    img_dir='dataset_split/test',
    img_size=img_size,
    is_train=False
)

train_loader = DataLoader(
    train_ds, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
)
test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
)

In [None]:
# ResNet50 PyTorch

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
    
class ResNet50(nn.Module):
    def __init__(self, num_classes=1):  
        super().__init__()
        
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(64, 3)
        self.layer2 = self._make_layer(128, 4, stride=2)
        self.layer3 = self._make_layer(256, 6, stride=2)
        self.layer4 = self._make_layer(512, 3, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=0.5)
        self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)

    def _make_layer(self, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * Bottleneck.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels,
                          out_channels * Bottleneck.expansion,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                nn.BatchNorm2d(out_channels * Bottleneck.expansion),
            )

        layers = []
        layers.append(Bottleneck(self.in_channels, out_channels,
                                 stride=stride,
                                 downsample=downsample))
        
        self.in_channels = out_channels * Bottleneck.expansion
        

        for _ in range(1, blocks):
            layers.append(Bottleneck(self.in_channels,
                                     out_channels))
        
        return nn.Sequential(*layers)

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)   
        x = torch.flatten(x, 1)   
        x = self.dropout(x)
        logits = self.fc(x)       
               
        return logits

model = ResNet50().to(device)

In [None]:
# EfficientNet B4
model1 = models.efficientnet_b4(weights=EfficientNet_B4_Weights)
model1.classifier = nn.Sequential(nn.Dropout(0.6), nn.Linear(1792, 1))
model1 = model1.to(device)

# Заморозка
for param in model1.parameters():
    param.requires_grad = False
for param in model1.classifier.parameters():
    param.requires_grad = True

In [None]:
# ResNet50 
model2 = models.resnet50(weights=ResNet50_Weights)

# Замена классификатора 
num_features = model2.fc.in_features
model2.fc = nn.Sequential(
    nn.Dropout(0.8),  # Dropout для регуляризации
    nn.Linear(num_features, 1)  # Один выход для бинарной классификации
)

model2 = model2.to(device)

for param in model2.parameters():
    param.requires_grad = False
    
for param in model2.fc.parameters():
    param.requires_grad = True


In [None]:
# ConvNeXt
model3 = models.convnext_small(pretrained=True)

for param in model3.parameters():
    param.requires_grad = False

# Заменяем классификатор на бинарный
num_features = model3.classifier[-1].in_features 
model3.classifier[-1] = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(num_features, 1),  # Выходной слой с 1 нейроном 
    nn.Flatten() 
)

# Размораживаем классификатор
for param in model3.classifier.parameters():
    param.requires_grad = True

In [None]:
def show_metrics(epoch: int, **samples) -> None:
    clear_output(wait=True)
    plt.figure(figsize=(18, 10))
    
    for i, sample in enumerate(samples):
        plt.subplot(ceil(len(samples) / 3), 3, i + 1)
        plt.title(sample)
        plt.yscale('log' if sample == 'Loss' else 'linear')
        
        for phase in ['train', 'val', 'test']:  # Теперь поддерживает test
            if phase in samples[sample]:
                plt.plot(range(len(samples[sample][phase])), 
                        samples[sample][phase], 
                        label=f"{phase}")
        plt.legend()
    
    plt.show()
    
    print(f"\nEpoch {epoch} summary:")
    for name, grp in samples.items():
        phases = []
        for phase in ['train', 'val', 'test']:
            if phase in grp and grp[phase]:
                phases.append(f"{phase}: {grp[phase][-1]:.4f}")
        print(f"{name:<9} | {' | '.join(phases)}")


def plot_confusion_matrix(all_labels, all_predictions):
    cm = confusion_matrix(all_labels, all_predictions)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Доброкачественные', 'Злокачественные'],
                yticklabels=['Доброкачественные', 'Злокачественные'])
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.show()

In [None]:
class TrainingPipeline:
    def __init__(self, model: nn.Module, train_loader: DataLoader,
                 val_loader: DataLoader, test_loader: DataLoader, criterion: nn.Module,
                 optimizer: optim.Optimizer, device: torch.device,
                 scheduler=None,
                 metrics_visualizer=show_metrics,
                 scheduler_step_per_epoch: bool = True,
                 checkpoint_dir=None,
                 metric_average: str = 'macro',
                 grad_accum_steps: int = 1,
                 use_amp: bool = True,
                 unfreeze_epoch: int = 20,
                 early_stopping_patience: int = None, 
                 early_stopping_min_delta: float = 0.0): 
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.scheduler = scheduler
        self.metrics_visualizer = metrics_visualizer
        self.scheduler_step_per_epoch = scheduler_step_per_epoch
        self.checkpoint_dir = checkpoint_dir
        self.metric_average = metric_average
        self.grad_accum_steps = grad_accum_steps
        self.use_amp = use_amp
        self.scaler = torch.amp.GradScaler(enabled=use_amp)
        
        self.early_stopping_patience = early_stopping_patience
        self.early_stopping_min_delta = early_stopping_min_delta
        self.best_val_loss = float('inf')
        self.epochs_without_improvement = 0

        self.current_stage = 1  # 1 - только классификатор, 2 - все слои
        self.unfreeze_epoch = unfreeze_epoch  # Эпоха для размораживания

        if self.checkpoint_dir:
            os.makedirs(self.checkpoint_dir, exist_ok=True)

        self.metrics = {
            'Loss': {'train': [], 'val': [], 'test': []},
            'Accuracy': {'train': [], 'val': [], 'test': []},
            'Precision': {'train': [], 'val': [], 'test': []},
            'Recall': {'train': [], 'val': [], 'test': []},
            'F1': {'train': [], 'val': [], 'test': []}
        }

    def unfreeze_backbone(self):
        """Размораживает backbone и добавляет новую группу параметров"""
        print("\nUnfreezing backbone layers...")
        
        # 1. Размораживаем все параметры
        for param in self.model.parameters():
            param.requires_grad = True
        
        # 2. Создаем новые группы параметров
        head_params = []
        backbone_params = []
        
        for name, param in self.model.named_parameters():
            if 'classifier' in name or 'fc' in name:  # Параметры головы
                head_params.append(param)
            else:  # Параметры backbone
                backbone_params.append(param)
        
        # 3. Заменяем оптимизатор
        self.optimizer = torch.optim.AdamW([
                {'params': [p for n,p in self.model.named_parameters() if 'classifier' in n or 'fc' in n], 'lr': 1e-6},
                {'params': [p for n,p in self.model.named_parameters() if 'classifier' not in n and 'fc' not in n], 'lr': 1e-5}
        ], weight_decay=0.05)
        
        self.current_stage = 2

    def evaluate_test(self):
        """Отдельный метод для оценки на тестовом наборе"""
        self.model.eval()
        all_labels_list = []
        all_outputs_list = []
        
        with torch.no_grad():
            for inputs, labels in tqdm(self.test_loader):
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                
                if outputs.dim() == 2 and outputs.shape[1] == 1:  
                    outputs = outputs.squeeze(-1)
                
                all_outputs_list.append(outputs.cpu())
                all_labels_list.append(labels.cpu())
        
        all_outputs = torch.cat(all_outputs_list)
        all_labels = torch.cat(all_labels_list)
        
        # Вычисляем метрики
        test_res = self._calculate_metrics(all_outputs, all_labels)
        test_res['loss'] = self.criterion(all_outputs.to(self.device), 
                                        all_labels.float().to(self.device)).item()
        
        # Сохраняем метрики
        for k, alias in (('loss', 'Loss'),
                        ('accuracy', 'Accuracy'),
                        ('precision', 'Precision'),
                        ('recall', 'Recall'),
                        ('f1', 'F1')):
            self.metrics[alias]['test'].append(test_res[k])
        
        # Вычисляем предсказания для confusion matrix
        predictions = (torch.sigmoid(all_outputs) > 0.5).long()
        if len(all_outputs.shape) > 1 and all_outputs.shape[1] == 1:
            predictions = predictions.squeeze(1)
        
        # Отображаем confusion matrix
        print("\n=== Test Confusion Matrix ===")
        plot_confusion_matrix(all_labels.numpy(), predictions.numpy())
        
        return test_res

    def _check_early_stopping(self, val_loss: float) -> bool:
        """Проверяет, нужно ли остановить обучение."""
        if self.early_stopping_patience is None:
            return False
            
        improved = (self.best_val_loss - val_loss) > self.early_stopping_min_delta
        
        if improved:
            self.best_val_loss = val_loss
            self.epochs_without_improvement = 0
        else:
            self.epochs_without_improvement += 1
            if self.epochs_without_improvement >= self.early_stopping_patience:
                print(f"\nEarly stopping triggered! No improvement for {self.early_stopping_patience} epochs.")
                return True
        return False

    def _calculate_metrics(self, all_outputs: torch.Tensor, all_labels: torch.Tensor):
        if all_labels.numel() == 0:
            return {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
        
        predicted = (torch.sigmoid(all_outputs) > 0.5).long()
        if len(all_outputs.shape) > 1 and all_outputs.shape[1] == 1:
            predicted = predicted.squeeze(1)
        
        if all_outputs.dim() > 1:
            if all_outputs.size(1) == 1:  # 
                all_outputs = all_outputs.squeeze(-1)
            else:  
                all_outputs = torch.argmax(all_outputs, dim=1)

        labels_np = all_labels.cpu().numpy()
        predicted_np = predicted.cpu().numpy()
        
        accuracy = (predicted == all_labels).float().mean().item()
        precision = precision_score(labels_np, predicted_np, average='binary', zero_division=0)
        recall = recall_score(labels_np, predicted_np, average='binary', zero_division=0)
        f1 = f1_score(labels_np, predicted_np, average='binary', zero_division=0)
        
        return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1}

    def _run_epoch(self, phase: str):
        is_train = phase == 'train'
        if phase == 'train':
            self.model.train()
            loader = self.train_loader
        elif phase == 'val':
            self.model.eval()
            loader = self.val_loader
        else:  # test
            self.model.eval()
            loader = self.test_loader

        running_loss = 0.0
        all_labels_list = []
        all_outputs_list = []
        processed_samples = 0
        counter = 0

        context = torch.enable_grad() if is_train else torch.no_grad()
        with context:
            for batch_idx, (inputs, labels) in enumerate(tqdm(loader)):
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                with torch.amp.autocast(device_type='cuda', enabled=self.use_amp and is_train):
                    outputs = self.model(inputs)
                    
                    if outputs.dim() == 2 and outputs.shape[1] == 1:  
                        outputs = outputs.squeeze(-1) 
                    elif outputs.dim() > 2:
                        outputs = outputs.view(outputs.size(0), -1)  
                        if outputs.shape[1] == 1:
                            outputs = outputs.squeeze(-1)
                    
                    loss = self.criterion(outputs, labels.float())
                    
                    if is_train:
                        loss = loss / self.grad_accum_steps

                if is_train:
                    self.scaler.scale(loss).backward()

                    if (batch_idx + 1) % self.grad_accum_steps == 0 or (batch_idx + 1) == len(loader):
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                        self.optimizer.zero_grad()

                batch_size = inputs.size(0)
                running_loss += loss.item() * batch_size * (self.grad_accum_steps if is_train else 1)
                processed_samples += batch_size

                all_outputs_list.append(outputs.detach().cpu())
                all_labels_list.append(labels.detach().cpu())

        all_outputs_tensor = torch.cat(all_outputs_list, dim=0) if all_outputs_list else torch.empty(0)
        all_labels_tensor = torch.cat(all_labels_list, dim=0) if all_labels_list else torch.empty(0)

        epoch_loss = running_loss / processed_samples if processed_samples > 0 else 0.0
        epoch_metrics = self._calculate_metrics(all_outputs_tensor, all_labels_tensor)

        return {'loss': epoch_loss, **epoch_metrics}

    def _save_checkpoint(self, epoch: int, val_loss: float):
        if not self.checkpoint_dir:
            return

        checkpoint_path = os.path.join(self.checkpoint_dir, f'model_epoch_{epoch}.pth')
        state = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_loss': val_loss,
            'scaler_state_dict': self.scaler.state_dict() if self.use_amp else None,
        }
        if self.scheduler:
            state['scheduler_state_dict'] = self.scheduler.state_dict()

        torch.save(state, checkpoint_path)

    def run_training(self, num_epochs: int):
        for epoch in range(1, num_epochs + 1):
            if epoch == self.unfreeze_epoch and self.current_stage == 1:
                self.unfreeze_backbone()
            train_res = self._run_epoch('train')
            val_res = self._run_epoch('val')

            for k, alias in (('loss', 'Loss'),
                            ('accuracy', 'Accuracy'),
                            ('precision', 'Precision'),
                            ('recall', 'Recall'),
                            ('f1', 'F1')):
                self.metrics[alias]['train'].append(train_res[k])
                self.metrics[alias]['val'].append(val_res[k])

            if self.metrics_visualizer:
                filtered_metrics = {
                    k: {'train': v['train'], 'val': v['val']} 
                    for k, v in self.metrics.items()
                }
                self.metrics_visualizer(epoch, **filtered_metrics)
                print('Current LR:', self.optimizer.param_groups[0]['lr'])

            if self._check_early_stopping(val_res['loss']):
                break

            if self.scheduler:
                if self.scheduler_step_per_epoch and not isinstance(
                        self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step()
                elif isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_res['loss'])

            if self.checkpoint_dir:
                self._save_checkpoint(epoch, val_res['loss'])

        #test_res = self._run_epoch('test')
        test_res = self.evaluate_test() 
        print("\n=== Final Test Evaluation ===")
        for metric, value in test_res.items():
            print(f"{metric.capitalize():<10}: {value:.4f}")
            self.metrics[metric.capitalize()]['test'].append(value)  

        print("Training finished.")
        return self.metrics
    
    def evaluate_ensemble(self, model1, model2, model3):
        """Оценка ансамбля на тестовой выборке + метрики."""
        ensemble = HardVotingEnsemble(model1, model2, model3, self.device)
        preds, labels = ensemble.predict(self.test_loader)
        
        # Вычисляем метрики
        metrics = {
            "accuracy": accuracy_score(labels, preds),
            "precision": precision_score(labels, preds, zero_division=0),
            "recall": recall_score(labels, preds, zero_division=0),
            "f1": f1_score(labels, preds, zero_division=0),
        }
        
        # Confusion Matrix
        self._plot_confusion_matrix(labels.numpy(), preds.numpy())
        
        return metrics
    
    def _plot_confusion_matrix(self, y_true, y_pred):
        cm = confusion_matrix(y_true, y_pred)
        plt.figure(figsize=(8, 6))
        sns.heatmap(
            cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Benign (0)", "Malignant (1)"],
            yticklabels=["Benign (0)", "Malignant (1)"]
        )
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.title("Ensemble Confusion Matrix")
        plt.show()

In [None]:
checkpoint = torch.load('tested_models/checkpoints_ConvNeXt_small/ConvNeXt.pth')['model_state_dict']
model.load_state_dict(checkpoint) # Чтобы загрузить состояние

In [None]:
# FocalLoss

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):

        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)  
        focal_loss = self.alpha * (1 - pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        return focal_loss

In [None]:
model # Чтобы посмотреть архитектуру

In [None]:
num_epochs = 0 # Отдельно задать число эпох

In [None]:
optimizer = torch.optim.AdamW(
    model1.parameters(),  # Только размороженные параметры
    lr=1e-6,
    weight_decay=0.05
)


criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.37]).to(device))

scheduler = CosineAnnealingLR(
    optimizer=optimizer,
    T_max=30,
    eta_min=1e-6
)

In [None]:
pipeline = TrainingPipeline(model=model1,
                            train_loader=train_loader,
                            val_loader=val_loader,
                            test_loader=test_loader,
                            criterion=criterion,
                            optimizer=optimizer,
                            device=device,
                            scheduler=scheduler,
                            metrics_visualizer=show_metrics,
                            scheduler_step_per_epoch=True,
                            grad_accum_steps=2,
                            use_amp=True,
                            unfreeze_epoch=70,
                            early_stopping_patience=50,  
                            early_stopping_min_delta=0.00001, 
                            checkpoint_dir='tested_models/checkpoints_ConvNeXt_small')


In [None]:
final_metrics = pipeline.run_training(num_epochs=num_epochs)
print(final_metrics)

In [None]:
class HardVotingEnsemble:
    def __init__(self, model1, model2, model3, device="cuda"):
        self.models = [model1.to(device), model2.to(device), model3.to(device)]
        self.device = device
        
    def predict(self, dataloader):
        """Предсказывает классы для всей тестовой выборки."""
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for inputs, labels in tqdm(dataloader, desc="Ensemble Prediction"):
                inputs = inputs.to(self.device)
                batch_preds = []
                
                # Получаем предсказания от каждой модели
                for model in self.models:
                    model.eval()
                    outputs = model(inputs)
                    preds = (torch.sigmoid(outputs) > 0.5).int().cpu()
                    batch_preds.append(preds)
                
                # Применяем Hard Voting: 0 только если все три модели дали 0
                batch_preds = torch.stack(batch_preds)
                ensemble_preds = (batch_preds.sum(dim=0) > 0).int()  # 1 если хотя бы одна модель дала 1
                
                all_preds.append(ensemble_preds)
                all_labels.append(labels)
        
        return torch.cat(all_preds), torch.cat(all_labels)

In [None]:
checkpoint = torch.load('tested_models/checkpoints_ENetb4/ENet.pth')['model_state_dict']
model1.load_state_dict(checkpoint) # Чтобы загрузить состояние


checkpoint = torch.load('tested_models/checkpoints_ResNet50/ResNet.pth')['model_state_dict']
model2.load_state_dict(checkpoint) # Чтобы загрузить состояние

checkpoint = torch.load('tested_models/checkpoints_ConvNeXt_small/ConvNeXt.pth')['model_state_dict']
model3.load_state_dict(checkpoint) # Чтобы загрузить состояние
# Оцениваем ансамбль
ensemble_metrics = pipeline.evaluate_ensemble(model1, model2, model3)

print("\n=== Ensemble Test Metrics ===")
for name, value in ensemble_metrics.items():
    print(f"{name.capitalize():<10}: {value:.4f}")