Выбор датасета.

Мой датасет представляет собой сегментацию медицинской маски. Еще недавно придумывали различные модели для сегментации и детектирования масок во время короновируса.

In [1]:
!curl -L -o qw.zip https://www.kaggle.com/api/v1/datasets/download/perke986/face-mask-segmentation-dataset
!unzip qw.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100 78.9M  100 78.9M    0     0  17.9M      0  0:00:04  0:00:04 --:--:-- 28.2M
Archive:  qw.zip
  inflating: Info.txt                
  inflating: images/coronavirus-4947340_1920.jpg  
  inflating: images/coronavirus-5064371_1920.jpg  
  inflating: images/mask-4898571_1920.jpg  
  inflating: images/mask-5136259_1920.jpg  
  inflating: images/mouth-guard-5060809_1920.jpg  
  inflating: images/mouth-guard-5068146_1920.jpg  
  inflating: images/nurse-4962034_1920.jpg  
  inflating: images/pexels-andrea-piacquadio-3881247.jpg  
  inflating: images/pexels-anna-shvets-3902881.jpg  
  inflating: images/pexels-anna-shvets-3943881.jpg  
  inflating: images/pexels-anna-shvets-3943882.jpg  
  inflating: images/pexels-anna-shvets-3943883.jpg  
  inflating: images/p

Загрука данных

In [12]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from segmentation_models_pytorch import Unet, Segformer
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import cv2

class FaceMaskDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.transform = transform

        image_files = os.listdir(images_dir)
        self.image_names = [f for f in image_files if os.path.isfile(os.path.join(masks_dir, os.path.splitext(f)[0] + ".png"))]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        base_name = os.path.splitext(img_name)[0]

        img_path = os.path.join(self.images_dir, img_name)
        mask_path = os.path.join(self.masks_dir, base_name + ".png")

        image = np.array(Image.open(img_path).convert("RGB"))
        mask = mask_rgb_to_class(Image.open(mask_path).convert("RGB"))

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask.long()

def mask_rgb_to_class(mask):
    mask_array = np.array(mask)
    class_mask = np.zeros(mask_array.shape[:2], dtype=np.uint8)

    for rgb, class_idx in MASK_COLORS.items():
        matches = np.all(mask_array == rgb, axis=-1)
        class_mask[matches] = class_idx
    return class_mask

Аугментация данных

In [7]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
EPOCHS = 5
NUM_CLASSES = 2
LEARNING_RATE = 1e-4
IMAGE_SIZE = (224, 224)

MASK_COLORS = {
    (255, 255, 255): 1,
    (253, 237, 237): 1,
    (252, 219, 219): 1,
}
train_transform = A.Compose([
    A.Resize(*IMAGE_SIZE),
    A.HorizontalFlip(),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(*IMAGE_SIZE),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])


Загрузка встроенных моделей

In [8]:
dataset = FaceMaskDataset("images", "masks", transform=train_transform)
val_size = int(0.2 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

unet_model = Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    classes=NUM_CLASSES,
    activation=None
).to(DEVICE)

segformer_model = Segformer(
    encoder_name="mit_b0",
    encoder_weights="imagenet",
    in_channels=3,
    classes=NUM_CLASSES,
    activation=None
).to(DEVICE)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Функция тренировки и оценки

In [9]:
def train_epoch(model, loader, loss_fn, optimizer):
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(loader, desc="Training"):
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)

        outputs = model(images)
        loss = loss_fn(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    return running_loss / len(loader)

def evaluate(model, loader):
    model.eval()
    ious = []
    dices = []
    with torch.no_grad():
        for images, masks in tqdm(loader, desc="Evaluation"):
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)

            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            intersection = (preds & masks).float().sum((1,2))
            union = (preds | masks).float().sum((1,2))
            iou = (intersection + 1e-6) / (union + 1e-6)
            ious.append(iou.cpu().numpy())

            dice = (2*intersection + 1e-6) / (preds.sum((1,2)) + masks.sum((1,2)) + 1e-6)
            dices.append(dice.cpu().numpy())

    return np.mean(np.concatenate(ious)), np.mean(np.concatenate(dices))


