### Load libraries:

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from datasets import load_dataset, DatasetDict
from evaluate import load
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cpu


In [6]:
repo_id = "peaceAsh/fashion-sam-dataset"
model_checkpoint = "nvidia/segformer-b0-finetuned-ade-512-512"

id2label = {0: "background", 1: "fashion"}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)

In [7]:
def load_and_split_dataset(repo_id):
    dataset = load_dataset(repo_id)
    train_test_split = dataset["train"].train_test_split(test_size=0.15, seed=42)
    val_test_split = train_test_split["test"].train_test_split(test_size=(2/3), seed=42)
    return DatasetDict({
        "train": train_test_split["train"],
        "validation": val_test_split["train"],
        "test": val_test_split["test"]
    })

fashion_dataset = load_and_split_dataset(repo_id)
processor = SegformerImageProcessor.from_pretrained(model_checkpoint, do_reduce_labels=False)

  image_processor = cls(**image_processor_dict)


In [8]:
fashion_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 212
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 12
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 26
    })
})

In [9]:
processor

SegformerImageProcessor {
  "do_normalize": true,
  "do_reduce_labels": false,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "SegformerImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 512,
    "width": 512
  }
}

In [10]:
class FashionSegDataset(Dataset):
    def __init__(self,dataset,processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)
     
    def __getitem__(self,idx):
        item = self.dataset[idx]
        image = item['image'].convert("RGB")
        mask = item['label'].convert("L")
        
        # Binarize the mask
        mask = np.array(mask)
        mask = (mask > 0).astype(np.uint8)

        inputs = self.processor(
            images=[image], 
            segmentation_maps=[mask], 
            return_tensors="pt"
        )
        
        pixel_values = inputs['pixel_values'].squeeze(0)
        labels = inputs['labels'].squeeze(0).long() 
        
        return pixel_values, labels

In [11]:
train_dataset = FashionSegDataset(fashion_dataset['train'], processor)
val_dataset = FashionSegDataset(fashion_dataset['validation'], processor)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)

In [20]:
model = SegformerForSemanticSegmentation.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
).to(device)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([2]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([2, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-5)

In [22]:
num_epochs = 1
writer = SummaryWriter('runs/segformer_fashion_pytorch')
metric = load("mean_iou")
best_iou = -1.0

Downloading builder script: 12.9kB [00:00, 13.9MB/s]


In [None]:
for epoch in tqdm(range(num_epochs),desc="Overall Training Progress"):
    # Training 
    model.train()
    train_loss = 0.0
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for pixel_values, labels in train_pbar:
        pixel_values = pixel_values.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_pbar.set_postfix(loss=f"{loss.item():.4f}")
    
    avg_train_loss = train_loss / len(train_loader)
    writer.add_scalar('Loss/train', avg_train_loss, epoch)
    print(f"Epoch {epoch+1} - Average Training Loss: {avg_train_loss:.4f}")

    #  Validation 
    model.eval()
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
        for pixel_values, labels in val_pbar:
            pixel_values = pixel_values.to(device)
            labels = labels.to(device)
            
            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits

            # Upsample logits to match label size for metric calculation
            upsampled_logits = torch.nn.functional.interpolate(
                logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
            )
            predicted = upsampled_logits.argmax(dim=1)
            
            metric.add_batch(predictions=predicted.cpu().numpy(), references=labels.cpu().numpy())
    
    metrics = metric.compute(num_labels=num_labels, ignore_index=255, reduce_labels=False)
    mean_iou = metrics['mean_iou']
    writer.add_scalar('mIoU/validation', mean_iou, epoch)
    print(f"Epoch {epoch+1} - Validation mIoU: {mean_iou:.4f}")

    # Save the best model
    if mean_iou > best_iou:
        best_iou = mean_iou
        torch.save(model.state_dict(), "segformer_best.pth")
        print(f"New best model saved with mIoU: {best_iou:.4f}")

writer.close()