In [6]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import BeitForImageClassification
from torch.optim import Adam
from torch import nn

# Define directories
DATASET_DIR = './dataset_split/'
TRAIN_DIR = os.path.join(DATASET_DIR, 'train')
VAL_DIR = os.path.join(DATASET_DIR, 'val')
TEST_DIR = os.path.join(DATASET_DIR, 'test')

# Define categories (classes)
CATEGORIES = ['hurricane', 'earthquake', 'wildfire', 'not meaningful']

# Device configuration (use GPU if available)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 as expected by BEiT
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize RGB values
])

# Load datasets
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=transform)
val_dataset = datasets.ImageFolder(VAL_DIR, transform=transform)
test_dataset = datasets.ImageFolder(TEST_DIR, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

# Load BEiT model with 4 output labels (for 4 classes) and ignore size mismatches
model = BeitForImageClassification.from_pretrained(
    'microsoft/beit-base-patch16-224', 
    num_labels=4,  # Adjust for 4 classes
    ignore_mismatched_sizes=True  # Ignore mismatches in classifier layer sizes
)
model.to(device)  # Move model to GPU if available

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=5e-5)

# Function to save the model
def save_model(model, path):
    model.save_pretrained(path)  # Save the model
    print(f"Model saved to {path}")

# Function to load the model
def load_model(path):
    model = BeitForImageClassification.from_pretrained(
        path, 
        num_labels=4,  # Ensure the number of classes matches
        ignore_mismatched_sizes=True
    )
    model.to(device)
    print(f"Model loaded from {path}")
    return model

# Function to train the model
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)
        
        train_accuracy = 100 * correct_predictions / total_predictions
        val_accuracy = validate_model(model, val_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss:.4f}, '
              f'Train Accuracy: {train_accuracy:.2f}%, Val Accuracy: {val_accuracy:.2f}%')

    # Save the model after training
    save_model(model, 'beit_custom_model')

# Function to validate the model
def validate_model(model, val_loader):
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)
    
    return 100 * correct_predictions / total_predictions

# Function to test the model
def test_model(model, test_loader):
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)
            correct_predictions += (predicted == labels).sum().item()
            total_predictions += labels.size(0)
    
    test_accuracy = 100 * correct_predictions / total_predictions
    print(f'Test Accuracy: {test_accuracy:.2f}%')

# Main function
def main():
    num_epochs = 5
    # Train the model
    train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)
    
    # Load the saved model (for testing or further use)
    loaded_model = load_model('beit_custom_model')

    # Test the model
    test_model(loaded_model, test_loader)

if __name__ == "__main__":
    main()


Some weights of BeitForImageClassification were not initialized from the model checkpoint at microsoft/beit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([4, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch [1/5], Loss: 69.4819, Train Accuracy: 69.00%, Val Accuracy: 86.43%
Epoch [2/5], Loss: 22.0745, Train Accuracy: 92.07%, Val Accuracy: 94.75%
Epoch [3/5], Loss: 6.1060, Train Accuracy: 98.33%, Val Accuracy: 95.40%
Epoch [4/5], Loss: 2.0920, Train Accuracy: 99.49%, Val Accuracy: 96.28%
Epoch [5/5], Loss: 1.3074, Train Accuracy: 99.42%, Val Accuracy: 96.72%
Model saved to beit_custom_model
Model loaded from beit_custom_model
Test Accuracy: 96.85%