Тренировка моделей

In [10]:
optimizer = Adam(unet_model.parameters(), lr=LEARNING_RATE)
loss_fn = CrossEntropyLoss()

for epoch in range(EPOCHS):
    print(f"UNet Epoch {epoch+1}/{EPOCHS}")
    train_loss = train_epoch(unet_model, train_loader, loss_fn, optimizer)
    iou, dice = evaluate(unet_model, val_loader)
    print(f"Loss: {train_loss:.4f}, IoU: {iou:.4f}, Dice: {dice:.4f}")

optimizer = Adam(segformer_model.parameters(), lr=LEARNING_RATE)
for epoch in range(EPOCHS):
    print(f"SegFormer Epoch {epoch+1}/{EPOCHS}")
    train_loss = train_epoch(segformer_model, train_loader, loss_fn, optimizer)
    iou, dice = evaluate(segformer_model, val_loader)
    print(f"Loss: {train_loss:.4f}, IoU: {iou:.4f}, Dice: {dice:.4f}")

print("Evaluation:")
for model_name, model in [("UNet", unet_model), ("SegFormer", segformer_model)]:
    iou, dice = evaluate(model, val_loader)
    print(f"{model_name}: IoU={iou:.4f}, Dice={dice:.4f}")


UNet Epoch 1/5


Training: 100%|██████████| 45/45 [01:04<00:00,  1.44s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.28s/it]


Loss: 0.3494, IoU: 0.2047, Dice: 0.2917
UNet Epoch 2/5


Training: 100%|██████████| 45/45 [00:59<00:00,  1.33s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.24s/it]


Loss: 0.1684, IoU: 0.3868, Dice: 0.4875
UNet Epoch 3/5


Training: 100%|██████████| 45/45 [00:59<00:00,  1.33s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.25s/it]


Loss: 0.1039, IoU: 0.5139, Dice: 0.6261
UNet Epoch 4/5


Training: 100%|██████████| 45/45 [01:00<00:00,  1.33s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Loss: 0.0743, IoU: 0.5909, Dice: 0.6937
UNet Epoch 5/5


Training: 100%|██████████| 45/45 [00:59<00:00,  1.32s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Loss: 0.0582, IoU: 0.6172, Dice: 0.7191
SegFormer Epoch 1/5


Training: 100%|██████████| 45/45 [00:58<00:00,  1.30s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.24s/it]


Loss: 0.3072, IoU: 0.2545, Dice: 0.3304
SegFormer Epoch 2/5


Training: 100%|██████████| 45/45 [00:58<00:00,  1.31s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.24s/it]


Loss: 0.1329, IoU: 0.3487, Dice: 0.4297
SegFormer Epoch 3/5


Training: 100%|██████████| 45/45 [00:58<00:00,  1.31s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.27s/it]


Loss: 0.0782, IoU: 0.4772, Dice: 0.5790
SegFormer Epoch 4/5


Training: 100%|██████████| 45/45 [00:59<00:00,  1.32s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.26s/it]


Loss: 0.0552, IoU: 0.5555, Dice: 0.6553
SegFormer Epoch 5/5


Training: 100%|██████████| 45/45 [00:58<00:00,  1.31s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Loss: 0.0417, IoU: 0.5866, Dice: 0.6836
Evaluation:


Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.26s/it]


UNet: IoU=0.6172, Dice=0.7191


Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.25s/it]

SegFormer: IoU=0.5866, Dice=0.6836





Улучшение бейзлайна

In [13]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

IMG_SIZE = 256

train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.ElasticTransform(0.5),
    A.ElasticTransform(p=0.2),
    A.Normalize(),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(),
    ToTensorV2(),
])



optimizer = Adam(unet_model.parameters(), lr=LEARNING_RATE)
loss = CrossEntropyLoss()


for epoch in range(EPOCHS):
    print(f"UNet Epoch {epoch+1}/{EPOCHS}")
    train_loss = train_epoch(unet_model, train_loader, loss_fn, optimizer)
    iou, dice = evaluate(unet_model, val_loader)
    print(f"Loss: {train_loss:.4f}, IoU: {iou:.4f}, Dice: {dice:.4f}")

