In [1]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
from PIL import ImageFile
from torchvision import transforms, datasets
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import os

In [2]:
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [3]:
# --- Configurare ---
DATA_DIR = "./dataset"
#dataset structure:
#dataset/
#    train/
#        rabbit/
#        misc/
#        nothing/
#    test/
#        rabbit/
#        misc/
#        nothing/

BATCH_SIZE = 16
LR = 1e-4
EPOCHS = 10
NUM_CLASSES = 3 # for now rabbit, nothing future: 3rd class misc
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_SAVE_PATH = "./model/vit_intruder.pth"
MODEL_QUANTIZED_PATH = "./model/vit_intruder_q.pth"

In [4]:
# --- Augmentări și transformări ---
train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.GaussianBlur(3, sigma=(0.1, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [5]:
# --- Dataset și DataLoader ---
train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform = train_transforms)
test_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "test"), transform = train_transforms)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [6]:
# --- Dataset Weights ---
train_dir = os.path.join(DATA_DIR, "train")
classes = [d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))]

# dictionar cu count pe clasa
class_counts = {}

for cls in classes:
    cls_path = os.path.join(train_dir, cls)
    n_files = len([f for f in os.listdir(cls_path) if os.path.isfile(os.path.join(cls_path, f))])
    class_counts[cls] = n_files

print(class_counts)

counts = torch.tensor([class_counts[cls] for cls in classes], dtype=torch.float)

weights = 1.0 / counts
class_weights = weights / weights.sum()

print(class_weights)

{'misc': 95, 'nothing': 403, 'rabbit': 7948}
tensor([0.8015, 0.1889, 0.0096])


In [7]:
# --- Model ViT Tiny ---
weights = ViT_B_16_Weights.DEFAULT
model = vit_b_16(weights=weights)
in_features = model.heads.head.in_features
model.heads.head = nn.Linear(in_features, NUM_CLASSES)
model = model.to(DEVICE)

In [8]:
# --- Weighted Loss și Optimizer ---
criterion = torch.nn.CrossEntropyLoss(weight = class_weights.to(DEVICE))
optimizer = optim.AdamW(model.parameters(), lr=LR)

In [None]:
# --- Training loop ---
print("--- TRAINING STARTED ---")
for epoch in range(EPOCHS) :
    start_time = time.time()
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = running_loss /total
    train_acc = correct / total

    #Test
    model.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            test_total += labels.size(0)
            test_correct += (predicted == labels).sum().item()

    test_loss /= test_total
    test_acc = test_correct / test_total
    
    epoch_time = time.time() - start_time

    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f} | "
          f"{epoch_time/60:.2f} minutes")
print("--- TRAINING ENDED ---")

--- TRAINING STARTED ---


In [None]:
# --- Salvare model ---
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model salvat la {MODEL_SAVE_PATH}")

In [11]:
# --- Quantizare pentru 8 biti ---
model.cpu()
model.eval()

model.qconfig = torch.quantization.get_default_qconfig('qnnpack') # fbgemm pentru x86, qnnpack pentru arm
torch.quantization.prepare(model, inplace=True)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.prepare(model, inplace=True)


VisionTransformer(
  (conv_proj): Conv2d(
    3, 768, kernel_size=(16, 16), stride=(16, 16)
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm(
          (768,), eps=1e-06, elementwise_affine=True
          (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
        )
        (self_attention): QuantizableMultiheadAttention(
          (out_proj): Linear(
            in_features=768, out_features=768, bias=True
            (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
          )
          (linear_Q): Linear(
            in_features=768, out_features=768, bias=True
            (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
          )
          (linear_K): Linear(
            in_features=768, out_features=768, bias=True
            (a

In [14]:
# --- use some images from test for calibration
print("Running calibraiton for quantization...")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader):
        model(images) # forward pass for calibration
        if i >= 10: # 10 batches enough for quantization
            break

Running calibraiton for quantization...


In [16]:
# convert to 8 bit
torch.quantization.convert(model, inplace=True)
print("8-bit Quantization compelte")

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  torch.quantization.convert(model, inplace=True)


8-bit Quantization compelte


In [17]:
# --- Salvare model quantizat ---
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model salvat la {MODEL_QUANTIZED_PATH}")

Model salvat la ./vit_intruder.pth
