In [1]:
import torch
import torch.nn as nn
import torchvision.models
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import numpy as np

import os
from time import time

  check_for_updates()


### Get the data

In [2]:
import gdown
url = 'https://drive.google.com/uc?id=10f1H2T-5W-BiqabHHtlZ4ASs19TZmg8R'
output = 'data.zip'
gdown.download(url, output, quiet=False)
!unzip data.zip

Downloading...
From (original): https://drive.google.com/uc?id=10f1H2T-5W-BiqabHHtlZ4ASs19TZmg8R
From (redirected): https://drive.google.com/uc?id=10f1H2T-5W-BiqabHHtlZ4ASs19TZmg8R&confirm=t&uuid=558ce7b5-3afa-4b45-840b-9e78a2cef627
To: /content/data.zip
100%|██████████| 979M/979M [00:09<00:00, 108MB/s]


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0098_51410.jpg  
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0115_51891.jpg  
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0106_52729.jpg  
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0056_51523.jpg  
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0113_51525.jpg  
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0009_51301.jpg  
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0117_51363.jpg  
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0104_52614.jpg  
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0108_51108.jpg  
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0029_52613.jpg  
  inflating: data/train/images/064.Ring_billed_Gull/Ring_Billed_Gull_0119_5

### Utilities (0.5 point)

Complete dataset to load prepared images and masks. Don't forget to use augmentations.

Some of the images are 1 channels, so use `gray2rgb`.

#### config

In [150]:
# config/config.py
from dataclasses import dataclass
from typing import Tuple, Optional
import torch


@dataclass
class TrainingConfig:
    """Configuration for training parameters"""
    # Training
    epochs: int = 100
    early_stopping_patience: int = 3
    batch_size: int = 8
    learning_rate: float = 1e-3
    dropout_rate: float = 0.1
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # Optimizer
    weight_decay: float = 1e-5
    scheduler_patience: int = 2
    scheduler_factor: float = 0.1

    # Logging
    project_name: str = "bird-segmentation"
    run_name: Optional[str] = None
    use_wandb: bool = True

    # Model
    input_size: Tuple[int, int] = (256, 256)
    num_classes: int = 1
    iou_threshold: float = 0.3
    normalize_mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
    normalize_std: Tuple[float, float, float] = (0.229, 0.224, 0.225)

### Dataset

In [151]:
# dataset.py
import cv2
import os
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from typing import Tuple, Optional
import torch


