In [1]:
import os
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import transformers

# Если используете albumentations, раскомментируйте импорт
# !pip install albumentations>=1.0.0 --quiet
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

rescale = 1
os.environ["CUDA_VISIBLE_DEVICES"] = "4,6"
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
PATCH_SIZE = 256 // rescale
NUM_LABELS = 3
BATCH_SIZE = 32

  from .autonotebook import tqdm as notebook_tqdm
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.5 (you have 1.4.7). Upgrade using: pip install --upgrade albumentations


In [2]:
from datasets import load_dataset
import logging

logging.basicConfig(level=logging.DEBUG)

dataset = load_dataset("./my_segmentation_dataset.py", data_dir="./data_overlapped" + f"_x{rescale}" * (rescale != 1), trust_remote_code=True)
train_ds = dataset["train"]
# valid_ds = dataset["validation"]
valid_ds = load_dataset("./my_segmentation_dataset.py", data_dir="./data" + f"_x{rescale}" * (rescale != 1), trust_remote_code=True)["validation"]

id2label = {0: "unlabeled", 1: "forest0", 2: "forest1"}
label2id = {v: k for k, v in id2label.items()}

INFO:datasets:PyTorch version 2.6.0+cu124 available.


In [3]:
def transforms(examples):
    transformA = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        
        # При необходимости можно добавить и другие аугментации, 
        # например, яркость/контраст, шум и т.д.
        
        # Простейшая нормализация, если нужно
        A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),  
        ToTensorV2()
    ])
    
    transformed_images, transformed_masks = [], []
    for image, seg_mask in zip(examples["image"], examples["annotation"]):
        image, seg_mask = np.array(image), np.array(seg_mask)
        transformed = transformA(image=image, mask=seg_mask)
        transformed_images.append(transformed["image"])
        transformed_masks.append(transformed["mask"])
    examples["pixel_values"] = transformed_images
    examples["label"] = transformed_masks
    return examples

In [4]:
from transformers import AutoImageProcessor, MaskFormerFeatureExtractor

# checkpoint = "nvidia/mit-b0"
checkpoint = "nvidia/segformer-b0-finetuned-ade-512-512"
# checkpoint = "sawthiha/segformer-b0-finetuned-deprem-satellite"
# checkpoint = "nvidia/segformer-b5-finetuned-ade-640-640"
checkpoint = "facebook/maskformer-swin-base-ade"
# image_processor = SegformerImageProcessor.from_pretrained(checkpoint, size={"height": PATCH_SIZE, "width": PATCH_SIZE}, do_reduce_labels=False)
image_processor = AutoImageProcessor.from_pretrained(checkpoint, size={"height": PATCH_SIZE, "width": PATCH_SIZE},
                                                     do_reduce_labels=False, num_labels=NUM_LABELS, return_tensors="pt")
# feature_extractor = MaskFormerFeatureExtractor.from_pretrained(checkpoint)

# def preprocess_examples(examples):
#     # В датасете "scene_parse_150" поля называются "image" и "annotation".
#     # "image" — путь к изображению, "annotation" — путь к маске.
#     images = []
#     masks = []
#     for img, mask in zip(examples["image"], examples["annotation"]):
#         image = img.convert("RGB")
#         mask = mask.convert("L")  # Маска в градациях серого
#         images.append(image)
#         masks.append(mask)
    
#     # feature_extractor вернёт словарь с:
#     # {
#     #   "pixel_values": тензор [batch_size, 3, height, width],
#     #   "class_labels": (опционально, если задать labels) ...
#     # }
#     # Для семантической сегментации используется аргумент "masks".
#     encoded_inputs = feature_extractor(images, masks, return_tensors="pt")
    
#     # Преобразуем результат из batched формата datasets в обычный список Python
#     batch = {}
#     for k, v in encoded_inputs.items():
#         batch[k] = v
    
#     return batch

# train_ds = train_ds.map(preprocess_examples, batched=True, batch_size=BATCH_SIZE)
# valid_ds = valid_ds.map(preprocess_examples, batched=True, batch_size=BATCH_SIZE)

# columns_to_remove = set(train_ds.column_names) - {"pixel_values", "mask_labels"}
# train_ds = train_ds.remove_columns(list(columns_to_remove))
# valid_ds = valid_ds.remove_columns(list(columns_to_remove))

# train_ds.set_format("torch")
# valid_ds.set_format("torch")



from torchvision.transforms import ColorJitter

jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

