In [12]:
import requests, zipfile, io
from datasets import load_dataset

In [13]:
from torch.utils.data import Dataset
import os
from PIL import Image

class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, root_dir, feature_extractor, train=True):
        """
        Args:
            root_dir (string): Root directory of the dataset containing the images + annotations.
            feature_extractor (SegFormerFeatureExtractor): feature extractor to prepare images + segmentation maps.
            train (bool): Whether to load "training" or "validation" images + annotations.
        """
        self.root_dir = root_dir
        self.feature_extractor = feature_extractor
        self.train = train

        sub_path = "train" if self.train else "validation"
        self.img_dir = os.path.join(self.root_dir, sub_path, "rgb")
        self.ann_dir = os.path.join(self.root_dir, sub_path, "labels")
        
        # read images
        image_file_names = []
        for root, dirs, files in os.walk(self.img_dir):
          image_file_names.extend(files)
        self.images = sorted(image_file_names)
        
        # read annotations
        annotation_file_names = []
        for root, dirs, files in os.walk(self.ann_dir):
          annotation_file_names.extend(files)
        self.annotations = sorted(annotation_file_names)

        assert len(self.images) == len(self.annotations), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        
        image = Image.open(os.path.join(self.img_dir, self.images[idx]))
        segmentation_map = Image.open(os.path.join(self.ann_dir, self.annotations[idx]))

        # randomly crop + pad both image and segmentation map to same size
        encoded_inputs = self.feature_extractor(image, segmentation_map, return_tensors="pt")

        for k,v in encoded_inputs.items():
          encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs

In [32]:
from transformers import SegformerFeatureExtractor

root_dir = '/home/klimenko/seg_materials/VAL_SEGFORMER/data/4/'# '/home/klimenko/facade_materials/materials/'
feature_extractor = SegformerFeatureExtractor(reduce_labels=True)

train_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor)
valid_dataset = SemanticSegmentationDataset(root_dir=root_dir, feature_extractor=feature_extractor, train=False)



In [33]:
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(valid_dataset))

Number of training examples: 121
Number of validation examples: 23


In [34]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1)

In [35]:
from transformers import SegformerForSemanticSegmentation
import json
from huggingface_hub import cached_download, hf_hub_url, hf_hub_download

# load id2label mapping from a JSON on the hub
id2label = json.load(open('materials.json'))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

# define model
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=7, 
                                                         id2label=id2label, 
                                                         label2id=label2id,
)


  return torch.load(checkpoint_file, map_location="cpu")
Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.3.proj.weight', 'decode_head.linear_c.0.proj.weight', 'decode_head.classifier.weight', 'decode_head.linear_c.2.proj.bias

In [36]:
from datasets import load_metric
metric = load_metric("mean_iou")

In [37]:
import torch
import numpy as np
from torch import nn
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
# move model to GPU
# os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


for epoch in range(12):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    loss_list = []
    model.train()

    
    for idx, batch in enumerate(tqdm(train_dataloader)):
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        
        #print(pixel_values.shape)
        #print(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits
        print(idx)
        
        loss.backward()
        optimizer.step()

        # evaluate
#         with torch.no_grad():
#             upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
#             predicted = upsampled_logits.argmax(dim=1)
#             metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())
#             metrics = metric.compute(num_labels=len(id2label), ignore_index=255,reduce_labels=False)
            
#             if idx % 1 == 0:
#                 print("idx: ", str(idx)," Loss:", str(loss.item())[0:7]," Mean_iou:", str(metrics["mean_iou"])[0:5], " Mean accuracy:", str(metrics["mean_accuracy"])[0:5] )
#         loss_list.append(loss.item())
        
#     loss_value = np.mean(np.array(loss_list))  
#     print('loss mean ', loss_value)
    #model.save_pretrained("weights/fold_0_"+str(epoch)+"_ep_"+str(loss_value)[0:4]+".pth")
    
    model.eval()
    
    for idx, batch in enumerate(tqdm(valid_dataloader)):
        # get the inputs;
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits
        
