In [1]:
import os
import shutil
import random

import numpy as np
import torch
import torch.nn as nn
from torch.nn.attention import flex_attention
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
import numpy as np
from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
import timm
from clearml import Task, Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from omegaconf import OmegaConf

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def test_task(cfg_path):
    cfg = OmegaConf.load('configs/config.yaml')

    train_task = Task.init(task_name='123')
    train_task.connect_configuration(cfg, name="config")

    train_task.add_parameter("train.epoch", cfg.train.epoch)
    train_task.add_parameter("train.learning_rate", cfg.train.learning_rate)
    print(f"EPOCH = {cfg.train.epoch}")

In [None]:
task = Task.init(task_name='123')

In [18]:
cfg = OmegaConf.load('configs/config.yaml')

In [21]:
type(dict(cfg))

dict

In [22]:
dict(cfg)

{'pipeline': {'pipe_name': 'SSL pipeline', 'pipe_proj_name': 'PixPro'},
 'task': {'proj_name': 'PixPro', 'task_name': 'ResNet'},
 'model': {'backbone': 'resnet18', 'pretrained': False, 'projector_blocks': 1, 'predictor_blocks': 1, 'reduction': 4},
 'data': {'img_size': 640, 'dataset_name': 'ssl_turbine_dataset', 'train_folder': 'turbine_train', 'val_folder': 'turbine_val', 'batchsize': 32, 'numworkers': 16},
 'train': {'epoch': 1, 'lr_start': 0.001, 'lr_end': 1e-05, 'devices': 'auto', 'accelerator': 'auto', 'val_step': 10, 'log_step': 5},
 'val': {'eps': 0.5, 'min_samples': 5, 'sample_fraction': 1.0}}

In [23]:
task.connect_configuration(dict(cfg))

{'pipeline': {'pipe_name': 'SSL pipeline', 'pipe_proj_name': 'PixPro'},
 'task': {'proj_name': 'PixPro', 'task_name': 'ResNet'},
 'model': {'backbone': 'resnet18', 'pretrained': False, 'projector_blocks': 1, 'predictor_blocks': 1, 'reduction': 4},
 'data': {'img_size': 640, 'dataset_name': 'ssl_turbine_dataset', 'train_folder': 'turbine_train', 'val_folder': 'turbine_val', 'batchsize': 32, 'numworkers': 16},
 'train': {'epoch': 1, 'lr_start': 0.001, 'lr_end': 1e-05, 'devices': 'auto', 'accelerator': 'auto', 'val_step': 10, 'log_step': 5},
 'val': {'eps': 0.5, 'min_samples': 5, 'sample_fraction': 1.0}}

In [27]:
new_cfg = task.get_parameters()

In [None]:
new_cfg

{}

: 

In [None]:
timm.list_models()

In [None]:
class PixelPropagationModule(nn.Module):
    def __init__(self, in_channels, reduction=4):

        super(PixelPropagationModule, self).__init__()
        self.inter_channels = in_channels // reduction
        self.query_conv = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
        self.key_conv   = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        B, C, H, W = x.size()
        proj_query = self.query_conv(x).view(B, self.inter_channels, -1).permute(0, 2, 1)  # [B, H*W, C]
        proj_key   = self.key_conv(x).view(B, self.inter_channels, -1) # [B, C, H*W]
        score = torch.bmm(proj_query, proj_key) # [B, H*W, H*W]
        attention = F.softmax(score, dim=-1) # [B, H*W, H*W]
        proj_value = self.value_conv(x).view(B, C, -1) # [B, C, H*W]
        out = torch.bmm(proj_value, attention.permute(0, 2, 1)) # transpose attention - [B, C, H*W]
        out = out.view(B, C, H, W)
        out = self.gamma * out + x
        return out

