In [13]:
import torch
import torch.nn as nn
import torchvision.models
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F

import albumentations as A
from albumentations.pytorch import ToTensorV2
import wandb

from tqdm import tqdm
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import numpy as np

import os
from time import time

### Get the data

In [14]:
import gdown
url = 'https://drive.google.com/uc?id=10f1H2T-5W-BiqabHHtlZ4ASs19TZmg8R'
output = 'data.zip'
gdown.download(url, output, quiet=False)
!unzip data.zip

Downloading...
From (original): https://drive.google.com/uc?id=10f1H2T-5W-BiqabHHtlZ4ASs19TZmg8R
From (redirected): https://drive.google.com/uc?id=10f1H2T-5W-BiqabHHtlZ4ASs19TZmg8R&confirm=t&uuid=f3e62ebf-e2c7-426e-a89b-6ff761a73994
To: /content/data.zip
100%|██████████| 979M/979M [00:10<00:00, 89.1MB/s]


Archive:  data.zip
replace data/val/gt/193.Bewick_Wren/Bewick_Wren_0124_184771.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

### Utilities (0.5 point)

Complete dataset to load prepared images and masks. Don't forget to use augmentations.

Some of the images are 1 channels, so use `gray2rgb`.

In [20]:
# config.py
from dataclasses import dataclass
from typing import Tuple, Optional


@dataclass
class TrainingConfig:
    # Training parameters
    epochs: int = 15
    batch_size: int = 16
    learning_rate: float = 1e-4
    dropout_rate: float = 0.1
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Dataset parameters
    input_size: Tuple[int, int] = (256, 256)
    num_classes: int = 1

    # Optimizer parameters
    weight_decay: float = 0.01
    scheduler_patience: int = 2
    scheduler_factor: float = 0.1

    # WandB parameters
    project_name: str = "bird-segmentation"
    run_name: Optional[str] = None

    # Data augmentation parameters
    normalize_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
    normalize_std: Tuple[float, float, float] = (0.229, 0.224, 0.225)

    pred_threshold: float = 0.3

### Dataset

In [21]:
# dataset.py
import cv2
import os
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from typing import Tuple, Optional
import torch