#         loss.backward()
#         optimizer.step()

        # evaluate
        with torch.no_grad():
            upsampled_logits = nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            predicted = upsampled_logits.argmax(dim=1)
            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())
            metrics = metric.compute(num_labels=len(id2label), ignore_index=255,reduce_labels=False)
            
            if idx % 1 == 0:
                print("idx: ", str(idx)," Loss:", str(loss.item())[0:7]," Mean_iou:", str(metrics["mean_iou"])[0:5], " Mean accuracy:", str(metrics["mean_accuracy"])[0:5] )
        loss_list.append(loss.item())
        
    loss_value = np.mean(np.array(loss_list))  
    print('loss mean ', loss_value)
    model.save_pretrained("weights/fold_4_"+str(epoch)+"_ep_"+str(loss_value)[0:4]+".pth")

Epoch: 0


  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

  acc = total_area_intersect / total_area_label


idx:  0  Loss: 1.88071  Mean_iou: 0.062  Mean accuracy: 0.223
idx:  1  Loss: 1.69356  Mean_iou: 0.364  Mean accuracy: 0.471
idx:  2  Loss: 1.76160  Mean_iou: 0.221  Mean accuracy: 0.372
idx:  3  Loss: 1.92636  Mean_iou: 0.022  Mean accuracy: 0.113
idx:  4  Loss: 1.81595  Mean_iou: 0.105  Mean accuracy: 0.185
idx:  5  Loss: 1.83824  Mean_iou: 0.056  Mean accuracy: 0.175


  iou = total_area_intersect / total_area_union