class BirdsDataset(Dataset):
    """Dataset class for bird segmentation"""

    def __init__(
        self,
        folder: str,
        config: TrainingConfig,
        transform: Optional[A.Compose] = None,
        is_training: bool = True
    ) -> None:
        self.image_paths, self.mask_paths = self._get_paths(folder)
        self.transform = transform or (
            self._get_train_transforms(config) if is_training
            else self._get_default_transforms(config)
        )

    def _get_paths(self, folder: str) -> Tuple[list, list]:
        """Get paths for images and masks"""
        images_folder = os.path.join(folder, "images")
        gt_folder = os.path.join(folder, "gt")

        image_paths = []
        mask_paths = []

        for class_name in os.listdir(images_folder):
            class_folder = os.path.join(images_folder, class_name)
            if os.path.isdir(class_folder):
                for fname in os.listdir(class_folder):
                    image_paths.append(os.path.join(class_folder, fname))
                    mask_paths.append(
                        os.path.join(gt_folder, class_name, fname[:-3] + "png")
                    )

        return image_paths, mask_paths

    @staticmethod
    def _get_train_transforms(config: TrainingConfig) -> A.Compose:
        return A.Compose([
            A.RandomResizedCrop(*config.input_size, scale=(0.8, 1.0)),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.5),
            A.Normalize(mean=config.normalize_mean, std=config.normalize_std),
            ToTensorV2(),
        ])

    @staticmethod
    def _get_default_transforms(config: TrainingConfig) -> A.Compose:
        """Get default transforms for validation"""
        return A.Compose([
            A.Resize(*config.input_size),
            A.Normalize(
                mean=config.normalize_mean,
                std=config.normalize_std
            ),
            ToTensorV2(),
        ])

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a sample from the dataset"""
        # Load images
        img = cv2.imread(self.image_paths[index])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.mask_paths[index], cv2.IMREAD_GRAYSCALE)

        # Apply transforms
        transformed = self.transform(image=img, mask=mask)

        return (
            transformed["image"],
            transformed["mask"].float().unsqueeze(0) / 255.0,
        )

    def __len__(self) -> int:
        return len(self.image_paths)


### Architecture (1 point)
Your task for today is to build your own Unet to solve the segmentation problem.

As an encoder, you can use pre-trained on IMAGENET models(or parts) from torchvision. The decoder must be trained from scratch.
It is forbidden to use data not from the `data` folder.

I advise you to experiment with the number of blocks so as not to overfit on the training sample and get good quality on validation.

In [152]:
# models/blocks.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


class AttentionGate(nn.Module):
    """Attention Gate for focusing on relevant features"""

    def __init__(self, F_g: int, F_l: int, F_int: int):
        super().__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi


class ResidualBlock(nn.Module):
    """Residual block for enhanced feature learning"""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        self.shortcut = self._make_shortcut(in_channels, out_channels)
        self.relu = nn.ReLU(inplace=True)

    def _make_shortcut(self, in_channels: int, out_channels: int) -> nn.Sequential:
        if in_channels != out_channels:
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1),
                nn.BatchNorm2d(out_channels)
            )
        return nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = self.shortcut(x)
        x = self.conv_block(x)
        x += residual
        return self.relu(x)


class DecoderBlock(nn.Module):
    """Decoder block with attention and residual connections"""

    def __init__(self, in_channels: int, skip_channels: int, out_channels: int):
        super().__init__()
        self.attention = AttentionGate(out_channels, skip_channels, out_channels)
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2)
        self.residual1 = ResidualBlock(out_channels + skip_channels, out_channels)
        self.residual2 = ResidualBlock(out_channels, out_channels)

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        x = self.upconv(x)
        x = self._handle_size_mismatch(x, skip)
        skip = self.attention(x, skip)
        x = torch.cat([x, skip], dim=1)
        x = self.residual1(x)
        x = self.residual2(x)
        return x

    @staticmethod
    def _handle_size_mismatch(x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        if x.shape != skip.shape:
            diff_h = skip.size(2) - x.size(2)
            diff_w = skip.size(3) - x.size(3)
            x = F.pad(x, [diff_w//2, diff_w-diff_w//2, diff_h//2, diff_h-diff_h//2])
        return x


class UNet(nn.Module):
    """U-Net architecture with ResNet50 encoder and attention"""

    def __init__(self, config):
        super().__init__()
        self.config = config

        # Initialize ResNet encoder
        resnet = torchvision.models.resnet50(weights="DEFAULT")

        # Encoder
        self.encoder1 = nn.Sequential(
            resnet.conv1, resnet.bn1, resnet.relu
        )
        self.pool = resnet.maxpool
        self.encoder2 = resnet.layer1  # 256
        self.encoder3 = resnet.layer2  # 512
        self.encoder4 = resnet.layer3  # 1024
        self.encoder5 = resnet.layer4  # 2048

        # Decoder
        self.decoder4 = DecoderBlock(2048, 1024, 512)
        self.decoder3 = DecoderBlock(512, 512, 256)
        self.decoder2 = DecoderBlock(256, 256, 128)
        self.decoder1 = DecoderBlock(128, 64, 64)

        self.final_conv = nn.Conv2d(64, config.num_classes, 1)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool(enc1))
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)
        enc5 = self.encoder5(enc4)

        # Decoder
        dec4 = self.decoder4(enc5, enc4)
        dec3 = self.decoder3(dec4, enc3)
        dec2 = self.decoder2(dec3, enc2)
        dec1 = self.decoder1(dec2, enc1)

        # Final output
        out = self.final_conv(dec1)
        return nn.functional.interpolate(
            out,
            size=x.shape[2:],
            mode="bilinear",
            align_corners=False
        )

### Train script (0.5 point)

Complete the train and predict scripts.

#### metrics

In [153]:
# training/trainer.py
import os
import torch
import wandb
from torch.utils.data import DataLoader
from typing import Dict
from tqdm import tqdm


class DiceLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, inputs, targets, smooth=1):
        inputs = torch.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice = (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        return 1 - dice


class Trainer:
    """Handles model training and validation"""

    def __init__(self, model: torch.nn.Module, config):
        self.model = model
        self.config = config
        self.device = config.device

        # Initialize metrics
        self.metrics = MetricsCalculator()

        # Loss and optimization
        self.criterion = lambda inputs, targets: (
            torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets) +
            DiceLoss()(inputs, targets)
        )
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=10, gamma=0.5
        )

        self.early_stopping_patience = config.early_stopping_patience
        self.no_improve_epochs = 0
        self.best_val_loss = float('inf')

        # Setup logging
        if config.use_wandb:
            self._setup_wandb()

    def _setup_wandb(self):
        """Initialize WandB logging"""
        self.run = wandb.init(
            project=self.config.project_name,
            name=self.config.run_name,
            config=self.config
        )
        wandb.watch(self.model, self.criterion, log="all", log_freq=10)

    def train(self, train_loader: DataLoader, val_loader: DataLoader) -> torch.nn.Module:
        """Training loop with validation"""
        for epoch in range(self.config.epochs):
            # Training
            train_metrics = self._train_epoch(train_loader)

            # Validation
            val_metrics = self._validate_epoch(val_loader)

            # Update learning rate
            self.scheduler.step()

            # Logging
            self._log_metrics(epoch, train_metrics, val_metrics)

            # Save best model
            if val_metrics["loss"] > self.best_val_loss:
                best_iou = val_metrics["loss"]
                self._save_checkpoint("best_model.pth")

            # Early stopping
            if val_metrics["loss"] < self.best_val_loss:
                self.best_val_loss = val_metrics["loss"]
                self.no_improve_epochs = 0
                self._save_checkpoint("best_model.pth")
            else:
                self.no_improve_epochs += 1
                if self.no_improve_epochs >= self.early_stopping_patience:
                    print("Early stopping triggered.")
                    break

        if self.config.use_wandb:
            self.run.finish()

        return self.model

    def _train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
        """Single training epoch"""
        self.model.train()
        self.metrics.reset()

        for inputs, targets in tqdm(dataloader, desc="Training"):
            # Move data to device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            # Forward pass
            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Calculate metrics
            with torch.inference_mode():
                iou = self.metrics.calculate_iou(outputs, targets, self.config.iou_threshold)
                self.metrics.update(loss.item(), iou)

        return self.metrics.compute()

    def _validate_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
        """Single validation epoch"""
        self.model.eval()
        self.metrics.reset()

        with torch.inference_mode():
            for inputs, targets in tqdm(dataloader, desc="Validation"):
                # Move data to device
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                # Forward pass
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                # Calculate metrics
                iou = self.metrics.calculate_iou(outputs, targets, self.config.iou_threshold)
                self.metrics.update(loss.item(), iou)

        return self.metrics.compute()

    def _log_metrics(
        self,
        epoch: int,
        train_metrics: Dict[str, float],
        val_metrics: Dict[str, float]
    ) -> None:
        """Log metrics to console and WandB if enabled"""
        # Console logging
        print(f"\nEpoch {epoch + 1}/{self.config.epochs}")
        print(f"Train Loss: {train_metrics['loss']:.4f} - Train IoU: {train_metrics['iou']:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f} - Val IoU: {val_metrics['iou']:.4f}")

        # WandB logging
        if self.config.use_wandb:
            wandb.log({
                "epoch": epoch,
                "train_loss": train_metrics["loss"],
                "train_iou": train_metrics["iou"],
                "val_loss": val_metrics["loss"],
                "val_iou": val_metrics["iou"],
                "learning_rate": self.scheduler.get_last_lr()[0]
            })

    def _save_checkpoint(self, filename: str) -> None:
        """Save model checkpoint"""
        state = {
            "epoch": self.scheduler.last_epoch,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "scheduler_state_dict": self.scheduler.state_dict(),
            "config": self.config
        }
        torch.save(state, filename)

    def load_checkpoint(self, filename: str) -> None:
        """Load model checkpoint"""
        if not os.path.exists(filename):
            raise FileNotFoundError(f"Checkpoint file {filename} not found")

        checkpoint = torch.load(filename, map_location=self.device)

        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

        # Move model to correct device after loading
        self.model = self.model.to(self.device)

    def predict(self, inputs: torch.Tensor) -> torch.Tensor:
        """Make predictions using the trained model"""
        self.model.eval()
        with torch.inference_mode():
            outputs = self.model(inputs.to(self.device))
            preds = (torch.sigmoid(outputs) > 0.5).float()
        return preds

In [154]:
# training/metrics.py
import torch
from typing import Dict


class MetricsCalculator:
    """Handles calculation and aggregation of training metrics"""

    def __init__(self):
        self.reset()

    def reset(self):
        """Reset accumulated metrics"""
        self.metrics = {
            "loss": 0.0,
            "iou": 0.0,
            "count": 0
        }

    def update(self, loss: float, iou: float):
        """Update metrics with batch results"""
        self.metrics["loss"] += loss
        self.metrics["iou"] += iou
        self.metrics["count"] += 1

    def compute(self) -> Dict[str, float]:
        """Compute average metrics"""
        count = max(self.metrics["count"], 1)  # Avoid division by zero
        return {
            "loss": self.metrics["loss"] / count,
            "iou": self.metrics["iou"] / count
        }

    @staticmethod
    def calculate_iou(logits: torch.Tensor, target: torch.Tensor, threshold: float) -> float:
        """Calculate IoU score for binary segmentation"""
        preds = (torch.sigmoid(logits) > threshold)
        intersection = (preds & target.bool()).float().sum((1, 2, 3))
        union = (preds | target.bool()).float().sum((1, 2, 3))
        iou = (intersection + 1e-6) / (union + 1e-6)
        return iou.mean().item()

#### trainer

#### main

In [None]:
# main.py
def main():
    # Initialize configuration
    config = TrainingConfig()

    # Create datasets
    train_dataset = BirdsDataset("data/train", config, is_training=True)
    val_dataset = BirdsDataset("data/val", config, is_training=False)

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False
    )

    # Initialize model
    model = UNet(config).to(config.device)
    trainer = Trainer(model, config)

    trained_model = trainer.train(train_loader, val_loader)
    print("Training completed.")
    return trained_model


if __name__ == "__main__":
    model = main()

You can also experiment with models and write a small report about results. If the report will be meaningful, you will receive an extra point.

### Testing (8 points)
Your model will be tested on the new data, similar to validation, so use techniques to prevent overfitting the model.

* IoU > 0.85 — 8 points
* IoU > 0.80 — 7 points
* IoU > 0.75 — 6 points
* IoU > 0.70 — 5 points
* IoU > 0.60 — 4 points
* IoU > 0.50 — 3 points
* IoU > 0.40 — 2 points
* IoU > 0.30 — 1 points

In [None]:
def predict(model, img_path):
    transform = A.Compose([
        A.Resize(256, 256),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    transformed = transform(image=img)
    input_tensor = transformed['image'].unsqueeze(0)

    model.eval()
    with torch.no_grad():
        output = model(input_tensor)
        pred = torch.sigmoid(output)
        pred = (pred > 0.5).float()

    # Resize prediction back to original size
    pred = F.interpolate(pred, size=(img.shape[0], img.shape[1]), mode='bilinear', align_corners=False)
    pred = pred.squeeze().numpy()

    return pred

def get_model(path):
    model = UNet()
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [None]:
model = get_model('model_14.pth').to('cuda')

In [None]:
ious, times = [], []
test_dir = 'data/val/'

for class_name in tqdm(sorted(os.listdir(os.path.join(test_dir, 'images')))):
    for img_name in sorted(os.listdir(os.path.join(test_dir, 'images', class_name))):

        t_start = time()
        pred = predict(model, os.path.join(test_dir, 'images', class_name, img_name))
        times.append(time() - t_start)

        gt_name = img_name.replace('jpg', 'png')
        gt = np.asarray(Image.open(os.path.join(test_dir, 'gt', class_name, gt_name)), dtype = np.uint8)
        if len(gt.shape) > 2:
            gt = gt[:, :, 0]

        iou = get_iou(gt==255, pred>0.5)
        ious.append(iou)

np.mean(ious), np.mean(times)

### Compression (1 point)

Try to speed up the model in any way without losing more than 1% in iou score.
For example [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt)

In [None]:
def get_fast_model():
    # YOUR CODE HERE
    return model

In [None]:
fast_model = get_fast_model().to('cuda')

In [None]:
ious, times = [], []
test_dir = 'data/val/'

for class_name in tqdm(sorted(os.listdir(os.path.join(test_dir, 'images')))):
    for img_name in sorted(os.listdir(os.path.join(test_dir, 'images', class_name))):

        t_start = time()
        pred = predict(fast_model, os.path.join(test_dir, 'images', class_name, img_name))
        times.append(time() - t_start)

        gt_name = img_name.replace('jpg', 'png')
        gt = np.asarray(Image.open(os.path.join(test_dir, 'gt', class_name, gt_name)), dtype = np.uint8)
        if len(gt.shape) > 2:
            gt = gt[:, :, 0]

        iou = get_iou(gt==255, pred>0.5)
        ious.append(iou)

np.mean(ious), np.mean(times)

**Bonus:** For the best iou score on test(without compression) in group you will get 1.5, 1, 0.5 extra points(for 1st, 2nd, 3rd places).