In [None]:
class FlexAttentionPPM(nn.Module):
    def __init__(self, in_channels, dropout_p=0.0, is_causal=False):

        super(FlexAttentionPPM, self).__init__()
        # Проекция для формирования Q, K, V в один шаг
        self.qkv_proj = nn.Conv2d(in_channels, in_channels * 3, kernel_size=1)
        self.dropout_p = dropout_p
        self.is_causal = is_causal
        # Обучаемый коэффициент для остаточного соединения
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):

        B, C, H, W = x.size()
        qkv = self.qkv_proj(x)
        q, k, v = torch.chunk(qkv, chunks=3, dim=1)  # каждый [B, C, H, W]
        
        q = q.view(B, C, -1).permute(0, 2, 1)  # [B, H*W, C]
        k = k.view(B, C, -1).permute(0, 2, 1)  # [B, H*W, C]
        v = v.view(B, C, -1).permute(0, 2, 1)  # [B, H*W, C]
        
        attn_out = flex_attention(query=q, key=k, value=v, 
                                  attn_mask=None,
                                  dropout_p=self.dropout_p,
                                  is_causal=self.is_causal)

        attn_out = attn_out.permute(0, 2, 1).view(B, C, H, W)
        out = self.gamma * attn_out + x
        return out

In [None]:
backbone = timm.create_model('resnet18', pretrained=False, features_only=True)

In [None]:
class PixPro(nn.Module):
    def __init__(self,
                 backbone_name,
                 pretrained=False
                 ):

        super(PixPro, self).__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, features_only=True)
        self.in_features = self.backbone(torch.randn(1, 3, 1024, 768))[-1].shape[1]
        self.proj_dim = self.in_features * 4
        self.hidden_dim = self.in_features // 4
        
        self.projector = nn.Sequential(
            nn.Conv2d(self.in_features, self.proj_dim, kernel_size=1),
            nn.BatchNorm2d(self.proj_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.proj_dim, self.proj_dim, kernel_size=1)
        )
        self.predictor = nn.Sequential(
            nn.Conv2d(self.proj_dim, self.hidden_dim, kernel_size=1),
            nn.BatchNorm2d(self.hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.hidden_dim, self.proj_dim, kernel_size=1)
        )
        self.pixel_propagation = PixelPropagationModule(self.proj_dim, reduction=4)

    def forward(self, x1, x2):

        f1 = self.backbone(x1)[-1]  # [B, in_features, H, W]
        f2 = self.backbone(x2)[-1]
        
        z1 = self.projector(f1)     # [B, proj_dim, H, W]
        z2 = self.projector(f2)
        
        p1 = self.predictor(z1)     # Предсказания (ветвь, по которой обновляются веса)
        p2 = self.predictor(z2)
        
        y1 = self.pixel_propagation(z1)  # Целевые представления (для target)
        y2 = self.pixel_propagation(z2)
        
        return p1, p2, y1, y2

In [None]:
def pixpro_loss(p1, p2, y1, y2):
    # Flatten по пространственным измерениям: [B, proj_dim, H, W] -> [B, proj_dim, H*W]
    p1_flat = p1.flatten(2)
    p2_flat = p2.flatten(2)
    y1_flat = y1.flatten(2)
    y2_flat = y2.flatten(2)
    # Вычисляем негативное косинусное сходство
    loss1 = -F.cosine_similarity(p1_flat, y2_flat.detach(), dim=1).mean()
    loss2 = -F.cosine_similarity(p2_flat, y1_flat.detach(), dim=1).mean()
    return 0.5 * (loss1 + loss2)

In [None]:
def advanced_augmentations(image_tensor):

    pil_img = transforms.ToPILImage()(image_tensor)
    
    augmentation = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.2, 1.0), ratio=(0.75, 1.33)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([
            transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
        # (Опционально) Solarize: можно раскомментировать, если эксперименты показывают пользу
        # transforms.RandomSolarize(threshold=128, p=0.2),
        transforms.ToTensor(),
        # Нормализация, если используется предобученный backbone с ImageNet нормализацией
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ])
    
    return augmentation(pil_img)