optimizer = Adam(segformer_model.parameters(), lr=LEARNING_RATE)
for epoch in range(EPOCHS):
    print(f"SegFormer Epoch {epoch+1}/{EPOCHS}")
    train_loss = train_epoch(segformer_model, train_loader, loss_fn, optimizer)
    iou, dice = evaluate(segformer_model, val_loader)
    print(f"Loss: {train_loss:.4f}, IoU: {iou:.4f}, Dice: {dice:.4f}")

print("Evaluation:")
for model_name, model in [("UNet", unet_model), ("SegFormer", segformer_model)]:
    iou, dice = evaluate(model, val_loader)
    print(f"{model_name}: IoU={iou:.4f}, Dice={dice:.4f}")


UNet Epoch 1/5


Training: 100%|██████████| 45/45 [01:00<00:00,  1.35s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.26s/it]


Loss: 0.0429, IoU: 0.6349, Dice: 0.7369
UNet Epoch 2/5


Training: 100%|██████████| 45/45 [01:05<00:00,  1.45s/it]
Evaluation: 100%|██████████| 11/11 [00:16<00:00,  1.52s/it]


Loss: 0.0342, IoU: 0.6917, Dice: 0.7831
UNet Epoch 3/5


Training: 100%|██████████| 45/45 [01:05<00:00,  1.46s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.26s/it]


Loss: 0.0235, IoU: 0.7480, Dice: 0.8229
UNet Epoch 4/5


Training: 100%|██████████| 45/45 [01:00<00:00,  1.35s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.32s/it]


Loss: 0.0171, IoU: 0.7770, Dice: 0.8437
UNet Epoch 5/5


Training: 100%|██████████| 45/45 [01:01<00:00,  1.36s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.27s/it]


Loss: 0.0136, IoU: 0.7746, Dice: 0.8432
SegFormer Epoch 1/5


Training: 100%|██████████| 45/45 [00:59<00:00,  1.32s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.26s/it]


Loss: 0.0342, IoU: 0.6011, Dice: 0.7016
SegFormer Epoch 2/5


Training: 100%|██████████| 45/45 [00:58<00:00,  1.30s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Loss: 0.0228, IoU: 0.6002, Dice: 0.6952
SegFormer Epoch 3/5


Training: 100%|██████████| 45/45 [00:58<00:00,  1.30s/it]
Evaluation: 100%|██████████| 11/11 [00:15<00:00,  1.41s/it]


Loss: 0.0177, IoU: 0.6771, Dice: 0.7749
SegFormer Epoch 4/5


Training: 100%|██████████| 45/45 [01:02<00:00,  1.39s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.32s/it]


Loss: 0.0140, IoU: 0.6536, Dice: 0.7525
SegFormer Epoch 5/5


Training: 100%|██████████| 45/45 [00:58<00:00,  1.30s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.28s/it]


Loss: 0.0126, IoU: 0.6996, Dice: 0.7916
Evaluation:


Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


UNet: IoU=0.7746, Dice=0.8432


Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.29s/it]

SegFormer: IoU=0.6996, Dice=0.7916





Собственная имплементация

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class MyUNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=2, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], num_classes, kernel_size=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])
            x = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](x)

        return self.final_conv(x)


In [30]:
class MySegformer(nn.Module):
    def __init__(self, in_channels=3, num_classes=2):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, num_classes, kernel_size=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [31]:
unet_model = MyUNet(in_channels=3, num_classes=NUM_CLASSES).to(DEVICE)
segformer_model = MySegformer(in_channels=3, num_classes=NUM_CLASSES).to(DEVICE)
optimizer = Adam(segformer_model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    print(f"UNet Epoch {epoch+1}/{EPOCHS}")
    train_loss = train_epoch(unet_model, train_loader, loss_fn, optimizer)
    iou, dice = evaluate(unet_model, val_loader)
    print(f"Loss: {train_loss:.4f}, IoU: {iou:.4f}, Dice: {dice:.4f}")

for epoch in range(EPOCHS):
    print(f"SegFormer Epoch {epoch+1}/{EPOCHS}")
    train_loss = train_epoch(segformer_model, train_loader, loss_fn, optimizer)
    iou, dice = evaluate(segformer_model, val_loader)
    print(f"Loss: {train_loss:.4f}, IoU: {iou:.4f}, Dice: {dice:.4f}")

print("Evaluation:")
for model_name, model in [("UNet", unet_model), ("SegFormer", segformer_model)]:
    iou, dice = evaluate(model, val_loader)
    print(f"{model_name}: IoU={iou:.4f}, Dice={dice:.4f}")

UNet Epoch 1/5


Training: 100%|██████████| 45/45 [01:05<00:00,  1.45s/it]
Evaluation: 100%|██████████| 11/11 [00:15<00:00,  1.37s/it]


Loss: 0.5837, IoU: 0.0370, Dice: 0.0679
UNet Epoch 2/5


Training: 100%|██████████| 45/45 [01:03<00:00,  1.41s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Loss: 0.5848, IoU: 0.0331, Dice: 0.0615
UNet Epoch 3/5


Training: 100%|██████████| 45/45 [01:03<00:00,  1.42s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Loss: 0.5838, IoU: 0.0329, Dice: 0.0613
UNet Epoch 4/5


Training: 100%|██████████| 45/45 [01:03<00:00,  1.41s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Loss: 0.5841, IoU: 0.0339, Dice: 0.0629
UNet Epoch 5/5


Training: 100%|██████████| 45/45 [01:04<00:00,  1.44s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.34s/it]


Loss: 0.5834, IoU: 0.0335, Dice: 0.0622
SegFormer Epoch 1/5


Training: 100%|██████████| 45/45 [00:56<00:00,  1.25s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.24s/it]


Loss: 0.6565, IoU: 0.0017, Dice: 0.0033
SegFormer Epoch 2/5


Training: 100%|██████████| 45/45 [00:57<00:00,  1.27s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.23s/it]


Loss: 0.5899, IoU: 0.0000, Dice: 0.0000
SegFormer Epoch 3/5


Training: 100%|██████████| 45/45 [00:55<00:00,  1.24s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.25s/it]


Loss: 0.4708, IoU: 0.0000, Dice: 0.0000
SegFormer Epoch 4/5


Training: 100%|██████████| 45/45 [00:56<00:00,  1.25s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.22s/it]


Loss: 0.3341, IoU: 0.0000, Dice: 0.0000
SegFormer Epoch 5/5


Training: 100%|██████████| 45/45 [00:56<00:00,  1.25s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.29s/it]


Loss: 0.2320, IoU: 0.0000, Dice: 0.0000
Evaluation:


Evaluation: 100%|██████████| 11/11 [00:15<00:00,  1.41s/it]


UNet: IoU=0.0335, Dice=0.0622


Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.30s/it]

SegFormer: IoU=0.0000, Dice=0.0000





Собственная имплементаация с улучшенным бейзлайном

In [32]:
class MyUNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=2, features=[64, 128, 256, 512]):
        super(MyUNet, self).__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.dropout = nn.Dropout(0.3)

        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))

        self.final_conv = nn.Conv2d(features[0], num_classes, kernel_size=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        x = self.dropout(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = F.interpolate(x, size=skip_connection.shape[2:])

            x = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](x)

        return self.final_conv(x)


class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, embed_dim=64, patch_size=4):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        return self.proj(x)

class MLPDecoder(nn.Module):
    def __init__(self, embed_dim, num_classes):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, num_classes, kernel_size=1)
        )

    def forward(self, x):
        return self.decoder(x)

class MySegformer(nn.Module):
    def __init__(self, in_channels=3, num_classes=2, embed_dim=64):
        super().__init__()
        self.patch_embed = PatchEmbedding(in_channels, embed_dim=embed_dim, patch_size=4)

        self.encoder = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(embed_dim),
            nn.ReLU(inplace=True),
        )

        self.decoder_head = MLPDecoder(embed_dim, num_classes)
        self.upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.encoder(x)
        x = self.decoder_head(x)
        x = self.upsample(x)
        return x

