In [None]:
import torch

REPO_DIR = '/home/fergus/repos/dinov3'
WEIGHTS_PATH = '/home/fergus/dinov2_vitg14_reg4_pretrain.pth'
WEIGHTS_PATH = '/home/fergus/dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth'


MODEL_DINOV3_VITS = "dinov3_vits16"
MODEL_DINOV3_VITSP = "dinov3_vits16plus"
MODEL_DINOV3_VITB = "dinov3_vitb16"
MODEL_DINOV3_VITL = "dinov3_vitl16"
MODEL_DINOV3_VITHP = "dinov3_vith16plus"
MODEL_DINOV3_VIT7B = "dinov3_vit7b16"

SAT_SMALL = 'dinov3_vitl16'
SAT_BIG = 'dinov3_vit7b16'



dinov3_vitl16 = torch.hub.load(REPO_DIR, SAT_SMALL, source='local', weights=WEIGHTS_PATH)




In [None]:
for name, param in dinov3_vitl16.named_parameters():
    param.requires_grad = False
        

In [None]:
dinov3_vitl16

In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Non-trainable parameters: {total_params - trainable_params:,}")
    
    return total_params, trainable_params

# Usage
total, trainable = count_parameters(dinov3_vitl16)

In [None]:
= https://www.labellerr.com/blog/dinov3/

In [None]:
def trace_model_forward(model, x):
    print("=== Model Forward Trace ===")
    
    # Step 1: patch embedding
    patches = model.patch_embed(x)
    print(f"1. patch_embed output: {patches.shape}")
    print(f"   First few values: {patches[0, 0, :5]}")
    
    # Step 2: through blocks manually
    x_manual = patches.clone()
    for i, block in enumerate(model.blocks):
        x_manual = block(x_manual)
        if i == 0:
            print(f"2. After first block: {x_manual.shape}")
            print(f"   First few values: {x_manual[0, 0, :5]}")
    
    # Step 3: final norm
    x_manual = model.norm(x_manual)
    print(f"3. After final norm: {x_manual.shape}")
    print(f"   First few values: {x_manual[0, 0, :5]}")
    
    # Step 4: full model call
    full_output = model(x)
    print(f"4. Full model() output: {full_output.shape}")
    print(f"   First few values: {full_output[0, :5]}")
    
    # Compare patch_embed vs final processed patches
    if x_manual.shape[1] > 1:  # If we have patch tokens
        comparison_tensor = x_manual[:, 1:]  # Skip CLS if present
        cosine_sim = torch.nn.functional.cosine_similarity(
            patches.flatten(1), 
            comparison_tensor.flatten(1), 
            dim=1
        )
        print(f"5. Similarity between patch_embed and processed patches: {cosine_sim.item():.4f}")
    
    return patches, x_manual, full_output

# Run the trace
patch_embed_out, final_patches, model_out = trace_model_forward(dinov3_vitl16, inp)

In [None]:
dir(dinov3_vitl16)

In [None]:
dinov3_vitl16.patch_embed()

In [None]:
ff = dinov3_vitl16.forward_features(inp)
ff

In [None]:
ff['x_norm_clstoken'].shape, ff['x_storage_tokens'].shape, ff['x_norm_patchtokens'].shape, ff['x_prenorm'].shape

In [None]:
dinov3_vitl16.rope_embed(inp)

In [None]:
import torchvision
from torchvision import transforms




def make_transform(resize_size: int = 224):
    to_tensor = transforms.ToTensor()
    resize = transforms.Resize((resize_size, resize_size), antialias=True)
    normalize = transforms.Normalize(
        mean=(0.430, 0.411, 0.296),
        std=(0.213, 0.156, 0.143),
    )
    return transforms.Compose([to_tensor, resize, normalize])

def make_transform_mask(resize_size: int = 224):
    to_tensor = transforms.ToTensor()
    resize = transforms.Resize((resize_size, resize_size), antialias=True)
    return transforms.Compose([to_tensor, resize])