idx:  6  Loss: 1.72436  Mean_iou: 0.179  Mean accuracy: 0.284
idx:  7  Loss: 1.75097  Mean_iou: 0.189  Mean accuracy: 0.318
idx:  8  Loss: 1.67079  Mean_iou: 0.170  Mean accuracy: 0.418
idx:  9  Loss: 1.79329  Mean_iou: 0.105  Mean accuracy: 0.226
idx:  10  Loss: 1.90454  Mean_iou: 0.028  Mean accuracy: 0.191
idx:  11  Loss: 1.82430  Mean_iou: 0.084  Mean accuracy: 0.196
idx:  12  Loss: 1.63593  Mean_iou: 0.222  Mean accuracy: 0.325
idx:  13  Loss: 1.72226  Mean_iou: 0.198  Mean accuracy: 0.273
idx:  14  Loss: 1.75188  Mean_iou: 0.088  Mean accuracy: 0.170
idx:  15  Loss: 1.72929  Mean_iou: 0.198  Mean accuracy: 0.301
idx:  16  Loss: 1.73056  Mean_iou: 0.206  Mean accuracy: 0.295
idx:  17  Loss: 1.82725  Mean_iou: 0.143  Mean accuracy: 0.284
idx:  18  Loss: 1.50481  Mean_iou: 0.230  Mean accuracy: 0.348
idx:  19  Loss: 1.62975  Mean_iou: 0.259  Mean accuracy: 0.378
idx:  20  Loss: 1.87397  Mean_iou: 0.041  Mean accuracy: 0.175
idx:  21  Loss: 1.73183  Mean_iou: 0.179  Mean accuracy: 0.

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 1.63332  Mean_iou: 0.065  Mean accuracy: 0.160
idx:  1  Loss: 1.18432  Mean_iou: 0.432  Mean accuracy: 0.466
idx:  2  Loss: 1.18535  Mean_iou: 0.287  Mean accuracy: 0.364
idx:  3  Loss: 1.83413  Mean_iou: 0.053  Mean accuracy: 0.173
idx:  4  Loss: 1.51373  Mean_iou: 0.129  Mean accuracy: 0.206
idx:  5  Loss: 1.59234  Mean_iou: 0.057  Mean accuracy: 0.178
idx:  6  Loss: 1.24090  Mean_iou: 0.291  Mean accuracy: 0.407
idx:  7  Loss: 1.23742  Mean_iou: 0.239  Mean accuracy: 0.333
idx:  8  Loss: 1.21257  Mean_iou: 0.229  Mean accuracy: 0.409
idx:  9  Loss: 1.48915  Mean_iou: 0.125  Mean accuracy: 0.205
idx:  10  Loss: 1.80579  Mean_iou: 0.028  Mean accuracy: 0.195
idx:  11  Loss: 1.57979  Mean_iou: 0.094  Mean accuracy: 0.208
idx:  12  Loss: 0.99734  Mean_iou: 0.247  Mean accuracy: 0.348
idx:  13  Loss: 1.25423  Mean_iou: 0.286  Mean accuracy: 0.309
idx:  14  Loss: 1.34238  Mean_iou: 0.111  Mean accuracy: 0.174
idx:  15  Loss: 1.36347  Mean_iou: 0.194  Mean accuracy: 0.299
id

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 1.50346  Mean_iou: 0.066  Mean accuracy: 0.165
idx:  1  Loss: 0.91569  Mean_iou: 0.367  Mean accuracy: 0.484
idx:  2  Loss: 0.79797  Mean_iou: 0.288  Mean accuracy: 0.350
idx:  3  Loss: 1.72221  Mean_iou: 0.063  Mean accuracy: 0.197
idx:  4  Loss: 1.28318  Mean_iou: 0.133  Mean accuracy: 0.214
idx:  5  Loss: 1.47254  Mean_iou: 0.076  Mean accuracy: 0.182
idx:  6  Loss: 1.12622  Mean_iou: 0.290  Mean accuracy: 0.413
idx:  7  Loss: 0.84703  Mean_iou: 0.252  Mean accuracy: 0.330
idx:  8  Loss: 1.05894  Mean_iou: 0.202  Mean accuracy: 0.425
idx:  9  Loss: 1.50060  Mean_iou: 0.127  Mean accuracy: 0.228
idx:  10  Loss: 1.68641  Mean_iou: 0.035  Mean accuracy: 0.203
idx:  11  Loss: 1.40996  Mean_iou: 0.101  Mean accuracy: 0.210
idx:  12  Loss: 0.76526  Mean_iou: 0.229  Mean accuracy: 0.324
idx:  13  Loss: 0.98063  Mean_iou: 0.294  Mean accuracy: 0.323
idx:  14  Loss: 1.20064  Mean_iou: 0.112  Mean accuracy: 0.176
idx:  15  Loss: 1.23475  Mean_iou: 0.193  Mean accuracy: 0.301
id

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 1.40290  Mean_iou: 0.078  Mean accuracy: 0.162
idx:  1  Loss: 0.76001  Mean_iou: 0.424  Mean accuracy: 0.492
idx:  2  Loss: 0.72790  Mean_iou: 0.330  Mean accuracy: 0.383
idx:  3  Loss: 1.42593  Mean_iou: 0.200  Mean accuracy: 0.294
idx:  4  Loss: 1.16194  Mean_iou: 0.180  Mean accuracy: 0.283
idx:  5  Loss: 1.46400  Mean_iou: 0.083  Mean accuracy: 0.186
idx:  6  Loss: 0.80232  Mean_iou: 0.337  Mean accuracy: 0.415
idx:  7  Loss: 0.67438  Mean_iou: 0.268  Mean accuracy: 0.340
idx:  8  Loss: 0.99290  Mean_iou: 0.291  Mean accuracy: 0.412
idx:  9  Loss: 1.20648  Mean_iou: 0.140  Mean accuracy: 0.207
idx:  10  Loss: 1.68138  Mean_iou: 0.042  Mean accuracy: 0.214
idx:  11  Loss: 1.38366  Mean_iou: 0.099  Mean accuracy: 0.211
idx:  12  Loss: 0.65351  Mean_iou: 0.252  Mean accuracy: 0.315
idx:  13  Loss: 0.82482  Mean_iou: 0.288  Mean accuracy: 0.312
idx:  14  Loss: 1.20526  Mean_iou: 0.103  Mean accuracy: 0.169
idx:  15  Loss: 1.22716  Mean_iou: 0.206  Mean accuracy: 0.299
id

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 1.14096  Mean_iou: 0.130  Mean accuracy: 0.201
idx:  1  Loss: 0.84205  Mean_iou: 0.398  Mean accuracy: 0.475
idx:  2  Loss: 0.65945  Mean_iou: 0.304  Mean accuracy: 0.335
idx:  3  Loss: 1.16254  Mean_iou: 0.254  Mean accuracy: 0.290
idx:  4  Loss: 1.06846  Mean_iou: 0.172  Mean accuracy: 0.272
idx:  5  Loss: 1.33991  Mean_iou: 0.105  Mean accuracy: 0.205
idx:  6  Loss: 0.81577  Mean_iou: 0.342  Mean accuracy: 0.431
idx:  7  Loss: 0.55385  Mean_iou: 0.293  Mean accuracy: 0.341
idx:  8  Loss: 0.91573  Mean_iou: 0.252  Mean accuracy: 0.435
idx:  9  Loss: 1.29113  Mean_iou: 0.137  Mean accuracy: 0.219
idx:  10  Loss: 1.56358  Mean_iou: 0.062  Mean accuracy: 0.244
idx:  11  Loss: 1.30811  Mean_iou: 0.099  Mean accuracy: 0.201
idx:  12  Loss: 0.61084  Mean_iou: 0.264  Mean accuracy: 0.324
idx:  13  Loss: 0.86611  Mean_iou: 0.218  Mean accuracy: 0.304
idx:  14  Loss: 1.20781  Mean_iou: 0.102  Mean accuracy: 0.168
idx:  15  Loss: 1.15037  Mean_iou: 0.206  Mean accuracy: 0.302
id

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 1.18139  Mean_iou: 0.099  Mean accuracy: 0.177
idx:  1  Loss: 0.62814  Mean_iou: 0.410  Mean accuracy: 0.450
idx:  2  Loss: 0.62548  Mean_iou: 0.281  Mean accuracy: 0.335
idx:  3  Loss: 0.98928  Mean_iou: 0.244  Mean accuracy: 0.287
idx:  4  Loss: 1.17607  Mean_iou: 0.124  Mean accuracy: 0.205
idx:  5  Loss: 1.11753  Mean_iou: 0.185  Mean accuracy: 0.258
idx:  6  Loss: 0.82190  Mean_iou: 0.331  Mean accuracy: 0.448
idx:  7  Loss: 0.45623  Mean_iou: 0.292  Mean accuracy: 0.332
idx:  8  Loss: 0.85144  Mean_iou: 0.327  Mean accuracy: 0.444
idx:  9  Loss: 1.30304  Mean_iou: 0.136  Mean accuracy: 0.224
idx:  10  Loss: 1.39055  Mean_iou: 0.108  Mean accuracy: 0.299
idx:  11  Loss: 1.25079  Mean_iou: 0.120  Mean accuracy: 0.206
idx:  12  Loss: 0.56831  Mean_iou: 0.298  Mean accuracy: 0.348
idx:  13  Loss: 0.70398  Mean_iou: 0.230  Mean accuracy: 0.270
idx:  14  Loss: 1.12313  Mean_iou: 0.108  Mean accuracy: 0.170
idx:  15  Loss: 1.10709  Mean_iou: 0.210  Mean accuracy: 0.305
id

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 1.20328  Mean_iou: 0.126  Mean accuracy: 0.223
idx:  1  Loss: 0.69467  Mean_iou: 0.361  Mean accuracy: 0.440
idx:  2  Loss: 0.74884  Mean_iou: 0.286  Mean accuracy: 0.353
idx:  3  Loss: 0.85941  Mean_iou: 0.264  Mean accuracy: 0.309
idx:  4  Loss: 1.03948  Mean_iou: 0.172  Mean accuracy: 0.272
idx:  5  Loss: 1.15209  Mean_iou: 0.164  Mean accuracy: 0.251
idx:  6  Loss: 0.80353  Mean_iou: 0.352  Mean accuracy: 0.442
idx:  7  Loss: 0.46748  Mean_iou: 0.299  Mean accuracy: 0.342
idx:  8  Loss: 0.85084  Mean_iou: 0.332  Mean accuracy: 0.448
idx:  9  Loss: 1.26012  Mean_iou: 0.137  Mean accuracy: 0.215
idx:  10  Loss: 1.65822  Mean_iou: 0.064  Mean accuracy: 0.251
idx:  11  Loss: 1.30668  Mean_iou: 0.109  Mean accuracy: 0.221
idx:  12  Loss: 0.59558  Mean_iou: 0.263  Mean accuracy: 0.322
idx:  13  Loss: 0.71419  Mean_iou: 0.244  Mean accuracy: 0.323
idx:  14  Loss: 1.10146  Mean_iou: 0.107  Mean accuracy: 0.176
idx:  15  Loss: 1.03146  Mean_iou: 0.213  Mean accuracy: 0.304
id

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 0.95171  Mean_iou: 0.203  Mean accuracy: 0.266
idx:  1  Loss: 0.65160  Mean_iou: 0.395  Mean accuracy: 0.454
idx:  2  Loss: 0.77571  Mean_iou: 0.351  Mean accuracy: 0.410
idx:  3  Loss: 0.48536  Mean_iou: 0.269  Mean accuracy: 0.304
idx:  4  Loss: 0.96448  Mean_iou: 0.192  Mean accuracy: 0.303
idx:  5  Loss: 0.92717  Mean_iou: 0.195  Mean accuracy: 0.262
idx:  6  Loss: 0.62898  Mean_iou: 0.370  Mean accuracy: 0.448
idx:  7  Loss: 0.43431  Mean_iou: 0.317  Mean accuracy: 0.358
idx:  8  Loss: 0.84059  Mean_iou: 0.328  Mean accuracy: 0.444
idx:  9  Loss: 1.19567  Mean_iou: 0.143  Mean accuracy: 0.220
idx:  10  Loss: 1.48753  Mean_iou: 0.092  Mean accuracy: 0.261
idx:  11  Loss: 1.02938  Mean_iou: 0.180  Mean accuracy: 0.239
idx:  12  Loss: 0.57463  Mean_iou: 0.270  Mean accuracy: 0.328
idx:  13  Loss: 0.71370  Mean_iou: 0.217  Mean accuracy: 0.283
idx:  14  Loss: 1.11029  Mean_iou: 0.107  Mean accuracy: 0.170
idx:  15  Loss: 0.95488  Mean_iou: 0.251  Mean accuracy: 0.328
id

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 1.27677  Mean_iou: 0.090  Mean accuracy: 0.167
idx:  1  Loss: 0.43105  Mean_iou: 0.413  Mean accuracy: 0.449
idx:  2  Loss: 0.38317  Mean_iou: 0.280  Mean accuracy: 0.336
idx:  3  Loss: 0.58265  Mean_iou: 0.222  Mean accuracy: 0.245
idx:  4  Loss: 1.32589  Mean_iou: 0.106  Mean accuracy: 0.185
idx:  5  Loss: 0.97682  Mean_iou: 0.190  Mean accuracy: 0.258
idx:  6  Loss: 0.77247  Mean_iou: 0.307  Mean accuracy: 0.438
idx:  7  Loss: 0.41628  Mean_iou: 0.301  Mean accuracy: 0.333
idx:  8  Loss: 0.78444  Mean_iou: 0.333  Mean accuracy: 0.451
idx:  9  Loss: 1.49985  Mean_iou: 0.132  Mean accuracy: 0.228
idx:  10  Loss: 1.78627  Mean_iou: 0.052  Mean accuracy: 0.269
idx:  11  Loss: 1.25724  Mean_iou: 0.147  Mean accuracy: 0.225
idx:  12  Loss: 0.49928  Mean_iou: 0.313  Mean accuracy: 0.360
idx:  13  Loss: 0.67959  Mean_iou: 0.228  Mean accuracy: 0.310
idx:  14  Loss: 1.06066  Mean_iou: 0.118  Mean accuracy: 0.188
idx:  15  Loss: 0.81382  Mean_iou: 0.358  Mean accuracy: 0.410
id

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 1.47317  Mean_iou: 0.105  Mean accuracy: 0.232
idx:  1  Loss: 0.65591  Mean_iou: 0.384  Mean accuracy: 0.489
idx:  2  Loss: 1.37404  Mean_iou: 0.247  Mean accuracy: 0.351
idx:  3  Loss: 0.42817  Mean_iou: 0.257  Mean accuracy: 0.288
idx:  4  Loss: 1.11647  Mean_iou: 0.252  Mean accuracy: 0.362
idx:  5  Loss: 0.95045  Mean_iou: 0.187  Mean accuracy: 0.251
idx:  6  Loss: 0.81876  Mean_iou: 0.317  Mean accuracy: 0.384
idx:  7  Loss: 0.65302  Mean_iou: 0.381  Mean accuracy: 0.445
idx:  8  Loss: 0.89267  Mean_iou: 0.326  Mean accuracy: 0.436
idx:  9  Loss: 1.40822  Mean_iou: 0.113  Mean accuracy: 0.165
idx:  10  Loss: 1.44373  Mean_iou: 0.093  Mean accuracy: 0.296
idx:  11  Loss: 0.97937  Mean_iou: 0.211  Mean accuracy: 0.266
idx:  12  Loss: 0.77918  Mean_iou: 0.258  Mean accuracy: 0.313
idx:  13  Loss: 1.59945  Mean_iou: 0.183  Mean accuracy: 0.282
idx:  14  Loss: 1.40522  Mean_iou: 0.112  Mean accuracy: 0.191
idx:  15  Loss: 0.91006  Mean_iou: 0.281  Mean accuracy: 0.354
id

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 1.06734  Mean_iou: 0.159  Mean accuracy: 0.198
idx:  1  Loss: 0.45372  Mean_iou: 0.422  Mean accuracy: 0.480
idx:  2  Loss: 0.57205  Mean_iou: 0.375  Mean accuracy: 0.460
idx:  3  Loss: 0.51124  Mean_iou: 0.277  Mean accuracy: 0.313
idx:  4  Loss: 1.15408  Mean_iou: 0.192  Mean accuracy: 0.310
idx:  5  Loss: 1.07283  Mean_iou: 0.180  Mean accuracy: 0.249
idx:  6  Loss: 0.71384  Mean_iou: 0.287  Mean accuracy: 0.350
idx:  7  Loss: 0.64581  Mean_iou: 0.315  Mean accuracy: 0.400
idx:  8  Loss: 0.85248  Mean_iou: 0.333  Mean accuracy: 0.447
idx:  9  Loss: 1.16473  Mean_iou: 0.146  Mean accuracy: 0.216
idx:  10  Loss: 1.74269  Mean_iou: 0.051  Mean accuracy: 0.253
idx:  11  Loss: 1.16751  Mean_iou: 0.167  Mean accuracy: 0.215
idx:  12  Loss: 0.65463  Mean_iou: 0.260  Mean accuracy: 0.319
idx:  13  Loss: 0.96731  Mean_iou: 0.242  Mean accuracy: 0.305
idx:  14  Loss: 1.25396  Mean_iou: 0.104  Mean accuracy: 0.174
idx:  15  Loss: 1.03438  Mean_iou: 0.204  Mean accuracy: 0.294
id

  0%|          | 0/13 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
