- https://github.com/facebookresearch/dinov2
- https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Train_a_linear_classifier_on_top_of_DINOv2_for_semantic_segmentation.ipynb#scrollTo=rLzR_mt_SnE2
- https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DINOv2/Fine_tune_DINOv2_for_image_classification_%5Bminimal%5D.ipynb
- 

- dinov2 has 14x14 patch size. so (floor(width)x2 ) +1 hidden

In [None]:
import requests
from PIL import Image
import os
import torch
from torch.utils.data import Dataset

from PIL import Image
import os
import albumentations as A
import cv2 
import numpy as np 
import matplotlib.pyplot as plt 
from gis.config import Config

config = Config()

In [None]:
from sklearn.model_selection import train_test_split
import albumentations as A
from torch.utils.data import DataLoader


def get_image_and_mask_files():
    mask_files = os.listdir(config.mnt_path / 'label/18')
    coords = []
    for mask in mask_files:
        x,y = mask.split('_')
        x,y = int(x), int(y.replace('.npy', ''))
        coords.append((x,y))
    image_files = [f'18_{x}_{y}.jpg' for (x,y) in coords]
    return image_files, mask_files

class SegmentationDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks 
        self.transform = transform
        self.image_dir = config.mnt_path / 'image/18'
        self.mask_dir = config.mnt_path / 'label/18'
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])  
        image = Image.open(img_path).convert("RGB")
        original_image = np.array(image)
        original_mask = np.load(mask_path)

        transformed = self.transform(image=original_image, mask=original_mask)
        image, target = image, target = torch.tensor(transformed['image']), torch.LongTensor(transformed['mask'])
        image = image.permute(2,0,1)
        return image, target, original_image, original_mask


model_config = {
        'batch_size': 4,
        'epochs': 5,
        'learning_rate': 1e-4,
        'val_split': 0.2,
        'num_workers': 4,
    }


image_files, mask_files = get_image_and_mask_files()

train_images, val_images, train_masks, val_masks = train_test_split(
    image_files, mask_files, test_size=model_config['val_split'], random_state=42
)


ADE_MEAN = np.array([123.675, 116.280, 103.530]) / 255
ADE_STD = np.array([58.395, 57.120, 57.375]) / 255
width = 256 

train_transform = A.Compose([
    A.Resize(width=width, height=width),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),
])

val_transform = A.Compose([
    A.Resize(width=width, height=width),
    A.Normalize(mean=ADE_MEAN, std=ADE_STD),

])
train_dataset = SegmentationDataset(train_images, train_masks, transform=train_transform)
val_dataset = SegmentationDataset(val_images, val_masks, transform=val_transform)



train_loader =  DataLoader(train_dataset, batch_size=model_config['batch_size'])
val_loader =  DataLoader(val_dataset, batch_size=model_config['batch_size'])




In [None]:
id2label = {
    0: "baclgrpimd",
    1: "track",
}

TOKEN_WIDTH = 18 # floor(image_width / 14)

import torch
from transformers import Dinov2Model, Dinov2PreTrainedModel
from transformers.modeling_outputs import SemanticSegmenterOutput

class LinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, tokenW=TOKEN_WIDTH, tokenH=TOKEN_WIDTH, num_labels=1):
        super(LinearClassifier, self).__init__()

        self.in_channels = in_channels
        self.width = tokenW
        self.height = tokenH
        self.mixer = torch.nn.Conv2d(in_channels, 128, (3,3))
        self.classifier = torch.nn.Conv2d(128, num_labels, (1,1))

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.in_channels)
        embeddings = embeddings.permute(0,3,1,2)

        return self.classifier(self.mixer(embeddings))

class Dinov2ForSemanticSegmentation(Dinov2PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)

    self.dinov2 = Dinov2Model(config)
    self.classifier = LinearClassifier(config.hidden_size, TOKEN_WIDTH, TOKEN_WIDTH, 1)


  def forward(self, pixel_values, output_hidden_states=False, output_attentions=False, labels=None):
    # use frozen features
    outputs = self.dinov2(pixel_values,
                            output_hidden_states=output_hidden_states,
                            output_attentions=output_attentions)
    # get the patch embeddings - so we exclude the CLS token
    # cls_embeddings = outputs.last_hidden_state[:, 0, :]
    patch_embeddings = outputs.last_hidden_state[:,1:,:]
    logits = self.classifier(patch_embeddings)
    logits = torch.nn.functional.interpolate(logits, size=pixel_values.shape[2:], mode="bilinear", align_corners=False)
    return logits 

model = Dinov2ForSemanticSegmentation.from_pretrained("facebook/dinov2-base", id2label=id2label, num_labels=len(id2label))
for name, param in model.named_parameters():
    if name.startswith("dinov2"):
        param.requires_grad = False

In [None]:
from tqdm import tqdm
import segmentation_models_pytorch as smp