In [None]:
import matplotlib.pyplot as plt 
import os 
import cv2 

files = os.listdir(f'/mnt/gis/image/18') 
img = cv2.imread(f'/mnt/gis/image/18/{files[1]}')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img)

In [None]:
files = os.listdir(f'/mnt/gis/image/18') 
img = cv2.imread(f'/mnt/gis/image/18/{files[9]}')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.imshow(img)

In [None]:
import numpy as np 

transform = make_transform()
inp = transform(img)
inp.shape
inp = torch.unsqueeze(inp, 0)
inp.shape

In [None]:
out = dinov3_vitl16(inp)

In [None]:
out.shape

In [None]:
emb = dinov3_vitl16.patch_embed(inp)
emb.shape

In [None]:
def get_patch_tokens_only(model, x):
    # Run full forward pass but return all tokens
    with torch.no_grad():
        # Manually run through the model
        x = model.patch_embed(x)
        
        # Through transformer blocks
        for block in model.blocks:
            x = block(x)
            
        # Final normalization
        x = model.norm(x)
        
        # Split CLS and patch tokens
        # Assuming first token is CLS (this might vary by model)
        if x.shape[1] > 1:  # More than just CLS token
            patch_tokens = x[:, 1:]  # All except first token
            cls_token = x[:, 0:1]    # First token only
        else:
            # If only CLS token, there are no patch tokens
            patch_tokens = None
            cls_token = x
            
        return patch_tokens, cls_token
    


In [None]:

import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import cv2
from pathlib import Path

patch_size = 14 *2

feature_dim = 1024

features = emb[0].detach().numpy()


patches = features.reshape(-1, feature_dim)
print(patches.shape)

scaler = StandardScaler()
patches_scaled = scaler.fit_transform(patches)

# Apply PCA to reduce to 3 components (RGB)
pca = PCA(n_components=3, random_state=42)
pca_result = pca.fit_transform(patches_scaled) 
pca_result.shape

pca_min = pca_result.min(axis=0)
pca_max = pca_result.max(axis=0)
pca_normalized = (pca_result - pca_min) / (pca_max - pca_min + 1e-8)


pca_spatial = pca_normalized.reshape(patch_size, patch_size, 3)
pca_spatial.shape

fig, axs = plt.subplots(1,3,figsize=(15,5))
axs[0].imshow(img)
axs[1].imshow(img)
axs[2].imshow(pca_spatial)
plt.show()

In [None]:
# Train

In [None]:
import requests
from PIL import Image
import os
import torch
from torch.utils.data import Dataset

from PIL import Image
import os
import albumentations as A
import cv2 
import numpy as np 
import matplotlib.pyplot as plt 
from gis.config import Config
from sklearn.model_selection import train_test_split
import albumentations as A
from torch.utils.data import DataLoader
config = Config()

def get_image_and_mask_files():
    mask_files = os.listdir(config.mnt_path / 'label/18')
    coords = []
    for mask in mask_files:
        x,y = mask.split('_')
        x,y = int(x), int(y.replace('.npy', ''))
        coords.append((x,y))
    image_files = [f'18_{x}_{y}.jpg' for (x,y) in coords]
    return image_files, mask_files


class SegmentationDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks 
        self.transform = transform
        self.image_dir = config.mnt_path / 'image/18'
        self.mask_dir = config.mnt_path / 'label/18'
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.masks[idx])  
        image = Image.open(img_path).convert("RGB")
        original_image = np.array(image)
        original_mask = np.load(mask_path)

        #transformed = self.transform(image=original_image, mask=original_mask)
        image = self.transform(original_image)
        #print(original_image.shape)
        image, target = torch.tensor(image), torch.LongTensor(original_mask)
        #print(image.shape)
        #mage = image.permute(2,0,1)
        #print(image.shape)
        return image, target, original_image, original_mask
    
model_config = {
        'batch_size': 4,
        'epochs': 5,
        'learning_rate': 1e-4,
        'val_split': 0.2,
        'num_workers': 4,
    }


