In [7]:
import os
import torch 

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

from datasets import load_metric
from transformers import SegformerImageProcessor
from transformers import SegformerForSemanticSegmentation

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset

In [8]:
class SemanticSegmentationDataset(Dataset):
    """Image (semantic) segmentation dataset."""

    def __init__(self, image_processor, train, img_dir = './data/imgs', masks_dir = './data/masks'):
        

        self.img_pathes = [ f'{img_dir}/{item}' for item in sorted(os.listdir(img_dir)) ]
        self.img_pathes = [item for item in self.img_pathes if item[-3:] == 'jpg']
        self.mask_pathes = [ f'{masks_dir}/{item}' for item in sorted(os.listdir(masks_dir)) ]
        # Евалюсь на последних 20 чтоб не ебаться с шаффлами
        if train == False: 
            self.img_pathes = self.img_pathes[:20]
            self.mask_pathes = self.mask_pathes[:20]
        
        self.image_processor = image_processor

        assert len(self.img_pathes) == len(self.mask_pathes), "There must be as many images as there are segmentation maps"

    def __len__(self):
        return len(self.img_pathes)

    def __getitem__(self, idx):

        mask = np.load(self.mask_pathes[idx])
        image = np.array( Image.open(self.img_pathes[idx])) 

        encoded_inputs = self.image_processor(image, mask, return_tensors="pt")

        for k,v in encoded_inputs.items():
            encoded_inputs[k].squeeze_() # remove batch dimension

        return encoded_inputs
     

In [11]:
image_processor = SegformerImageProcessor(reduce_labels=False)
train_ds = SemanticSegmentationDataset(image_processor, True)
eval_ds = SemanticSegmentationDataset(image_processor, False)

In [12]:
train_dataloader = DataLoader(train_ds, batch_size=8, shuffle=True)
eval_dataloader = DataLoader(eval_ds, batch_size=8, shuffle=True)

In [13]:
model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                         num_labels=6,
                                                       #  id2label=id2label,
                                                       #  label2id=label2id,
)
model.to('cuda')

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

In [14]:
log_path = './logs'
writer = SummaryWriter(log_path)
metric = load_metric("mean_iou")

  metric = load_metric("mean_iou")


In [None]:
import torch
from tqdm import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

train_loss_iter = []
train_loss_epoch = []
eval_iou = []
eval_acc = []
eval_loss = []
model.train()

for epoch in tqdm(range(200)):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    curr_epoch_loss = []
    curr_epoch_eval_loss = []
    for idx, batch in enumerate(train_dataloader):
        
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits

        loss.backward()
        optimizer.step()
        #print()
        train_loss_iter.append(loss.item())
        curr_epoch_loss.append(loss.item())
        writer.add_scalar("Train/loss_step", train_loss_iter[-1], idx + epoch * len(train_dataloader))
        writer.add_scalar("Train/epoch", epoch + 1, idx + epoch * len(train_dataloader))
    train_loss_epoch.append(sum(curr_epoch_loss) / len(curr_epoch_loss))   
    writer.add_scalar("Train/loss_epoch", train_loss_epoch[-1], epoch + 1)
    with torch.no_grad():
        
        for batch in eval_dataloader:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss, logits = outputs.loss, outputs.logits
            curr_epoch_eval_loss.append(outputs.loss.item())
            upsampled_logits = torch.nn.functional.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            predicted = upsampled_logits.argmax(dim=1)
            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels.detach().cpu().numpy())
            
        metrics = metric.compute(num_labels=6, 
                                       ignore_index=255,
                                       reduce_labels=False, # we've already reduced the labels before)
            )
        eval_iou.append( metrics["mean_iou"])
        eval_acc.append(metrics["mean_accuracy"])
        eval_loss.append(sum(curr_epoch_eval_loss) / len(eval_dataloader))
        
        writer.add_scalar("Eval/loss",eval_loss[-1], epoch + 1)
        writer.add_scalar("Eval/Accuracy", metrics["mean_accuracy"], epoch + 1)
        writer.add_scalar("Eval/IoU", metrics["mean_iou"], epoch + 1)
        
        print("Mean_iou:", metrics["mean_iou"])
        print("Loss:", train_loss_epoch[-1])
        print("Mean accuracy:", metrics["mean_accuracy"])

  0%|                                                   | 0/200 [00:00<?, ?it/s]

Epoch: 0


  0%|▏                                       | 1/200 [02:09<7:08:12, 129.11s/it]

Mean_iou: 0.2652821412123574
Loss: 1.636090874671936
Mean accuracy: 0.35721163772084336
Epoch: 1


  1%|▍                                       | 2/200 [04:08<6:46:45, 123.26s/it]

Mean_iou: 0.2684602946840831
Loss: 1.3528793183240024
Mean accuracy: 0.34861102911303227
Epoch: 2


In [16]:
!nvidia-smi

Mon Dec 16 23:42:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce GTX 1650 Ti     Off | 00000000:01:00.0  On |                  N/A |
| N/A   67C    P8               4W /  50W |   3896MiB /  4096MiB |      2%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [17]:
!kill -9 14330