def calculate_metrics(pred, target, threshold=0.5):
    """Calculate IoU, Dice, and other metrics."""
    pred_binary = (pred > threshold).float()
    target_binary = target.float()
    
    # IoU
    intersection = (pred_binary * target_binary).sum()
    union = pred_binary.sum() + target_binary.sum() - intersection
    iou = intersection / (union + 1e-8)
    
    # Dice coefficient
    dice = (2 * intersection) / (pred_binary.sum() + target_binary.sum() + 1e-8)
    
    # Pixel accuracy
    correct = (pred_binary == target_binary).sum()
    total = target_binary.numel()
    accuracy = correct / total
    
    return {
        'iou': iou.item(),
        'dice': dice.item(),
        'accuracy': accuracy.item()
    }

def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    train_loss = 0.0
    train_metrics = {'iou': 0.0, 'dice': 0.0, 'accuracy': 0.0}
    
    for images, masks, _, _ in train_loader:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        #print(outputs.shape, masks.shape, masks.unsqueeze(1).shape)
        loss = criterion(outputs, masks.unsqueeze(1))
        loss.backward()
        optimizer.step()
            
        
        train_loss += loss.item()
        
        # Calculate metrics
        with torch.no_grad():
            batch_metrics = calculate_metrics(torch.sigmoid(outputs), masks.unsqueeze(1))
            for key in train_metrics:
                train_metrics[key] += batch_metrics[key]
    
    train_loss /= len(train_loader)
    for key in train_metrics:
        train_metrics[key] /= len(train_loader)
    
    return train_loss, train_metrics

def validate_model(model, val_loader, criterion, device):
    """Validate the model."""
    model.eval()
    val_loss = 0.0
    val_metrics = {'iou': 0.0, 'dice': 0.0, 'accuracy': 0.0}
    
    with torch.no_grad():
        for images, masks, _, _ in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks.unsqueeze(1))  # Add channel dim for masks
            
            val_loss += loss.item()
            
            # Calculate metrics
            batch_metrics = calculate_metrics(outputs, masks.unsqueeze(1))
            for key in val_metrics:
                val_metrics[key] += batch_metrics[key]
    
    # Average metrics
    val_loss /= len(val_loader)
    for key in val_metrics:
        val_metrics[key] /= len(val_loader)
    
    return val_loss, val_metrics



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.AdamW(model.parameters(), lr=model_config['learning_rate'])
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

model = model.to(device)



history = {
    'train_loss': [], 'val_loss': [],
    'train_iou': [], 'val_iou': [],
    'train_dice': [], 'val_dice': [],
    'train_accuracy': [], 'val_accuracy': []
}


for epoch in range(50):
    train_loss, train_metrics = train_epoch(
        model, train_loader, loss_fn, optimizer, device
    )
    val_loss, val_metrics = validate_model(model, val_loader, loss_fn, device) # todo: get eval loaders 

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_metrics['iou'])
    history['val_iou'].append(val_metrics['iou'])
    history['train_dice'].append(train_metrics['dice'])
    history['val_dice'].append(val_metrics['dice'])
    history['train_accuracy'].append(train_metrics['accuracy'])
    history['val_accuracy'].append(val_metrics['accuracy'])

    print(f"Train Loss: {train_loss:.4f}, Train IoU: {train_metrics['iou']:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val IoU: {val_metrics['iou']:.4f}")

In [None]:
import matplotlib.pyplot as plt 


def plot_training_history(history):
    """Plot training history."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train Loss')
    axes[0, 0].plot(history['val_loss'], label='Val Loss')
    axes[0, 0].set_title('Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].legend()
    
    # IoU
    axes[0, 1].plot(history['train_iou'], label='Train IoU')
    axes[0, 1].plot(history['val_iou'], label='Val IoU')
    axes[0, 1].set_title('IoU')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].legend()
    
    # Dice
    axes[1, 0].plot(history['train_dice'], label='Train Dice')
    axes[1, 0].plot(history['val_dice'], label='Val Dice')
    axes[1, 0].set_title('Dice Coefficient')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].legend()
    
    # Accuracy
    axes[1, 1].plot(history['train_accuracy'], label='Train Accuracy')
    axes[1, 1].plot(history['val_accuracy'], label='Val Accuracy')
    axes[1, 1].set_title('Accuracy')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].legend()
    
    plt.tight_layout()

    plt.show()

plot_training_history(history)

In [None]:
for batch_i, (images, masks, original_images, _) in enumerate(train_loader):
    images = images.to(device)
    masks = masks.to(device)
    
    with torch.no_grad():
        outputs = model(images)

    images = images.detach().cpu()
    masks = masks.detach().cpu()
    preds = outputs.detach().cpu()
    
    for i in range(model_config['batch_size']):
        fig, axs = plt.subplots(1,4,figsize=(20,5))
        axs[0].imshow(images[i].permute(1, 2, 0))
        axs[1].imshow(preds[i].squeeze() > 0)
        axs[2].imshow(original_images[i])
        axs[3].imshow(masks[i])
        plt.show()


    if batch_i > 4:
        break

In [None]:
from pathlib import Path

save_dir = Path('saved_models/')
torch.save(model.state_dict(), save_dir / "model_state_dict.pth")