10
11
12


  0%|          | 0/23 [00:00<?, ?it/s]

idx:  0  Loss: 0.58644  Mean_iou: 0.264  Mean accuracy: 0.290
idx:  1  Loss: 0.54912  Mean_iou: 0.401  Mean accuracy: 0.471
idx:  2  Loss: 0.40303  Mean_iou: 0.392  Mean accuracy: 0.436
idx:  3  Loss: 0.39193  Mean_iou: 0.261  Mean accuracy: 0.289
idx:  4  Loss: 0.97780  Mean_iou: 0.177  Mean accuracy: 0.278
idx:  5  Loss: 0.82842  Mean_iou: 0.195  Mean accuracy: 0.254
idx:  6  Loss: 0.68324  Mean_iou: 0.344  Mean accuracy: 0.431
idx:  7  Loss: 0.41664  Mean_iou: 0.342  Mean accuracy: 0.375
idx:  8  Loss: 0.72406  Mean_iou: 0.334  Mean accuracy: 0.449
idx:  9  Loss: 1.24964  Mean_iou: 0.139  Mean accuracy: 0.214
idx:  10  Loss: 1.67485  Mean_iou: 0.085  Mean accuracy: 0.256
idx:  11  Loss: 0.94927  Mean_iou: 0.199  Mean accuracy: 0.232
idx:  12  Loss: 0.57486  Mean_iou: 0.260  Mean accuracy: 0.320
idx:  13  Loss: 0.88788  Mean_iou: 0.217  Mean accuracy: 0.288
idx:  14  Loss: 1.17218  Mean_iou: 0.112  Mean accuracy: 0.189
idx:  15  Loss: 0.95373  Mean_iou: 0.248  Mean accuracy: 0.328
id

