In [2]:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import GradScaler, autocast        
from torch.optim.lr_scheduler import CosineAnnealingLR  
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from yacs.config import CfgNode as CN

from data import make_dataset, transform_map
from eval import evaluate
from model import UModel
from torch_utils import ema_update, freeze_layers
from utils import setup, clean_exp_savedir


In [3]:
cfg = CN(new_allowed=True)
cfg.merge_from_file("config.yaml")

source_train_loader, target_train_loader, source_test_loader, target_test_loader = (
    make_dataset(
        source_dataset=cfg.dataset.source,
        target_dataset=cfg.dataset.target,
        imgsize=cfg.img_size,
        train_bs=cfg.domain_adapt.train_bs,
        eval_bs=cfg.domain_adapt.eval_bs,
        num_workers=cfg.domain_adapt.num_workers,
    )
)


model = UModel(
    backbone=cfg.model.backbone.type,
    hidden_dim=cfg.model.backbone.hidden_dim,
    out_dim=cfg.dataset.num_classes,
    imgsize=cfg.img_size,
    freeze_backbone=cfg.model.backbone.freeze,
)

device = torch.device(cfg.device)
model = model.to(device)

Loaded pretrained weights.


In [None]:
class EigenCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.activations = None
        self.hook_handle = None
        
    def save_activation(self, module, input, output):
        self.activations = output
        
    def register_hook(self):
        if self.hook_handle is None:
            self.hook_handle = self.target_layer.register_forward_hook(self.save_activation)
    
    def remove_hook(self):
        if self.hook_handle is not None:
            self.hook_handle.remove()
            self.hook_handle = None
    
    def __call__(self, input_tensor, branch="tch"):
        self.activations = None
        
        with torch.no_grad():
            _ = self.model(input_tensor, branch=branch)
        
        if self.activations is None:
            raise RuntimeError("No activations captured. Check target layer.")
        
        activations = self.activations
        B = activations.shape[0]
        
        if activations.shape[1] > 1:
            activations = activations[:, 1:, :]
        
        original_dtype = activations.dtype
        if activations.dtype == torch.float16:
            activations = activations.float()
        
        mean = activations.mean(dim=2, keepdim=True)
        centered = activations - mean
        
        cov = torch.bmm(centered, centered.transpose(1, 2))
        
        v = torch.randn(B, cov.shape[1], 1, device=cov.device, dtype=cov.dtype)
        
        for _ in range(10):
            v = torch.bmm(cov, v)
            v = v / (torch.norm(v, dim=1, keepdim=True) + 1e-10)
        
        cam = v.squeeze(-1)
        
        num_patches_side = int(cam.shape[1] ** 0.5)
        cam = cam.view(B, num_patches_side, num_patches_side)
        
        cam_min = cam.view(B, -1).min(dim=1, keepdim=True)[0].view(B, 1, 1)
        cam_max = cam.view(B, -1).max(dim=1, keepdim=True)[0].view(B, 1, 1)
        cam = (cam - cam_min) / (cam_max - cam_min + 1e-10)
        
        if original_dtype == torch.float16:
            cam = cam.half()

        return cam

In [5]:
source_data = next(iter(source_train_loader))
target_data = next(iter(target_train_loader))
_, src_k_data, src_labels, _ = source_data
tgt_q_data, tgt_k_data, _, affine_params = target_data

src_img = src_k_data.to(device)
src_labels = src_labels.to(device)
tgt_k_img = tgt_k_data.to(device)  # strong
tgt_q_img = tgt_q_data.to(device)  # weak


In [6]:
cam = EigenCAM(model, model.backbone.transformer.blocks[-1].norm2)
cam.register_hook()
sal_map = cam(tgt_q_img)

In [19]:
def transform_map(salience_map, affine_params, transform_params=[0.0, 1.0], imgsize=224):
    bg_w, fg_w = transform_params
    device = salience_map.device if torch.is_tensor(salience_map) else 'cpu'
    
    if not torch.is_tensor(salience_map):
        salience_map = torch.from_numpy(salience_map).to(device)
    
    salience_map = bg_w * (1 - salience_map) + fg_w * salience_map # [B, H, W]
    salience_map = salience_map.unsqueeze(1) # [B, 1, H, W]
            
    angles = torch.as_tensor(affine_params['angle'], dtype=torch.float32, device=device).view(-1)
    scales = torch.as_tensor(affine_params['scale'], dtype=torch.float32, device=device).view(-1)
    translates = torch.stack(affine_params['translate'], dim=1).to(dtype=torch.float32)
    shears = torch.stack(affine_params['shear'], dim=1).to(dtype=torch.float32)
    
    affine_matrices = _get_affine_matrix_vectorized(
        angles, translates, scales, shears, salience_map.shape[-2:], device
    )
    
    grid = F.affine_grid(
        affine_matrices, 
        salience_map.shape, 
        align_corners=False
    )
    
    transformed_maps = F.grid_sample(
        salience_map, 
        grid, 
        mode='bilinear', 
        padding_mode='zeros',
        align_corners=False
    )
    
    if transformed_maps.shape[-1] != imgsize or transformed_maps.shape[-2] != imgsize:
        transformed_maps = F.interpolate(
            transformed_maps, 
            size=(imgsize, imgsize), 
            mode='bilinear', 
            align_corners=False
        )
    
    return transformed_maps

def _get_affine_matrix_vectorized(angles, translates, scales, shears, img_size, device):
    angles = torch.as_tensor(angles, dtype=torch.float32, device=device)
    translates = torch.as_tensor(translates, dtype=torch.float32, device=device)
    scales = torch.as_tensor(scales, dtype=torch.float32, device=device)
    shears = torch.as_tensor(shears, dtype=torch.float32, device=device)
    
    batch_size = angles.shape[0]

    angle_rad = torch.deg2rad(angles)
    shear_x_rad = torch.deg2rad(shears[:, 0])
    shear_y_rad = torch.deg2rad(shears[:, 1])

    height, width = img_size
    center = torch.tensor([width / 2.0, height / 2.0], device=device).view(1, 2)

    cos_a = torch.cos(angle_rad)
    sin_a = torch.sin(angle_rad)
    rotation_matrix = torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
    rotation_matrix[:, 0, 0] = cos_a
    rotation_matrix[:, 0, 1] = -sin_a
    rotation_matrix[:, 1, 0] = sin_a
    rotation_matrix[:, 1, 1] = cos_a

    scale_matrix = torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
    scale_matrix[:, 0, 0] = scales
    scale_matrix[:, 1, 1] = scales

    tan_sx = torch.tan(shear_x_rad)
    tan_sy = torch.tan(shear_y_rad)
    shear_matrix = torch.eye(3, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
    shear_matrix[:, 0, 1] = -tan_sx
    shear_matrix[:, 1, 0] = -tan_sy
    
    matrix = torch.bmm(shear_matrix, rotation_matrix)
    matrix = torch.bmm(scale_matrix, matrix)
    
    c = translates[:, 0] + center[:, 0] - matrix[:, 0, 0] * center[:, 0] - matrix[:, 0, 1] * center[:, 1]
    f = translates[:, 1] + center[:, 1] - matrix[:, 1, 0] * center[:, 0] - matrix[:, 1, 1] * center[:, 1]

    final_matrix = torch.zeros(batch_size, 2, 3, device=device)
    final_matrix[:, :, :2] = matrix[:, :2, :2]
    final_matrix[:, 0, 2] = c
    final_matrix[:, 1, 2] = f

    return final_matrix

In [20]:
trans_map = transform_map(sal_map, affine_params)

In [22]:
trans_map.shape

torch.Size([64, 1, 224, 224])