unet_model = MyUNet(in_channels=3, num_classes=NUM_CLASSES).to(DEVICE)
segformer_model = MySegformer(in_channels=3, num_classes=NUM_CLASSES).to(DEVICE)
optimizer = Adam(segformer_model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    print(f"UNet Epoch {epoch+1}/{EPOCHS}")
    train_loss = train_epoch(unet_model, train_loader, loss_fn, optimizer)
    iou, dice = evaluate(unet_model, val_loader)
    print(f"Loss: {train_loss:.4f}, IoU: {iou:.4f}, Dice: {dice:.4f}")

for epoch in range(EPOCHS):
    print(f"SegFormer Epoch {epoch+1}/{EPOCHS}")
    train_loss = train_epoch(segformer_model, train_loader, loss_fn, optimizer)
    iou, dice = evaluate(segformer_model, val_loader)
    print(f"Loss: {train_loss:.4f}, IoU: {iou:.4f}, Dice: {dice:.4f}")

print("Evaluation:")
for model_name, model in [("UNet", unet_model), ("SegFormer", segformer_model)]:
    iou, dice = evaluate(model, val_loader)
    print(f"{model_name}: IoU={iou:.4f}, Dice={dice:.4f}")

UNet Epoch 1/5


Training: 100%|██████████| 45/45 [01:11<00:00,  1.59s/it]
Evaluation: 100%|██████████| 11/11 [00:15<00:00,  1.44s/it]


Loss: 0.7223, IoU: 0.0152, Dice: 0.0295
UNet Epoch 2/5


Training: 100%|██████████| 45/45 [01:04<00:00,  1.44s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.33s/it]


Loss: 0.7231, IoU: 0.0156, Dice: 0.0302
UNet Epoch 3/5


Training: 100%|██████████| 45/45 [01:09<00:00,  1.55s/it]
Evaluation: 100%|██████████| 11/11 [00:15<00:00,  1.41s/it]


Loss: 0.7235, IoU: 0.0149, Dice: 0.0289
UNet Epoch 4/5


Training: 100%|██████████| 45/45 [01:05<00:00,  1.45s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.30s/it]


Loss: 0.7220, IoU: 0.0155, Dice: 0.0301
UNet Epoch 5/5


Training: 100%|██████████| 45/45 [01:05<00:00,  1.46s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.31s/it]


Loss: 0.7220, IoU: 0.0154, Dice: 0.0298
SegFormer Epoch 1/5


Training: 100%|██████████| 45/45 [00:55<00:00,  1.23s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.22s/it]


Loss: 0.3076, IoU: 0.0000, Dice: 0.0000
SegFormer Epoch 2/5


Training: 100%|██████████| 45/45 [00:56<00:00,  1.26s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.22s/it]


Loss: 0.1061, IoU: 0.0000, Dice: 0.0000
SegFormer Epoch 3/5


Training: 100%|██████████| 45/45 [00:55<00:00,  1.23s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.23s/it]


Loss: 0.0982, IoU: 0.0000, Dice: 0.0000
SegFormer Epoch 4/5


Training: 100%|██████████| 45/45 [00:55<00:00,  1.24s/it]
Evaluation: 100%|██████████| 11/11 [00:13<00:00,  1.22s/it]


Loss: 0.0933, IoU: 0.0000, Dice: 0.0000
SegFormer Epoch 5/5


Training: 100%|██████████| 45/45 [00:55<00:00,  1.24s/it]
Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.34s/it]


Loss: 0.0881, IoU: 0.0000, Dice: 0.0000
Evaluation:


Evaluation: 100%|██████████| 11/11 [00:16<00:00,  1.46s/it]


UNet: IoU=0.0154, Dice=0.0298


Evaluation: 100%|██████████| 11/11 [00:14<00:00,  1.31s/it]

SegFormer: IoU=0.0000, Dice=0.0000