In [None]:
model.save_pretrained("fold_0_1ep.pth")

In [None]:
model = SegformerForSemanticSegmentation.from_pretrained("fold_0_5ep.pth")
model.to(device)

In [None]:

model = SegformerForSemanticSegmentation.from_pretrained("fold_0_5ep.pth")
model.to(device)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


for filename in os.listdir(root_dir+'validation/rgb/'):
    
    image_path = root_dir+'validation/rgb/'+filename
    image = Image.open(image_path).convert("RGB")

    inputs = feature_extractor(images=image, return_tensors="pt").to(device)
    resized_img = image.resize((128, 128))
    image_np = np.array(resized_img)
    
    model.eval()
    with torch.no_grad():
        outputs = model(inputs.pixel_values).logits
        upsampled_logits = nn.functional.interpolate(outputs,size=image.size[::-1],mode='bilinear',align_corners=False)
        print(upsampled_logits.shape)
        seg = upsampled_logits.cpu().argmax(dim=1)[0].numpy()
        replacement_dict = {0: 10, 1: 11, 2:12, 3:13, 4:16, 5:17, 6:0}
        seg2 = np.vectorize(replacement_dict.get)(seg)
        seg3 = np.stack([seg2] * 3, axis=-1)
        
        np.save('/home/klimenko/seg_materials/VAL_SEGFORMER/results/'+filename.replace(".png", ".npy"), seg2)
        
        
        #image = Image.fromarray(seg3.astype(np.uint8))
        #image.save('/home/klimenko/seg_materials/VAL_SEGFORMER/results/'+filename.replace(".png", ".jpg"))
    
    




