In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from transformers import Dinov2Config, Dinov2ForImageClassification, AutoImageProcessor
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the smallest DINOv2 model
model_name = "facebook/dinov2-small"
config = Dinov2Config.from_pretrained(model_name)
model = Dinov2ForImageClassification.from_pretrained(model_name)
image_processor = AutoImageProcessor.from_pretrained(model_name)

# Modify the classifier for your number of classes
num_classes = 10  # Change this to match your dataset
model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)

# Move model to device
model.to(device)

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std),
])

# Load your custom dataset
data_dir = "/path/to/your/dataset"  # Change this to your dataset path
dataset = ImageFolder(data_dir, transform=transform)

# Split dataset into train and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Define optimizer and loss function
optimizer = AdamW(model.parameters(), lr=5e-6)  # Using a lower learning rate as suggested
criterion = CrossEntropyLoss()

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss/len(train_loader):.4f}")
    print(f"Val Loss: {val_loss/len(val_loader):.4f}")
    print(f"Val Accuracy: {100*correct/total:.2f}%")
    print()

# Save the fine-tuned model
torch.save(model.state_dict(), "fine_tuned_dinov2_small.pth")

In [None]:
from torchinfo import summary
summary(model)

In [None]:
!pip install torchgeo

In [None]:
import torch
from torchgeo.datasets import BigEarthNet
from torch.utils.data import DataLoader
from torchvision import transforms

# Set the root directory where you want to store the dataset
root = "path/to/your/data/directory"

# Define transforms (if needed)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create the BigEarthNet dataset
# Set download=True to download the dataset if it's not already present
dataset = BigEarthNet(
    root=root,
    split="train",  # You can change this to "val" or "test" as needed
    bands="rgb",    # You can use "all" for all bands, or "s1" for Sentinel-1 bands
    num_classes=19, # You can set this to 43 for the full set of classes
    transforms=transform,
    download=True,
    checksum=True   # Verify the integrity of downloaded files
)

# Create a DataLoader
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Example of iterating through the dataset
for batch in dataloader:
    images = batch['image']
    labels = batch['label']
    
    # Your training or processing code here
    print(f"Batch image shape: {images.shape}")
    print(f"Batch label shape: {labels.shape}")
    break  # Remove this to process all batches

print(f"Dataset size: {len(dataset)}")