Установка и импорт библиотек

In [None]:
# %pip install torch torchvision torchaudio matplotlib segmentation-models-pytorch albumentations timm tqdm imagehash
import segmentation_models_pytorch.losses as smp_losses
import os
import cv2
import numpy as np
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import imagehash
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
import random
import timm
import torch.nn.functional as F


Инициализация путей

In [None]:

DATASET_NAME = "camvid-dataset"

X_TRAIN_DIR = f"{DATASET_NAME}/Train"
Y_TRAIN_DIR = f"{DATASET_NAME}/Trainannot"

X_VALID_DIR = f"{DATASET_NAME}/Validation"
Y_VALID_DIR = f"{DATASET_NAME}/Validationannot"

X_TEST_DIR = f"{DATASET_NAME}/Test"



Гиперпараметры


In [None]:
NUM_EPOCHS =50
BATCH_SIZE = 8
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0
ENCODER_NAME = "50 эпох 0 веса  0.001 самое новое"
ENCODER_WEIGHTS = "imagenet"
NUM_CLASSES = 1
IMAGE_SIZE = 512
TARGET_COLOR = (0, 255, 255) 
MODEL_SAVE_PATH = f"models/{ENCODER_NAME}_best_unet.pth"
GRAF_PATH = f"graf/{ENCODER_NAME}_graf.pth"
WHEEL_COLOR = (0, 255, 255)


Генератор синтетических кругов

In [None]:
def generate_circle_image(width=512, height=384):
    image = np.zeros((height, width, 3), dtype=np.uint8)
    mask = np.zeros((height, width, 3), dtype=np.uint8)
    num_circles = np.random.randint(1, 6)
    for _ in range(num_circles):
        center = (np.random.randint(0, width), np.random.randint(0, height))
        radius = np.random.randint(30, 150)
        color = tuple(np.random.randint(0, 256, size=3))
        cv2.circle(image, center, radius, color, -1)
        cv2.circle(mask, center, radius, (0, 255, 255), -1)
    noise = np.random.normal(0, 10, (height, width, 3)).astype(np.uint8)
    image = cv2.add(image, noise)
    return image, mask

def generate_dataset(num_train=1000, num_val=200):
    os.makedirs(X_TRAIN_DIR, exist_ok=True)
    os.makedirs(Y_TRAIN_DIR, exist_ok=True)
    os.makedirs(X_VALID_DIR, exist_ok=True)
    os.makedirs(Y_VALID_DIR, exist_ok=True)
    for i in range(num_train):
        img, mask = generate_circle_image()
        cv2.imwrite(f"{X_TRAIN_DIR}/train_{i:04d}.png", img)
        cv2.imwrite(f"{Y_TRAIN_DIR}/train_{i:04d}.png", mask)
    for i in range(num_val):
        img, mask = generate_circle_image()
        cv2.imwrite(f"{X_VALID_DIR}/val_{i:04d}.png", img)
        cv2.imwrite(f"{Y_VALID_DIR}/val_{i:04d}.png", mask)
    print("готово")

хаф Преобразование

In [None]:

def show_hough_filled(img_bgr, circles,
                      fill_color=(0, 0, 255),
                      alpha=0.35):
    overlay = img_bgr.copy()
    output  = img_bgr.copy()

    if circles is not None:
        for x, y, r in circles.astype(int):
            cv2.circle(overlay, (x, y), r, fill_color, -1)
            cv2.circle(output,  (x, y), r, fill_color,  2)
            cv2.circle(output,  (x, y), 2, (0, 255, 0), -1)

    output = cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0)
    return output


for fname in tqdm(sorted(os.listdir(X_TEST_DIR))):
    if not fname.lower().endswith(('.png', '.jpg', '.jpeg')):
        continue

    path = os.path.join(X_TEST_DIR, fname)
    img_bgr = cv2.imread(path)

    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    gray = cv2.medianBlur(gray, 5)

    circles = cv2.HoughCircles(
        gray, cv2.HOUGH_GRADIENT,
        dp=1.2,
        minDist=30,
        param1=400,
        param2=40,
        minRadius=15, maxRadius=200)

    vis = show_hough_filled(
        img_bgr,
        circles[0] if circles is not None else None,
        fill_color=(0, 0, 255),
        alpha=0.35)


    plt.figure(figsize=(6, 4))
    plt.imshow(cv2.cvtColor(vis, cv2.COLOR_BGR2RGB))
    plt.axis("off")
    plt.show()


даталоудер