def batch_augmentations(batch_images):
    """
    Принимает батч изображений в виде тензора [B, C, H, W] и возвращает батч аугментированных изображений.
    """
    augmented_images = []
    for img in batch_images:
        augmented_images.append(advanced_augmentations(img))
    return torch.stack(augmented_images)

In [None]:
def train_pixpro(model, dataloader, optimizer, device, epoch, augment_fn=batch_augmentations):
    """
    Обучает модель PixPro на одной эпохе.
    """
    model.train()
    total_loss = 0.0
    task = Task.current_task()
    for batch_idx, (images, _) in enumerate(dataloader):
        x1 = augment_fn(images.clone())
        x2 = augment_fn(images.clone())
        x1, x2 = x1.to(device), x2.to(device)
        
        optimizer.zero_grad()
        p1, p2, y1, y2 = model(x1, x2)
        loss = pixpro_loss(p1, p2, y1, y2)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch} Batch {batch_idx}/{len(dataloader)} Loss: {loss.item():.4f}")
            task.get_logger().report_scalar("Training", "Loss", iteration=epoch * len(dataloader) + batch_idx, value=loss.item())
    
    avg_loss = total_loss / len(dataloader)
    print(f"=== Epoch {epoch} Average Loss: {avg_loss:.4f} ===")
    return avg_loss

In [None]:
def extract_dense_features_global(model, dataloader, device, sample_fraction=1.0):
    """
    Извлекает dense признаки для всего датасета: из последней карты признаков, преобразованной в форму [B*H*W, C].
    """
    model.eval()
    all_features = []
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            feat_maps = model.backbone(images)[-1]  # [B, C, H, W]
            feat_maps = F.normalize(feat_maps, p=2, dim=1)
            B, C, H, W = feat_maps.shape
            features = feat_maps.view(B, C, -1).permute(0, 2, 1).contiguous().view(-1, C)
            if sample_fraction < 1.0:
                num_samples = int(features.size(0) * sample_fraction)
                idx = torch.randperm(features.size(0))[:num_samples]
                features = features[idx]
            all_features.append(features.cpu().numpy())
    all_features = np.concatenate(all_features, axis=0)
    return all_features

In [None]:
def global_clustering_dbscan(model, dataloader, device, eps=0.5, min_samples=5, sample_fraction=1.0):
    """
    Применяет DBSCAN к dense признакам, извлечённым из всего датасета, и вычисляет:
      - silhouette score,
      - Davies-Bouldin index,
      - Calinski-Harabasz score.
    """
    features = extract_dense_features_global(model, dataloader, device, sample_fraction)
    print(f"Total extracted features: {features.shape[0]}")
    dbscan = DBSCAN(eps=eps, min_samples=min_samples)
    cluster_labels = dbscan.fit_predict(features)
    
    unique_clusters = set(cluster_labels)
    if -1 in unique_clusters:
        unique_clusters.remove(-1)
    
    if len(unique_clusters) < 2:
        print("Not enough clusters (>=2 required) in global clustering.")
        return None, None, None, cluster_labels
    
    sil = silhouette_score(features, cluster_labels)
    db_index = davies_bouldin_score(features, cluster_labels)
    ch_score = calinski_harabasz_score(features, cluster_labels)
    return sil, db_index, ch_score, cluster_labels

In [None]:
def extract_dense_features_per_image(model, image, device):
    """
    Извлекает dense признаки для одного изображения, возвращает массив размерности [H*W, C].
    """
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)  # [1, 3, H, W]
        feat_map = model.backbone(image)[-1]     # [1, C, H, W]
        feat_map = F.normalize(feat_map, p=2, dim=1)
        _, C, H, W = feat_map.shape
        features = feat_map.view(1, C, -1).permute(0, 2, 1).contiguous().view(-1, C)
    return features.cpu().numpy()

