In [None]:
from transformers import Dinov2ForImageClassification
import torch.nn as nn

# Define new number of labels
num_labels = dataset.get_num_labels()

# Load Pretrained DINOv2
model = Dinov2ForImageClassification.from_pretrained(
    "facebook/dinov2-small-imagenet1k-1-layer",
    num_labels=num_labels,
    ignore_mismatched_sizes=True  # Ignore mismatched classifier head
)

# Freeze all layers except classifier
for param in model.parameters():
    param.requires_grad = False  # Freeze backbone

# Unfreeze only the classifier head
for param in model.classifier.parameters():
    param.requires_grad = True

# Replace classifier with sigmoid for multi-label classification
model.classifier = nn.Sequential(
    nn.Linear(model.classifier.in_features, num_labels),
    nn.Sigmoid()  # Sigmoid for multi-label classification
)

# Print model structure to confirm changes
print(model)


  from .autonotebook import tqdm as notebook_tqdm
Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-small-imagenet1k-1-layer and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([500]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([500, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Dinov2ForImageClassification(
  (dinov2): Dinov2Model(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Dinov2Encoder(
      (layer): ModuleList(
        (0-11): 12 x Dinov2Layer(
          (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
          (attention): Dinov2SdpaAttention(
            (attention): Dinov2SdpaSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): Dinov2SelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)

In [2]:
# Define Binary Classification Head (MLP)
class BinaryClassifier(nn.Module):
    def __init__(self, input_dim):
        super(BinaryClassifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()  # Binary output
        )

    def forward(self, x):
        return self.fc(x)

# Initialize the binary classifier
binary_model = BinaryClassifier(input_dim=num_labels)


In [10]:
from torch.utils.data import Dataset
import torch

class HDV11K(Dataset):
    def __init__(self):
        pass
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.rand(3, 224, 224), torch.randint(0, 2, (1,))

In [11]:
train_dataset = HDV11K()

In [12]:
import torch.optim as optim
from torch.utils.data import DataLoader

# Define Loss & Optimizer
multi_label_loss_fn = nn.BCEWithLogitsLoss()  # Multi-label loss
binary_loss_fn = nn.BCELoss()  # Binary classification loss
optimizer = optim.Adam(list(model.parameters()) + list(binary_model.parameters()), lr=0.001)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Training Loop
num_epochs = 5

for epoch in range(num_epochs):
    for batch in train_loader:
        images = batch["pixel_values"]
        multi_labels = batch["multi_labels"]
        binary_labels = batch["binary_label"].unsqueeze(1)  # Make binary labels shape (B, 1)

        # Multi-Label Predictions
        multi_label_preds = model(images).logits  # Get multi-label outputs

        # Binary Classification Predictions
        binary_preds = binary_model(multi_label_preds)

        # Compute Loss
        multi_label_loss = multi_label_loss_fn(multi_label_preds, multi_labels)
        binary_loss = binary_loss_fn(binary_preds, binary_labels)
        loss = multi_label_loss + binary_loss  # Combine losses

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")

# Save trained models
model.save_pretrained("./dinov2-multilabel")
# torch.save(binary_model.state_dict(), "binary_classifier.pth")


AttributeError: 'HDV11K' object has no attribute 'data'