In [None]:
class WheelDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None, target_color=(0, 255, 255)):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform
        self.target_color = np.array(target_color)

        self.images_fps = sorted([
            os.path.join(images_dir, fname)
            for fname in os.listdir(images_dir)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])

        self.masks_fps = [
            os.path.join(masks_dir, os.path.basename(img).rsplit('.', 1)[0] + '.png')
            for img in self.images_fps
        ]

    def __len__(self):
        return len(self.images_fps)

    def __getitem__(self, idx):
        image = cv2.imread(self.images_fps[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask_path = self.masks_fps[idx]
        if os.path.exists(mask_path):
            mask = cv2.imread(mask_path)
            binary_mask = np.all(mask == self.target_color, axis=-1).astype('float32')
        else:
            binary_mask = np.zeros(image.shape[:2], dtype='float32')


        if self.transform:
            augmented = self.transform(image=image, mask=binary_mask)
            image = augmented['image']
            binary_mask = augmented['mask'].unsqueeze(0)

        return image, binary_mask


аугументация

In [None]:
train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.GaussNoise(p=0.3),
    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])


val_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])


Загружаем батчи

In [None]:
train_loader = DataLoader(
    WheelDataset(X_TRAIN_DIR, Y_TRAIN_DIR, train_transform, WHEEL_COLOR),
    batch_size=BATCH_SIZE, shuffle=True, num_workers=0
)
val_loader = DataLoader(
    WheelDataset(X_VALID_DIR, Y_VALID_DIR, val_transform, WHEEL_COLOR),
    batch_size=BATCH_SIZE, shuffle=False, num_workers=0
)


dice метрика

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super().__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = torch.sigmoid(preds)
        intersection = (preds * targets).sum()
        dice = (2.0 * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)
        return 1.0 - dice

def dice_score(preds, targets, threshold=0.5):
    preds = (torch.sigmoid(preds) > threshold)
    targets = (targets > threshold)
    intersection = (preds & targets).float().sum()
    return (2.0 * intersection) / (preds.sum() + targets.sum() + 1e-6)

Сама архетектура и энкодер

In [None]:

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)

