## Finetuning BEIT

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BeitForImageClassification, BeitImageProcessor
from PIL import Image
import pandas as pd
import time

# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, csv_file, feature_extractor):
        print("Loading dataset...")
        self.data = pd.read_csv(csv_file).drop_duplicates(subset=['image_path']).reset_index(drop=True)
        self.feature_extractor = feature_extractor
        print(f"Dataset loaded with {len(self.data)} unique samples.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx]['image_path']
        label = self.data.iloc[idx]['label']
        image = Image.open(img_path).convert('RGB')
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        return inputs['pixel_values'].squeeze(0), torch.tensor(label, dtype=torch.long)

# Load pre-trained BEiT model and feature extractor
print("Loading BEiT model and feature extractor...")
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
feature_extractor = BeitImageProcessor.from_pretrained('microsoft/beit-base-patch16-224')
print("Model and feature extractor loaded successfully.")

# Enable gradient checkpointing to reduce memory usage
model.gradient_checkpointing_enable()


# Adjust model for binary classification
print("Modifying model for binary classification...")
model.classifier = torch.nn.Linear(model.classifier.in_features, 2)  # 2 classes
print("Model adjusted.")

# Load dataset
dataset = CustomDataset('image_data.csv', feature_extractor)
# num_workers=0
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, pin_memory=True)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model.to(device)

# Optimizer and loss
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)  # Reduced learning rate
criterion = torch.nn.CrossEntropyLoss()

# Enable mixed precision training
scaler = torch.amp.GradScaler()


Loading BEiT model and feature extractor...


  return func(*args, **kwargs)


Model and feature extractor loaded successfully.
Modifying model for binary classification...
Model adjusted.
Loading dataset...
Dataset loaded with 2594 unique samples.
Using device: cuda


In [5]:
# Early stopping parameters
patience = 3  # Number of epochs to wait before stopping if no improvement
best_loss = float('inf')
stopping_counter = 0

# Training loop
num_epochs = 5
print("Starting training...")
start_time = time.time()
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs} - Training started...")
    epoch_start_time = time.time()
    model.train()
    epoch_loss = 0
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        # if batch_idx % 20 == 0:
        #     print(f"Epoch {epoch+1}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")

        # print(f"Processing Batch {batch_idx+1}/{len(dataloader)}")  # Debug print
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):  # Enable mixed precision
            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
    
    avg_loss = epoch_loss / len(dataloader)
    epoch_end_time = time.time()
    print(f"Epoch {epoch+1} completed. Avg Loss: {avg_loss:.4f}, Time taken: {(epoch_end_time - epoch_start_time):.2f} sec")
    
    # Early stopping check
    if avg_loss < best_loss:
        best_loss = avg_loss
        stopping_counter = 0  # Reset counter if loss improves
        model.save_pretrained('fine_tuned_beit_best')  # Save best model
        print("New best model saved!")
    else:
        stopping_counter += 1
        print(f"Early stopping counter: {stopping_counter}/{patience}")
        if stopping_counter >= patience:
            print("Early stopping triggered. Stopping training.")
            break

total_time = time.time() - start_time
print(f"Training completed in {total_time/60:.2f} minutes.")

# Save final model
model.save_pretrained('fine_tuned_beit')
print("Model fine-tuned and saved successfully!")

Starting training...

Epoch 1/5 - Training started...
Epoch 1, Batch 0/163, Loss: 0.8319
Epoch 1, Batch 10/163, Loss: 0.5901
Epoch 1, Batch 20/163, Loss: 0.4773
Epoch 1, Batch 30/163, Loss: 0.1502
Epoch 1, Batch 40/163, Loss: 0.5048
Epoch 1, Batch 50/163, Loss: 0.6168
Epoch 1, Batch 60/163, Loss: 0.4445
Epoch 1, Batch 70/163, Loss: 0.2021
Epoch 1, Batch 80/163, Loss: 0.5241
Epoch 1, Batch 90/163, Loss: 0.6617
Epoch 1, Batch 100/163, Loss: 0.1382
Epoch 1, Batch 110/163, Loss: 0.4132
Epoch 1, Batch 120/163, Loss: 0.2105
Epoch 1, Batch 130/163, Loss: 0.2608
Epoch 1, Batch 140/163, Loss: 0.4116
Epoch 1, Batch 150/163, Loss: 0.6157
Epoch 1, Batch 160/163, Loss: 0.5728
Epoch 1 completed. Avg Loss: 0.4691, Time taken: 1366.22 sec
New best model saved!

Epoch 2/5 - Training started...
Epoch 2, Batch 0/163, Loss: 0.3528
Epoch 2, Batch 10/163, Loss: 0.8584
Epoch 2, Batch 20/163, Loss: 0.2477
Epoch 2, Batch 30/163, Loss: 0.6037
Epoch 2, Batch 40/163, Loss: 0.7081
Epoch 2, Batch 50/163, Loss: 0.48

## Converting Fine-tuned Model to ONNX

In [7]:
from transformers import BeitConfig, BeitForImageClassification
import torch

# Load the existing config from your fine-tuned checkpoint
config = BeitConfig.from_pretrained('fine_tuned_beit')
config.num_labels = 2  # explicitly set number of labels

# Now load the model using the updated config
model = BeitForImageClassification.from_pretrained(
    'fine_tuned_beit', 
    config=config
)
model.eval()

# Export to ONNX
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    model, 
    dummy_input, 
    "beit_finetuned_model.onnx",
    input_names=["input"], 
    output_names=["output"]
)
print("Model converted to ONNX!")


  if num_channels != self.num_channels:
  if interpolate_pos_encoding:


Model converted to ONNX!


In [9]:
import onnxruntime as ort
import torch
import numpy as np
from PIL import Image
from transformers import BeitImageProcessor

# 1. Load the ONNX model into an InferenceSession
session = ort.InferenceSession("beit_finetuned_model.onnx")

# 2. Create the processor for image preprocessing
processor = BeitImageProcessor.from_pretrained('microsoft/beit-base-patch16-224')

# 3. Load and preprocess an image
image = Image.open("classified_images/malignant/ISIC_0000013.jpg").convert("RGB")
inputs = processor(images=image, return_tensors="pt")  
# inputs["pixel_values"] will be shape [1, 3, 224, 224]

# 4. ONNX Runtime expects Numpy arrays as input
ort_inputs = {session.get_inputs()[0].name: inputs["pixel_values"].numpy()}

# 5. Run inference
ort_outputs = session.run(None, ort_inputs)
# ort_outputs is a list; the first element is your model's output logits
logits = ort_outputs[0]  # shape: [1, 2]

# 6. Convert logits to probabilities and get predicted class
#    Wrap logits with torch.tensor() just to use PyTorch’s softmax & argmax
logits_tensor = torch.tensor(logits)
probs = torch.softmax(logits_tensor, dim=1)
pred_label = torch.argmax(probs, dim=1)

print("ONNX Output logits:", logits)
print("Probabilities:", probs)
print("Predicted class:", pred_label.item())


ONNX Output logits: [[-2.1029592  1.2445506]]
Probabilities: tensor([[0.0340, 0.9660]])
Predicted class: 1


In [10]:
df = pd.read_csv("image_data.csv") 

In [5]:
print(f"Expected batches per epoch: {len(dataloader)}")


Expected batches per epoch: 82
