In [1]:
from datasets import load_dataset
from labels import labels, Label
from transformers import AutoImageProcessor
from torchvision.transforms import ColorJitter, ToTensor
import numpy as np 
import evaluate
import numpy as np
import torch
from torch import nn
from transformers import AutoModelForSemanticSegmentation, AdamW, get_scheduler
import wandb
import os
import random
from tqdm import tqdm

import config
from dataset import Dataset
from visualization import visualize_samples, visualize_mask

def compute_metrics(metric, num_labels):
    with torch.no_grad():
        metrics = metric.compute(
            num_labels=num_labels,
            ignore_index=0,
            reduce_labels=False
        )
        for key, value in metrics.items():
            if type(value) is np.ndarray:
                metrics[key] = value.tolist()
        return metrics
    
def add_batch_to_metrics(metric, logits, labels):
    with torch.no_grad():
        logits = nn.functional.interpolate(
            logits,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)
        pred_labels = logits.detach().cpu().numpy()
        labels = labels.detach().cpu().numpy()
        print(f"Labels: {labels.shape}, {np.unique(labels)}")
        print(f"Preds: {pred_labels.shape}, {np.unique(pred_labels)}")
        metric.add_batch(
            predictions=pred_labels,
            references=labels
        )
    return pred_labels
    
def validate(model, metric, eval_ds, id2labels, num_images_to_log, device):
    num_labels = len(id2labels)
    model.eval()
    eval_loss = 0
    progress_bar = tqdm(range(len(eval_ds)), desc=f"Evaluating")
    image_log_dict = {}
    for i, batch in enumerate(eval_ds):
        print(f"Eval {i}:")
        input_image = batch["pixel_values"].to(device)
        gt_labels = batch["labels"].to(device)
        for label in gt_labels:
            print(f"Labels: {label.shape}, {np.unique(label)}")
        with torch.no_grad():
            outputs = model(pixel_values=input_image, labels=gt_labels)
            loss = outputs.loss
            eval_loss += loss.item()
            pred_labels = add_batch_to_metrics(metric, outputs.logits, gt_labels)
        # if i == 0:
        #     input_image = input_image.cpu().numpy()
        #     gt_labels = gt_labels.cpu().numpy()
        #     for i in range(min(gt_labels.shape[0], num_images_to_log)):
        #         image_log_dict[f"eval/image_{i}"] = wandb.Image(np.transpose(input_image[i], (1, 2, 0)), masks={
        #             "predictions" : {
        #                 "mask_data" : pred_labels[i],
        #                 "class_labels" : id2labels
        #             },
        #             "ground_truth" : {
        #                 "mask_data" : gt_labels[i],
        #                 "class_labels" : id2labels
        #             }
        #         })
        progress_bar.update(1)
    metric_results = compute_metrics(metric, num_labels)
    print(metric_results)
    result_dict = {
        "eval/loss": eval_loss / len(eval_ds),
        "eval/mIoU": metric_results["mean_iou"],
        "eval/mean_acc": metric_results["mean_accuracy"],
        "eval/overall_acc": metric_results["overall_accuracy"],
    }
    result_dict.update(image_log_dict)
    return result_dict


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

dataset = Dataset(config.checkpoint, config.image_size, config.batch_size, config.to_sample, config.sample_size)
test_ds = dataset.get_test_dataloader()
print("CHECK TEST DATASET")
example = next(iter(test_ds))
ex_image = example["pixel_values"]
ex_labels = example["labels"]
print(type(ex_image))
print(ex_image.shape)
print(type(ex_labels))
print(ex_labels.shape)
print(torch.unique(ex_labels[0]))
num_labels = dataset.get_num_labels()

model = AutoModelForSemanticSegmentation.from_pretrained("segformer-b0-cityscapes/nvidia-mit-b0_1", id2label=dataset.id2label, label2id=dataset.label2id)
# prepare training device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()

metric = evaluate.load("mean_iou")
eval_results = validate(model, metric, test_ds, dataset.id2label, config.eval_images_to_log, device)
print(eval_results)

Found cached dataset parquet (/home/dejang/.cache/huggingface/datasets/Chris1___parquet/Chris1--cityscapes-2bd50e1e8cc703b7/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
100%|██████████| 3/3 [00:00<00:00, 630.28it/s]


Dataset({
    features: ['image', 'semantic_segmentation'],
    num_rows: 2975
})
Dataset({
    features: ['image', 'semantic_segmentation'],
    num_rows: 500
})
Dataset({
    features: ['image', 'semantic_segmentation'],
    num_rows: 1525
})
Dataset is sampled
Example from dataset class
(1024, 2048, 3)
(1024, 2048, 3)
[0 1 3]


Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


Validation example from dataset class
<class 'torch.Tensor'>
torch.Size([2, 3, 512, 512])
<class 'torch.Tensor'>
torch.Size([2, 512, 512])
tensor([ 0,  1,  2,  3,  5,  6,  8,  9, 11, 12, 14])
CHECK TEST DATASET
<class 'torch.Tensor'>
torch.Size([2, 3, 512, 512])
<class 'torch.Tensor'>
torch.Size([2, 512, 512])
tensor([0])


Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Eval 0:
Labels: torch.Size([512, 512]), [0]
Labels: torch.Size([512, 512]), [0]
Labels: (2, 512, 512), [0]
Preds: (2, 512, 512), [ 0  1  2  3  9 11 14]


KeyboardInterrupt: 