def train_transforms(example_batch):
    images = [jitter(x) for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    example_batch = transforms(example_batch)
    inputs = image_processor(images, labels, return_tensors="pt")
    # print(inputs) mask_labels
    # print(f"pixel_val_shape: {inputs['pixel_values'].shape}")
    # print(f"mask_labels_len: {len(inputs['mask_labels'])}")
    # print(f"mask_labels_shape: {inputs['mask_labels'][0].shape}")
    # print(f"class_labels_len: {len(inputs['class_labels'])}")
    # inputs = {k: v[0] for k, v in inputs.items()}
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch["image"]]
    labels = [x for x in example_batch["annotation"]]
    inputs = image_processor(images, labels, return_tensors="pt")
    # inputs = {k: v[0] for k, v in inputs.items()}
    return inputs


train_ds.set_transform(train_transforms)
valid_ds.set_transform(val_transforms)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
  return func(*args, **kwargs)


In [5]:
import evaluate

metric = evaluate.load("mean_iou")
metric_f1 = evaluate.load("f1")

def compute_metrics(eval_pred, trainer, dataset, processor):
    with torch.no_grad():
        # print(eval_pred.inputs)
        # print(list(eval_pred.inputs.keys()))
        # # logits, labels = eval_pred.predictions, eval_pred.label_ids
        # print(len(logits))
        # print(logits[0].shape)
        # print(labels)
        outputs = trainer.predict(dataset)
        print(outputs)
        
        preds = processor.post_process_semantic_segmentation(outputs)
        print(preds.shape)
                                                           

        pred_labels = np.argmax(logits, axis=1)
        metrics = metric.compute(
            predictions=pred_labels,
            references=labels,
            num_labels=NUM_LABELS,
            ignore_index=255,
            reduce_labels=False,
        )
        metrics_f1 = metric_f1.compute(
            predictions=pred_labels.flatten(),
            references=labels.flatten(),
            average="macro",
            # ignore_index=255,
            # reduce_labels=True,
        )
        per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        per_category_iou = metrics.pop("per_category_iou").tolist()
    
        metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
        metrics.update({"f1_score": metrics_f1["f1"]})
    
        return metrics

In [6]:
from transformers import TrainerCallback
import subprocess

class SaveEvalPredictionsCallback(TrainerCallback):
    def __init__(self, trainer, eval_dataset, output_dir, num_labels=3, val_transforms=val_transforms):
        super().__init__()
        self.trainer = trainer  # явно передаём объект trainer
        self.eval_dataset = eval_dataset
        self.output_dir = output_dir
        self.num_labels = num_labels
        os.makedirs(self.output_dir, exist_ok=True)

    def upvalue(self, img):
        return ((img / (self.num_labels - 1)) * 255).astype(np.uint8)

    def on_evaluate(self, args, state, control, **kwargs):
        # Выполняем предсказания на eval_dataset с помощью полученного trainer
        pred_res = self.trainer.predict(self.eval_dataset)
        pred_labels = np.argmax(pred_res.predictions, axis=1)  # [N, H, W]
        # upvalue
        pred_labels = self.upvalue(pred_labels)
    
        # Создаём поддиректорию в зависимости от шага
        step_folder = os.path.join(self.output_dir, f"step_{state.global_step}")
        os.makedirs(step_folder, exist_ok=True)
    
        # Сохраняем каждую маску как png
        try:
            self.eval_dataset.reset_format()
            for i, label_map in enumerate(pred_labels):
                label_img = Image.fromarray(np.concatenate((np.repeat(label_map[:, :, None], 3, -1), self.upvalue(np.repeat(np.array(self.eval_dataset[i]["annotation"])[:, :, None], 3, -1)), self.eval_dataset[i]["image"]), axis=1))
                label_img.save(os.path.join(step_folder, f"pred_{self.eval_dataset[i]['filename'].rsplit('.')[0]}.png"))
        except RuntimeError as e:
            self.eval_dataset.set_transform(val_transforms)
            raise e
        self.eval_dataset.set_transform(val_transforms)

        print(f"Saved evaluation predictions to: {step_folder}")

