# SegFormer Inferencce

In [9]:
import torch
import numpy as np
import random
import os
import sys
from src.utils.utils import *

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)

  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

sys.path.append("..")
set_seed(42)

In [10]:
from src.segFormer import SegFormer

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [None]:
from src.utils.utils import *
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch import nn

val_transforms = A.Compose(
    [
        A.Resize(height=256, width=256),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

In [None]:
CHECKPOINT_FILENAME = "./checkpoints/segformer-checkpoints/checkpoint_7.pth.tar"

model = SegFormer(
  in_channels=3,
  widths=[64, 128, 256, 512],
  depths=[3, 4, 6, 3],
  all_num_heads=[1, 2, 4, 8],
  patch_sizes=[7, 3, 3, 3],
  overlap_sizes=[4, 2, 2, 2],
  reduction_ratios=[8, 4, 2, 1],
  mlp_expansions=[4, 4, 4, 4],
  decoder_channels=256,
  scale_factors=[8, 4, 2, 1],
  num_classes=1,
  drop_prob=0.3,
).to(device)

checkpoint = torch.load(CHECKPOINT_FILENAME)
model, _, _ = load_checkpoint(checkpoint, model)

=> Loading checkpoint


In [32]:
from PIL import Image
import torch
import torch.nn.functional as F

def run_inference(model, image_path, device="cuda"):
    # Load the image
    image = Image.open(image_path)
    orig_size = image.size

    # Apply the transformations
    image = val_transforms(image=np.array(image))['image']

    # Add an extra dimension for the batch size
    image = image.unsqueeze(0)

    # Move the image tensor to the device
    image = image.to(device)

    # Run the image tensor through the model
    model.eval()
    with torch.no_grad():
        preds = model(image)

    # Apply the sigmoid function and threshold at 0.5
    preds = torch.sigmoid(preds)
    preds = (preds > 0.5).float()

    # Remove the batch dimension and channel dimension
    preds = preds.squeeze().cpu()

    # Add back the batch dimension and channel dimension
    preds = preds.unsqueeze(0).unsqueeze(0)

    # Resize the prediction back to the original size
    preds = F.interpolate(preds, size=(orig_size[1], orig_size[0]), mode="nearest")

    # Convert the tensor to a numpy array
    preds_np = preds.squeeze().numpy()

    # Convert the numpy array to a PIL Image and display it
    preds_img = Image.fromarray((preds_np * 255).astype(np.uint8))
    # preds_img.show()

    # Save the image to the specified filepath
    filename = os.path.basename(image_path)
    output_path = os.path.join("./streamlit-app/data/prediction/", filename)
    preds_img.save(output_path)

    return preds

run_inference(model, "./streamlit-app/data/sample/ISIC_0036333.jpg", device=device)

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]])