In [1]:
from transformers import MaskFormerForInstanceSegmentation, MaskFormerImageProcessor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from evaluate import load

import sys
sys.path.append('..')

from training.dataset import SemanticSegmentationDataset
from training.trainer import SegmenterModeltrainer

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
idx2label = ['Barren', 'Forest', 'Agriculture', 'Road', 'Building', 'Water']

processor = MaskFormerImageProcessor(reduce_labels=True, ignore_index=255, do_resize=False, do_rescale=False, do_normalize=False)
model = MaskFormerForInstanceSegmentation.from_pretrained(
    "facebook/maskformer-swin-base-ade",
    id2label=idx2label,
    ignore_mismatched_sizes=True
)

Some weights of MaskFormerForInstanceSegmentation were not initialized from the model checkpoint at facebook/maskformer-swin-base-ade and are newly initialized because the shapes did not match:
- class_predictor.weight: found shape torch.Size([151, 256]) in the checkpoint and torch.Size([7, 256]) in the model instantiated
- class_predictor.bias: found shape torch.Size([151]) in the checkpoint and torch.Size([7]) in the model instantiated
- criterion.empty_weight: found shape torch.Size([151]) in the checkpoint and torch.Size([7]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
img_dir = '/home/a-ploskin/repos/TerraLabel/data/task_0/data'
masks_dir = '/home/a-ploskin/repos/TerraLabel/data/masks'

train_ds, eval_ds = SemanticSegmentationDataset.get_train_and_eval_datasets(
    processor, img_dir, masks_dir
)

In [4]:
import torch
import torch.nn.functional as F

def custom_collate_fn(batch):
    # Find the maximum height and width in the batch
    max_height = max(item["pixel_values"].shape[-2] for item in batch)
    max_width = max(item["pixel_values"].shape[-1] for item in batch)

    # Create lists to store padded pixel values and labels
    pixel_values_list = []
    labels_list = []

    for item in batch:
        # Calculate padding dimensions
        pad_height = max_height - item["pixel_values"].shape[-2]
        pad_width = max_width - item["pixel_values"].shape[-1]

        # Pad pixel values and labels
        padded_pixel_values = F.pad(item["pixel_values"], (0, pad_width, 0, pad_height))
        padded_labels = F.pad(item["labels"], (0, pad_width, 0, pad_height), value=255)  # Use 255 as ignore_index

        # Append padded tensors to lists
        pixel_values_list.append(padded_pixel_values)
        labels_list.append(padded_labels)

    # Stack all pixel values and labels into tensors
    pixel_values_batch = torch.stack(pixel_values_list)
    labels_batch = torch.stack(labels_list)

    return {"pixel_values": pixel_values_batch, "labels": labels_batch}


In [5]:
train_dataloader = DataLoader(train_ds, batch_size=2, shuffle=True, drop_last=True, collate_fn=custom_collate_fn)
eval_dataloader = DataLoader(eval_ds, batch_size=2, shuffle=True, drop_last=True, collate_fn=custom_collate_fn)

In [6]:
log_path = './logs'
writer = SummaryWriter(log_path)
metric = load("mean_iou")

In [7]:
from datetime import datetime


trainer = SegmenterModeltrainer(
    model=model,
    device='cuda:1',
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    metric=metric,
    writer=writer
)

file_path = f'./models/maskformer_{datetime.now()}'

segmenter = trainer.train(file_path=file_path, n_epochs=50)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 0


  0%|          | 0/50 [00:00<?, ?it/s]


KeyError: 'labels'