In [None]:
from datasets import load_dataset

food = load_dataset("food101", split="train[:5000]")

In [None]:
food = food.train_test_split(test_size=0.2)

In [None]:
food["train"][0]

import matplotlib.pyplot as plt
plt.imshow(food["train"][0]["image"])
plt.show()

In [None]:
labels = food["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [None]:
id2label[str(79)]

In [None]:
from transformers import AutoImageProcessor

checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [None]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [None]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [None]:
food = food.with_transform(transforms)

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

## Evaluate


In [None]:
import evaluate

accuracy = evaluate.load("accuracy")

In [None]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels) 

## Training

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

In [None]:
training_args = TrainingArguments(
    output_dir="my_awesome_food_model",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=food["train"],
    eval_dataset=food["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()

## Inference

In [None]:
ds = load_dataset("food101", split="validation[:10]")


In [None]:
from transformers import pipeline

import random
idx = random.randint(0, len(ds["image"])- 1)
image = ds["image"][idx]

plt.imshow(image)
plt.show()

classifier = pipeline("image-classification", model="my_awesome_food_model/checkpoint-186")
classifier(image)

# SegFormer - AutoLamella


TODO:
- Add id2label.json to dataset repo
- Add inference pipeline to model.py
- Migrate to stand alone script

In [None]:
from datasets import load_dataset, concatenate_datasets

waffle_train_ds = load_dataset("patrickcleeve/autolamella", name="waffle", split="train")
liftout_train_ds = load_dataset("patrickcleeve/autolamella", name="liftout", split="train")
serial_liftout_train_ds = load_dataset("patrickcleeve/autolamella", name="serial-liftout", split="train")


waffle_test_ds = load_dataset("patrickcleeve/autolamella", name="waffle", split="test")
liftout_test_ds = load_dataset("patrickcleeve/autolamella", name="liftout", split="test")
serial_liftout_test_ds = load_dataset("patrickcleeve/autolamella", name="serial-liftout", split="test")

# # concatenate datasets (e.g. mega model)
train_ds = concatenate_datasets([waffle_train_ds, liftout_train_ds, serial_liftout_train_ds])
test_ds = concatenate_datasets([waffle_test_ds, liftout_test_ds, serial_liftout_test_ds], split="test")

# ds = load_dataset("patrickcleeve/autolamella", name="waffle")

# train_ds = ds["train"]
# test_ds = ds["test"]

print(len(train_ds))
print(len(test_ds))

In [None]:
import matplotlib.pyplot as plt

from PIL import Image
import numpy as np
import random

idx = random.randint(0, len(train_ds) - 1)

image = np.asarray(Image.fromarray(np.asarray(train_ds[idx]["image"])).convert("RGB"))
labels = train_ds[idx]["annotation"]
# image = image.transpose(1, 2, 0)

print(image.shape)
plt.imshow(image)
plt.imshow(labels, alpha=0.5)
plt.show()

In [None]:
import json
from huggingface_hub import hf_hub_download

id2label = {0: "background", 1: "lamella", 2: "manipulator", 3: "landing_post", 4: "copper_adapter", 5: "volume_block"}
label2id = {v: k for k, v in id2label.items()}

num_labels = len(id2label)
print(id2label, num_labels)


In [None]:
from torchvision.transforms import ColorJitter, ToPILImage
from transformers import SegformerImageProcessor
import numpy as np


def to_rgb(image):
    # TODO: surely a better way to do this
    return np.asarray(Image.fromarray(np.asarray(image)).convert("RGB"))

processor = SegformerImageProcessor(do_resize=True, size={"height": 512, "width": 768})
jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1) 

def train_transforms(example_batch):
    images = [to_rgb(x) for x in example_batch['image']]
    labels = [x for x in example_batch['annotation']]
    inputs = processor(images, labels)
    return inputs


def val_transforms(example_batch):
    images = [to_rgb(x) for x in example_batch['image']]
    labels = [x for x in example_batch['annotation']]
    inputs = processor(images, labels)
    return inputs


# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)


In [None]:
import matplotlib.pyplot as plt

import random
idx = random.randint(0, len(train_ds) - 1)
image = train_ds[idx]["pixel_values"]
labels = train_ds[idx]["labels"]
image = image.transpose(1, 2, 0)
plt.imshow(image, cmap="gray")
plt.imshow(labels, alpha=0.5)
plt.show()



In [None]:
from transformers import SegformerForSemanticSegmentation

pretrained_model_name = "nvidia/mit-b1" 
model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name,
    id2label=id2label,
    label2id=label2id
)


In [None]:
from transformers import TrainingArguments

epochs = 50
lr = 0.00006
batch_size = 2

hub_model_id = "segformer-b0-finetuned-autolamella-mega-1"

training_args = TrainingArguments(
    hub_model_id,
    learning_rate=lr,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=False,
    hub_model_id=hub_model_id,
    # hub_strategy="end",
    report_to="wandb",
    run_name=hub_model_id,
    remove_unused_columns=False,
)


In [None]:
import torch
from torch import nn
import evaluate

metric = evaluate.load("mean_iou")

