## Car Damage Severity Classification

This notebook implements a deep learning pipeline based on the DINOv3 ConvNeXt-Small backbone to classify car damage severity into minor, moderate, and severe. The dataset is made by Prajwal Bhamere and comes from https://www.kaggle.com/datasets/prajwalbhamere/car-damage-severity-dataset/data.

### Dependencies

In [None]:
import torch
from torch import nn
from torch.optim import AdamW
from torchvision import transforms
from torchvision.datasets import ImageFolder
from transformers import AutoImageProcessor, AutoModel
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import numpy as np
import os


In [None]:
pretrained_model_name = "facebook/dinov3-convnext-small-pretrain-lvd1689m"

processor = AutoImageProcessor.from_pretrained(pretrained_model_name)
base_model = AutoModel.from_pretrained(pretrained_model_name)

## Wrapper Module
Let's make a wrapper module to wrap the original model so we can expose our own forward method to return just the logits

In [7]:
class CarDamageClassifier(nn.Module):
    def __init__(self, backbone, feature_dim: int = 768, hidden_dim = 256, dropout = 0.3):
        """
        Creates a torch.nn.Module for our car damage classifier.
        
        Parameters:
            feature_dim: int = 768 - Dimension of the input feature (DINOv3-ConvNext is 768)
            hidden_dim: int = 256 - Hidden layer dimensions
            dropout: float = 0.3 - Dropout Rate
        """
        super().__init__()
        self.backbone = backbone
        self.head = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 3),
        )

    def forward(self, images):
        features = self.backbone(images).pooler_output
        logits = self.head(features)
        return logits

In [8]:
model = CarDamageClassifier(base_model, feature_dim=768, dropout=0.5)
for p in model.backbone.parameters():
    p.requires_grad = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