In [None]:
def extract_dense_features_per_image(model, image, device):
    """
    Извлекает dense признаки для одного изображения, возвращает массив размерности [H*W, C].
    """
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)  # [1, 3, H, W]
        feat_map = model.backbone(image)[-1]     # [1, C, H, W]
        feat_map = F.normalize(feat_map, p=2, dim=1)
        _, C, H, W = feat_map.shape
        features = feat_map.view(1, C, -1).permute(0, 2, 1).contiguous().view(-1, C)
    return features.cpu().numpy()

In [None]:
def per_image_clustering_dbscan(model, dataloader, device, eps=0.5, min_samples=5):
    """
    Для каждого изображения выполняет DBSCAN кластеризацию dense признаков и вычисляет метрики:
      silhouette score, Davies-Bouldin и Calinski-Harabasz.
    Усредняет метрики по всем изображениям, где удалось получить >=2 кластера.
    """
    model.eval()
    sil_scores, db_scores, ch_scores = [], [], []
    image_count = 0
    for images, _ in dataloader:
        for i in range(images.size(0)):
            features = extract_dense_features_per_image(model, images[i], device)
            dbscan = DBSCAN(eps=eps, min_samples=min_samples)
            labels = dbscan.fit_predict(features)
            unique_clusters = set(labels)
            if -1 in unique_clusters:
                unique_clusters.remove(-1)
            if len(unique_clusters) < 2:
                continue
            try:
                sil = silhouette_score(features, labels)
                db = davies_bouldin_score(features, labels)
                ch = calinski_harabasz_score(features, labels)
                sil_scores.append(sil)
                db_scores.append(db)
                ch_scores.append(ch)
                image_count += 1
            except Exception as e:
                print(f"Error on image {image_count}: {e}")
                continue
    if len(sil_scores) == 0:
        print("No image produced enough clusters for per-image evaluation.")
        return None, None, None
    avg_sil = np.mean(sil_scores)
    avg_db = np.mean(db_scores)
    avg_ch = np.mean(ch_scores)
    print(f"Processed {image_count} images. Avg Silhouette: {avg_sil:.4f}, Avg Davies-Bouldin: {avg_db:.4f}, Avg Calinski-Harabasz: {avg_ch:.4f}")
    return avg_sil, avg_db, avg_ch

In [None]:
def split_dataset(source_dir, train_dir, val_dir, train_ratio=0.8, extensions=('.png', '.jpg', '.jpeg', '.bmp')):
    """
    Разбивает файлы из source_dir на тренировочный и валидационный наборы и копирует их в train_dir и val_dir.
    
    Args:
        source_dir (str): путь к исходной папке с изображениями.
        train_dir (str): путь к папке, куда будут скопированы тренировочные изображения.
        val_dir (str): путь к папке, куда будут скопированы валидационные изображения.
        train_ratio (float): доля изображений, которые пойдут в тренировочный набор.
        extensions (tuple): допустимые расширения файлов.
    """
    # Создаем папки, если их нет
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(val_dir, exist_ok=True)
    
    # Получаем список файлов с нужными расширениями
    all_files = [f for f in os.listdir(source_dir) if f.lower().endswith(extensions)]
    print(f"Найдено {len(all_files)} изображений в папке {source_dir}.")
    
    # Перемешиваем список случайным образом
    random.shuffle(all_files)
    
    # Определяем индекс для разбиения
    split_index = int(len(all_files) * train_ratio)
    train_files = all_files[:split_index]
    val_files = all_files[split_index:]
    
    # Копируем файлы в соответствующие папки
    for filename in train_files:
        shutil.copy(os.path.join(source_dir, filename), os.path.join(train_dir, filename))
    
    for filename in val_files:
        shutil.copy(os.path.join(source_dir, filename), os.path.join(val_dir, filename))
    
    print(f"Тренировочный набор: {len(train_files)} изображений.")
    print(f"Валидационный набор: {len(val_files)} изображений.")

