In [1]:
import os
from datasets_gta5 import GTA5, CityScapes
import albumentations as A
import torch

CITYSCAPES_PATH = '/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/Cityscapes/Cityscapes'


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  check_for_updates()


In [2]:

transform = A.Compose([
    A.Resize(512, 1024),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
CITYSCAPES_dataset = CityScapes(CITYSCAPES_PATH, train_val='train', transform=transform)


def load_student_checkpoint(checkpoint_path: str) -> dict:
    """
    Load and process student model checkpoint.
    
    Args:
        checkpoint_path (str): Path to the checkpoint file
        
    Returns:
        dict: Processed state dict containing only student model weights
    """
    # Load checkpoint
    state_dict = torch.load(checkpoint_path)
    
    # Extract and process student weights
    student_state_dict = {k.replace('student.model.', ''): v 
                         for k, v in state_dict['state_dict'].items() 
                         if k.startswith('student') and not k.startswith('student.feature_matchers')}
    
    return student_state_dict

# Load checkpoint and save state dict
checkpoint_path = '/home/arda/dinov2/distillation/logs/resnet50/distillation/version_69/checkpoints/epoch=2-val_similarity=1.00.ckpt'
student_state_dict = load_student_checkpoint(checkpoint_path)
# torch.save(student_state_dict, '/home/arda/dinov2/distillation/logs/resnet50/distillation/version_7/checkpoints/student_state_dict.pth')
# student_state_dict.keys()






In [7]:
student_state_dict

{'model.stem.conv1.weight': tensor([[[[nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan],
           ...,
           [nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan]],
 
          [[nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan],
           ...,
           [nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan]],
 
          [[nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan],
           ...,
           [nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan],
           [nan, nan, nan,  ..., nan, nan, nan]]],
 
 
         [[[nan, nan, nan,  ...

In [3]:
import sys
sys.path.append('/home/arda/dinov2/distillation')
from models.resnet_wrapper import ResNetWrapper
encoder = ResNetWrapper(depth=50, out_features=['res5'])
encoder.load_state_dict(student_state_dict)

<All keys matched successfully>

In [4]:

encoder.eval()
encoder.to(device)
encoder.model.eval()
encoder = encoder.model

In [5]:
asd = torch.randn(1, 3, 512, 1024).to(device)
encoder(asd)["res5"].shape

# # Freeze all parameters of the encoder
# for param in encoder.parameters():
#     param.requires_grad = False

torch.Size([1, 2048, 16, 32])

In [6]:
# First, let's create a simple decoder network
import numpy as np
import tqdm as tqdm
class SegmentationDecoder(torch.nn.Module):
    def __init__(self, in_channels=2048, num_classes=19):
        super().__init__()
        self.decoder = torch.nn.Sequential(
            # 16x32 -> 32x64
            torch.nn.ConvTranspose2d(in_channels, 1024, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            
            # 32x64 -> 64x128
            torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            
            # 64x128 -> 128x256
            torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            
            # 128x256 -> 256x512
            torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            
            # 256x512 -> 512x1024
            torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            
            # Final 1x1 conv to get to num_classes
            torch.nn.Conv2d(64, num_classes, kernel_size=1)
        )
    def forward(self, x):
        x = self.decoder(x)
        # Ensure exact output size
        if x.shape[-2:] != (512, 1024):
            x = torch.nn.functional.interpolate(
                x, size=(512, 1024), 
                mode='bilinear', 
                align_corners=False
            )
        return x

# Initialize decoder, optimizer, and loss function
decoder = SegmentationDecoder().to(device)
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:
    k = (b >= 0) & (b < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def per_class_iou(hist: np.ndarray) -> np.ndarray:
    epsilon = 1e-5
    return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)

def train_epoch(encoder, decoder, dataloader, optimizer, criterion, device, num_classes=19):
    decoder.train()
    encoder.train()  # Keep DINO frozen
    
    total_loss = 0
    hist = np.zeros((num_classes, num_classes))  # Single histogram for entire epoch
    total_pixels = 0
    correct_pixels = 0
    
    for images, labels in tqdm.tqdm(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Get DINO features
        # with torch.no_grad():
        features = encoder(images)["res5"]
        
        # Forward pass through decoder
        outputs = decoder(features)
        
        # Resize outputs to match label size if needed
        if outputs.shape[-2:] != labels.shape[-2:]:
            outputs = torch.nn.functional.interpolate(
                outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
        
        # Calculate loss
        print(outputs.shape, labels.shape)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate metrics
        preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
        
        # Pixel Accuracy
        valid_mask = labels != 255  # Ignore index
        total_pixels += valid_mask.sum().item()
        correct_pixels += ((preds == labels) & valid_mask).sum().item()
        
        # IoU
        preds = preds.cpu().numpy()
        target = labels.cpu().numpy()
        hist += fast_hist(preds.flatten(), target.flatten(), num_classes)
    
    # Calculate final metrics
    pixel_acc = correct_pixels / total_pixels
    
    # Per-class accuracy (mean class accuracy)
    class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)
    mean_class_acc = np.nanmean(class_acc)
    
    # IoU metrics
    iou = per_class_iou(hist)
    mean_iou = np.nanmean(iou)
    
    metrics = {
        'loss': total_loss / len(dataloader),
        'pixel_acc': pixel_acc,
        'mean_class_acc': mean_class_acc,
        'mean_iou': mean_iou,
        'class_iou': iou,
        'class_acc': class_acc
    }
    
    return metrics

train_loader = torch.utils.data.DataLoader(
    CITYSCAPES_dataset, 
    batch_size=4,
    shuffle=True,
    num_workers=4
)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    metrics = train_epoch(encoder, decoder, train_loader, optimizer, criterion, device)
    
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Loss: {metrics['loss']:.4f}")
    print(f"Pixel Accuracy: {metrics['pixel_acc']:.4f}")
    print(f"Mean Class Accuracy: {metrics['mean_class_acc']:.4f}")
    print(f"Mean IoU: {metrics['mean_iou']:.4f}")
    
    # Optionally print per-class metrics
    print("\nPer-class metrics:")
    for i in range(19):  # Assuming 19 classes
        print(f"Class {i:2d} - Acc: {metrics['class_acc'][i]:.4f}, IoU: {metrics['class_iou'][i]:.4f}")

100%|██████████| 393/393 [03:00<00:00,  2.18it/s]



Epoch 1/10
Loss: nan
Pixel Accuracy: 0.3638
Mean Class Accuracy: 0.0191
Mean IoU: 0.0191

Per-class metrics:
Class  0 - Acc: 0.3638, IoU: 0.3638
Class  1 - Acc: 0.0000, IoU: 0.0000
Class  2 - Acc: 0.0000, IoU: 0.0000
Class  3 - Acc: 0.0000, IoU: 0.0000
Class  4 - Acc: 0.0000, IoU: 0.0000
Class  5 - Acc: 0.0000, IoU: 0.0000
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0000, IoU: 0.0000
Class  8 - Acc: 0.0000, IoU: 0.0000
Class  9 - Acc: 0.0000, IoU: 0.0000
Class 10 - Acc: 0.0000, IoU: 0.0000
Class 11 - Acc: 0.0000, IoU: 0.0000
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.0000, IoU: 0.0000
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 393/393 [02:59<00:00,  2.19it/s]



Epoch 2/10
Loss: nan
Pixel Accuracy: 0.3638
Mean Class Accuracy: 0.0191
Mean IoU: 0.0191

Per-class metrics:
Class  0 - Acc: 0.3638, IoU: 0.3638
Class  1 - Acc: 0.0000, IoU: 0.0000
Class  2 - Acc: 0.0000, IoU: 0.0000
Class  3 - Acc: 0.0000, IoU: 0.0000
Class  4 - Acc: 0.0000, IoU: 0.0000
Class  5 - Acc: 0.0000, IoU: 0.0000
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0000, IoU: 0.0000
Class  8 - Acc: 0.0000, IoU: 0.0000
Class  9 - Acc: 0.0000, IoU: 0.0000
Class 10 - Acc: 0.0000, IoU: 0.0000
Class 11 - Acc: 0.0000, IoU: 0.0000
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.0000, IoU: 0.0000
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 393/393 [03:00<00:00,  2.18it/s]



Epoch 3/10
Loss: nan
Pixel Accuracy: 0.3638
Mean Class Accuracy: 0.0191
Mean IoU: 0.0191

Per-class metrics:
Class  0 - Acc: 0.3638, IoU: 0.3638
Class  1 - Acc: 0.0000, IoU: 0.0000
Class  2 - Acc: 0.0000, IoU: 0.0000
Class  3 - Acc: 0.0000, IoU: 0.0000
Class  4 - Acc: 0.0000, IoU: 0.0000
Class  5 - Acc: 0.0000, IoU: 0.0000
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0000, IoU: 0.0000
Class  8 - Acc: 0.0000, IoU: 0.0000
Class  9 - Acc: 0.0000, IoU: 0.0000
Class 10 - Acc: 0.0000, IoU: 0.0000
Class 11 - Acc: 0.0000, IoU: 0.0000
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.0000, IoU: 0.0000
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


100%|██████████| 393/393 [03:00<00:00,  2.18it/s]



Epoch 4/10
Loss: nan
Pixel Accuracy: 0.3638
Mean Class Accuracy: 0.0191
Mean IoU: 0.0191

Per-class metrics:
Class  0 - Acc: 0.3638, IoU: 0.3638
Class  1 - Acc: 0.0000, IoU: 0.0000
Class  2 - Acc: 0.0000, IoU: 0.0000
Class  3 - Acc: 0.0000, IoU: 0.0000
Class  4 - Acc: 0.0000, IoU: 0.0000
Class  5 - Acc: 0.0000, IoU: 0.0000
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.0000, IoU: 0.0000
Class  8 - Acc: 0.0000, IoU: 0.0000
Class  9 - Acc: 0.0000, IoU: 0.0000
Class 10 - Acc: 0.0000, IoU: 0.0000
Class 11 - Acc: 0.0000, IoU: 0.0000
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.0000, IoU: 0.0000
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000


 88%|████████▊ | 344/393 [02:36<00:22,  2.19it/s]


KeyboardInterrupt: 

In [7]:
import torch

x = torch.randn(10)
y = torch.randn(10)
x,y

(tensor([-0.3624,  0.8973, -1.0784,  0.3372, -1.4589, -0.1333,  0.2062,  0.4842,
          1.9451, -0.1368]),
 tensor([-0.2461,  0.7652,  0.0446, -0.3514,  0.1806,  0.4895,  0.2342,  0.8889,
         -0.8288, -0.6070]))

In [9]:
tuple([1,2,3])

(1, 2, 3)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttentionDistiller(nn.Module):
    def __init__(self, 
                 student_dim,   # e.g., 256 or 512
                 teacher_dim,   # e.g., 768
                 hidden_dim,    # dimension for cross-attention
                 num_heads=8):
        super().__init__()
        # Linear projections to match dimension
        self.student_proj = nn.Linear(student_dim, hidden_dim, bias=False)
        self.teacher_proj = nn.Linear(teacher_dim, hidden_dim, bias=False)
        
        # Using PyTorch's multi-head attention
        # batch_first=True means input shape is [B, seq_len, dim]
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, 
                                                num_heads=num_heads, 
                                                batch_first=True)
        
        # Optional final projection back to student_dim
        self.proj_back = nn.Linear(hidden_dim, student_dim, bias=False)

    def forward(self, student_map, teacher_map):
        """
        student_map: [B, C_s, H_s, W_s]
        teacher_map: [B, C_t, H_t, W_t]
        returns cross_attended_map: [B, C_s, H_s, W_s] (updated student features)
        """
        B, C_s, H_s, W_s = student_map.shape
        _, C_t, H_t, W_t = teacher_map.shape
        
        # Flatten [B, C_s, H_s, W_s] -> [B, H_s*W_s, C_s], then project
        student_tokens = student_map.permute(0,2,3,1).reshape(B, H_s*W_s, C_s)
        student_tokens = self.student_proj(student_tokens)  # => [B, H_s*W_s, hidden_dim]
        
        # Flatten teacher -> [B, H_t*W_t, C_t], then project
        teacher_tokens = teacher_map.permute(0,2,3,1).reshape(B, H_t*W_t, C_t)
        teacher_tokens = self.teacher_proj(teacher_tokens)  # => [B, H_t*W_t, hidden_dim]
        
        # Cross-Attention: Q=student, K=teacher, V=teacher
        cross_attended, _ = self.cross_attn(query=student_tokens,
                                            key=teacher_tokens,
                                            value=teacher_tokens)
        
        # Project back to student dimension if needed
        cross_attended = self.proj_back(cross_attended)  # [B, H_s*W_s, C_s]
        
        # Reshape back to [B, C_s, H_s, W_s]
        cross_attended_map = cross_attended.view(B, H_s, W_s, C_s).permute(0, 3, 1, 2)
        
        return cross_attended_map


# Example usage:
if __name__ == "__main__":
    B = 2
    student_map = torch.randn(B, 256, 16, 16)  # e.g., [B, C_s, H_s, W_s]
    teacher_map = torch.randn(B, 768, 8, 8)    # [B, C_t, H_t, W_t]
    
    # Suppose we choose hidden_dim=384
    distiller = CrossAttentionDistiller(student_dim=256, teacher_dim=768, hidden_dim=384, num_heads=6)
    updated_student_map = distiller(student_map, teacher_map)
    
    print("updated_student_map shape:", updated_student_map.shape)
    # => [2, 256, 16, 16]
    
    # Distillation loss (example)
    # You'd typically need to align teacher_map shape or do some pooling, 
    # but let's just do a simple MSE with naive upsampling of teacher_map:
    teacher_map_upsampled = F.interpolate(teacher_map, size=(16,16), mode='bilinear')
    loss = F.mse_loss(F.normalize(updated_student_map, dim=1),
                      F.normalize(teacher_map_upsampled, dim=1))
    print("distillation loss:", loss.item())


updated_student_map shape: torch.Size([2, 256, 16, 16])


  loss = F.mse_loss(F.normalize(updated_student_map, dim=1),


RuntimeError: The size of tensor a (256) must match the size of tensor b (768) at non-singleton dimension 1