image_files, mask_files = get_image_and_mask_files()

train_images, val_images, train_masks, val_masks = train_test_split(
    image_files, mask_files, test_size=model_config['val_split'], random_state=42
)


train_transform = make_transform()

train_dataset = SegmentationDataset(train_images, train_masks, transform=train_transform)
val_dataset = SegmentationDataset(val_images, val_masks, transform=make_transform_mask())


train_loader =  DataLoader(train_dataset, batch_size=model_config['batch_size'])
val_loader =  DataLoader(val_dataset, batch_size=model_config['batch_size'])


In [None]:
import torch.nn as nn
    
"""
It basically is. In the paper they state:

We perform linear probing on top of the dense features for two tasks: semantic segmentation and monocular depth estimation. 
In both cases, we train a linear transform on top of the frozen patch outputs of DINOv3.

Meaning they use the frozen DINOv3, train a linear layer nn.Linear(768, 256) 
and then you simply have to reshape the 256 to your 16x16 patch. 
Afterwards you would have a 224x224 segmentation image if the input is also 224x224.

https://www.reddit.com/r/computervision/comments/1mrvhrp/not_understanding_the_dense_feature_maps_of_dinov3/

[1, 784, 1024]
B, P*P, F
"""

class DINOv3Seg(nn.Module):
    def __init__(self):
        super(DINOv3Seg, self).__init__()

        self.dino = torch.hub.load(REPO_DIR, SAT_SMALL, source='local', weights=WEIGHTS_PATH)

        for param in self.dino.parameters():
            param.requires_grad = False

        self.head = nn.Linear(384, 1)
        
    def forward(self, x):
        B, C, H, W = x.shape
        x = self.dino.forward_features(x)['x_norm_patchtokens'] # torch.Size([1, 784, 1024])
        out = self.head(x)
        return out

