In [1]:
import os
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import transformers

# Если используете albumentations, раскомментируйте импорт
# !pip install albumentations>=1.0.0 --quiet
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

rescale = 1
os.environ["CUDA_VISIBLE_DEVICES"] = "3,4"
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
CUSTOM_MEAN = [0.1778, 0.2696, 0.1686]
CUSTOM_STD = [0.0942, 0.0915, 0.0762]

PATCH_SIZE = 256 // rescale
NUM_LABELS = 3

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.5 (you have 1.4.7). Upgrade using: pip install --upgrade albumentations


In [2]:
from datasets import load_dataset
import logging

logging.basicConfig(level=logging.DEBUG)

dataset = load_dataset("./my_segmentation_dataset.py", data_dir="./data_overlapped" + f"_x{rescale}" * (rescale != 1), trust_remote_code=True)
train_ds = dataset["train"]
valid_ds = load_dataset("./my_segmentation_dataset.py", data_dir="./data" + f"_x{rescale}" * (rescale != 1), trust_remote_code=True)["validation"]
# valid_ds = dataset

id2label = {0: "unlabeled", 1: "forest0", 2: "forest1"}
label2id = {v: k for k, v in id2label.items()}

INFO:datasets:PyTorch version 2.6.0+cu124 available.


In [3]:
def transforms(examples):
    transformA = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        
        # При необходимости можно добавить и другие аугментации, 
        # например, яркость/контраст, шум и т.д.
        
        # Простейшая нормализация, если нужно
        # A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),  
        # ToTensorV2()
    ])
    
    transformed_images, transformed_masks = [], []
    for image, seg_mask in zip(examples["image"], examples["annotation"]):
        image, seg_mask = np.array(image), np.array(seg_mask)
        transformed = transformA(image=image, mask=seg_mask)
        transformed_images.append(transformed["image"])
        transformed_masks.append(transformed["mask"])
    examples["pixel_values"] = transformed_images
    examples["label"] = transformed_masks
    return examples

In [4]:
import evaluate

metric = evaluate.load("mean_iou")
metric_f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    with torch.no_grad():
        # print(eval_pred.losses)
        logits, labels = eval_pred.predictions, eval_pred.label_ids

        pred_labels = np.argmax(logits, axis=1)
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=NUM_LABELS,
            ignore_index=255,
            reduce_labels=False,
        )
        metrics_f1 = metric_f1.compute(
            predictions=pred_labels.flatten(),
            references=labels.flatten(),
            average="macro",
            # ignore_index=255,
            # reduce_labels=True,
        )
        per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        per_category_iou = metrics.pop("per_category_iou").tolist()
    
        metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
        metrics.update({"f1_score": metrics_f1["f1"]})
    
        return metrics

In [5]:
from transformers import AutoImageProcessor, MobileViTImageProcessor

# checkpoint = "Intel/dpt-large-ade"
checkpoint = "apple/deeplabv3-mobilevit-small"
image_processor = MobileViTImageProcessor(do_resize=False, size={"height": PATCH_SIZE, "width": PATCH_SIZE},
                                                    do_reduce_labels=False, image_mean=CUSTOM_MEAN, image_std=CUSTOM_STD)

from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

def train_transforms(example_batch):
    example_batch = transforms(example_batch)
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor.preprocess(images, labels, do_resize=False)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels)
    return inputs


train_ds.set_transform(train_transforms)
valid_ds.set_transform(val_transforms)

In [6]:
from transformers import TrainerCallback

class SaveEvalPredictionsCallback(TrainerCallback):
    def __init__(self, trainer, eval_dataset, output_dir, num_labels=3, val_transforms=val_transforms):
        super().__init__()
        self.trainer = trainer  # явно передаём объект trainer
        self.eval_dataset = eval_dataset
        self.output_dir = output_dir
        self.num_labels = num_labels
        os.makedirs(self.output_dir, exist_ok=True)

    def upvalue(self, img):
        return ((img / (self.num_labels - 1)) * 255).astype(np.uint8)

    def on_evaluate(self, args, state, control, **kwargs):
        # Выполняем предсказания на eval_dataset с помощью полученного trainer
        pred_res = self.trainer.predict(self.eval_dataset)
        pred_labels = np.argmax(pred_res.predictions, axis=1)  # [N, H, W]
        # upvalue
        pred_labels = self.upvalue(pred_labels)
    
        # Создаём поддиректорию в зависимости от шага
        step_folder = os.path.join(self.output_dir, f"step_{state.global_step}")
        os.makedirs(step_folder, exist_ok=True)
    
        # Сохраняем каждую маску как png
        try:
            self.eval_dataset.reset_format()
            for i, label_map in enumerate(pred_labels):
                label_img = Image.fromarray(np.concatenate((np.repeat(label_map[:, :, None], 3, -1), self.upvalue(np.repeat(np.array(self.eval_dataset[i]["annotation"])[:, :, None], 3, -1)), self.eval_dataset[i]["image"]), axis=1))
                label_img.save(os.path.join(step_folder, f"pred_{self.eval_dataset[i]['filename'].rsplit('.')[0]}.png"))
        except RuntimeError as e:
            self.eval_dataset.set_transform(val_transforms)
            raise e
        self.eval_dataset.set_transform(val_transforms)

        print(f"Saved evaluation predictions to: {step_folder}")

