In [1]:
import os
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import transformers

# Если используете albumentations, раскомментируйте импорт
# !pip install albumentations>=1.0.0 --quiet
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

rescale = 1
os.environ["CUDA_VISIBLE_DEVICES"] = "5,6"
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
CUSTOM_MEAN = [0.1778, 0.2696, 0.1686]
CUSTOM_STD = [0.0942, 0.0915, 0.0762]

PATCH_SIZE = 256 // rescale
NUM_LABELS = 3

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.5 (you have 1.4.7). Upgrade using: pip install --upgrade albumentations


In [2]:
from datasets import load_dataset
import logging

logging.basicConfig(level=logging.DEBUG)

dataset = load_dataset("./my_segmentation_dataset.py", data_dir="./data" + f"_x{rescale}" * (rescale != 1), trust_remote_code=True)
train_ds = dataset["train"]
valid_ds = dataset["validation"]

id2label = {0: "unlabeled", 1: "forest0", 2: "forest1"}
label2id = {v: k for k, v in id2label.items()}

INFO:datasets:PyTorch version 2.6.0+cu124 available.


In [3]:
def transforms(examples):
    transformA = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        
        # При необходимости можно добавить и другие аугментации, 
        # например, яркость/контраст, шум и т.д.
        
        # Простейшая нормализация, если нужно
        # A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),  
        # ToTensorV2()
    ])
    
    transformed_images, transformed_masks = [], []
    for image, seg_mask in zip(examples["image"], examples["annotation"]):
        image, seg_mask = np.array(image), np.array(seg_mask)
        transformed = transformA(image=image, mask=seg_mask)
        transformed_images.append(transformed["image"])
        transformed_masks.append(transformed["mask"])
    examples["pixel_values"] = transformed_images
    examples["label"] = transformed_masks
    return examples

In [4]:
import evaluate

metric = evaluate.load("mean_iou")
metric_f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    with torch.no_grad():
        # print(eval_pred.losses)
        logits, labels = eval_pred

        pred_labels = np.argmax(logits, axis=1)
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=NUM_LABELS,
            ignore_index=255,
            reduce_labels=False,
        )
        metrics_f1 = metric_f1.compute(
            predictions=pred_labels.flatten(),
            references=labels.flatten(),
            average="macro",
            # ignore_index=255,
            # reduce_labels=True,
        )
        per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        per_category_iou = metrics.pop("per_category_iou").tolist()
    
        metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
        metrics.update({"f1_score": metrics_f1["f1"]})
    
        return metrics

In [5]:
from transformers import AutoImageProcessor

# checkpoint = "nvidia/mit-b0"
checkpoint = "nvidia/segformer-b0-finetuned-ade-512-512"
# checkpoint = "sawthiha/segformer-b0-finetuned-deprem-satellite"
# checkpoint = "nvidia/segformer-b5-finetuned-ade-640-640"
image_processor = SegformerImageProcessor.from_pretrained(checkpoint, size={"height": PATCH_SIZE, "width": PATCH_SIZE},
                                                          do_reduce_labels=False, image_mean=CUSTOM_MEAN, image_std=CUSTOM_STD)

from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

def train_transforms(example_batch):
    example_batch = transforms(example_batch)
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels)
    return inputs


train_ds.set_transform(train_transforms)
valid_ds.set_transform(val_transforms)

  return func(*args, **kwargs)


In [6]:
from transformers import TrainerCallback

class SaveEvalPredictionsCallback(TrainerCallback):
    def __init__(self, trainer, eval_dataset, output_dir, num_labels=3, val_transforms=val_transforms):
        super().__init__()
        self.trainer = trainer  # явно передаём объект trainer
        self.eval_dataset = eval_dataset
        self.output_dir = output_dir
        self.num_labels = num_labels
        os.makedirs(self.output_dir, exist_ok=True)

    def upvalue(self, img):
        return ((img / (self.num_labels - 1)) * 255).astype(np.uint8)

    def on_evaluate(self, args, state, control, **kwargs):
        # Выполняем предсказания на eval_dataset с помощью полученного trainer
        pred_res = self.trainer.predict(self.eval_dataset)
        pred_labels = np.argmax(pred_res.predictions, axis=1)  # [N, H, W]
        # upvalue
        pred_labels = self.upvalue(pred_labels)
    
        # Создаём поддиректорию в зависимости от шага
        step_folder = os.path.join(self.output_dir, f"step_{state.global_step}")
        os.makedirs(step_folder, exist_ok=True)
    
        # Сохраняем каждую маску как png
        try:
            self.eval_dataset.reset_format()
            for i, label_map in enumerate(pred_labels):
                label_img = Image.fromarray(np.concatenate((np.repeat(label_map[:, :, None], 3, -1), self.upvalue(np.repeat(np.array(self.eval_dataset[i]["annotation"])[:, :, None], 3, -1)), self.eval_dataset[i]["image"]), axis=1))
                label_img.save(os.path.join(step_folder, f"pred_{self.eval_dataset[i]['filename'].rsplit('.')[0]}.png"))
        except RuntimeError as e:
            self.eval_dataset.set_transform(val_transforms)
            raise e
        self.eval_dataset.set_transform(val_transforms)

        print(f"Saved evaluation predictions to: {step_folder}")

