
# Saliency / Cross-Attention Relation Classifier
Supports baselines (saliency-only, cross-only), early fusion variants, late fusion (two-branch), normalization/ablations, and richer evaluation.


In [None]:

import json
from pathlib import Path
from collections import Counter
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm.auto import tqdm

ROOT = Path('..').resolve()
DATA_ROOT = ROOT / 'saliency_datasets' / 'early_layers'
assert DATA_ROOT.exists(), DATA_ROOT

# MODE options:
#   'saliency' -> use saliency_maps (3 ch)
#   'cross'    -> use cross_attention_maps (3 ch)
#   'concat'   -> concat saliency + cross (6 ch)
#   'triple'   -> saliency + cross + product (9 ch)
#   'diff'     -> saliency + (cross - saliency) (6 ch)
#   'late'     -> two-branch model (separate WRNs for saliency & cross, logits averaged)
MODE = 'saliency'

# Preprocessing
NORM_STRATEGY = 'none'  # 'none' | 'minmax' | 'zscore'
SMOOTH = False          # apply simple avg pooling blur
SMOOTH_KERNEL = 3

# Ablations
ZERO_SUBJECT = False
ZERO_PREDICATE = False
ZERO_OBJECT = False
SHUFFLE_CONCEPTS = False  # permute channels randomly

# Data protocol
VAL_SPLIT = 0.1
USE_BALANCED_SAMPLER = False  # class-balanced sampling for train loader
AUG_HORIZONTAL_FLIP = False   # beware: left/right relations would need relabeling; default False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE, '| MODE:', MODE)


## Dataset loader with fusion, normalization, and ablations

In [None]:

def _normalize(x, strategy):
    if strategy == 'none':
        return x
    if strategy == 'minmax':
        xmin = x.amin(dim=(1, 2), keepdim=True)
        xmax = x.amax(dim=(1, 2), keepdim=True)
        denom = torch.clamp(xmax - xmin, min=1e-6)
        return (x - xmin) / denom
    if strategy == 'zscore':
        mean = x.mean(dim=(1, 2), keepdim=True)
        std = x.std(dim=(1, 2), keepdim=True).clamp(min=1e-6)
        return (x - mean) / std
    raise ValueError(strategy)


def _smooth(x, kernel):
    if kernel <= 1:
        return x
    return F.avg_pool2d(x.unsqueeze(0), kernel_size=kernel, stride=1, padding=kernel // 2).squeeze(0)


def _apply_ablation(x):
    # Channels: 0=subject, 1=predicate token, 2=object (as produced in data)
    if x.shape[0] >= 1 and ZERO_SUBJECT:
        x[0] = 0
    if x.shape[0] >= 2 and ZERO_PREDICATE:
        x[1] = 0
    if x.shape[0] >= 3 and ZERO_OBJECT:
        x[2] = 0
    if SHUFFLE_CONCEPTS:
        perm = torch.randperm(x.shape[0])
        x = x[perm]
    return x


def fuse_inputs(sal, cross, mode):
    if mode == 'saliency':
        return sal
    if mode == 'cross':
        return cross
    if mode == 'concat':
        return torch.cat([sal, cross], dim=0)
    if mode == 'triple':
        return torch.cat([sal, cross, sal * cross], dim=0)
    if mode == 'diff':
        return torch.cat([sal, cross - sal], dim=0)
    if mode == 'late':
        return (sal, cross)
    raise ValueError(mode)


class SaliencyDataset(Dataset):
    def __init__(self, root: Path, mode: str = 'saliency'):
        self.root = Path(root)
        self.mode = mode
        self.files = sorted(self.root.glob('class_*/sample_*.pt'))
        if len(self.files) == 0:
            raise RuntimeError(f'No samples found under {self.root}')

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        sample = torch.load(self.files[idx])
        sal = sample['saliency_maps'].float()
        cross = sample['cross_attention_maps'].float()
        sal = _normalize(sal, NORM_STRATEGY)
        cross = _normalize(cross, NORM_STRATEGY)
        if SMOOTH:
            sal = _smooth(sal, SMOOTH_KERNEL)
            cross = _smooth(cross, SMOOTH_KERNEL)
        sal = _apply_ablation(sal)
        cross = _apply_ablation(cross)
        x = fuse_inputs(sal, cross, self.mode)
        y = int(sample['class_id'])
        return x, y

    def class_counts(self):
        counts = Counter(int(torch.load(f)['class_id']) for f in self.files)
        return dict(sorted(counts.items()))


def make_sampler(counts):
    total = sum(counts.values())
    class_weights = {c: total / (len(counts) * n) for c, n in counts.items()}
    weights = []
    for f in sorted(DATA_ROOT.glob('class_*/sample_*.pt')):
        cid = int(torch.load(f)['class_id'])
        weights.append(class_weights[cid])
    return WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)