In [7]:
from transformers import (
    AutoModelForSemanticSegmentation,
    TrainingArguments,
    Trainer
)
from transformers.modeling_outputs import SemanticSegmenterOutput
from typing import Optional, Union, Tuple

model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, num_labels=3, ignore_mismatched_sizes=True)

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, outputs, targets):
        """
        Args:
            outputs (torch.Tensor): Прогнозы модели размерности (N, C, H, W), от LogSoftmax или Softmax.
            targets (torch.Tensor): Истинные метки размерности (N, H, W) или (N, 1, H, W) с классами в [0, C-1].

        Returns:
            torch.Tensor: Значение Dice Loss.
        """
        num_classes = outputs.size(1)

        # Конвертируем targets в one-hot формат, если необходимо
        if targets.ndim == 3:
            targets = targets.unsqueeze(1)
        targets_one_hot = torch.zeros_like(outputs).scatter_(1, targets, 1)

        # Вычисляем Dice Loss для каждого класса
        dice_loss = 0.0
        for c in range(num_classes):
            output_c = outputs[:, c, :, :]
            target_c = targets_one_hot[:, c, :, :]
            
            intersection = (output_c * target_c).sum()
            union = output_c.sum() + target_c.sum() - intersection

            dice_loss += (intersection + self.smooth) / (union + self.smooth)

        return 1 - dice_loss / num_classes