In [None]:
class ImageFolderDataset(Dataset):
    """
    Пользовательский датасет для папки с изображениями.
    Все файлы с расширениями .png, .jpg, .jpeg, .bmp будут загружены.
    Так как данные не размечены, возвращается фиктивная метка (0).
    """
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): путь к папке с изображениями.
            transform (callable, optional): Трансформации, которые применяются к изображению.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = [
            os.path.join(root_dir, file)
            for file in os.listdir(root_dir)
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))
        ]
        self.image_files = sorted(self.image_files)  # Опционально сортируем файлы

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        # Возвращаем изображение и фиктивную метку
        return image, 0

In [None]:
base_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                      std=[0.229, 0.224, 0.225])
])

In [None]:
data_folder = 'dataset/turbine'
train_folder = 'dataset/turbine_train'
val_folder = 'dataset/turbine_val'

# split_dataset(source_dir=data_folder, train_dir=train_folder, val_dir=val_folder)

In [None]:
train_dataset = ImageFolderDataset(root_dir=train_folder, transform=base_transform)
val_dataset   = ImageFolderDataset(root_dir=val_folder, transform=base_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

In [None]:
from utils.compute_mean_std import compute_mean_std

In [None]:
compute_mean_std(train_dataset, 32, 0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
task = Task.init(project_name="SSL_Detection", task_name="PixPro Training ")

# imagenet_model = timm.create_model('resnet18', pretrained=True, features_only=True)
model = PixPro(backbone_name='resnet18').to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 100
val_interval = 5
save_interval = 10
eps = 0.5
min_samples = 5
sample_fraction = 0.2

for epoch in range(1, num_epochs + 1):
    avg_loss = train_pixpro(model, train_loader, optimizer, device, epoch, augment_fn=batch_augmentations)
    task.get_logger().report_scalar("Training", "Loss", iteration=epoch, value=avg_loss)
    
    # Проводим валидацию каждые val_interval эпох
    if epoch % val_interval == 0:
        print(f"--- Validation at Epoch {epoch} ---")
        # Глобальная кластеризация DBSCAN
        global_results = global_clustering_dbscan(model, val_loader, device, eps, min_samples, sample_fraction)
        if global_results[0] is not None:
            sil, db_index, ch_score, _ = global_results
            print(f"Global DBSCAN -> Silhouette: {sil:.4f}, Davies-Bouldin: {db_index:.4f}, Calinski-Harabasz: {ch_score:.4f}")
            task.get_logger().report_scalar("Clustering_Sil", "Global_Silhouette", iteration=epoch, value=sil)
            task.get_logger().report_scalar("Clustering_DB", "Global_Davies_Bouldin", iteration=epoch, value=db_index)
            task.get_logger().report_scalar("Clustering_CH", "Global_Calinski_Harabasz", iteration=epoch, value=ch_score)
        else:
            print("Global clustering did not produce enough clusters.")
        
        # Кластеризация по отдельности для каждого изображения
        per_img_results = per_image_clustering_dbscan(model, val_loader, device, eps, min_samples)
        if per_img_results[0] is not None:
            avg_sil, avg_db, avg_ch = per_img_results
            print(f"Per-image DBSCAN -> Avg Silhouette: {avg_sil:.4f}, Avg Davies-Bouldin: {avg_db:.4f}, Avg Calinski-Harabasz: {avg_ch:.4f}")
            task.get_logger().report_scalar("Clustering_Sil", "PerImage_Silhouette", iteration=epoch, value=avg_sil)
            task.get_logger().report_scalar("Clustering_DB", "PerImage_Davies_Bouldin", iteration=epoch, value=avg_db)
            task.get_logger().report_scalar("Clustering_CH", "PerImage_Calinski_Harabasz", iteration=epoch, value=avg_ch)
        else:
            print("Per-image clustering did not produce metrics for enough images.")
    

    if epoch % save_interval == 0:
        path_to_ckpt = f'checkpoint_epoch_{epoch}.pth'
        torch.save(model.state_dict(), path_to_ckpt)
        task.upload_artifact(name="model_checkpoint", artifact_object=path_to_ckpt)