def compute_metrics(eval_pred):
  with torch.no_grad():
    logits, labels = eval_pred
    logits_tensor = torch.from_numpy(logits)
    # scale the logits to the size of the label
    logits_tensor = nn.functional.interpolate(
        logits_tensor,
        size=labels.shape[-2:],
        mode="bilinear",
        align_corners=False,
    ).argmax(dim=1)

    pred_labels = logits_tensor.detach().cpu().numpy()
    # currently using _compute instead of compute
    # see this issue for more info: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
    metrics = metric._compute(
            predictions=pred_labels,
            references=labels,
            num_labels=len(id2label),
            ignore_index=0,
            reduce_labels=processor.do_reduce_labels,
        )
    
    # add per category metrics as individual key-value pairs
    per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
    per_category_iou = metrics.pop("per_category_iou").tolist()

    metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
    metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
    
    return metrics


In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    compute_metrics=compute_metrics,
)


In [None]:
trainer.train()


In [None]:
kwargs = {
    "tags": ["vision", "image-segmentation"],
    "finetuned_from": pretrained_model_name,
    "dataset": "patrickcleeve/autolamella",
}

processor.push_to_hub(hub_model_id)
trainer.push_to_hub(**kwargs)

In [None]:
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

hf_username = "patrickcleeve"
hub_model_id = "segformer-b1-autolamella-mega-1"
# processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained(f"{hf_username}/{hub_model_id}")


In [None]:
from torch import nn

from fibsem.segmentation.utils import decode_segmap_v2
# image = test_ds[0]['pixel_values']
# gt_seg = test_ds[0]
# image


# plt.imshow(image, cmap="gray")
# plt.imshow(gt_seg, alpha=0.5)
# plt.show()


ds1 = load_dataset("patrickcleeve/autolamella", name="waffle", split="test")
ds2 = load_dataset("patrickcleeve/autolamella", name="liftout", split="test")
ds3 = load_dataset("patrickcleeve/autolamella", name="serial-liftout", split="test")

ds = concatenate_datasets([ds1, ds2, ds3])




for i in range(100):

    idx = random.randint(0, len(ds) - 1)

    image = ds[idx]['image']
    gt_seg = np.asarray(ds[idx]['annotation'])

    image = np.asarray(Image.fromarray(np.asarray(image)).convert("RGB"))


    # plt.imshow(image)
    # plt.imshow(decode_segmap_v2(gt_seg), alpha=0.5)
    # plt.show()


    processor = SegformerImageProcessor.from_pretrained(f"{hf_username}/{hub_model_id}")

    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits  # shape (batch_size, num_labels, height/4, width/4)

    print(inputs["pixel_values"].shape)

    # First, rescale logits to original image size
    upsampled_logits = nn.functional.interpolate(
        logits,
        size=(1024, 1536), # (height, width)
        mode='bilinear',
        align_corners=False
    )

    # Second, apply argmax on the class dimension
    pred_seg = upsampled_logits.argmax(dim=1)[0]

    # plot the prediction and ground truth
    fig, ax = plt.subplots(1, 2, figsize=(15, 5))
    plt.suptitle(f"Image {idx}")
    ax[0].imshow(image)
    ax[0].imshow(decode_segmap_v2(gt_seg), alpha=0.5)
    ax[0].axis('off')
    ax[0].set_title('Image')
    ax[1].imshow(image)
    ax[1].imshow(decode_segmap_v2(pred_seg), alpha=0.5)
    ax[1].set_title('Prediction')
    ax[1].axis('off')
    plt.show()






In [None]:

%load_ext autoreload
%autoreload 2

from fibsem.segmentation.hf_segmentation_model import SegmentationModelHuggingFace
from fibsem.segmentation.model import load_model
from fibsem.segmentation.utils import decode_segmap_v2
from datasets import load_dataset, concatenate_datasets
import matplotlib.pyplot as plt
import numpy as np
import random


ds1 = load_dataset("patrickcleeve/autolamella", name="waffle", split="test")
ds2 = load_dataset("patrickcleeve/autolamella", name="liftout", split="test")
ds3 = load_dataset("patrickcleeve/autolamella", name="serial-liftout", split="test")

ds = concatenate_datasets([ds1, ds2, ds3])

checkpoint = "patrickcleeve/segformer-b1-autolamella-mega-1"
model = load_model(checkpoint)


for i in range(20):

    idx = random.randint(0, len(ds) - 1)

    image = np.asarray(ds[idx]['image'])
    gt_seg = np.asarray(ds[idx]['annotation'])

    masks = model.inference(image, rgb=False)


    # plot the prediction and ground truth
    fig, ax = plt.subplots(1, 2, figsize=(12, 5))
    plt.suptitle(f"Image {idx}")
    ax[0].imshow(image, cmap="gray")
    ax[0].imshow(decode_segmap_v2(gt_seg), alpha=0.5)
    ax[0].axis('off')
    ax[0].set_title('Image')
    ax[1].imshow(image, cmap="gray")
    ax[1].imshow(decode_segmap_v2(masks), alpha=0.5)
    ax[1].set_title('Prediction')
    ax[1].axis('off')
    plt.show()