def build_dataloaders(root: Path, mode: str, batch_size=64, val_split=0.1, seed=42, use_balanced=False):
    ds = SaliencyDataset(root, mode=mode)
    n_total = len(ds)
    n_val = int(n_total * val_split)
    n_train = n_total - n_val
    generator = torch.Generator().manual_seed(seed)
    train_ds, val_ds = random_split(ds, [n_train, n_val], generator=generator)

    loader_kwargs = dict(num_workers=2, pin_memory=True)
    if use_balanced:
        counts = ds.class_counts()
        sampler = make_sampler(counts)
        train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, **loader_kwargs)
    else:
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, **loader_kwargs)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, **loader_kwargs)
    return train_loader, val_loader, ds


### Dataset summary

In [None]:

info_rows = []
for info_path in sorted(DATA_ROOT.glob('class_*/class_info.json')):
    with info_path.open() as f:
        info = json.load(f)
    info_rows.append((info['class_id'], info['predicate'], info['num_samples'], info_path.parent.name))
print('Declared classes:', len(info_rows))
print('Example rows:', info_rows[:3])

# Actual counts from disk
_, _, full_ds = build_dataloaders(DATA_ROOT, mode=MODE, batch_size=8, val_split=0.0)
actual_counts = full_ds.class_counts()
print('Total samples:', len(full_ds))
print('Per-class counts:', actual_counts)

x0, y0 = full_ds[0]
if MODE == 'late':
    print('Input shapes for mode late:', tuple(x0[0].shape), tuple(x0[1].shape), '| label', y0)
else:
    print('Input shape:', tuple(x0.shape), '| label', y0)


## Models (WRN backbone) and fusion variants

In [None]:

class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = None if self.equalInOut else nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)

    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
            out = self.relu2(self.bn2(self.conv1(x)))
        else:
            out = self.relu1(self.bn1(x))
            out = self.relu2(self.bn2(self.conv1(out)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        shortcut = x if self.equalInOut else self.convShortcut(x)
        return torch.add(shortcut, out)


class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super().__init__()
        layers = [block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate) for i in range(int(nb_layers))]
        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)


class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0, in_channels=3, global_pool='avg'):
        super().__init__()
        nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
        assert (depth - 4) % 6 == 0
        n = (depth - 4) / 6
        block = BasicBlock
        self.conv1 = nn.Conv2d(in_channels, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]
        self.global_pool = global_pool

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        if self.global_pool == 'avg':
            out = F.avg_pool2d(out, 8)
        elif self.global_pool == 'max':
            out = F.max_pool2d(out, 8)
        else:
            out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)


class DualWRN(nn.Module):
    # Late fusion: separate WRNs for saliency and cross, logits averaged.
    def __init__(self, depth, num_classes, widen_factor=8, dropRate=0.3):
        super().__init__()
        self.sal_branch = WideResNet(depth, num_classes, widen_factor, dropRate, in_channels=3)
        self.cross_branch = WideResNet(depth, num_classes, widen_factor, dropRate, in_channels=3)

    def forward(self, x_tuple):
        sal, cross = x_tuple
        logits_sal = self.sal_branch(sal)
        logits_cross = self.cross_branch(cross)
        return 0.5 * (logits_sal + logits_cross)


In [None]:

NUM_CLASSES = 24

in_channels_lookup = {
    'saliency': 3,
    'cross': 3,
    'concat': 6,
    'triple': 9,
    'diff': 6,
}

if MODE == 'late':
    model = DualWRN(depth=28, num_classes=NUM_CLASSES, widen_factor=8, dropRate=0.3).to(DEVICE)
else:
    in_channels = in_channels_lookup[MODE]
    model = WideResNet(depth=28, num_classes=NUM_CLASSES, widen_factor=8, dropRate=0.3, in_channels=in_channels).to(DEVICE)

print('Model params:', sum(p.numel() for p in model.parameters()) / 1e6, 'M')


## Training and evaluation (with confusion matrix)

In [None]:

def maybe_augment(x):
    if AUG_HORIZONTAL_FLIP and random.random() < 0.5:
        if isinstance(x, tuple):
            return tuple(torch.flip(t, dims=[2]) for t in x)  # flip width dimension
        return torch.flip(x, dims=[2])
    return x