class ReconstructionCallback(TrainerCallback):
    """
    Колбэк, который после сохранения чекпоинта запускает сторонний скрипт
    для реконструкции изображений (или постобработки, измерения IoU и т.п.).
    """
    def on_save(self, args, state, control, **kwargs):
        """
        Вызывается сразу после сохранения модели.
        Параметры:
        -----------
        args  : TrainingArguments
        state : TrainerState
        control : TrainerControl
        """
        # Путь к файлу скрипта, который вы хотите запустить
        script_path = "reconstructionv3.py"
    
        # Пример: передадим в скрипт некоторые аргументы, например,
        # путь к только что сохранённому чекпоинту.
        # state.best_model_checkpoint или state.global_step и т.д. 
        # Можете скорректировать по желанию.
        checkpoint_dir = args.output_dir.rsplit('/', 1)[0]
        step_str = str(state.global_step - args.save_steps)
        
        if step_str == "0":
            return control
        
        # Вызов внешнего скрипта:
        try:
            cmd = [
                # "CUDA_VISIBLE_DEVICES=" + os.getenv("CUDA_VISIBLE_DEVICES"),
                "python3", script_path, 
                "--model", "MaskFormer",
                "--out_dir", checkpoint_dir,
                "--step", step_str,
                "--rescale", str(rescale),
                "--format", "png"
            ]
            pipe = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, text=True)
            while pipe.poll() is None:
                continue
                # print(pipe.stdout.readline().strip())
            if pipe.returncode != 0:
                print('Command exited with non zero return code')
                print('Return code = %d', pipe.returncode)
                raise RuntimeError()
            # print('The job finished successfully')
            print(f"[INFO] Reconstruction script {script_path} завершился успешно.")
        except subprocess.CalledProcessError as e:
            print(f"[ERROR] Что-то пошло не так при вызове {script_path}. Код возврата: {e.returncode}")
            # При необходимости можно выбросить исключение или продолжить
    
        return control

In [7]:
from transformers import (
    AutoModelForSemanticSegmentation,
    TrainingArguments,
    Trainer
)
from transformers.modeling_outputs import SemanticSegmenterOutput
from typing import Optional, Union, Tuple

# model = SegformerForSemanticSegmentation.from_pretrained(checkpoint, num_labels=3, ignore_mismatched_sizes=True)

class DiceLoss(nn.Module):
    def __init__(self, smooth=1.0):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, outputs, targets):
        """
        Args:
            outputs (torch.Tensor): Прогнозы модели размерности (N, C, H, W), от LogSoftmax или Softmax.
            targets (torch.Tensor): Истинные метки размерности (N, H, W) или (N, 1, H, W) с классами в [0, C-1].

        Returns:
            torch.Tensor: Значение Dice Loss.
        """
        num_classes = outputs.size(1)

        # Конвертируем targets в one-hot формат, если необходимо
        if targets.ndim == 3:
            targets = targets.unsqueeze(1)
        targets_one_hot = torch.zeros_like(outputs).scatter_(1, targets, 1)

        # Вычисляем Dice Loss для каждого класса
        dice_loss = 0.0
        for c in range(num_classes):
            output_c = outputs[:, c, :, :]
            target_c = targets_one_hot[:, c, :, :]
            
            intersection = (output_c * target_c).sum()
            union = output_c.sum() + target_c.sum() - intersection

            dice_loss += (intersection + self.smooth) / (union + self.smooth)

        return 1 - dice_loss / num_classes

class ModelWrapper(torch.nn.Module):
    def __init__(self, segformer, patch_size=PATCH_SIZE, accepts_loss_kwargs=False):
        super().__init__()
        self.main = segformer
        self.accepts_loss_kwargs = accepts_loss_kwargs
        self.patch_size = patch_size
        self.dice_loss = DiceLoss()

    # def forward(self, *args, **kwargs):
    #     outputs = self.main(*args, **kwargs)
    def forward(
        self,
        pixel_values: torch.FloatTensor,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, SemanticSegmenterOutput]:
        outputs = self.main(pixel_values, labels=labels, output_attentions=output_attentions,
                            output_hidden_states=output_hidden_states, return_dict=return_dict)
        loss, logits_val = outputs.loss, outputs.logits
        upsampled_val_logits = torch.nn.functional.interpolate(
        logits_val, size=(self.patch_size, self.patch_size), mode="bilinear", align_corners=False
        )
        # loss += 0.2 * self.dice_loss(upsampled_val_logits, labels)
        return SemanticSegmenterOutput(loss=loss, logits=upsampled_val_logits)

In [None]:
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_WATCH"] = "all"
# os.environ["WANDB_DISABLED"] = "true"
os.environ["HF_TOKEN"] = "<your-token>"
# os.environ["WANDB_RESUME"] = "must"
# os.environ["WANDB_RUN_ID"] = "h4h4s471"

In [None]:
import warnings
import functools
warnings.filterwarnings("ignore")