class AttentionGate(nn.Module):
    def __init__(self, g_channels, x_channels, inter_channels):
        super().__init__()
        self.W_g = nn.Conv2d(g_channels, inter_channels, kernel_size=1)
        self.W_x = nn.Conv2d(x_channels, inter_channels, kernel_size=1)
        self.psi = nn.Conv2d(inter_channels, 1, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.sigmoid(self.psi(psi))
        return x * psi

class UNetEfficientNetB4(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        self.encoder = timm.create_model("efficientnet_b4", pretrained=True, features_only=True)
        enc_channels = self.encoder.feature_info.channels()
        
        self.center = ConvBlock(enc_channels[-1], 512)
        
        self.attn4 = AttentionGate(256, enc_channels[-2], enc_channels[-2] // 2)
        self.attn3 = AttentionGate(128, enc_channels[-3], enc_channels[-3] // 2)
        self.attn2 = AttentionGate(64, enc_channels[-4], enc_channels[-4] // 2)
        self.attn1 = AttentionGate(32, enc_channels[-5], enc_channels[-5] // 2)
        
        self.up4 = self._up_block(512, 256)
        self.up3 = self._up_block(256, 128)
        self.up2 = self._up_block(128, 64)
        self.up1 = self._up_block(64, 32)
        
        self.conv_block4 = ConvBlock(256 + enc_channels[-2], 256)
        self.conv_block3 = ConvBlock(128 + enc_channels[-3], 128)
        self.conv_block2 = ConvBlock(64 + enc_channels[-4], 64)
        self.conv_block1 = ConvBlock(32 + enc_channels[-5], 32)
        
        self.final_conv = nn.Conv2d(32, num_classes, kernel_size=1)

    def _up_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        input_size = x.shape[2:]
        features = self.encoder(x)
        
        x = self.center(features[-1])
        
        x = self.up4(x)
        attn4 = self.attn4(x, features[-2])
        x = torch.cat([x, attn4], dim=1)
        x = self.conv_block4(x)
        
        x = self.up3(x)
        attn3 = self.attn3(x, features[-3])
        x = torch.cat([x, attn3], dim=1)
        x = self.conv_block3(x)
        
        x = self.up2(x)
        attn2 = self.attn2(x, features[-4])
        x = torch.cat([x, attn2], dim=1)
        x = self.conv_block2(x)
        
        x = self.up1(x)
        attn1 = self.attn1(x, features[-5])
        x = torch.cat([x, attn1], dim=1)
        x = self.conv_block1(x)
        
        x = self.final_conv(x)
        x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=False)
        return x

Инициализация модели

In [None]:
model = UNetEfficientNetB4(num_classes=NUM_CLASSES).cuda()

обучение УБРАТЬ КОММАНТАРИЙ ДЛЯ ЗАПУСКА

In [None]:
'''
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)
dice_loss = DiceLoss()
best_val_dice = 0
best_val_loss = float('inf')
train_losses, val_losses = [], []
train_dices, val_dices = [], []

start_time = time.time()

for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start = time.time()
    print(f"\nЭпоха {epoch}/{NUM_EPOCHS}")
    model.train()
    total_loss, total_dice = 0, 0
    for images, masks in tqdm(train_loader):
        images, masks = images.cuda(), masks.cuda()
        preds = model(images)
        loss = dice_loss(preds, masks)
        dsc = dice_score(preds, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_dice += dsc.item()
    avg_train_loss = total_loss / len(train_loader)
    avg_train_dice = total_dice / len(train_loader)
    train_losses.append(avg_train_loss)
    train_dices.append(avg_train_dice)
    print(f"Тренеровочный Loss: {avg_train_loss:.4f}, Dice: {avg_train_dice:.4f}")

    model.eval()
    val_loss, val_dice = 0, 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.cuda(), masks.cuda()
            preds = model(images)
            loss = dice_loss(preds, masks)
            dsc = dice_score(preds, masks)
            val_loss += loss.item()
            val_dice += dsc.item()
    avg_val_loss = val_loss / len(val_loader)
    avg_val_dice = val_dice / len(val_loader)
    val_losses.append(avg_val_loss)
    val_dices.append(avg_val_dice)
    print(f"Проверочный   Loss: {avg_val_loss:.4f}, Dice: {avg_val_dice:.4f}")

    scheduler.step(avg_val_dice)
    epoch_time = time.time() - epoch_start
    eta = epoch_time * (NUM_EPOCHS - epoch)
    print(f"Время: {epoch_time:.1f}s | Осталость: {eta/60:.1f} min")

    torch.save({
    'train_losses': train_losses,
    'val_losses': val_losses,
    'train_dices': train_dices,
    'val_dices': val_dices
    }, GRAF_PATH)

    if avg_val_dice > best_val_dice:
        best_val_dice = avg_val_dice
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print("Эту сохранил")



total_time = time.time() - start_time
print(f"\nЗавершил {total_time/60:.2f} минут")
'''

Царь график

In [None]:


GRAF_PATH = r"graf\50 эпох 0 веса  0.001 самое новое_graf.pth"

data = torch.load(GRAF_PATH)

train_losses = data['train_losses']
val_losses = data['val_losses']
train_dices = data['train_dices']
val_dices = data['val_dices']
print(f"Последние значения:")
print(f"Train Loss: {train_losses[-1]}")
print(f"Val Loss:   {val_losses[-1]}")
print(f"Train Dice: {train_dices[-1]}")
print(f"Val Dice:   {val_dices[-1]}")


plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Обучение')
plt.plot(val_losses, label='Валидация')
plt.title('Функция Loss по эпохам')
plt.xlabel('Эпоха')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_dices, label='Обучение')
plt.plot(val_dices, label='Валидация')
plt.title('Dice по эпохам')
plt.xlabel('Эпоха')
plt.ylabel('Dice')
plt.legend()

plt.tight_layout()
plt.show()


Предиктим

In [None]:


model.load_state_dict(torch.load(MODEL_SAVE_PATH))
model.eval()

transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

def visualize_prediction(img_path):
    orig_img = np.array(Image.open(img_path).convert("RGB"))

    tensor = transform(image=orig_img)['image'].unsqueeze(0).cuda()

    with torch.no_grad():
        pred_mask = model(tensor)
        pred_mask = (torch.sigmoid(pred_mask).cpu().squeeze().numpy() > 0.6).astype(np.uint8)
        pred_mask_resized = cv2.resize(pred_mask, (orig_img.shape[1], orig_img.shape[0]))

    overlay_img = orig_img.copy()
    overlay_img[pred_mask_resized == 1] = [255, 0, 0]

    overlay_img = cv2.addWeighted(orig_img, 0.5, overlay_img, 0.5, 0)

    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(orig_img)
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(pred_mask_resized, cmap="gray")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(overlay_img)
    plt.axis('off')

    plt.tight_layout()
    plt.show()

for img_file in sorted(os.listdir(X_TEST_DIR))[:20]:
    if img_file.lower().endswith((".png", ".jpg", ".jpeg")):
        img_path = os.path.join(X_TEST_DIR, img_file)
        visualize_prediction(img_path)
