# FASTVIT PET MOBILE

## 1. Environment setup

In [1]:
!git clone https://github.com/HenryNVP/fastvit-pet-mobile.git
%cd fastvit-pet-mobile

Cloning into 'fastvit-pet-mobile'...
remote: Enumerating objects: 63, done.[K
remote: Counting objects: 100% (63/63), done.[K
remote: Compressing objects: 100% (50/50), done.[K
remote: Total 63 (delta 11), reused 57 (delta 8), pack-reused 0 (from 0)[K
Receiving objects: 100% (63/63), 3.12 MiB | 37.63 MiB/s, done.
Resolving deltas: 100% (11/11), done.
/content/fastvit-pet-mobile


In [2]:
!pip install -r requirements.txt

Collecting onnxruntime (from -r requirements.txt (line 11))
  Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting coloredlogs (from onnxruntime->-r requirements.txt (line 11))
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime->-r requirements.txt (line 11))
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m138.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## 2. Download and prepare data

This downloads the official dataset and prepares 256x256 splits.

In [3]:
!python scripts/get_data.py
!python scripts/split_dataset.py

Downloading images.tar.gz from https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz
  tf.extractall(destination)
Downloading annotations.tar.gz from https://thor.robots.ox.ac.uk/~vgg/data/pets/annotations.tar.gz
Prepared train=3680 val=1832 test=1837


## 3. Finetune teacher model

In [5]:
import torch
import fastvit.models
from timm.models import create_model
from fastvit.models.modules.mobileone import reparameterize_model

# Create the model
teacher_model = create_model("fastvit_t8", pretrained=False)

# Download and load checkpoint
!wget -O fastvit_t8.pth.tar \
"https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar"

checkpoint = torch.load("fastvit_t8.pth.tar", map_location='cpu')
teacher_model.load_state_dict(checkpoint['state_dict'])

# ... train ...

# For inference
teacher_model.eval()
model_inf = reparameterize_model(teacher_model)
# Use model_inf at test-time

--2025-11-19 22:45:05--  https://docs-assets.developer.apple.com/ml-research/models/fastvit/image_classification_models/fastvit_t8.pth.tar
Resolving docs-assets.developer.apple.com (docs-assets.developer.apple.com)... 17.253.118.202, 2403:300:a32:f100::2
Connecting to docs-assets.developer.apple.com (docs-assets.developer.apple.com)|17.253.118.202|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16378213 (16M) [application/x-tar]
Saving to: ‘fastvit_t8.pth.tar’


2025-11-19 22:45:07 (11.1 MB/s) - ‘fastvit_t8.pth.tar’ saved [16378213/16378213]



In [6]:
from PIL import Image
from torchvision import transforms

# Load image
img = Image.open("/content/fastvit-pet-mobile/data/train/Abyssinian/Abyssinian_1.jpg").convert("RGB")

# Preprocess
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

input_tensor = preprocess(img).unsqueeze(0)  # add batch dimension

with torch.no_grad():
    output = model_inf(input_tensor)  # shape: [1, num_classes]
    predicted_class = torch.argmax(output, dim=1)
    print("Predicted class index:", predicted_class.item())

Predicted class index: 273


In [9]:
!python fastvit/train.py data --model fastvit_t8 -b 128 --lr 1e-3 --native-amp --mixup 0.2 --output ./output/teacher --input-size 3 256 256


If for semantic segmentation, please install mmsegmentation first
If for detection, please install mmdetection first
  @register_model
  @register_model
  @register_model
  @register_model
  @register_model
  @register_model
  @register_model
Training with a single process on 1 GPUs.
Model fastvit_t8 created, param count:4026232
Data processing configuration for current model + dataset:
	input_size: (3, 256, 256)
	interpolation: bicubic
	mean: (0.485, 0.456, 0.406)
	std: (0.229, 0.224, 0.225)
	crop_pct: 0.875
	crop_mode: center
Using native Torch AMP. Training in mixed precision.
Scheduled epochs: 310
  with amp_autocast():
  with amp_autocast():
Test: [   0/14]  Time: 0.862 (0.862)  Loss:  6.9453 (6.9453)  Acc@1:  0.0000 ( 0.0000)  Acc@5:  0.0000 ( 0.0000)
Test: [  14/14]  Time: 4.316 (0.445)  Loss:  6.9570 (6.9164)  Acc@1:  0.0000 ( 0.4367)  Acc@5:  0.0000 ( 2.2380)
Test (EMA): [   0/14]  Time: 0.537 (0.537)  Loss:  6.9102 (6.9102)  Acc@1:  0.0000 ( 0.0000)  Acc@5:  0.0000 ( 0.0000)


In [10]:
print(teacher_model)

FastViT(
  (patch_embed): Sequential(
    (0): MobileOneBlock(
      (se): Identity()
      (activation): GELU(approximate='none')
      (rbr_conv): ModuleList(
        (0): Sequential(
          (conv): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (rbr_scale): Sequential(
        (conv): Conv2d(3, 48, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): MobileOneBlock(
      (se): Identity()
      (activation): GELU(approximate='none')
      (rbr_conv): ModuleList(
        (0): Sequential(
          (conv): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=48, bias=False)
          (bn): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (rbr_scale): S

In [12]:
tuned_teacher = create_model(
    "fastvit_t8",
    pretrained=False,
    num_classes=37,
)

checkpoint = torch.load("/content/fastvit-pet-mobile/output/teacher/20251119-225323-fastvit_t8-256/model_best.pth.tar", map_location='cpu')
tuned_teacher_model.load_state_dict(checkpoint['state_dict'])


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL argparse.Namespace was not an allowed global by default. Please use `torch.serialization.add_safe_globals([argparse.Namespace])` or the `torch.serialization.safe_globals([argparse.Namespace])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

## 4. Setup and finetune student model

## 5. Model evaluation

## 6. Export model

## 7. Summary and conclusion