CarDamageClassifier(
  (backbone): DINOv3ConvNextModel(
    (stages): ModuleList(
      (0): DINOv3ConvNextStage(
        (downsample_layers): ModuleList(
          (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
          (1): DINOv3ConvNextLayerNorm((96,), eps=1e-06, elementwise_affine=True)
        )
        (layers): ModuleList(
          (0-2): 3 x DINOv3ConvNextLayer(
            (depthwise_conv): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
            (layer_norm): DINOv3ConvNextLayerNorm((96,), eps=1e-06, elementwise_affine=True)
            (pointwise_conv1): Linear(in_features=96, out_features=384, bias=True)
            (activation_fn): GELUActivation()
            (pointwise_conv2): Linear(in_features=384, out_features=96, bias=True)
            (drop_path): Identity()
          )
        )
      )
      (1): DINOv3ConvNextStage(
        (downsample_layers): ModuleList(
          (0): DINOv3ConvNextLayerNorm((96,), eps=1e-06, elementwi

## Data Loading
Let's load our data in and apply some augmentations

### Augmentations

First, let's decide which augmentations we actually need.

For both training and validation, we'll resize to 256x256.

For training, we’ll choose a more aggressive approach.

* **Random Horizontal Flips**
* **Color Jitter** (brightness, contrast, saturation, hue) so it doesn’t overfit to a specific lighting.
* **Mild Gaussian blur or small rotations** to simulate motion blur / slight camera tilt.
* **Normalization using Imagenet mean and standard deviation**, matching what was used to pretrain the DINOv3 ConvNeXt backbone.

For validation, there are no augmentations.


In [9]:
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.05
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

val_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

### Data Loading
Now, we can make use of torch's `ImageFolder` to load our data in and automatically apply augmentations

In [10]:
train_dataset = ImageFolder("../dataset/train", transform=train_transform)
val_dataset   = ImageFolder("../dataset/val",   transform=val_transform)

### Training Configuration


Now we can shift to training our model. We need to configure some of the backing training configurations, such as loss, learning rate, epochs, batch size, etc. 

Once we define our batch size, we can also create our dataloaders.

In [11]:
epochs = 100
batch_size = 32
learning_rate = 1e-3
optimizer = AdamW(params=model.head.parameters(), lr=learning_rate, weight_decay=0.01)
loss = CrossEntropyLoss()

In [12]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=4)

Now we can train the model!

In [13]:
best_val_loss = np.inf

for epoch in range(epochs):
    train_loss = 0.0
    correct = 0
    total = 0

    model.train()  # ensure training mode each epoch

    # ----- TRAINING -----
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # forward pass → logits
        logits = model(images)

        # compute loss
        computed_loss = loss(logits, labels)

        # backward + step
        computed_loss.backward()
        optimizer.step()

        # stats
        train_loss += computed_loss.item() * images.size(0)
        _, preds = torch.max(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_train_loss = train_loss / total
    train_acc = correct / total

    # ----- VALIDATION -----
    model.eval()  # Enter evaluation mode
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{epochs}"):
            images = images.to(device)
            labels = labels.to(device)

            logits = model(images)
            computed_loss = loss(logits, labels)

            val_loss += computed_loss.item() * images.size(0)
            _, preds = torch.max(logits, dim=1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    avg_val_loss = val_loss / val_total
    val_acc = val_correct / val_total

    save_dir = f"checkpoints/e{epochs}_b{batch_size}_lr{learning_rate}"

    # Save best model based on validation loss
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        
        os.makedirs(save_dir, exist_ok=True)
        
        torch.save(model.state_dict(), f"{save_dir}/best_model.pt")
        print("New best val loss: model saved!")

    print(
        f"Epoch [{epoch+1}/{epochs}] "
        f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} "
        f"| Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f}"
    )

Epoch 1/100: 100%|██████████| 44/44 [03:29<00:00,  4.76s/it]
Validation Epoch 1/100: 100%|██████████| 8/8 [00:55<00:00,  6.95s/it]


New best val loss: model saved!
Epoch [1/100] Train Loss: 0.9851 | Train Acc: 0.5358 | Val Loss: 0.6781 | Val Acc: 0.7097


Epoch 2/100: 100%|██████████| 44/44 [03:34<00:00,  4.88s/it]
Validation Epoch 2/100: 100%|██████████| 8/8 [00:57<00:00,  7.15s/it]


New best val loss: model saved!
Epoch [2/100] Train Loss: 0.7109 | Train Acc: 0.6833 | Val Loss: 0.6441 | Val Acc: 0.7339


Epoch 3/100: 100%|██████████| 44/44 [03:26<00:00,  4.69s/it]
Validation Epoch 3/100: 100%|██████████| 8/8 [00:55<00:00,  6.99s/it]


Epoch [3/100] Train Loss: 0.6303 | Train Acc: 0.7245 | Val Loss: 0.6474 | Val Acc: 0.6855


Epoch 4/100: 100%|██████████| 44/44 [03:28<00:00,  4.73s/it]
Validation Epoch 4/100: 100%|██████████| 8/8 [00:55<00:00,  6.98s/it]


Epoch [4/100] Train Loss: 0.5602 | Train Acc: 0.7636 | Val Loss: 0.6625 | Val Acc: 0.6774


Epoch 5/100: 100%|██████████| 44/44 [03:27<00:00,  4.73s/it]
Validation Epoch 5/100: 100%|██████████| 8/8 [00:56<00:00,  7.07s/it]


Epoch [5/100] Train Loss: 0.5240 | Train Acc: 0.7643 | Val Loss: 0.6522 | Val Acc: 0.7540


Epoch 6/100: 100%|██████████| 44/44 [03:25<00:00,  4.67s/it]
Validation Epoch 6/100: 100%|██████████| 8/8 [00:55<00:00,  6.97s/it]


Epoch [6/100] Train Loss: 0.4656 | Train Acc: 0.8040 | Val Loss: 0.6736 | Val Acc: 0.7298


Epoch 7/100: 100%|██████████| 44/44 [03:25<00:00,  4.67s/it]
Validation Epoch 7/100: 100%|██████████| 8/8 [00:57<00:00,  7.13s/it]


Epoch [7/100] Train Loss: 0.4293 | Train Acc: 0.8301 | Val Loss: 0.6773 | Val Acc: 0.7540


Epoch 8/100: 100%|██████████| 44/44 [03:26<00:00,  4.69s/it]
Validation Epoch 8/100: 100%|██████████| 8/8 [00:55<00:00,  6.97s/it]


Epoch [8/100] Train Loss: 0.3993 | Train Acc: 0.8409 | Val Loss: 0.6959 | Val Acc: 0.7339


Epoch 9/100: 100%|██████████| 44/44 [03:25<00:00,  4.68s/it]
Validation Epoch 9/100: 100%|██████████| 8/8 [00:55<00:00,  6.92s/it]


Epoch [9/100] Train Loss: 0.3621 | Train Acc: 0.8648 | Val Loss: 0.7227 | Val Acc: 0.7137


Epoch 10/100: 100%|██████████| 44/44 [03:26<00:00,  4.69s/it]
Validation Epoch 10/100: 100%|██████████| 8/8 [00:55<00:00,  6.98s/it]


Epoch [10/100] Train Loss: 0.3468 | Train Acc: 0.8655 | Val Loss: 0.7060 | Val Acc: 0.7177


Epoch 11/100: 100%|██████████| 44/44 [03:24<00:00,  4.65s/it]
Validation Epoch 11/100: 100%|██████████| 8/8 [00:55<00:00,  6.95s/it]


Epoch [11/100] Train Loss: 0.3397 | Train Acc: 0.8764 | Val Loss: 0.7170 | Val Acc: 0.7177


Epoch 12/100: 100%|██████████| 44/44 [03:24<00:00,  4.64s/it]
Validation Epoch 12/100: 100%|██████████| 8/8 [00:56<00:00,  7.02s/it]


Epoch [12/100] Train Loss: 0.3226 | Train Acc: 0.8713 | Val Loss: 0.7480 | Val Acc: 0.7500


Epoch 13/100: 100%|██████████| 44/44 [03:24<00:00,  4.64s/it]
Validation Epoch 13/100: 100%|██████████| 8/8 [00:55<00:00,  6.98s/it]


Epoch [13/100] Train Loss: 0.2867 | Train Acc: 0.8800 | Val Loss: 0.7935 | Val Acc: 0.6855


Epoch 14/100: 100%|██████████| 44/44 [03:23<00:00,  4.63s/it]
Validation Epoch 14/100: 100%|██████████| 8/8 [00:55<00:00,  6.96s/it]


Epoch [14/100] Train Loss: 0.2696 | Train Acc: 0.8959 | Val Loss: 0.7997 | Val Acc: 0.7258


Epoch 15/100: 100%|██████████| 44/44 [03:31<00:00,  4.80s/it]
Validation Epoch 15/100: 100%|██████████| 8/8 [00:57<00:00,  7.21s/it]


Epoch [15/100] Train Loss: 0.2538 | Train Acc: 0.8973 | Val Loss: 0.8078 | Val Acc: 0.7218


Epoch 16/100: 100%|██████████| 44/44 [03:37<00:00,  4.93s/it]
Validation Epoch 16/100: 100%|██████████| 8/8 [00:57<00:00,  7.17s/it]


Epoch [16/100] Train Loss: 0.2252 | Train Acc: 0.9205 | Val Loss: 0.8509 | Val Acc: 0.7177


Epoch 17/100: 100%|██████████| 44/44 [03:34<00:00,  4.86s/it]
Validation Epoch 17/100: 100%|██████████| 8/8 [00:58<00:00,  7.26s/it]


Epoch [17/100] Train Loss: 0.2332 | Train Acc: 0.9132 | Val Loss: 0.8215 | Val Acc: 0.7097


Epoch 18/100: 100%|██████████| 44/44 [03:32<00:00,  4.83s/it]
Validation Epoch 18/100: 100%|██████████| 8/8 [00:55<00:00,  6.91s/it]


Epoch [18/100] Train Loss: 0.2245 | Train Acc: 0.9183 | Val Loss: 0.7935 | Val Acc: 0.7137


Epoch 19/100: 100%|██████████| 44/44 [03:29<00:00,  4.77s/it]
Validation Epoch 19/100: 100%|██████████| 8/8 [00:57<00:00,  7.15s/it]


Epoch [19/100] Train Loss: 0.2065 | Train Acc: 0.9262 | Val Loss: 0.8683 | Val Acc: 0.7258


Epoch 20/100: 100%|██████████| 44/44 [04:26<00:00,  6.06s/it]
Validation Epoch 20/100: 100%|██████████| 8/8 [01:09<00:00,  8.64s/it]


Epoch [20/100] Train Loss: 0.1831 | Train Acc: 0.9277 | Val Loss: 0.8663 | Val Acc: 0.7258


Epoch 21/100: 100%|██████████| 44/44 [05:04<00:00,  6.91s/it]
Validation Epoch 21/100: 100%|██████████| 8/8 [01:08<00:00,  8.53s/it]


Epoch [21/100] Train Loss: 0.1535 | Train Acc: 0.9501 | Val Loss: 0.8535 | Val Acc: 0.7258


Epoch 22/100: 100%|██████████| 44/44 [05:50<00:00,  7.97s/it]
Validation Epoch 22/100: 100%|██████████| 8/8 [00:55<00:00,  6.88s/it]


Epoch [22/100] Train Loss: 0.1714 | Train Acc: 0.9328 | Val Loss: 0.9309 | Val Acc: 0.6976


Epoch 23/100: 100%|██████████| 44/44 [03:23<00:00,  4.63s/it]
Validation Epoch 23/100: 100%|██████████| 8/8 [00:55<00:00,  6.94s/it]


Epoch [23/100] Train Loss: 0.1596 | Train Acc: 0.9537 | Val Loss: 0.9501 | Val Acc: 0.7258


Epoch 24/100: 100%|██████████| 44/44 [03:24<00:00,  4.64s/it]
Validation Epoch 24/100: 100%|██████████| 8/8 [00:54<00:00,  6.87s/it]


Epoch [24/100] Train Loss: 0.1793 | Train Acc: 0.9349 | Val Loss: 0.8882 | Val Acc: 0.7218


Epoch 25/100: 100%|██████████| 44/44 [03:25<00:00,  4.67s/it]
Validation Epoch 25/100: 100%|██████████| 8/8 [00:55<00:00,  6.91s/it]


Epoch [25/100] Train Loss: 0.1570 | Train Acc: 0.9429 | Val Loss: 0.9524 | Val Acc: 0.7137


Epoch 26/100: 100%|██████████| 44/44 [03:23<00:00,  4.63s/it]
Validation Epoch 26/100: 100%|██████████| 8/8 [00:55<00:00,  6.94s/it]


Epoch [26/100] Train Loss: 0.1369 | Train Acc: 0.9602 | Val Loss: 0.9612 | Val Acc: 0.7097


Epoch 27/100: 100%|██████████| 44/44 [03:29<00:00,  4.76s/it]
Validation Epoch 27/100: 100%|██████████| 8/8 [00:57<00:00,  7.13s/it]


Epoch [27/100] Train Loss: 0.1459 | Train Acc: 0.9537 | Val Loss: 1.0303 | Val Acc: 0.6895


Epoch 28/100:   0%|          | 0/44 [00:00<?, ?it/s]Traceback (most recent call last):
Epoch 28/100:   0%|          | 0/44 [00:00<?, ?it/s]  File "<string>", line 1, in <module>
  File "/opt/miniconda3/envs/py312/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/py312/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/py312/lib/python3.12/site-packages/torch/__init__.py", line 18, in <module>
    import inspect
  File "/opt/miniconda3/envs/py312/lib/python3.12/inspect.py", line 146, in <module>
    import dis
  File "/opt/miniconda3/envs/py312/lib/python3.12/dis.py", line 245, in <module>

    Positions = collections.namedtuple(
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/py312/lib/python3.12/collections/_

KeyboardInterrupt: 