In [None]:
image_path = "/home/klimenko/seg_materials/VAL_SEGFORMER/data/0/train/rgb/Copy of YAqsjUKCMc_DkjpPMtrIcQ_180.png"  # replace with your image path
image = Image.open(image_path).convert("RGB")

inputs = feature_extractor(images=image, return_tensors="pt").to(device)
resized_img = image.resize((128, 128))
image_np = np.array(resized_img)

In [None]:
image.size

In [None]:

upsampled_logits = nn.functional.interpolate(outputs,size=(1000, 1000)[::-1],mode='bilinear',align_corners=False)

In [None]:
upsampled_logits.shape

In [None]:
model.eval()
with torch.no_grad():
    outputs = model(inputs.pixel_values).logits
    upsampled_logits = nn.functional.interpolate(outputs,size=(1000, 1000)[::-1],mode='bilinear',align_corners=False)
    print(upsampled_logits.shape)
    seg = upsampled_logits.cpu().argmax(dim=1)[0].numpy()

In [None]:
seg.shape

In [None]:
import numpy as np
outputs_np = upsampled_logits.cpu().numpy()[0]

In [None]:
{
  "0": "nothing",
  "1": "10_glazing",
  "2": "11_concrete",
  "3": "12_masonry",
  "4": "13_siding",
  "5": "16_stucco",
  "6": "17_metal"
}




In [None]:
replacement_dict = {0: 10, 1: 11, 2:12, 3:13, 4:16, 5:17, 6:0}
seg2 = np.vectorize(replacement_dict.get)(seg)
seg3 = np.stack([seg2] * 3, axis=-1)

In [None]:
seg2

In [None]:
import cv2
image_path = "/home/klimenko/seg_materials/VAL_SEGFORMER/data/0/train/labels/Copy of YAqsjUKCMc_DkjpPMtrIcQ_180.png"
qq = cv2.imread(image_path)

In [None]:
%matplotlib notebook
plt.imshow(qq*10)

In [None]:
seg = outputs.logits.cpu().argmax(dim=1)[0].numpy()

In [None]:
print(np.unique(seg))