In [1]:
import os
from datasets_gta5 import GTA5, CityScapes
import albumentations as A
import torch

CITYSCAPES_PATH = '/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/Cityscapes/Cityscapes'


device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

  check_for_updates()


In [2]:

transform = A.Compose([
    A.Resize(512, 1024),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
CITYSCAPES_dataset = CityScapes(CITYSCAPES_PATH, train_val='train', transform=transform)


def load_student_checkpoint(checkpoint_path: str) -> dict:
    """
    Load and process student model checkpoint.
    
    Args:
        checkpoint_path (str): Path to the checkpoint file
        
    Returns:
        dict: Processed state dict containing only student model weights
    """
    # Load checkpoint
    state_dict = torch.load(checkpoint_path)
    
    # Extract and process student weights
    student_state_dict = {k.replace('student.model.', ''): v 
                         for k, v in state_dict['state_dict'].items() 
                         if k.startswith('student') and not k.startswith('student.feature_matchers')}
    
    return student_state_dict

# Load checkpoint and save state dict
checkpoint_path = '/home/arda/dinov2/distillation/logs/resnet50/distillation/dino_s_full_scalekd/checkpoints/last.ckpt'
student_state_dict = load_student_checkpoint(checkpoint_path)
# torch.save(student_state_dict, '/home/arda/dinov2/distillation/logs/resnet50/distillation/version_7/checkpoints/student_state_dict.pth')
# student_state_dict.keys()






In [3]:
import sys
sys.path.append('/home/arda/dinov2/distillation')
from models.resnet_wrapper import ResNetWrapper
encoder = ResNetWrapper(depth=50, out_features=['res5'])
encoder.load_state_dict(student_state_dict)

<All keys matched successfully>

In [4]:

encoder.eval()
encoder.to(device)
encoder.model.eval()
encoder = encoder.model

In [5]:
asd = torch.randn(1, 3, 512, 1024).to(device)
encoder(asd)["res5"].shape

# # Freeze all parameters of the encoder
for param in encoder.parameters():
    param.requires_grad = False

In [6]:
# First, let's create a simple decoder network
import numpy as np
import tqdm as tqdm
class SegmentationDecoder(torch.nn.Module):
    def __init__(self, in_channels=2048, num_classes=19):
        super().__init__()
        self.decoder = torch.nn.Sequential(
            # 16x32 -> 32x64
            torch.nn.ConvTranspose2d(in_channels, 1024, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            
            # 32x64 -> 64x128
            torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            
            # 64x128 -> 128x256
            torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            
            # 128x256 -> 256x512
            torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            
            # 256x512 -> 512x1024
            torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            
            # Final 1x1 conv to get to num_classes
            torch.nn.Conv2d(64, num_classes, kernel_size=1)
        )
    def forward(self, x):
        x = self.decoder(x)
        # Ensure exact output size
        if x.shape[-2:] != (512, 1024):
            x = torch.nn.functional.interpolate(
                x, size=(512, 1024), 
                mode='bilinear', 
                align_corners=False
            )
        return x

# Initialize decoder, optimizer, and loss function
decoder = SegmentationDecoder().to(device)
optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:
    k = (b >= 0) & (b < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def per_class_iou(hist: np.ndarray) -> np.ndarray:
    epsilon = 1e-5
    return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)

def train_epoch(encoder, decoder, dataloader, optimizer, criterion, device, num_classes=19):
    decoder.train()
    encoder.train()  # Keep DINO frozen
    
    total_loss = 0
    hist = np.zeros((num_classes, num_classes))  # Single histogram for entire epoch
    total_pixels = 0
    correct_pixels = 0
    
    for images, labels in tqdm.tqdm(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Get DINO features
        # with torch.no_grad():
        features = encoder(images)["res5"]
        
        # Forward pass through decoder
        outputs = decoder(features)
        
        # Resize outputs to match label size if needed
        if outputs.shape[-2:] != labels.shape[-2:]:
            outputs = torch.nn.functional.interpolate(
                outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate metrics
        preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
        
        # Pixel Accuracy
        valid_mask = labels != 255  # Ignore index
        total_pixels += valid_mask.sum().item()
        correct_pixels += ((preds == labels) & valid_mask).sum().item()
        
        # IoU
        preds = preds.cpu().numpy()
        target = labels.cpu().numpy()
        hist += fast_hist(preds.flatten(), target.flatten(), num_classes)
    
    # Calculate final metrics
    pixel_acc = correct_pixels / total_pixels
    
    # Per-class accuracy (mean class accuracy)
    class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)
    mean_class_acc = np.nanmean(class_acc)
    
    # IoU metrics
    iou = per_class_iou(hist)
    mean_iou = np.nanmean(iou)
    
    metrics = {
        'loss': total_loss / len(dataloader),
        'pixel_acc': pixel_acc,
        'mean_class_acc': mean_class_acc,
        'mean_iou': mean_iou,
        'class_iou': iou,
        'class_acc': class_acc
    }
    
    return metrics

train_loader = torch.utils.data.DataLoader(
    CITYSCAPES_dataset, 
    batch_size=4,
    shuffle=True,
    num_workers=4
)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    metrics = train_epoch(encoder, decoder, train_loader, optimizer, criterion, device)
    
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Loss: {metrics['loss']:.4f}")
    print(f"Pixel Accuracy: {metrics['pixel_acc']:.4f}")
    print(f"Mean Class Accuracy: {metrics['mean_class_acc']:.4f}")
    print(f"Mean IoU: {metrics['mean_iou']:.4f}")
    
    # Optionally print per-class metrics
    print("\nPer-class metrics:")
    for i in range(19):  # Assuming 19 classes
        print(f"Class {i:2d} - Acc: {metrics['class_acc'][i]:.4f}, IoU: {metrics['class_iou'][i]:.4f}")

100%|██████████| 393/393 [02:21<00:00,  2.78it/s]



Epoch 1/10
Loss: 1.1662
Pixel Accuracy: 0.8138
Mean Class Accuracy: 0.2825
Mean IoU: 0.2166

Per-class metrics:
Class  0 - Acc: 0.9126, IoU: 0.8739
Class  1 - Acc: 0.6781, IoU: 0.4225
Class  2 - Acc: 0.8256, IoU: 0.7680
Class  3 - Acc: 0.0460, IoU: 0.0006
Class  4 - Acc: 0.0615, IoU: 0.0360
Class  5 - Acc: 0.0305, IoU: 0.0016
Class  6 - Acc: 0.0019, IoU: 0.0001
Class  7 - Acc: 0.0081, IoU: 0.0037
Class  8 - Acc: 0.7805, IoU: 0.6840
Class  9 - Acc: 0.0116, IoU: 0.0000
Class 10 - Acc: 0.8149, IoU: 0.5956
Class 11 - Acc: 0.4612, IoU: 0.1397
Class 12 - Acc: 0.0014, IoU: 0.0005
Class 13 - Acc: 0.7020, IoU: 0.5830
Class 14 - Acc: 0.0032, IoU: 0.0001
Class 15 - Acc: 0.0021, IoU: 0.0013
Class 16 - Acc: 0.0175, IoU: 0.0023
Class 17 - Acc: 0.0012, IoU: 0.0003
Class 18 - Acc: 0.0073, IoU: 0.0021


100%|██████████| 393/393 [02:20<00:00,  2.80it/s]



Epoch 2/10
Loss: 0.5651
Pixel Accuracy: 0.8816
Mean Class Accuracy: 0.4976
Mean IoU: 0.2904

Per-class metrics:
Class  0 - Acc: 0.9592, IoU: 0.9304
Class  1 - Acc: 0.7370, IoU: 0.6188
Class  2 - Acc: 0.8608, IoU: 0.8168
Class  3 - Acc: 0.0000, IoU: 0.0000
Class  4 - Acc: 0.5119, IoU: 0.2124
Class  5 - Acc: 0.2820, IoU: 0.0016
Class  6 - Acc: 1.0000, IoU: 0.0000
Class  7 - Acc: 0.7279, IoU: 0.0278
Class  8 - Acc: 0.8605, IoU: 0.7992
Class  9 - Acc: 0.0000, IoU: 0.0000
Class 10 - Acc: 0.8706, IoU: 0.8122
Class 11 - Acc: 0.6467, IoU: 0.5119
Class 12 - Acc: 0.1463, IoU: 0.0002
Class 13 - Acc: 0.8343, IoU: 0.7837
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0483, IoU: 0.0000
Class 16 - Acc: 0.5000, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.4685, IoU: 0.0030


100%|██████████| 393/393 [02:21<00:00,  2.77it/s]



Epoch 3/10
Loss: 0.4210
Pixel Accuracy: 0.8960
Mean Class Accuracy: 0.5986
Mean IoU: 0.3554

Per-class metrics:
Class  0 - Acc: 0.9645, IoU: 0.9386
Class  1 - Acc: 0.7750, IoU: 0.6584
Class  2 - Acc: 0.8789, IoU: 0.8361
Class  3 - Acc: 0.6822, IoU: 0.0666
Class  4 - Acc: 0.6243, IoU: 0.3690
Class  5 - Acc: 0.5865, IoU: 0.0989
Class  6 - Acc: 0.9640, IoU: 0.0003
Class  7 - Acc: 0.7165, IoU: 0.2929
Class  8 - Acc: 0.8781, IoU: 0.8194
Class  9 - Acc: 0.7898, IoU: 0.1991
Class 10 - Acc: 0.9085, IoU: 0.8502
Class 11 - Acc: 0.6884, IoU: 0.5675
Class 12 - Acc: 0.0370, IoU: 0.0000
Class 13 - Acc: 0.8581, IoU: 0.8112
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.4222, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.6002, IoU: 0.2435


100%|██████████| 393/393 [02:19<00:00,  2.82it/s]



Epoch 4/10
Loss: 0.3505
Pixel Accuracy: 0.9069
Mean Class Accuracy: 0.6120
Mean IoU: 0.4123

Per-class metrics:
Class  0 - Acc: 0.9670, IoU: 0.9431
Class  1 - Acc: 0.8099, IoU: 0.6863
Class  2 - Acc: 0.8968, IoU: 0.8517
Class  3 - Acc: 0.6561, IoU: 0.2631
Class  4 - Acc: 0.6630, IoU: 0.4198
Class  5 - Acc: 0.5851, IoU: 0.2231
Class  6 - Acc: 0.9374, IoU: 0.0008
Class  7 - Acc: 0.7381, IoU: 0.3857
Class  8 - Acc: 0.8985, IoU: 0.8363
Class  9 - Acc: 0.7329, IoU: 0.4603
Class 10 - Acc: 0.9302, IoU: 0.8741
Class 11 - Acc: 0.7221, IoU: 0.6021
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8739, IoU: 0.8282
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.6207, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.5972, IoU: 0.4600


100%|██████████| 393/393 [02:14<00:00,  2.93it/s]



Epoch 5/10
Loss: 0.3049
Pixel Accuracy: 0.9139
Mean Class Accuracy: 0.6498
Mean IoU: 0.4356

Per-class metrics:
Class  0 - Acc: 0.9705, IoU: 0.9483
Class  1 - Acc: 0.8308, IoU: 0.7137
Class  2 - Acc: 0.9082, IoU: 0.8635
Class  3 - Acc: 0.6716, IoU: 0.3519
Class  4 - Acc: 0.6955, IoU: 0.4612
Class  5 - Acc: 0.5986, IoU: 0.2626
Class  6 - Acc: 0.9005, IoU: 0.0574
Class  7 - Acc: 0.7365, IoU: 0.4327
Class  8 - Acc: 0.9072, IoU: 0.8488
Class  9 - Acc: 0.7428, IoU: 0.5167
Class 10 - Acc: 0.9357, IoU: 0.8820
Class 11 - Acc: 0.7428, IoU: 0.6246
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8743, IoU: 0.8303
Class 14 - Acc: 1.0000, IoU: 0.0000
Class 15 - Acc: 0.2370, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.5945, IoU: 0.4819


100%|██████████| 393/393 [02:20<00:00,  2.80it/s]



Epoch 6/10
Loss: 0.2738
Pixel Accuracy: 0.9188
Mean Class Accuracy: 0.6915
Mean IoU: 0.4599

Per-class metrics:
Class  0 - Acc: 0.9720, IoU: 0.9507
Class  1 - Acc: 0.8408, IoU: 0.7243
Class  2 - Acc: 0.9157, IoU: 0.8722
Class  3 - Acc: 0.6937, IoU: 0.4050
Class  4 - Acc: 0.7141, IoU: 0.4955
Class  5 - Acc: 0.6254, IoU: 0.2974
Class  6 - Acc: 0.7853, IoU: 0.2506
Class  7 - Acc: 0.7617, IoU: 0.4780
Class  8 - Acc: 0.9142, IoU: 0.8586
Class  9 - Acc: 0.7475, IoU: 0.5401
Class 10 - Acc: 0.9407, IoU: 0.8910
Class 11 - Acc: 0.7539, IoU: 0.6411
Class 12 - Acc: 0.7475, IoU: 0.0004
Class 13 - Acc: 0.8794, IoU: 0.8383
Class 14 - Acc: 0.7057, IoU: 0.0028
Class 15 - Acc: 0.5432, IoU: 0.0004
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.5969, IoU: 0.4914


100%|██████████| 393/393 [02:20<00:00,  2.81it/s]



Epoch 7/10
Loss: 0.2501
Pixel Accuracy: 0.9231
Mean Class Accuracy: 0.7839
Mean IoU: 0.4889

Per-class metrics:
Class  0 - Acc: 0.9732, IoU: 0.9529
Class  1 - Acc: 0.8523, IoU: 0.7382
Class  2 - Acc: 0.9199, IoU: 0.8778
Class  3 - Acc: 0.7037, IoU: 0.4368
Class  4 - Acc: 0.7293, IoU: 0.5157
Class  5 - Acc: 0.6461, IoU: 0.3227
Class  6 - Acc: 0.7421, IoU: 0.3254
Class  7 - Acc: 0.7762, IoU: 0.5007
Class  8 - Acc: 0.9195, IoU: 0.8659
Class  9 - Acc: 0.7566, IoU: 0.5598
Class 10 - Acc: 0.9416, IoU: 0.8931
Class 11 - Acc: 0.7689, IoU: 0.6583
Class 12 - Acc: 0.7301, IoU: 0.0495
Class 13 - Acc: 0.8960, IoU: 0.8560
Class 14 - Acc: 0.5891, IoU: 0.1150
Class 15 - Acc: 0.5989, IoU: 0.1243
Class 16 - Acc: 0.7992, IoU: 0.0043
Class 17 - Acc: 0.9621, IoU: 0.0001
Class 18 - Acc: 0.5884, IoU: 0.4930


100%|██████████| 393/393 [02:20<00:00,  2.80it/s]



Epoch 8/10
Loss: 0.2322
Pixel Accuracy: 0.9277
Mean Class Accuracy: 0.7966
Mean IoU: 0.5329

Per-class metrics:
Class  0 - Acc: 0.9750, IoU: 0.9556
Class  1 - Acc: 0.8591, IoU: 0.7504
Class  2 - Acc: 0.9254, IoU: 0.8848
Class  3 - Acc: 0.7436, IoU: 0.4877
Class  4 - Acc: 0.7582, IoU: 0.5637
Class  5 - Acc: 0.6597, IoU: 0.3410
Class  6 - Acc: 0.7186, IoU: 0.3561
Class  7 - Acc: 0.7838, IoU: 0.5196
Class  8 - Acc: 0.9216, IoU: 0.8691
Class  9 - Acc: 0.7768, IoU: 0.5870
Class 10 - Acc: 0.9446, IoU: 0.8974
Class 11 - Acc: 0.7795, IoU: 0.6685
Class 12 - Acc: 0.7390, IoU: 0.1863
Class 13 - Acc: 0.9096, IoU: 0.8685
Class 14 - Acc: 0.6682, IoU: 0.1868
Class 15 - Acc: 0.6368, IoU: 0.2739
Class 16 - Acc: 0.8174, IoU: 0.1717
Class 17 - Acc: 0.9026, IoU: 0.0455
Class 18 - Acc: 0.6159, IoU: 0.5121


100%|██████████| 393/393 [02:19<00:00,  2.81it/s]



Epoch 9/10
Loss: 0.2147
Pixel Accuracy: 0.9327
Mean Class Accuracy: 0.7997
Mean IoU: 0.5857

Per-class metrics:
Class  0 - Acc: 0.9769, IoU: 0.9588
Class  1 - Acc: 0.8701, IoU: 0.7666
Class  2 - Acc: 0.9299, IoU: 0.8892
Class  3 - Acc: 0.7605, IoU: 0.5222
Class  4 - Acc: 0.7869, IoU: 0.5979
Class  5 - Acc: 0.6726, IoU: 0.3623
Class  6 - Acc: 0.7266, IoU: 0.3846
Class  7 - Acc: 0.7954, IoU: 0.5436
Class  8 - Acc: 0.9254, IoU: 0.8750
Class  9 - Acc: 0.8002, IoU: 0.6197
Class 10 - Acc: 0.9459, IoU: 0.9000
Class 11 - Acc: 0.7977, IoU: 0.6855
Class 12 - Acc: 0.7064, IoU: 0.3347
Class 13 - Acc: 0.9270, IoU: 0.8860
Class 14 - Acc: 0.6848, IoU: 0.3512
Class 15 - Acc: 0.5764, IoU: 0.3170
Class 16 - Acc: 0.7761, IoU: 0.3217
Class 17 - Acc: 0.8636, IoU: 0.2589
Class 18 - Acc: 0.6728, IoU: 0.5532


100%|██████████| 393/393 [02:20<00:00,  2.80it/s]


Epoch 10/10
Loss: 0.1991
Pixel Accuracy: 0.9366
Mean Class Accuracy: 0.8118
Mean IoU: 0.6237

Per-class metrics:
Class  0 - Acc: 0.9780, IoU: 0.9601
Class  1 - Acc: 0.8750, IoU: 0.7763
Class  2 - Acc: 0.9349, IoU: 0.8955
Class  3 - Acc: 0.7752, IoU: 0.5582
Class  4 - Acc: 0.7906, IoU: 0.6049
Class  5 - Acc: 0.6852, IoU: 0.3806
Class  6 - Acc: 0.7266, IoU: 0.4059
Class  7 - Acc: 0.7930, IoU: 0.5532
Class  8 - Acc: 0.9298, IoU: 0.8823
Class  9 - Acc: 0.8174, IoU: 0.6443
Class 10 - Acc: 0.9486, IoU: 0.9030
Class 11 - Acc: 0.8080, IoU: 0.6962
Class 12 - Acc: 0.6959, IoU: 0.3880
Class 13 - Acc: 0.9354, IoU: 0.8954
Class 14 - Acc: 0.7078, IoU: 0.4313
Class 15 - Acc: 0.6869, IoU: 0.4285
Class 16 - Acc: 0.7975, IoU: 0.4008
Class 17 - Acc: 0.8096, IoU: 0.4573
Class 18 - Acc: 0.7282, IoU: 0.5883





In [16]:
Epoch 10/10
Loss: 0.1003
Pixel Accuracy: 0.9679
Mean Class Accuracy: 0.8727
Mean IoU: 0.7644

Per-class metrics:
Class  0 - Acc: 0.9918, IoU: 0.9840
Class  1 - Acc: 0.9472, IoU: 0.9003
Class  2 - Acc: 0.9699, IoU: 0.9463
Class  3 - Acc: 0.9133, IoU: 0.8134
Class  4 - Acc: 0.9030, IoU: 0.8098
Class  5 - Acc: 0.8038, IoU: 0.5960
Class  6 - Acc: 0.7877, IoU: 0.5778
Class  7 - Acc: 0.8654, IoU: 0.7161
Class  8 - Acc: 0.9620, IoU: 0.9328
Class  9 - Acc: 0.9049, IoU: 0.8156
Class 10 - Acc: 0.9672, IoU: 0.9387
Class 11 - Acc: 0.8874, IoU: 0.8046
Class 12 - Acc: 0.7999, IoU: 0.6324
Class 13 - Acc: 0.9691, IoU: 0.9456
Class 14 - Acc: 0.9386, IoU: 0.8745
Class 15 - Acc: 0.6084, IoU: 0.2200
Class 16 - Acc: 0.6444, IoU: 0.5633
Class 17 - Acc: 0.8798, IoU: 0.7277
Class 18 - Acc: 0.8375, IoU: 0.7250

SyntaxError: invalid syntax (1867324368.py, line 1)

In [9]:
tuple([1,2,3])

(1, 2, 3)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttentionDistiller(nn.Module):
    def __init__(self, 
                 student_dim,   # e.g., 256 or 512
                 teacher_dim,   # e.g., 768
                 hidden_dim,    # dimension for cross-attention
                 num_heads=8):
        super().__init__()
        # Linear projections to match dimension
        self.student_proj = nn.Linear(student_dim, hidden_dim, bias=False)
        self.teacher_proj = nn.Linear(teacher_dim, hidden_dim, bias=False)
        
        # Using PyTorch's multi-head attention
        # batch_first=True means input shape is [B, seq_len, dim]
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, 
                                                num_heads=num_heads, 
                                                batch_first=True)
        
        # Optional final projection back to student_dim
        self.proj_back = nn.Linear(hidden_dim, student_dim, bias=False)

    def forward(self, student_map, teacher_map):
        """
        student_map: [B, C_s, H_s, W_s]
        teacher_map: [B, C_t, H_t, W_t]
        returns cross_attended_map: [B, C_s, H_s, W_s] (updated student features)
        """
        B, C_s, H_s, W_s = student_map.shape
        _, C_t, H_t, W_t = teacher_map.shape
        
        # Flatten [B, C_s, H_s, W_s] -> [B, H_s*W_s, C_s], then project
        student_tokens = student_map.permute(0,2,3,1).reshape(B, H_s*W_s, C_s)
        student_tokens = self.student_proj(student_tokens)  # => [B, H_s*W_s, hidden_dim]
        
        # Flatten teacher -> [B, H_t*W_t, C_t], then project
        teacher_tokens = teacher_map.permute(0,2,3,1).reshape(B, H_t*W_t, C_t)
        teacher_tokens = self.teacher_proj(teacher_tokens)  # => [B, H_t*W_t, hidden_dim]
        
        # Cross-Attention: Q=student, K=teacher, V=teacher
        cross_attended, _ = self.cross_attn(query=student_tokens,
                                            key=teacher_tokens,
                                            value=teacher_tokens)
        
        # Project back to student dimension if needed
        cross_attended = self.proj_back(cross_attended)  # [B, H_s*W_s, C_s]
        
        # Reshape back to [B, C_s, H_s, W_s]
        cross_attended_map = cross_attended.view(B, H_s, W_s, C_s).permute(0, 3, 1, 2)
        
        return cross_attended_map


# Example usage:
if __name__ == "__main__":
    B = 2
    student_map = torch.randn(B, 256, 16, 16)  # e.g., [B, C_s, H_s, W_s]
    teacher_map = torch.randn(B, 768, 8, 8)    # [B, C_t, H_t, W_t]
    
    # Suppose we choose hidden_dim=384
    distiller = CrossAttentionDistiller(student_dim=256, teacher_dim=768, hidden_dim=384, num_heads=6)
    updated_student_map = distiller(student_map, teacher_map)
    
    print("updated_student_map shape:", updated_student_map.shape)
    # => [2, 256, 16, 16]
    
    # Distillation loss (example)
    # You'd typically need to align teacher_map shape or do some pooling, 
    # but let's just do a simple MSE with naive upsampling of teacher_map:
    teacher_map_upsampled = F.interpolate(teacher_map, size=(16,16), mode='bilinear')
    loss = F.mse_loss(F.normalize(updated_student_map, dim=1),
                      F.normalize(teacher_map_upsampled, dim=1))
    print("distillation loss:", loss.item())


updated_student_map shape: torch.Size([2, 256, 16, 16])


  loss = F.mse_loss(F.normalize(updated_student_map, dim=1),


RuntimeError: The size of tensor a (256) must match the size of tensor b (768) at non-singleton dimension 1