In [9]:
import requests, zipfile, io
from datasets import load_dataset

from torch.utils.data import Dataset
import os
from PIL import Image

class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir, feature_extractor, train=True):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            feature_extractor (SegFormerFeatureExtractor): feature extractor to prepare images + segmentation maps.
            train (bool): Whether to load "training" or "validation" images + annotations.
        """
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.train = train

        sub_path = "train" if self.train else "validation"
        self.img_dir = os.path.join(self.root_dir, sub_path, "rgb")
        self.ann_dir = os.path.join(self.root_dir, sub_path, "labels")
        
        # read images
        image_file_names = []
        for root, dirs, files in os.walk(self.img_dir):
          image_file_names.extend(files)
        self.images = sorted(image_file_names)
        
        # read annotations
        annotation_file_names = []
        for root, dirs, files in os.walk(self.ann_dir):
          annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        
        image = Image.open(os.path.join(self.img_dir, self.images[idx]))
        segmentation_map = Image.open(os.path.join(self.ann_dir, self.annotations[idx]))

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs
    
    


In [22]:
from transformers import SegformerFeatureExtractor

root_dir = '/home/klimenko/seg_materials/VAL_SEGFORMER/data/4/'# '/home/klimenko/facade_materials/materials/'
feature_extractor = SegformerFeatureExtractor(reduce_labels=True)

train_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor)
valid_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor, train=False)

In [23]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(valid_dataset))

Number of training examples: 121
Number of validation examples: 23


In [24]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1)

from transformers import SegformerForSemanticSegmentation
import json
from huggingface_hub import cached_download, hf_hub_url, hf_hub_download

# load id2label mapping from a JSON on the hub
id2label = json.load(open('materials.json'))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

# define model
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=7, 
                                                         id2label=id2label, 
                                                         label2id=label2id,
)


Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.3.proj.weight', 'decode_head.linear_c.2.proj.weight', 'decode_head.batch_norm.bias', 'decode_head.batch_norm.running_mean', 'decode_head.linear_c.1.proj.bias', 'decode_head.line

In [25]:
import torch
import numpy as np
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

model = SegformerForSemanticSegmentation.from_pretrained("weights/fold_4_10_ep_0.91.pth")



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for filename in os.listdir(root_dir+'validation/rgb/'):
    
    image_path = root_dir+'validation/rgb/'+filename
    image = Image.open(image_path).convert("RGB")

    inputs = feature_extractor(images=image, return_tensors="pt").to(device)
    resized_img = image.resize((128, 128))
    image_np = np.array(resized_img)
    
    model.eval()
    with torch.no_grad():
        outputs = model(inputs.pixel_values).logits
        upsampled_logits = nn.functional.interpolate(outputs,size=image.size[::-1],mode='bilinear',align_corners=False)
        print(upsampled_logits.shape)
        seg = upsampled_logits.cpu().argmax(dim=1)[0].numpy()
        replacement_dict = {0: 10, 1: 11, 2:12, 3:13, 4:16, 5:17, 6:0}
        seg2 = np.vectorize(replacement_dict.get)(seg)
        seg3 = np.stack([seg2] * 3, axis=-1)
        
        np.save('/home/klimenko/seg_materials/VAL_SEGFORMER/results/'+filename.replace(".png", ".npy"), seg2)
        

torch.Size([1, 7, 1000, 1000])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 1000, 1000])
torch.Size([1, 7, 1000, 1000])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 1000, 1000])
torch.Size([1, 7, 1000, 1000])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 921, 921])
torch.Size([1, 7, 2304, 3066])
torch.Size([1, 7, 921, 921])