# Unet inte

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
    
    
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)

    def forward(self, x1, x2=None):
        if x2 is not None:
            diffY = x1.size()[2] - x2.size()[2]
            diffX = x1.size()[3] - x2.size()[3]
            x2 = F.pad(x2, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
            x = torch.cat([x1, x2], dim=1)
        else:
            x = x1
        x = self.up(x)
        return self.conv(x)
    

class DINOv3_UNet(nn.Module):
    def __init__(self):
        super(DINOv3_UNet, self).__init__()

        self.dino = torch.hub.load(REPO_DIR, SAT_SMALL, source='local', weights=WEIGHTS_PATH)

        for param in self.dino.parameters():
            param.requires_grad = False

        self.reduce1 = nn.Conv2d(1024, 128, 1)
        self.reduce2 = nn.Conv2d(1024, 128, 1)
        self.reduce3 = nn.Conv2d(1024, 128, 1)
        self.reduce4 = nn.Conv2d(1024, 128, 1)

        self.up1 = Up(256, 128)
        self.up2 = Up(256, 128)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 128)
        self.head = nn.Conv2d(128, 1, 1)
        
    def forward(self, x):
        B, C, H, W = x.shape
        x = self.dino.forward_features(x)['x_norm_patchtokens']
        x = x.view(B, H//16, W//16, -1).permute(0, 3, 1, 2)
        # Create 4 different scales from same features 
        x1 = F.interpolate(self.reduce1(x), size=(H//4, W//4), mode='bilinear')
        x2 = F.interpolate(self.reduce2(x), size=(H//8, W//8), mode='bilinear')
        x3 = F.interpolate(self.reduce3(x), size=(H//16, W//16), mode='bilinear')
        x4 = F.interpolate(self.reduce4(x), size=(H//32, W//32), mode='bilinear')
        x = self.up4(x4)
        x = self.up3(x, x3)
        x = self.up2(x, x2)
        x = self.up1(x, x1)
        out = self.head(x)
        out = F.interpolate(self.head(x), scale_factor=2, mode='bilinear')
        return out

# REPO_DIR = '/home/fergus/repos/dinov3'
# WEIGHTS_PATH = '/home/fergus/dinov2_vitg14_reg4_pretrain.pth'
# WEIGHTS_PATH = '/home/fergus/dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth'
# dinoUnet = DINOv3_UNet()
# count_parameters(dinoUnet)

In [None]:
class SingleConv(nn.Module):
    """Single convolution => [BN] => ReLU"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class LightweightUp(nn.Module):
    """Lightweight upscaling with single conv"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = SingleConv(in_channels, out_channels)

    def forward(self, x1, x2=None):
        if x2 is not None:
            # Pad if needed
            diffY = x1.size()[2] - x2.size()[2]
            diffX = x1.size()[3] - x2.size()[3]
            x2 = F.pad(x2, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
            x = torch.cat([x1, x2], dim=1)
        else:
            x = x1
        x = self.up(x)
        return self.conv(x)

class UltraLightDINOv3_UNet(nn.Module):
    """Ultra-lightweight version with ~300K parameters"""
    def __init__(self, num_classes=1):
        super().__init__()
        
        # Load DINOv3 (frozen)
        self.dino = torch.hub.load(REPO_DIR, SAT_SMALL, source='local', weights=WEIGHTS_PATH)
        for param in self.dino.parameters():
            param.requires_grad = False
        
        # Much smaller channel dimensions
        hidden_dim = 32  # Reduced from 128 to 32
        
        # Channel reduction layers
        self.reduce1 = nn.Conv2d(1024, hidden_dim, 1)
        self.reduce2 = nn.Conv2d(1024, hidden_dim, 1)
        self.reduce3 = nn.Conv2d(1024, hidden_dim, 1)
        self.reduce4 = nn.Conv2d(1024, hidden_dim, 1)
        
        # Lightweight decoder
        self.up1 = LightweightUp(hidden_dim * 2, hidden_dim)
        self.up2 = LightweightUp(hidden_dim * 2, hidden_dim)
        self.up3 = LightweightUp(hidden_dim * 2, hidden_dim)
        self.up4 = LightweightUp(hidden_dim, hidden_dim)
        
        # Final head
        self.head = nn.Conv2d(hidden_dim, num_classes, 1)
        
    def forward(self, x):
        B, C, H, W = x.shape
        
        # Extract DINOv3 features
        features = self.dino.forward_features(x)['x_norm_patchtokens']
        features = features.view(B, H//16, W//16, -1).permute(0, 3, 1, 2)
        
        # Create multi-scale features
        x1 = F.interpolate(self.reduce1(features), size=(H//4, W//4), mode='bilinear')
        x2 = F.interpolate(self.reduce2(features), size=(H//8, W//8), mode='bilinear')
        x3 = F.interpolate(self.reduce3(features), size=(H//16, W//16), mode='bilinear')
        x4 = F.interpolate(self.reduce4(features), size=(H//32, W//32), mode='bilinear')
        
        # Decoder
        x = self.up4(x4)
        x = self.up3(x, x3)
        x = self.up2(x, x2)
        x = self.up1(x, x1)
        
        # Final output
        out = self.head(x)
        out = F.interpolate(out, scale_factor=2, mode='bilinear')
        
        return out

# Even more extreme version (~100K parameters)
class MinimalDINOv3_UNet(nn.Module):
    """Minimal version with ~100K parameters"""
    def __init__(self, num_classes=1):
        super().__init__()
        
        # Load DINOv3 (frozen)
        self.dino = torch.hub.load(REPO_DIR, SAT_SMALL, source='local', weights=WEIGHTS_PATH)
        for param in self.dino.parameters():
            param.requires_grad = False
        
        # Very small channels
        hidden_dim = 16
        
        # Single channel reduction
        self.reduce = nn.Conv2d(1024, hidden_dim, 1)
        
        # Simple decoder without skip connections
        self.decoder = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, num_classes, 1)
        )
        
    def forward(self, x):
        B, C, H, W = x.shape
        
        # Extract features
        features = self.dino.forward_features(x)['x_norm_patchtokens']
        features = features.view(B, H//16, W//16, -1).permute(0, 3, 1, 2)
        
        # Reduce channels and decode
        x = self.reduce(features)
        x = self.decoder(x)
        
        # Upsample to final size
        out = F.interpolate(x, size=(H//2, W//2), mode='bilinear')
        
        return out
    

model_minimal = MinimalDINOv3_UNet(num_classes=1)
print("\n=== Minimal Model ===")
count_parameters(model_minimal)

In [None]:
from tqdm import tqdm
import segmentation_models_pytorch as smp

def calculate_metrics(pred, target, threshold=0.5):
    """Calculate IoU, Dice, and other metrics."""
    pred_binary = (pred > threshold).float()
    target_binary = target.float()
    
    # IoU
    intersection = (pred_binary * target_binary).sum()
    union = pred_binary.sum() + target_binary.sum() - intersection
    iou = intersection / (union + 1e-8)
    
    # Dice coefficient
    dice = (2 * intersection) / (pred_binary.sum() + target_binary.sum() + 1e-8)
    
    # Pixel accuracy
    correct = (pred_binary == target_binary).sum()
    total = target_binary.numel()
    accuracy = correct / total
    
    return {
        'iou': iou.item(),
        'dice': dice.item(),
        'accuracy': accuracy.item()
    }

def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    train_loss = 0.0
    train_metrics = {'iou': 0.0, 'dice': 0.0, 'accuracy': 0.0}
    
    for images, masks, _, _ in train_loader:
        images = images.to(device)
        masks = masks.to(device)

        print(images.shape, masks.shape)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        print(outputs.shape, masks.shape, masks.unsqueeze(1).shape)
        
        loss = criterion(outputs, masks.unsqueeze(1))
        loss.backward()
        optimizer.step()
            
        
        train_loss += loss.item()
        
        # Calculate metrics
        with torch.no_grad():
            batch_metrics = calculate_metrics(torch.sigmoid(outputs), masks.unsqueeze(1))
            for key in train_metrics:
                train_metrics[key] += batch_metrics[key]
    
    train_loss /= len(train_loader)
    for key in train_metrics:
        train_metrics[key] /= len(train_loader)
    
    return train_loss, train_metrics

def validate_model(model, val_loader, criterion, device):
    """Validate the model."""
    model.eval()
    val_loss = 0.0
    val_metrics = {'iou': 0.0, 'dice': 0.0, 'accuracy': 0.0}
    
    with torch.no_grad():
        for images, masks, _, _ in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks.unsqueeze(1))  # Add channel dim for masks
            
            val_loss += loss.item()
            
            # Calculate metrics
            batch_metrics = calculate_metrics(outputs, masks.unsqueeze(1))
            for key in val_metrics:
                val_metrics[key] += batch_metrics[key]
    
    # Average metrics
    val_loss /= len(val_loader)
    for key in val_metrics:
        val_metrics[key] /= len(val_loader)
    
    return val_loss, val_metrics



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.AdamW(model_minimal.parameters(), lr=model_config['learning_rate'])
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)

model = model_minimal.to(device)



history = {
    'train_loss': [], 'val_loss': [],
    'train_iou': [], 'val_iou': [],
    'train_dice': [], 'val_dice': [],
    'train_accuracy': [], 'val_accuracy': []
}


for epoch in range(5):
    train_loss, train_metrics = train_epoch(
        model, train_loader, loss_fn, optimizer, device
    )
    val_loss, val_metrics = validate_model(model, val_loader, loss_fn, device) # todo: get eval loaders 

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_iou'].append(train_metrics['iou'])
    history['val_iou'].append(val_metrics['iou'])
    history['train_dice'].append(train_metrics['dice'])
    history['val_dice'].append(val_metrics['dice'])
    history['train_accuracy'].append(train_metrics['accuracy'])
    history['val_accuracy'].append(val_metrics['accuracy'])

    print(f"Train Loss: {train_loss:.4f}, Train IoU: {train_metrics['iou']:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val IoU: {val_metrics['iou']:.4f}")