# modelwr = ModelWrapper(model)
modelwr = transformers.MaskFormerForInstanceSegmentation.from_pretrained(checkpoint, num_labels=NUM_LABELS, ignore_mismatched_sizes=True)
print(f"num_model_parameters:{sum(par.numel() for par in modelwr.parameters())}")
outputs_dir = "outputs_MaskFormer_wo_postproc"
if rescale > 1:
    outputs_dir += f"_x{rescale}"


def custom_collate_fn(batch):
    res = dict()
    res["pixel_values"] = torch.stack([example["pixel_values"] for example in batch], dim=0)
    res["pixel_mask"] = torch.stack([example["pixel_mask"] for example in batch], dim=0)
    res["mask_labels"] = [example["mask_labels"] for example in batch]
    res["class_labels"] = [example["class_labels"] for example in batch]
    return res

training_args = TrainingArguments(
    output_dir=os.path.join(outputs_dir, "checkpoints"),
    learning_rate=6e-5,
    num_train_epochs=1000,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    save_total_limit=5,
    # eval_strategy="steps",
    save_strategy="steps",
    save_steps=1000,
    # eval_steps=10,
    logging_steps=100,
    # eval_accumulation_steps=2,
    remove_unused_columns=False,
    # load_best_model_at_end=True,
    # metric_for_best_model = "eval_mean_iou",
    # log_level='debug',
    # use_cpu=True,
    # push_to_hub=True,
)

trainer = Trainer(
    model=modelwr,
    args=training_args,
    train_dataset=train_ds,
    # eval_dataset=valid_ds,
    compute_metrics=None,
    data_collator=custom_collate_fn,
)

# callback = SaveEvalPredictionsCallback(
#     trainer=trainer,
#     eval_dataset=valid_ds,
#     output_dir=os.path.join(outputs_dir, "eval_predict"),
# )

rec_callback = ReconstructionCallback()

# добавляем callback после инициализации trainer'а
# trainer.add_callback(callback)
trainer.add_callback(rec_callback)
# trainer.add_callback(transformers.integrations.WandbCallback())
trainer.train()

Some weights of MaskFormerForInstanceSegmentation were not initialized from the model checkpoint at facebook/maskformer-swin-base-ade and are newly initialized because the shapes did not match:
- class_predictor.weight: found shape torch.Size([151, 256]) in the checkpoint and torch.Size([4, 256]) in the model instantiated
- class_predictor.bias: found shape torch.Size([151]) in the checkpoint and torch.Size([4]) in the model instantiated
- criterion.empty_weight: found shape torch.Size([151]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


num_model_parameters:101793660


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33marslan-valeev-03[0m ([33mcourse_sr[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
100,0.6982
200,0.5992
300,0.5462
400,0.5086
500,0.4705
600,0.4593
700,0.3985
800,0.4316
900,0.3914
1000,0.3681


[34m[1mwandb[0m: Adding directory to artifact (./outputs_MaskFormer_wo_postproc/checkpoints/checkpoint-1000)... Done. 2.2s
[34m[1mwandb[0m: Adding directory to artifact (./outputs_MaskFormer_wo_postproc/checkpoints/checkpoint-2000)... Done. 2.1s


[INFO] Reconstruction script reconstructionv3.py завершился успешно.


[34m[1mwandb[0m: Adding directory to artifact (./outputs_MaskFormer_wo_postproc/checkpoints/checkpoint-3000)... Done. 2.1s


[INFO] Reconstruction script reconstructionv3.py завершился успешно.


[34m[1mwandb[0m: Adding directory to artifact (./outputs_MaskFormer_wo_postproc/checkpoints/checkpoint-4000)... Done. 2.0s


[INFO] Reconstruction script reconstructionv3.py завершился успешно.


[34m[1mwandb[0m: Adding directory to artifact (./outputs_MaskFormer_wo_postproc/checkpoints/checkpoint-5000)... Done. 2.2s


[INFO] Reconstruction script reconstructionv3.py завершился успешно.


[34m[1mwandb[0m: Adding directory to artifact (./outputs_MaskFormer_wo_postproc/checkpoints/checkpoint-6000)... Done. 2.1s


[INFO] Reconstruction script reconstructionv3.py завершился успешно.


[34m[1mwandb[0m: Adding directory to artifact (./outputs_MaskFormer_wo_postproc/checkpoints/checkpoint-7000)... Done. 2.1s


[INFO] Reconstruction script reconstructionv3.py завершился успешно.


[34m[1mwandb[0m: Adding directory to artifact (./outputs_MaskFormer_wo_postproc/checkpoints/checkpoint-8000)... Done. 2.1s


[INFO] Reconstruction script reconstructionv3.py завершился успешно.
