In [1]:
from transformers import AutoImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
import requests

In [2]:
model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name_or_path='nvidia/segformer-b0-finetuned-ade-512-512',
    num_channels=5,
    num_labels=13,
    ignore_mismatched_sizes=True
)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- segformer.encoder.patch_embeddings.0.proj.weight: found shape torch.Size([32, 3, 7, 7]) in the checkpoint and torch.Size([32, 5, 7, 7]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([13, 256, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([13]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
import os
import sys

sys.path.append('/Users/louis/Projects/00 - RosIA/flair-2/')

from torch.utils.data import DataLoader
from src.data.make_dataset import FLAIR2Dataset
from src.data.make_dataset import get_list_images

from src.constants import get_constants

cst = get_constants()

path_data = os.path.join('/Users/louis/Projects/00 - RosIA/flair-2/', cst.path_data_train)
list_images = get_list_images(path_data)

dataset = FLAIR2Dataset(
    list_images=list_images,
    sen_size=40,
    sen_temp_size=3,
    sen_temp_reduc='median',
    sen_list_bands=['2', '3', '4', '5', '6', '7', '8', '8a', '11', '12'],
    prob_cover=10,
    is_test=False,
)

dataloader = DataLoader(
    dataset=dataset,
    batch_size=2,
    shuffle=False,
)

for image_id, aerial, sen, labels in dataloader:
    break

In [10]:
aerial.shape

torch.Size([2, 5, 512, 512])

In [16]:
from torch import nn

outputs = model(pixel_values=aerial)
logits = outputs.logits

upsampled_logits = nn.functional.interpolate(
    logits, 
    size=aerial.shape[-2:], 
    mode='bilinear', 
    align_corners=False
)

In [17]:
upsampled_logits.shape

torch.Size([2, 13, 512, 512])