In [7]:
from transformers import (
    AutoModelForSemanticSegmentation,
    TrainingArguments,
    Trainer
)
from transformers.modeling_outputs import SemanticSegmenterOutput
from typing import Optional, Union, Tuple

model = SegformerForSemanticSegmentation.from_pretrained(checkpoint, num_labels=3, ignore_mismatched_sizes=True)

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, outputs, targets):
        """
        Args:
            outputs (torch.Tensor): Прогнозы модели размерности (N, C, H, W), от LogSoftmax или Softmax.
            targets (torch.Tensor): Истинные метки размерности (N, H, W) или (N, 1, H, W) с классами в [0, C-1].

        Returns:
            torch.Tensor: Значение Dice Loss.
        """
        num_classes = outputs.size(1)

        # Конвертируем targets в one-hot формат, если необходимо
        if targets.ndim == 3:
            targets = targets.unsqueeze(1)
        targets_one_hot = torch.zeros_like(outputs).scatter_(1, targets, 1)

        # Вычисляем Dice Loss для каждого класса
        dice_loss = 0.0
        for c in range(num_classes):
            output_c = outputs[:, c, :, :]
            target_c = targets_one_hot[:, c, :, :]
            
            intersection = (output_c * target_c).sum()
            union = output_c.sum() + target_c.sum() - intersection

            dice_loss += (intersection + self.smooth) / (union + self.smooth)

        return 1 - dice_loss / num_classes

class ModelWrapper(torch.nn.Module):
    def __init__(self, segformer, patch_size=PATCH_SIZE, accepts_loss_kwargs=False):
        super().__init__()
        self.main = segformer
        self.accepts_loss_kwargs = accepts_loss_kwargs
        self.patch_size = patch_size
        self.dice_loss = DiceLoss()

    # def forward(self, *args, **kwargs):
    #     outputs = self.main(*args, **kwargs)
    def forward(
        self,
        pixel_values: torch.FloatTensor,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SemanticSegmenterOutput]:
        outputs = self.main(pixel_values, labels=labels, output_attentions=output_attentions,
                            output_hidden_states=output_hidden_states, return_dict=return_dict)
        loss, logits_val = outputs.loss, outputs.logits
        upsampled_val_logits = torch.nn.functional.interpolate(
        logits_val, size=(self.patch_size, self.patch_size), mode="bilinear", align_corners=False
        )
        # loss += 0.5 * self.dice_loss(upsampled_val_logits, labels)
        return SemanticSegmenterOutput(loss=loss, logits=upsampled_val_logits)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([3]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([3, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_WATCH"] = "all"
# os.environ["WANDB_DISABLED"] = "true"
os.environ["HF_TOKEN"] = "<your-token>"

In [None]:
import warnings
warnings.filterwarnings("ignore")

modelwr = ModelWrapper(model)
print(f"num_model_parameters:{sum(par.numel() for par in modelwr.parameters())}")
outputs_dir = "outputs_custom_stat"
if rescale > 1:
    outputs_dir += f"_x{rescale}"

training_args = TrainingArguments(
    output_dir=os.path.join(outputs_dir, "checkpoints"),
    learning_rate=6e-5,
    num_train_epochs=1000,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    save_total_limit=5,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=5000,
    eval_steps=1000,
    logging_steps=1000,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model = "eval_mean_iou",
    # log_level='debug',
    # use_cpu=True,
    # push_to_hub=True,
)

trainer = Trainer(
    model=modelwr,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    compute_metrics=compute_metrics,
)
callback = SaveEvalPredictionsCallback(
    trainer=trainer,
    eval_dataset=valid_ds,
    output_dir=os.path.join(outputs_dir, "eval_predict"),
)

# добавляем callback после инициализации trainer'а
trainer.add_callback(callback)
# trainer.add_callback(transformers.integrations.WandbCallback())
trainer.train()

num_model_parameters:3714915


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33marslan-valeev-03[0m ([33mcourse_sr[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