class BirdsDataset(Dataset):
    """Dataset class for bird segmentation"""

    def __init__(
        self,
        folder: str,
        config: TrainingConfig,
        transform: Optional[A.Compose] = None,
        is_training: bool = True
    ) -> None:
        self.image_paths, self.mask_paths = self._get_paths(folder)
        self.transform = transform or (
            self._get_train_transforms(config) if is_training
            else self._get_default_transforms(config)
        )

    def _get_paths(self, folder: str) -> Tuple[list, list]:
        """Get paths for images and masks"""
        images_folder = os.path.join(folder, "images")
        gt_folder = os.path.join(folder, "gt")

        image_paths = []
        mask_paths = []

        for class_name in os.listdir(images_folder):
            class_folder = os.path.join(images_folder, class_name)
            if os.path.isdir(class_folder):
                for fname in os.listdir(class_folder):
                    image_paths.append(os.path.join(class_folder, fname))
                    mask_paths.append(
                        os.path.join(gt_folder, class_name, fname[:-3] + "png")
                    )

        return image_paths, mask_paths

    @staticmethod
    def _get_train_transforms(config: TrainingConfig) -> A.Compose:
        """Get enhanced training augmentation pipeline"""
        return A.Compose([
            # Spatial augmentations
            A.RandomResizedCrop(
                *config.input_size,
                scale=(0.8, 1.0),
                ratio=(0.9, 1.1),
                p=1.0
            ),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(
                shift_limit=0.2,
                scale_limit=0.2,
                rotate_limit=30,
                border_mode=cv2.BORDER_CONSTANT,
                value=0,
                p=0.5
            ),

            # Color augmentations
            A.OneOf([
                A.RandomBrightnessContrast(
                    brightness_limit=0.2,
                    contrast_limit=0.2,
                    p=1
                ),
                A.RandomGamma(gamma_limit=(80, 120), p=1),
                A.HueSaturationValue(
                    hue_shift_limit=20,
                    sat_shift_limit=30,
                    val_shift_limit=20,
                    p=1
                )
            ], p=0.3),

            # Noise augmentations
            A.OneOf([
                A.GaussNoise(var_limit=(10.0, 50.0), p=1),
                A.GaussianBlur(blur_limit=(3, 7), p=1),
                A.MotionBlur(blur_limit=7, p=1)
            ], p=0.2),

            # Dropout for regularization
            A.CoarseDropout(
                max_holes=8,
                max_height=32,
                max_width=32,
                min_holes=5,
                min_height=8,
                min_width=8,
                fill_value=0,
                p=0.2
            ),

            # Normalization and conversion to tensor
            A.Normalize(
                mean=config.normalize_mean,
                std=config.normalize_std
            ),
            ToTensorV2(),
        ])

    @staticmethod
    def _get_default_transforms(config: TrainingConfig) -> A.Compose:
        """Get default transforms for validation"""
        return A.Compose([
            A.Resize(*config.input_size),
            A.Normalize(
                mean=config.normalize_mean,
                std=config.normalize_std
            ),
            ToTensorV2(),
        ])

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a sample from the dataset"""
        # Load images
        img = cv2.imread(self.image_paths[index])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_paths[index], cv2.IMREAD_GRAYSCALE)

        # Apply transforms
        transformed = self.transform(image=img, mask=mask)

        return (
            transformed["image"],
            transformed["mask"].float().unsqueeze(0) / 255.0,
        )

    def __len__(self) -> int:
        return len(self.image_paths)

### Architecture (1 point)
Your task for today is to build your own Unet to solve the segmentation problem.

As an encoder, you can use pre-trained on IMAGENET models(or parts) from torchvision. The decoder must be trained from scratch.
It is forbidden to use data not from the `data` folder.

I advise you to experiment with the number of blocks so as not to overfit on the training sample and get good quality on validation.

In [22]:
# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class AttentionGate(nn.Module):
    """Attention Gate for focusing on relevant features"""

    def __init__(self, F_g: int, F_l: int, F_int: int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class ResidualBlock(nn.Module):
    """Residual block for enhanced feature learning"""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm2d(out_channels)
            )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = self.shortcut(x)
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        x = self.relu(x)
        return x

class DecoderBlock(nn.Module):
    """Enhanced decoder block with attention and residual connections"""

    def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
        super().__init__()
        self.attention = AttentionGate(out_channels, skip_channels, out_channels)
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.residual1 = ResidualBlock(out_channels + skip_channels, out_channels)
        self.residual2 = ResidualBlock(out_channels, out_channels)

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        x = self.upconv(x)
        x = self._handle_size_mismatch(x, skip)
        skip = self.attention(x, skip)
        x = torch.cat([x, skip], dim=1)
        x = self.residual1(x)
        x = self.residual2(x)
        return x

    @staticmethod
    def _handle_size_mismatch(x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        if x.size() != skip.size():
            diff_h = skip.size()[2] - x.size()[2]
            diff_w = skip.size()[3] - x.size()[3]
            x = F.pad(
                x,
                [diff_w // 2, diff_w - diff_w // 2,
                 diff_h // 2, diff_h - diff_h // 2],
            )
        return x

class UNet(nn.Module):
    """U-Net architecture with ResNet50 encoder"""

    def __init__(self, config: TrainingConfig):
        super().__init__()
        self.config = config

        # Initialize ResNet encoder
        resnet = torchvision.models.resnet50(
            weights=torchvision.models.ResNet50_Weights.DEFAULT
        )

        # Encoder
        self.encoder1 = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
        )  # 64 channels
        self.pool = resnet.maxpool
        self.encoder2 = resnet.layer1  # 256 channels
        self.encoder3 = resnet.layer2  # 512 channels
        self.encoder4 = resnet.layer3  # 1024 channels
        self.encoder5 = resnet.layer4  # 2048 channels

        # Decoder with adjusted channel sizes for ResNet50
        self.decoder4 = DecoderBlock(2048, 1024, 512)
        self.decoder3 = DecoderBlock(512, 512, 256)
        self.decoder2 = DecoderBlock(256, 256, 128)
        self.decoder1 = DecoderBlock(128, 64, 64)

        self.final_conv = nn.Conv2d(64, config.num_classes, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        enc1 = self.encoder1(x)           # 64 channels
        enc2 = self.encoder2(self.pool(enc1))  # 256 channels
        enc3 = self.encoder3(enc2)        # 512 channels
        enc4 = self.encoder4(enc3)        # 1024 channels
        enc5 = self.encoder5(enc4)        # 2048 channels

        # Decoder
        dec4 = self.decoder4(enc5, enc4)
        dec3 = self.decoder3(dec4, enc3)
        dec2 = self.decoder2(dec3, enc2)
        dec1 = self.decoder1(dec2, enc1)

        # Final output
        out = self.final_conv(dec1)
        return F.interpolate(
            out,
            size=x.shape[2:],
            mode="bilinear",
            align_corners=False
        )

### Train script (0.5 point)

Complete the train and predict scripts.

In [23]:
def calculate_iou(outputs: torch.Tensor, masks: torch.Tensor, threshold: float) -> float:
    """Calculate IoU score"""
    pred = (torch.sigmoid(outputs) > threshold)
    intersection = (pred & masks.bool()).float().sum((1, 2, 3))
    union = (pred | masks.bool()).float().sum((1, 2, 3))
    iou = (intersection + 1e-6) / (union + 1e-6)
    return iou.mean().item()

In [24]:
import wandb
from torch.utils.data import DataLoader
from typing import Dict
import torch.nn as nn
from tqdm import tqdm
import torch


class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        return 1 - dice


class Trainer:
    """Training class for U-Net model"""

    def __init__(self, model: nn.Module, config: TrainingConfig):
        self.model = model
        self.config = config
        self.device = config.device

        self.bce_loss = nn.BCEWithLogitsLoss()
        self.dice_loss = DiceLoss()
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay,
        )
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode="min",
            patience=config.scheduler_patience,
            factor=config.scheduler_factor,
        )

        self._setup_wandb()

    def _setup_wandb(self):
        """Initialize WandB run"""
        self.run = wandb.init(
            project=self.config.project_name,
            name=self.config.run_name,
            config=self.config,
        )
        wandb.watch(self.model, self.bce_loss, log="all", log_freq=10)

    def _init_metrics(self) -> Dict[str, float]:
        """Initialize metrics dictionary"""
        return {
            "loss": 0.0,
            "iou": 0.0,
            "count": 0,
        }

    def _update_metrics(
        self, epoch_metrics: Dict[str, float], batch_metrics: Dict[str, float]
    ):
        """Update epoch metrics with batch metrics"""
        epoch_metrics["loss"] += batch_metrics["loss"]
        epoch_metrics["iou"] += batch_metrics["iou"]
        epoch_metrics["count"] += 1

    def _finalize_metrics(
        self, epoch_metrics: Dict[str, float], num_batches: int
    ) -> Dict[str, float]:
        """Calculate final metrics for the epoch"""
        return {
            "loss": epoch_metrics["loss"] / num_batches,
            "iou": epoch_metrics["iou"] / num_batches,
        }

    def _validate_epoch(self, val_loader: DataLoader) -> Dict[str, float]:
        """Validation loop for one epoch"""
        self.model.eval()
        val_metrics = self._init_metrics()

        with torch.inference_mode():
            for inputs, masks in tqdm(val_loader):
                inputs = inputs.to(self.device)
                masks = masks.to(self.device)

                outputs = self.model(inputs)
                loss = self.bce_loss(outputs, masks) + self.dice_loss(outputs, masks)
                iou = calculate_iou(outputs, masks, self.config.pred_threshold)

                self._update_metrics(
                    val_metrics, {"loss": loss.item(), "iou": iou}
                )

        return {
            "val_loss": val_metrics["loss"] / len(val_loader),
            "val_iou": val_metrics["iou"] / len(val_loader),
        }

    def _log_metrics(
        self,
        epoch: int,
        train_metrics: Dict[str, float],
        val_metrics: Dict[str, float],
    ):
        """Log metrics to WandB"""
        metrics = {
            "epoch": epoch,
            "train_loss": train_metrics["loss"],
            "train_iou": train_metrics["iou"],
            "val_loss": val_metrics["val_loss"],
            "val_iou": val_metrics["val_iou"],
            "learning_rate": self.optimizer.param_groups[0]["lr"],
        }
        wandb.log(metrics)
        print(f"Epoch {epoch}:", metrics)

    def train(
        self, train_loader: DataLoader, val_loader: DataLoader
    ) -> nn.Module:
        """Main training loop"""
        best_val_iou = 0

        for epoch in range(self.config.epochs):
            # Training phase
            train_metrics = self._train_epoch(train_loader)

            # Validation phase
            val_metrics = self._validate_epoch(val_loader)

            # Update scheduler
            self.scheduler.step(val_metrics["val_loss"])

            # Log metrics
            self._log_metrics(epoch, train_metrics, val_metrics)

            # Save best model
            if val_metrics["val_iou"] > best_val_iou:
                best_val_iou = val_metrics["val_iou"]
                self._save_model("best_model.pth")

        self.run.finish()
        return self.model

    def _train_epoch(self, train_loader: DataLoader) -> Dict[str, float]:
        """Training loop for one epoch"""
        self.model.train()
        epoch_metrics = self._init_metrics()

        for step, (inputs, masks) in enumerate(tqdm(train_loader)):
            batch_metrics = self._train_step(inputs, masks)
            self._update_metrics(epoch_metrics, batch_metrics)

            if step % 50 == 0:
                wandb.log(
                    {
                        "train_batch_loss": batch_metrics["loss"],
                        "train_batch_iou": batch_metrics["iou"],
                        "learning_rate": self.optimizer.param_groups[0]["lr"],
                    }
                )

        return self._finalize_metrics(epoch_metrics, len(train_loader))

    def _train_step(
        self, inputs: torch.Tensor, masks: torch.Tensor
    ) -> Dict[str, float]:
        """Single training step"""
        inputs = inputs.to(self.device)
        masks = masks.to(self.device)

        self.optimizer.zero_grad()
        outputs = self.model(inputs)
        loss = self.bce_loss(outputs, masks) + self.dice_loss(outputs, masks)
        loss.backward()
        self.optimizer.step()

        with torch.inference_mode():
            iou = calculate_iou(outputs, masks, self.config.pred_threshold)

        return {"loss": loss.item(), "iou": iou}

    @staticmethod


    def _save_model(self, filename: str):
        """Save model to WandB"""
        path = os.path.join(wandb.run.dir, filename)
        torch.save(self.model.state_dict(), path)
        artifact = wandb.Artifact('best_model', type='model')
        artifact.add_file(path)
        self.run.log_artifact(artifact)

In [None]:
# main.py
from torch.utils.data import DataLoader


def main():
    # Initialize configuration
    config = TrainingConfig()

    # Create datasets
    train_dataset = BirdsDataset("data/train", config, is_training=True)
    val_dataset = BirdsDataset("data/val", config, is_training=False)

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False
    )

    # Initialize model
    model = UNet(config).to(config.device)

    # Initialize trainer and train
    trainer = Trainer(model, config)
    trained_model = trainer.train(train_loader, val_loader)

    return trained_model


if __name__ == "__main__":
    model = main()


You can also experiment with models and write a small report about results. If the report will be meaningful, you will receive an extra point.

### Testing (8 points)
Your model will be tested on the new data, similar to validation, so use techniques to prevent overfitting the model.

* IoU > 0.85 — 8 points
* IoU > 0.80 — 7 points
* IoU > 0.75 — 6 points
* IoU > 0.70 — 5 points
* IoU > 0.60 — 4 points
* IoU > 0.50 — 3 points
* IoU > 0.40 — 2 points
* IoU > 0.30 — 1 points

In [26]:
def get_model(config):
    wandb.init(project="bird-segmentation", entity="team-aspisov")
    artifact = wandb.use_artifact('model:v1')
    artifact_dir = artifact.download()
    model_path = f"{artifact_dir}/best_model.pth"
    model = UNet(config)
    model.load_state_dict(torch.load(model_path))
    return model

In [27]:
config = TrainingConfig()

model = get_model(config).to('cuda')

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Downloading large artifact model:v1, 188.33MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:7.7
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 174MB/s]
  model.load_state_dict(torch.load(model_path))


In [60]:
def predict(model, img_path):
    transform = A.Compose([
        A.Resize(256, 256),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    transformed = transform(image=img)
    input_tensor = transformed['image'].unsqueeze(0)

    model.eval()
    with torch.inference_mode():
        return model(input_tensor.to(config.device))

In [61]:
ious, times = [], []
test_dir = 'data/val/'

for class_name in tqdm(sorted(os.listdir(os.path.join(test_dir, 'images')))):
    for img_name in sorted(os.listdir(os.path.join(test_dir, 'images', class_name))):

        t_start = time()
        pred = predict(model, os.path.join(test_dir, 'images', class_name, img_name))
        times.append(time() - t_start)

        gt_name = img_name.replace('jpg', 'png')
        gt = np.asarray(Image.open(os.path.join(test_dir, 'gt', class_name, gt_name)), dtype = np.uint8)
        if len(gt.shape) > 2:
            gt = gt[:, :, 0]
        transform = A.Compose([
            A.Resize(256, 256),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2()
        ])
        gt = transform(image=gt)['image'].unsqueeze(0).to(config.device)

        iou = calculate_iou(gt==255, pred, config.pred_threshold)
        ious.append(iou)

np.mean(ious), np.mean(times)

100%|██████████| 200/200 [00:30<00:00,  6.46it/s]


(1.0, 0.01816144233747024)

### Compression (1 point)

Try to speed up the model in any way without losing more than 1% in iou score.
For example [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt)

In [None]:
def get_fast_model():
    # YOUR CODE HERE
    return model

In [None]:
fast_model = get_fast_model().to('cuda')

In [None]:
ious, times = [], []
test_dir = 'data/val/'

for class_name in tqdm(sorted(os.listdir(os.path.join(test_dir, 'images')))):
    for img_name in sorted(os.listdir(os.path.join(test_dir, 'images', class_name))):

        t_start = time()
        pred = predict(fast_model, os.path.join(test_dir, 'images', class_name, img_name))
        times.append(time() - t_start)

        gt_name = img_name.replace('jpg', 'png')
        gt = np.asarray(Image.open(os.path.join(test_dir, 'gt', class_name, gt_name)), dtype = np.uint8)
        if len(gt.shape) > 2:
            gt = gt[:, :, 0]

        iou = get_iou(gt==255, pred>0.5)
        ious.append(iou)

np.mean(ious), np.mean(times)

**Bonus:** For the best iou score on test(without compression) in group you will get 1.5, 1, 0.5 extra points(for 1st, 2nd, 3rd places).