def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    running_correct = 0
    total = 0
    for xb, yb in tqdm(loader, leave=False):
        xb = maybe_augment(xb)
        if isinstance(xb, tuple):
            xb = tuple(t.to(device) for t in xb)
        else:
            xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * yb.size(0)
        preds = logits.argmax(dim=1)
        running_correct += (preds == yb).sum().item()
        total += yb.size(0)

    return running_loss / total, running_correct / total


def evaluate(model, loader, criterion, device, num_classes=NUM_CLASSES):
    model.eval()
    running_loss = 0.0
    running_correct = 0
    total = 0
    conf = torch.zeros(num_classes, num_classes, dtype=torch.long)
    with torch.no_grad():
        for xb, yb in loader:
            if isinstance(xb, tuple):
                xb = tuple(t.to(device) for t in xb)
            else:
                xb = xb.to(device)
            yb = yb.to(device)
            logits = model(xb)
            loss = criterion(logits, yb)

            running_loss += loss.item() * yb.size(0)
            preds = logits.argmax(dim=1)
            running_correct += (preds == yb).sum().item()
            total += yb.size(0)

            for t, p in zip(yb.view(-1), preds.view(-1)):
                conf[t.long(), p.long()] += 1

    per_class_acc = conf.diag() / conf.sum(dim=1).clamp(min=1)
    return running_loss / total, running_correct / total, conf.cpu(), per_class_acc.cpu()


In [None]:

BATCH_SIZE = 128
EPOCHS = 5
LR = 1e-3
WEIGHT_DECAY = 1e-4

train_loader, val_loader, _ = build_dataloaders(DATA_ROOT, mode=MODE, batch_size=BATCH_SIZE, val_split=VAL_SPLIT, use_balanced=USE_BALANCED_SAMPLER)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
    val_loss, val_acc, conf, per_class_acc = evaluate(model, val_loader, criterion, DEVICE)
    scheduler.step()

    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    print(f"Epoch {epoch:02d}/{EPOCHS} | train_loss={train_loss:.4f} acc={train_acc:.3f} | val_loss={val_loss:.4f} acc={val_acc:.3f}")

print('Per-class accuracy (val):', per_class_acc.tolist())

mis = conf - torch.diag(torch.diag(conf))
vals, idxs = torch.topk(mis.view(-1), k=10)
print('Top confusions (true, pred, count):')
for v, idx in zip(vals, idxs):
    if v <= 0:
        continue
    true = idx // NUM_CLASSES
    pred = idx % NUM_CLASSES
    print(int(true), int(pred), int(v))


## Save checkpoint

In [None]:

ckpt_dir = Path('runs')
ckpt_dir.mkdir(exist_ok=True)
ckpt_name = f"wrn_mode-{MODE}_norm-{NORM_STRATEGY}.pt"
ckpt_path = ckpt_dir / ckpt_name
torch.save(model.state_dict(), ckpt_path)
print('Saved to', ckpt_path.resolve())


## Quick inference on one sample

In [None]:

sample_path = next(DATA_ROOT.glob('class_02_behind/sample_*.pt'))
sample = torch.load(sample_path)

sal = _normalize(sample['saliency_maps'].float(), NORM_STRATEGY)
cross = _normalize(sample['cross_attention_maps'].float(), NORM_STRATEGY)
if SMOOTH:
    sal = _smooth(sal, SMOOTH_KERNEL)
    cross = _smooth(cross, SMOOTH_KERNEL)
sal = _apply_ablation(sal)
cross = _apply_ablation(cross)

if MODE == 'saliency':
    x = sal.unsqueeze(0).to(DEVICE)
elif MODE == 'cross':
    x = cross.unsqueeze(0).to(DEVICE)
elif MODE == 'concat':
    x = torch.cat([sal, cross], dim=0).unsqueeze(0).to(DEVICE)
elif MODE == 'triple':
    x = torch.cat([sal, cross, sal * cross], dim=0).unsqueeze(0).to(DEVICE)
elif MODE == 'diff':
    x = torch.cat([sal, cross - sal], dim=0).unsqueeze(0).to(DEVICE)
elif MODE == 'late':
    x = (sal.unsqueeze(0).to(DEVICE), cross.unsqueeze(0).to(DEVICE))
else:
    raise ValueError(MODE)

model.eval()
with torch.no_grad():
    logits = model(x)
    pred = logits.argmax(dim=1).item()
print('Prompt:', sample['prompt'])
print('Predicted class_id:', pred)