class ModelWrapper(torch.nn.Module):
    def __init__(self, segmodel, patch_size=PATCH_SIZE, accepts_loss_kwargs=False):
        super().__init__()
        self.main = segmodel
        self.accepts_loss_kwargs = accepts_loss_kwargs
        self.patch_size = patch_size
        self.dice_loss = DiceLoss()

    # def forward(self, *args, **kwargs):
    #     outputs = self.main(*args, **kwargs)
    def forward(
        self,
        pixel_values: torch.FloatTensor,
        labels: Optional[torch.LongTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SemanticSegmenterOutput]:
        outputs = self.main(pixel_values, labels=labels,
                            output_hidden_states=output_hidden_states, return_dict=return_dict)
        loss, logits_val = outputs.loss, outputs.logits
        upsampled_val_logits = torch.nn.functional.interpolate(
        logits_val, size=(self.patch_size, self.patch_size), mode="bilinear", align_corners=False
        )
        # loss += 0.5 * self.dice_loss(upsampled_val_logits, labels)
        return SemanticSegmenterOutput(loss=loss, logits=upsampled_val_logits)

Some weights of MobileViTForSemanticSegmentation were not initialized from the model checkpoint at apple/deeplabv3-mobilevit-small and are newly initialized because the shapes did not match:
- segmentation_head.classifier.convolution.weight: found shape torch.Size([21, 256, 1, 1]) in the checkpoint and torch.Size([3, 256, 1, 1]) in the model instantiated
- segmentation_head.classifier.convolution.bias: found shape torch.Size([21]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_WATCH"] = "all"
# os.environ["WANDB_DISABLED"] = "true"
os.environ["HF_TOKEN"] = "<your-token>"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
import warnings
warnings.filterwarnings("ignore")

modelwr = ModelWrapper(model)
# modelwr = model
print(f"num_model_parameters:{sum(par.numel() for par in modelwr.parameters())}")
outputs_dir = "outputs_DeepLab"
if rescale > 1:
    outputs_dir += f"_x{rescale}"

training_args = TrainingArguments(
    output_dir=os.path.join(outputs_dir, "checkpoints"),
    learning_rate=6e-5,
    num_train_epochs=1000,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    save_total_limit=5,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=1000,
    eval_steps=100,
    logging_steps=100,
    eval_accumulation_steps=5,
    remove_unused_columns=False,
    load_best_model_at_end=True,
    metric_for_best_model = "eval_mean_iou",
    # log_level='debug',
    # use_cpu=True,
    # push_to_hub=True,
)

trainer = Trainer(
    model=modelwr,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    compute_metrics=compute_metrics,
)
callback = SaveEvalPredictionsCallback(
    trainer=trainer,
    eval_dataset=valid_ds,
    output_dir=os.path.join(outputs_dir, "eval_predict"),
)

# добавляем callback после инициализации trainer'а
trainer.add_callback(callback)
# trainer.add_callback(transformers.integrations.WandbCallback())
trainer.train()

num_model_parameters:6353315


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33marslan-valeev-03[0m ([33mcourse_sr[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Mean Iou,Mean Accuracy,Overall Accuracy,Accuracy Unlabeled,Accuracy Forest0,Accuracy Forest1,Iou Unlabeled,Iou Forest0,Iou Forest1,F1 Score
100,0.9899,0.872477,0.463818,0.576046,0.76495,0.177301,0.91298,0.637858,0.176027,0.728706,0.486722,0.599061
200,0.8109,0.716575,0.613858,0.729924,0.823626,0.480523,0.88694,0.822309,0.457538,0.775078,0.608959,0.752691
300,0.6635,0.583689,0.702051,0.817757,0.859417,0.705033,0.887206,0.861032,0.638516,0.812177,0.655461,0.822539
400,0.5578,0.487352,0.721595,0.828131,0.871882,0.751333,0.909214,0.823847,0.658634,0.827715,0.678436,0.836113
500,0.4849,0.425505,0.746834,0.856509,0.881486,0.805293,0.901354,0.862879,0.70629,0.837154,0.697058,0.853572
600,0.4393,0.38647,0.759118,0.85105,0.889799,0.806965,0.927553,0.818634,0.72957,0.850346,0.697436,0.861506
700,0.3931,0.349917,0.774001,0.870977,0.89494,0.850934,0.919703,0.842293,0.758704,0.855437,0.707863,0.871277
800,0.3621,0.337404,0.76937,0.84926,0.893433,0.823419,0.941254,0.783109,0.774266,0.856785,0.677059,0.867693
900,0.3331,0.312665,0.78581,0.865557,0.899536,0.837448,0.934711,0.824512,0.799701,0.863351,0.694378,0.878332
1000,0.3161,0.284459,0.801515,0.880479,0.907977,0.856134,0.936127,0.849176,0.808311,0.873428,0.722805,0.888512


Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_100
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_200
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_300
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_400
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_500
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_600
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_700
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_800
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_900
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_1000


[34m[1mwandb[0m: Adding directory to artifact (./outputs_DeepLab/checkpoints/checkpoint-1000)... Done. 0.7s


Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_1100
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_1200
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_1300
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_1400
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_1500
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_1600
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_1700
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_1800
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_1900
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_2000


[34m[1mwandb[0m: Adding directory to artifact (./outputs_DeepLab/checkpoints/checkpoint-2000)... Done. 0.1s


Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_2100
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_2200
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_2300
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_2400
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_2500
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_2600
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_2700
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_2800
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_2900
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_3000


[34m[1mwandb[0m: Adding directory to artifact (./outputs_DeepLab/checkpoints/checkpoint-3000)... Done. 0.1s


Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_3100
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_3200
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_3300
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_3400
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_3500
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_3600
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_3700
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_3800
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_3900
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_4000


[34m[1mwandb[0m: Adding directory to artifact (./outputs_DeepLab/checkpoints/checkpoint-4000)... Done. 0.1s


Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_4100
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_4200
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_4300
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_4400
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_4500
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_4600
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_4700
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_4800
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_4900
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_5100
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_5200
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_5300
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_5400
Saved evaluation predictions to: outputs_DeepLab/eval_predict/st

[34m[1mwandb[0m: Adding directory to artifact (./outputs_DeepLab/checkpoints/checkpoint-6000)... Done. 0.1s


Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_6100
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_6200
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_6300
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_6400
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_6500
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_6600
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_6700
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_6800
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_6900
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_7000


[34m[1mwandb[0m: Adding directory to artifact (./outputs_DeepLab/checkpoints/checkpoint-7000)... Done. 0.1s


Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_7100
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_7200
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_7300
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_7400
Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_7500


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Saved evaluation predictions to: outputs_DeepLab/eval_predict/step_12900
