In [1]:
import os
from datasets_gta5 import GTA5, CityScapes
import albumentations as A
import torch

CITYSCAPES_PATH = '/home/arda/.cache/kagglehub/datasets/ardaerendoru/gtagta/versions/1/Cityscapes/Cityscapes'


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  check_for_updates()


In [7]:

transform = A.Compose([
    A.Resize(512, 1024),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
CITYSCAPES_dataset = CityScapes(CITYSCAPES_PATH, train_val='train', transform=transform)


def load_student_checkpoint(checkpoint_path: str) -> dict:
    """
    Load and process student model checkpoint.
    
    Args:
        checkpoint_path (str): Path to the checkpoint file
        
    Returns:
        dict: Processed state dict containing only student model weights
    """
    # Load checkpoint
    state_dict = torch.load(checkpoint_path)
    
    # Extract and process student weights
    student_state_dict = {k.replace('student.model.', ''): v 
                         for k, v in state_dict['state_dict'].items() 
                         if k.startswith('student') and not k.startswith('student.feature_matchers')}
    
    return student_state_dict

# Load checkpoint and save state dict
checkpoint_path = '/home/arda/dinov2/distillation/logs/stdc2/distillation/version_144/checkpoints/last.ckpt'
student_state_dict = load_student_checkpoint(checkpoint_path)
# torch.save(student_state_dict, '/home/arda/dinov2/distillation/logs/resnet50/distillation/version_7/checkpoints/student_state_dict.pth')
# student_state_dict.keys()






In [8]:
import sys
sys.path.append('/home/arda/dinov2/distillation')
from models.stdc_wrapper import STDCWrapper
encoder = STDCWrapper()
encoder.load_state_dict(student_state_dict)

<All keys matched successfully>

In [9]:

encoder.eval()
encoder.to(device)
encoder.model.eval()
encoder = encoder.model

In [10]:
asd = torch.randn(1, 3, 512, 1024).to(device)
encoder(asd)["res5"].shape

# # Freeze all parameters of the encoder
for param in encoder.parameters():
    param.requires_grad = False

In [11]:
# First, let's create a simple decoder network
import numpy as np
import tqdm as tqdm
class SegmentationDecoder(torch.nn.Module):
    def __init__(self, in_channels=1024, num_classes=19):
        super().__init__()
        self.decoder = torch.nn.Sequential(
            # 16x32 -> 32x64
            torch.nn.ConvTranspose2d(in_channels, 1024, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            
            # 32x64 -> 64x128
            torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            
            # 64x128 -> 128x256
            torch.nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            
            # 128x256 -> 256x512
            torch.nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            
            # 256x512 -> 512x1024
            torch.nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            
            # Final 1x1 conv to get to num_classes
            torch.nn.Conv2d(64, num_classes, kernel_size=1)
        )
    def forward(self, x):
        x = self.decoder(x)
        # Ensure exact output size
        if x.shape[-2:] != (512, 1024):
            x = torch.nn.functional.interpolate(
                x, size=(512, 1024), 
                mode='bilinear', 
                align_corners=False
            )
        return x

# Initialize decoder, optimizer, and loss function
decoder = SegmentationDecoder().to(device)
optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

def fast_hist(a: np.ndarray, b: np.ndarray, n: int) -> np.ndarray:
    k = (b >= 0) & (b < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def per_class_iou(hist: np.ndarray) -> np.ndarray:
    epsilon = 1e-5
    return (np.diag(hist)) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)

def train_epoch(encoder, decoder, dataloader, optimizer, criterion, device, num_classes=19):
    decoder.train()
    encoder.train()  # Keep DINO frozen
    
    total_loss = 0
    hist = np.zeros((num_classes, num_classes))  # Single histogram for entire epoch
    total_pixels = 0
    correct_pixels = 0
    
    for images, labels in tqdm.tqdm(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # Get DINO features
        # with torch.no_grad():
        features = encoder(images)["res5"]
        
        # Forward pass through decoder
        outputs = decoder(features)
        
        # Resize outputs to match label size if needed
        if outputs.shape[-2:] != labels.shape[-2:]:
            outputs = torch.nn.functional.interpolate(
                outputs, size=labels.shape[-2:], mode='bilinear', align_corners=False)
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        # Calculate metrics
        preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
        
        # Pixel Accuracy
        valid_mask = labels != 255  # Ignore index
        total_pixels += valid_mask.sum().item()
        correct_pixels += ((preds == labels) & valid_mask).sum().item()
        
        # IoU
        preds = preds.cpu().numpy()
        target = labels.cpu().numpy()
        hist += fast_hist(preds.flatten(), target.flatten(), num_classes)
    
    # Calculate final metrics
    pixel_acc = correct_pixels / total_pixels
    
    # Per-class accuracy (mean class accuracy)
    class_acc = np.diag(hist) / (hist.sum(1) + np.finfo(np.float32).eps)
    mean_class_acc = np.nanmean(class_acc)
    
    # IoU metrics
    iou = per_class_iou(hist)
    mean_iou = np.nanmean(iou)
    
    metrics = {
        'loss': total_loss / len(dataloader),
        'pixel_acc': pixel_acc,
        'mean_class_acc': mean_class_acc,
        'mean_iou': mean_iou,
        'class_iou': iou,
        'class_acc': class_acc
    }
    
    return metrics

train_loader = torch.utils.data.DataLoader(
    CITYSCAPES_dataset, 
    batch_size=4,
    shuffle=True,
    num_workers=4
)
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    metrics = train_epoch(encoder, decoder, train_loader, optimizer, criterion, device)
    
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Loss: {metrics['loss']:.4f}")
    print(f"Pixel Accuracy: {metrics['pixel_acc']:.4f}")
    print(f"Mean Class Accuracy: {metrics['mean_class_acc']:.4f}")
    print(f"Mean IoU: {metrics['mean_iou']:.4f}")
    
    # Optionally print per-class metrics
    print("\nPer-class metrics:")
    for i in range(19):  # Assuming 19 classes
        print(f"Class {i:2d} - Acc: {metrics['class_acc'][i]:.4f}, IoU: {metrics['class_iou'][i]:.4f}")

100%|██████████| 393/393 [01:59<00:00,  3.29it/s]



Epoch 1/10
Loss: 1.1482
Pixel Accuracy: 0.8062
Mean Class Accuracy: 0.2691
Mean IoU: 0.2060

Per-class metrics:
Class  0 - Acc: 0.9333, IoU: 0.8824
Class  1 - Acc: 0.5396, IoU: 0.4129
Class  2 - Acc: 0.8299, IoU: 0.7430
Class  3 - Acc: 0.0137, IoU: 0.0007
Class  4 - Acc: 0.0202, IoU: 0.0008
Class  5 - Acc: 0.0323, IoU: 0.0016
Class  6 - Acc: 0.0010, IoU: 0.0008
Class  7 - Acc: 0.0051, IoU: 0.0013
Class  8 - Acc: 0.7395, IoU: 0.6839
Class  9 - Acc: 0.3263, IoU: 0.0159
Class 10 - Acc: 0.8010, IoU: 0.5600
Class 11 - Acc: 0.0605, IoU: 0.0205
Class 12 - Acc: 0.0013, IoU: 0.0003
Class 13 - Acc: 0.7810, IoU: 0.5808
Class 14 - Acc: 0.0070, IoU: 0.0001
Class 15 - Acc: 0.0032, IoU: 0.0004
Class 16 - Acc: 0.0001, IoU: 0.0001
Class 17 - Acc: 0.0020, IoU: 0.0006
Class 18 - Acc: 0.0160, IoU: 0.0071


100%|██████████| 393/393 [01:59<00:00,  3.29it/s]



Epoch 2/10
Loss: 0.6058
Pixel Accuracy: 0.8665
Mean Class Accuracy: 0.3939
Mean IoU: 0.2810

Per-class metrics:
Class  0 - Acc: 0.9522, IoU: 0.9181
Class  1 - Acc: 0.7002, IoU: 0.5675
Class  2 - Acc: 0.8410, IoU: 0.7895
Class  3 - Acc: 0.0377, IoU: 0.0001
Class  4 - Acc: 0.5044, IoU: 0.0051
Class  5 - Acc: 0.0446, IoU: 0.0005
Class  6 - Acc: 0.0075, IoU: 0.0016
Class  7 - Acc: 0.1253, IoU: 0.0006
Class  8 - Acc: 0.8372, IoU: 0.7698
Class  9 - Acc: 0.5755, IoU: 0.3032
Class 10 - Acc: 0.8786, IoU: 0.8005
Class 11 - Acc: 0.6337, IoU: 0.3313
Class 12 - Acc: 0.0037, IoU: 0.0000
Class 13 - Acc: 0.8140, IoU: 0.7533
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0224, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0005, IoU: 0.0000
Class 18 - Acc: 0.5051, IoU: 0.0972


100%|██████████| 393/393 [01:59<00:00,  3.30it/s]



Epoch 3/10
Loss: 0.4690
Pixel Accuracy: 0.8816
Mean Class Accuracy: 0.4825
Mean IoU: 0.3277

Per-class metrics:
Class  0 - Acc: 0.9584, IoU: 0.9279
Class  1 - Acc: 0.7488, IoU: 0.6120
Class  2 - Acc: 0.8497, IoU: 0.8035
Class  3 - Acc: 0.3028, IoU: 0.0004
Class  4 - Acc: 0.5300, IoU: 0.1452
Class  5 - Acc: 0.3054, IoU: 0.0023
Class  6 - Acc: 0.0035, IoU: 0.0000
Class  7 - Acc: 0.7309, IoU: 0.0163
Class  8 - Acc: 0.8591, IoU: 0.7927
Class  9 - Acc: 0.6554, IoU: 0.4041
Class 10 - Acc: 0.8948, IoU: 0.8207
Class 11 - Acc: 0.6690, IoU: 0.5141
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8717, IoU: 0.8115
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.2222, IoU: 0.0000
Class 17 - Acc: 0.0024, IoU: 0.0000
Class 18 - Acc: 0.5640, IoU: 0.3758


100%|██████████| 393/393 [01:45<00:00,  3.72it/s]



Epoch 4/10
Loss: 0.4043
Pixel Accuracy: 0.8901
Mean Class Accuracy: 0.4959
Mean IoU: 0.3653

Per-class metrics:
Class  0 - Acc: 0.9606, IoU: 0.9322
Class  1 - Acc: 0.7748, IoU: 0.6348
Class  2 - Acc: 0.8664, IoU: 0.8186
Class  3 - Acc: 0.5150, IoU: 0.0311
Class  4 - Acc: 0.5592, IoU: 0.3023
Class  5 - Acc: 0.4451, IoU: 0.0710
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.6554, IoU: 0.2719
Class  8 - Acc: 0.8781, IoU: 0.8118
Class  9 - Acc: 0.6619, IoU: 0.4402
Class 10 - Acc: 0.9018, IoU: 0.8295
Class 11 - Acc: 0.6815, IoU: 0.5372
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8766, IoU: 0.8180
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0185, IoU: 0.0000
Class 18 - Acc: 0.6270, IoU: 0.4418


100%|██████████| 393/393 [01:57<00:00,  3.35it/s]



Epoch 5/10
Loss: 0.3639
Pixel Accuracy: 0.8961
Mean Class Accuracy: 0.6162
Mean IoU: 0.3890

Per-class metrics:
Class  0 - Acc: 0.9629, IoU: 0.9357
Class  1 - Acc: 0.7904, IoU: 0.6530
Class  2 - Acc: 0.8763, IoU: 0.8280
Class  3 - Acc: 0.6182, IoU: 0.1874
Class  4 - Acc: 0.6128, IoU: 0.3525
Class  5 - Acc: 0.4628, IoU: 0.1155
Class  6 - Acc: 0.0513, IoU: 0.0000
Class  7 - Acc: 0.6357, IoU: 0.3427
Class  8 - Acc: 0.8882, IoU: 0.8219
Class  9 - Acc: 0.6847, IoU: 0.4678
Class 10 - Acc: 0.9107, IoU: 0.8420
Class 11 - Acc: 0.6927, IoU: 0.5565
Class 12 - Acc: 1.0000, IoU: 0.0000
Class 13 - Acc: 0.8826, IoU: 0.8245
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0833, IoU: 0.0000
Class 17 - Acc: 0.9429, IoU: 0.0001
Class 18 - Acc: 0.6122, IoU: 0.4624


100%|██████████| 393/393 [01:59<00:00,  3.30it/s]



Epoch 6/10
Loss: 0.3351
Pixel Accuracy: 0.9005
Mean Class Accuracy: 0.6994
Mean IoU: 0.4035

Per-class metrics:
Class  0 - Acc: 0.9652, IoU: 0.9394
Class  1 - Acc: 0.8057, IoU: 0.6701
Class  2 - Acc: 0.8833, IoU: 0.8338
Class  3 - Acc: 0.6529, IoU: 0.2839
Class  4 - Acc: 0.6328, IoU: 0.3817
Class  5 - Acc: 0.4969, IoU: 0.1440
Class  6 - Acc: 0.0664, IoU: 0.0000
Class  7 - Acc: 0.6380, IoU: 0.3604
Class  8 - Acc: 0.8919, IoU: 0.8281
Class  9 - Acc: 0.7205, IoU: 0.5114
Class 10 - Acc: 0.9103, IoU: 0.8389
Class 11 - Acc: 0.7037, IoU: 0.5724
Class 12 - Acc: 0.7778, IoU: 0.0000
Class 13 - Acc: 0.8860, IoU: 0.8317
Class 14 - Acc: 0.3256, IoU: 0.0000
Class 15 - Acc: 1.0000, IoU: 0.0000
Class 16 - Acc: 0.5882, IoU: 0.0000
Class 17 - Acc: 0.7414, IoU: 0.0003
Class 18 - Acc: 0.6028, IoU: 0.4708


100%|██████████| 393/393 [01:57<00:00,  3.34it/s]



Epoch 7/10
Loss: 0.3115
Pixel Accuracy: 0.9047
Mean Class Accuracy: 0.7111
Mean IoU: 0.4134

Per-class metrics:
Class  0 - Acc: 0.9673, IoU: 0.9420
Class  1 - Acc: 0.8132, IoU: 0.6844
Class  2 - Acc: 0.8899, IoU: 0.8423
Class  3 - Acc: 0.6550, IoU: 0.3242
Class  4 - Acc: 0.6714, IoU: 0.4206
Class  5 - Acc: 0.5131, IoU: 0.1608
Class  6 - Acc: 0.6680, IoU: 0.0004
Class  7 - Acc: 0.6484, IoU: 0.3849
Class  8 - Acc: 0.8973, IoU: 0.8349
Class  9 - Acc: 0.7201, IoU: 0.5144
Class 10 - Acc: 0.9163, IoU: 0.8491
Class 11 - Acc: 0.7160, IoU: 0.5857
Class 12 - Acc: 0.8618, IoU: 0.0041
Class 13 - Acc: 0.8897, IoU: 0.8367
Class 14 - Acc: 0.2112, IoU: 0.0018
Class 15 - Acc: 1.0000, IoU: 0.0000
Class 16 - Acc: 0.1275, IoU: 0.0001
Class 17 - Acc: 0.7679, IoU: 0.0011
Class 18 - Acc: 0.5768, IoU: 0.4678


100%|██████████| 393/393 [01:58<00:00,  3.31it/s]



Epoch 8/10
Loss: 0.2939
Pixel Accuracy: 0.9083
Mean Class Accuracy: 0.6959
Mean IoU: 0.4294

Per-class metrics:
Class  0 - Acc: 0.9696, IoU: 0.9455
Class  1 - Acc: 0.8235, IoU: 0.6998
Class  2 - Acc: 0.8925, IoU: 0.8463
Class  3 - Acc: 0.6674, IoU: 0.3440
Class  4 - Acc: 0.6844, IoU: 0.4330
Class  5 - Acc: 0.5291, IoU: 0.1729
Class  6 - Acc: 0.5759, IoU: 0.0004
Class  7 - Acc: 0.6614, IoU: 0.3928
Class  8 - Acc: 0.9001, IoU: 0.8389
Class  9 - Acc: 0.7395, IoU: 0.5398
Class 10 - Acc: 0.9193, IoU: 0.8528
Class 11 - Acc: 0.7253, IoU: 0.5947
Class 12 - Acc: 0.8087, IoU: 0.0945
Class 13 - Acc: 0.8963, IoU: 0.8467
Class 14 - Acc: 0.5936, IoU: 0.0561
Class 15 - Acc: 0.1929, IoU: 0.0000
Class 16 - Acc: 0.2280, IoU: 0.0001
Class 17 - Acc: 0.8335, IoU: 0.0245
Class 18 - Acc: 0.5816, IoU: 0.4747


100%|██████████| 393/393 [01:45<00:00,  3.72it/s]



Epoch 9/10
Loss: 0.2783
Pixel Accuracy: 0.9116
Mean Class Accuracy: 0.7617
Mean IoU: 0.4674

Per-class metrics:
Class  0 - Acc: 0.9698, IoU: 0.9467
Class  1 - Acc: 0.8335, IoU: 0.7086
Class  2 - Acc: 0.8990, IoU: 0.8512
Class  3 - Acc: 0.6977, IoU: 0.3961
Class  4 - Acc: 0.7022, IoU: 0.4635
Class  5 - Acc: 0.5377, IoU: 0.1842
Class  6 - Acc: 0.8475, IoU: 0.0076
Class  7 - Acc: 0.6731, IoU: 0.4108
Class  8 - Acc: 0.9009, IoU: 0.8420
Class  9 - Acc: 0.7524, IoU: 0.5590
Class 10 - Acc: 0.9183, IoU: 0.8514
Class 11 - Acc: 0.7407, IoU: 0.6085
Class 12 - Acc: 0.7053, IoU: 0.2626
Class 13 - Acc: 0.9048, IoU: 0.8564
Class 14 - Acc: 0.4975, IoU: 0.1844
Class 15 - Acc: 0.8755, IoU: 0.0038
Class 16 - Acc: 0.5576, IoU: 0.0214
Class 17 - Acc: 0.8201, IoU: 0.2078
Class 18 - Acc: 0.6383, IoU: 0.5137


100%|██████████| 393/393 [01:57<00:00,  3.35it/s]


Epoch 10/10
Loss: 0.2610
Pixel Accuracy: 0.9165
Mean Class Accuracy: 0.7726
Mean IoU: 0.5073

Per-class metrics:
Class  0 - Acc: 0.9718, IoU: 0.9496
Class  1 - Acc: 0.8389, IoU: 0.7203
Class  2 - Acc: 0.9047, IoU: 0.8594
Class  3 - Acc: 0.7106, IoU: 0.4210
Class  4 - Acc: 0.7314, IoU: 0.4982
Class  5 - Acc: 0.5576, IoU: 0.2008
Class  6 - Acc: 0.7993, IoU: 0.0609
Class  7 - Acc: 0.6991, IoU: 0.4338
Class  8 - Acc: 0.9072, IoU: 0.8496
Class  9 - Acc: 0.7633, IoU: 0.5684
Class 10 - Acc: 0.9253, IoU: 0.8627
Class 11 - Acc: 0.7552, IoU: 0.6243
Class 12 - Acc: 0.7016, IoU: 0.3323
Class 13 - Acc: 0.9169, IoU: 0.8678
Class 14 - Acc: 0.5265, IoU: 0.2987
Class 15 - Acc: 0.9206, IoU: 0.0847
Class 16 - Acc: 0.5714, IoU: 0.1052
Class 17 - Acc: 0.8039, IoU: 0.3658
Class 18 - Acc: 0.6744, IoU: 0.5353





In [14]:
Epoch 1/10
Loss: 1.0804
Pixel Accuracy: 0.8199
Mean Class Accuracy: 0.2722
Mean IoU: 0.2217

Per-class metrics:
Class  0 - Acc: 0.9248, IoU: 0.8810
Class  1 - Acc: 0.5940, IoU: 0.4342
Class  2 - Acc: 0.8282, IoU: 0.7515
Class  3 - Acc: 0.0153, IoU: 0.0037
Class  4 - Acc: 0.1686, IoU: 0.0064
Class  5 - Acc: 0.0268, IoU: 0.0022
Class  6 - Acc: 0.0011, IoU: 0.0002
Class  7 - Acc: 0.0038, IoU: 0.0000
Class  8 - Acc: 0.8135, IoU: 0.7149
Class  9 - Acc: 0.0188, IoU: 0.0061
Class 10 - Acc: 0.7689, IoU: 0.6534
Class 11 - Acc: 0.2046, IoU: 0.0687
Class 12 - Acc: 0.0019, IoU: 0.0009
Class 13 - Acc: 0.7896, IoU: 0.6825
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0004, IoU: 0.0000
Class 16 - Acc: 0.0039, IoU: 0.0028
Class 17 - Acc: 0.0014, IoU: 0.0007
Class 18 - Acc: 0.0059, IoU: 0.0032
100%|██████████| 393/393 [01:59<00:00,  3.29it/s]

Epoch 2/10
Loss: 0.5921
Pixel Accuracy: 0.8667
Mean Class Accuracy: 0.4243
Mean IoU: 0.2693

Per-class metrics:
Class  0 - Acc: 0.9499, IoU: 0.9164
Class  1 - Acc: 0.6833, IoU: 0.5565
Class  2 - Acc: 0.8429, IoU: 0.7943
Class  3 - Acc: 0.1755, IoU: 0.0008
Class  4 - Acc: 0.3265, IoU: 0.0441
Class  5 - Acc: 0.0000, IoU: 0.0000
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.8625, IoU: 0.0008
Class  8 - Acc: 0.8498, IoU: 0.7810
Class  9 - Acc: 0.6544, IoU: 0.1346
Class 10 - Acc: 0.8494, IoU: 0.7781
Class 11 - Acc: 0.5073, IoU: 0.3340
Class 12 - Acc: 0.0000, IoU: 0.0000
Class 13 - Acc: 0.8419, IoU: 0.7770
Class 14 - Acc: 0.0000, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.5185, IoU: 0.0000
Class 18 - Acc: 0.0000, IoU: 0.0000
100%|██████████| 393/393 [01:56<00:00,  3.37it/s]

Epoch 3/10
Loss: 0.4732
Pixel Accuracy: 0.8788
Mean Class Accuracy: 0.5255
Mean IoU: 0.3166

Per-class metrics:
Class  0 - Acc: 0.9529, IoU: 0.9216
Class  1 - Acc: 0.7350, IoU: 0.5924
Class  2 - Acc: 0.8512, IoU: 0.8028
Class  3 - Acc: 0.3163, IoU: 0.0065
Class  4 - Acc: 0.4898, IoU: 0.2711
Class  5 - Acc: 0.6530, IoU: 0.0001
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.7375, IoU: 0.1374
Class  8 - Acc: 0.8674, IoU: 0.7979
Class  9 - Acc: 0.6460, IoU: 0.3652
Class 10 - Acc: 0.8977, IoU: 0.8232
Class 11 - Acc: 0.6725, IoU: 0.4984
Class 12 - Acc: 0.0722, IoU: 0.0000
Class 13 - Acc: 0.8550, IoU: 0.7957
Class 14 - Acc: 0.4718, IoU: 0.0002
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0000, IoU: 0.0000
Class 18 - Acc: 0.7669, IoU: 0.0037
100%|██████████| 393/393 [01:59<00:00,  3.29it/s]

Epoch 4/10
Loss: 0.4095
Pixel Accuracy: 0.8876
Mean Class Accuracy: 0.5370
Mean IoU: 0.3550

Per-class metrics:
Class  0 - Acc: 0.9565, IoU: 0.9275
Class  1 - Acc: 0.7604, IoU: 0.6160
Class  2 - Acc: 0.8707, IoU: 0.8215
Class  3 - Acc: 0.6204, IoU: 0.0553
Class  4 - Acc: 0.5190, IoU: 0.3313
Class  5 - Acc: 0.5279, IoU: 0.0226
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.6574, IoU: 0.3039
Class  8 - Acc: 0.8781, IoU: 0.8118
Class  9 - Acc: 0.6756, IoU: 0.4226
Class 10 - Acc: 0.9052, IoU: 0.8327
Class 11 - Acc: 0.6879, IoU: 0.5296
Class 12 - Acc: 0.4548, IoU: 0.0003
Class 13 - Acc: 0.8530, IoU: 0.7984
Class 14 - Acc: 0.0374, IoU: 0.0000
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.1667, IoU: 0.0000
Class 18 - Acc: 0.6328, IoU: 0.2716
100%|██████████| 393/393 [02:00<00:00,  3.27it/s]

Epoch 5/10
Loss: 0.3688
Pixel Accuracy: 0.8946
Mean Class Accuracy: 0.5375
Mean IoU: 0.3832

Per-class metrics:
Class  0 - Acc: 0.9593, IoU: 0.9311
Class  1 - Acc: 0.7775, IoU: 0.6356
Class  2 - Acc: 0.8833, IoU: 0.8345
Class  3 - Acc: 0.6123, IoU: 0.2143
Class  4 - Acc: 0.5975, IoU: 0.3719
Class  5 - Acc: 0.4950, IoU: 0.0902
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.6794, IoU: 0.3441
Class  8 - Acc: 0.8865, IoU: 0.8228
Class  9 - Acc: 0.6978, IoU: 0.4548
Class 10 - Acc: 0.9104, IoU: 0.8398
Class 11 - Acc: 0.6921, IoU: 0.5434
Class 12 - Acc: 0.5197, IoU: 0.0134
Class 13 - Acc: 0.8658, IoU: 0.8111
Class 14 - Acc: 0.0234, IoU: 0.0003
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.0625, IoU: 0.0000
Class 18 - Acc: 0.5505, IoU: 0.3739
100%|██████████| 393/393 [01:47<00:00,  3.66it/s]

Epoch 6/10
Loss: 0.3383
Pixel Accuracy: 0.8999
Mean Class Accuracy: 0.5900
Mean IoU: 0.4068

Per-class metrics:
Class  0 - Acc: 0.9621, IoU: 0.9353
Class  1 - Acc: 0.7931, IoU: 0.6533
Class  2 - Acc: 0.8925, IoU: 0.8431
Class  3 - Acc: 0.6518, IoU: 0.2993
Class  4 - Acc: 0.6223, IoU: 0.3914
Class  5 - Acc: 0.4808, IoU: 0.1278
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.6993, IoU: 0.3745
Class  8 - Acc: 0.8920, IoU: 0.8302
Class  9 - Acc: 0.7077, IoU: 0.4828
Class 10 - Acc: 0.9139, IoU: 0.8461
Class 11 - Acc: 0.7020, IoU: 0.5560
Class 12 - Acc: 0.5079, IoU: 0.1365
Class 13 - Acc: 0.8770, IoU: 0.8254
Class 14 - Acc: 0.3341, IoU: 0.0313
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.6476, IoU: 0.0009
Class 18 - Acc: 0.5267, IoU: 0.3943
100%|██████████| 393/393 [01:57<00:00,  3.34it/s]

Epoch 7/10
Loss: 0.3132
Pixel Accuracy: 0.9046
Mean Class Accuracy: 0.6116
Mean IoU: 0.4342

Per-class metrics:
Class  0 - Acc: 0.9645, IoU: 0.9384
Class  1 - Acc: 0.8041, IoU: 0.6683
Class  2 - Acc: 0.8976, IoU: 0.8499
Class  3 - Acc: 0.6763, IoU: 0.3529
Class  4 - Acc: 0.6583, IoU: 0.4310
Class  5 - Acc: 0.4748, IoU: 0.1467
Class  6 - Acc: 0.0000, IoU: 0.0000
Class  7 - Acc: 0.7080, IoU: 0.3954
Class  8 - Acc: 0.8977, IoU: 0.8366
Class  9 - Acc: 0.7146, IoU: 0.5002
Class 10 - Acc: 0.9180, IoU: 0.8529
Class 11 - Acc: 0.7111, IoU: 0.5655
Class 12 - Acc: 0.5126, IoU: 0.2226
Class 13 - Acc: 0.8906, IoU: 0.8390
Class 14 - Acc: 0.3960, IoU: 0.1433
Class 15 - Acc: 0.0000, IoU: 0.0000
Class 16 - Acc: 0.0000, IoU: 0.0000
Class 17 - Acc: 0.8021, IoU: 0.0645
Class 18 - Acc: 0.5951, IoU: 0.4418
100%|██████████| 393/393 [01:59<00:00,  3.30it/s]

Epoch 8/10
Loss: 0.2945
Pixel Accuracy: 0.9085
Mean Class Accuracy: 0.7302
Mean IoU: 0.4629

Per-class metrics:
Class  0 - Acc: 0.9667, IoU: 0.9420
Class  1 - Acc: 0.8166, IoU: 0.6851
Class  2 - Acc: 0.9008, IoU: 0.8529
Class  3 - Acc: 0.6885, IoU: 0.3750
Class  4 - Acc: 0.6791, IoU: 0.4547
Class  5 - Acc: 0.4792, IoU: 0.1614
Class  6 - Acc: 0.7800, IoU: 0.0077
Class  7 - Acc: 0.7181, IoU: 0.4115
Class  8 - Acc: 0.9007, IoU: 0.8416
Class  9 - Acc: 0.7311, IoU: 0.5272
Class 10 - Acc: 0.9191, IoU: 0.8550
Class 11 - Acc: 0.7238, IoU: 0.5808
Class 12 - Acc: 0.5949, IoU: 0.3069
Class 13 - Acc: 0.9001, IoU: 0.8475
Class 14 - Acc: 0.4454, IoU: 0.2258
Class 15 - Acc: 0.2458, IoU: 0.0001
Class 16 - Acc: 0.9798, IoU: 0.0001
Class 17 - Acc: 0.7477, IoU: 0.2295
Class 18 - Acc: 0.6556, IoU: 0.4907
100%|██████████| 393/393 [01:59<00:00,  3.29it/s]

Epoch 9/10
Loss: 0.2726
Pixel Accuracy: 0.9135
Mean Class Accuracy: 0.7577
Mean IoU: 0.4908

Per-class metrics:
Class  0 - Acc: 0.9696, IoU: 0.9462
Class  1 - Acc: 0.8290, IoU: 0.7039
Class  2 - Acc: 0.9067, IoU: 0.8607
Class  3 - Acc: 0.7221, IoU: 0.4184
Class  4 - Acc: 0.7113, IoU: 0.4963
Class  5 - Acc: 0.5224, IoU: 0.1762
Class  6 - Acc: 0.7757, IoU: 0.1308
Class  7 - Acc: 0.7232, IoU: 0.4289
Class  8 - Acc: 0.9048, IoU: 0.8460
Class  9 - Acc: 0.7436, IoU: 0.5466
Class 10 - Acc: 0.9203, IoU: 0.8577
Class 11 - Acc: 0.7327, IoU: 0.5933
Class 12 - Acc: 0.6154, IoU: 0.3383
Class 13 - Acc: 0.9101, IoU: 0.8596
Class 14 - Acc: 0.4096, IoU: 0.2638
Class 15 - Acc: 0.6476, IoU: 0.0117
Class 16 - Acc: 0.9485, IoU: 0.0138
Class 17 - Acc: 0.7313, IoU: 0.3245
Class 18 - Acc: 0.6727, IoU: 0.5081
100%|██████████| 393/393 [02:00<00:00,  3.25it/s]
Epoch 10/10
Loss: 0.2589
Pixel Accuracy: 0.9167
Mean Class Accuracy: 0.7606
Mean IoU: 0.5258

Per-class metrics:
Class  0 - Acc: 0.9703, IoU: 0.9476
Class  1 - Acc: 0.8346, IoU: 0.7116
Class  2 - Acc: 0.9116, IoU: 0.8657
Class  3 - Acc: 0.7337, IoU: 0.4497
Class  4 - Acc: 0.7450, IoU: 0.5394
Class  5 - Acc: 0.5560, IoU: 0.1903
Class  6 - Acc: 0.7002, IoU: 0.2374
Class  7 - Acc: 0.7371, IoU: 0.4402
Class  8 - Acc: 0.9063, IoU: 0.8491
Class  9 - Acc: 0.7590, IoU: 0.5671
Class 10 - Acc: 0.9229, IoU: 0.8600
Class 11 - Acc: 0.7403, IoU: 0.6005
Class 12 - Acc: 0.6436, IoU: 0.3547
Class 13 - Acc: 0.9147, IoU: 0.8628
Class 14 - Acc: 0.4679, IoU: 0.3213
Class 15 - Acc: 0.6989, IoU: 0.1052
Class 16 - Acc: 0.7634, IoU: 0.1535
Class 17 - Acc: 0.7544, IoU: 0.4113
Class 18 - Acc: 0.6910, IoU: 0.5232

SyntaxError: invalid syntax (1867324368.py, line 1)

In [9]:
tuple([1,2,3])

(1, 2, 3)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttentionDistiller(nn.Module):
    def __init__(self, 
                 student_dim,   # e.g., 256 or 512
                 teacher_dim,   # e.g., 768
                 hidden_dim,    # dimension for cross-attention
                 num_heads=8):
        super().__init__()
        # Linear projections to match dimension
        self.student_proj = nn.Linear(student_dim, hidden_dim, bias=False)
        self.teacher_proj = nn.Linear(teacher_dim, hidden_dim, bias=False)
        
        # Using PyTorch's multi-head attention
        # batch_first=True means input shape is [B, seq_len, dim]
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, 
                                                num_heads=num_heads, 
                                                batch_first=True)
        
        # Optional final projection back to student_dim
        self.proj_back = nn.Linear(hidden_dim, student_dim, bias=False)

    def forward(self, student_map, teacher_map):
        """
        student_map: [B, C_s, H_s, W_s]
        teacher_map: [B, C_t, H_t, W_t]
        returns cross_attended_map: [B, C_s, H_s, W_s] (updated student features)
        """
        B, C_s, H_s, W_s = student_map.shape
        _, C_t, H_t, W_t = teacher_map.shape
        
        # Flatten [B, C_s, H_s, W_s] -> [B, H_s*W_s, C_s], then project
        student_tokens = student_map.permute(0,2,3,1).reshape(B, H_s*W_s, C_s)
        student_tokens = self.student_proj(student_tokens)  # => [B, H_s*W_s, hidden_dim]
        
        # Flatten teacher -> [B, H_t*W_t, C_t], then project
        teacher_tokens = teacher_map.permute(0,2,3,1).reshape(B, H_t*W_t, C_t)
        teacher_tokens = self.teacher_proj(teacher_tokens)  # => [B, H_t*W_t, hidden_dim]
        
        # Cross-Attention: Q=student, K=teacher, V=teacher
        cross_attended, _ = self.cross_attn(query=student_tokens,
                                            key=teacher_tokens,
                                            value=teacher_tokens)
        
        # Project back to student dimension if needed
        cross_attended = self.proj_back(cross_attended)  # [B, H_s*W_s, C_s]
        
        # Reshape back to [B, C_s, H_s, W_s]
        cross_attended_map = cross_attended.view(B, H_s, W_s, C_s).permute(0, 3, 1, 2)
        
        return cross_attended_map


# Example usage:
if __name__ == "__main__":
    B = 2
    student_map = torch.randn(B, 256, 16, 16)  # e.g., [B, C_s, H_s, W_s]
    teacher_map = torch.randn(B, 768, 8, 8)    # [B, C_t, H_t, W_t]
    
    # Suppose we choose hidden_dim=384
    distiller = CrossAttentionDistiller(student_dim=256, teacher_dim=768, hidden_dim=384, num_heads=6)
    updated_student_map = distiller(student_map, teacher_map)
    
    print("updated_student_map shape:", updated_student_map.shape)
    # => [2, 256, 16, 16]
    
    # Distillation loss (example)
    # You'd typically need to align teacher_map shape or do some pooling, 
    # but let's just do a simple MSE with naive upsampling of teacher_map:
    teacher_map_upsampled = F.interpolate(teacher_map, size=(16,16), mode='bilinear')
    loss = F.mse_loss(F.normalize(updated_student_map, dim=1),
                      F.normalize(teacher_map_upsampled, dim=1))
    print("distillation loss:", loss.item())


updated_student_map shape: torch.Size([2, 256, 16, 16])


  loss = F.mse_loss(F.normalize(updated_student_map, dim=1),


RuntimeError: The size of tensor a (256) must match the size of tensor b (768) at non-singleton dimension 1