In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data


from domainbed import datasets, hparams_registry, algorithms


algorithm = 'xdomain_mix'
dataset = 'PACS'
seed = 0
data_dir = '' # input data dir here 
test_envs = [0]
holdout_fraction = 0.2
trial_seed = 0

hparams = hparams_registry.default_hparams(algorithm, dataset)

dataset = vars(datasets)[dataset](data_dir, test_envs, hparams)
algorithm_class = algorithms.get_algorithm_class(algorithm)
algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, len(dataset) - len(test_envs), hparams)
device = 'cpu'

model_path = '' # input model path here 

with open(model_path, 'rb') as f:
    model = torch.load(f, map_location='cpu')
algorithm.network.load_state_dict(model['model_dict'])
algorithm.to(device)
algorithm.eval()

In [None]:
import pandas as pd
import seaborn as sns
import torch.nn as nn

pool = nn.AdaptiveAvgPool2d((1, 1))

In [None]:
from domainbed.mixup_module import DomainClassMixAugmentation
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

def extract_four_mask(feature_map, class_gradient, domain_gradient):
    feature_map = torch.mean(feature_map, dim=(1,2), keepdim=True)
    cam = feature_map * class_gradient
    dam = feature_map * domain_gradient
    class_thr = DomainClassMixAugmentation.get_threshold(cam, 0.5)
    domain_thr = DomainClassMixAugmentation.get_threshold(dam, 0.5)

    cs_idx = cam >= class_thr
    cg_idx = cam < class_thr
    ds_idx = dam >= domain_thr
    di_idx = dam < domain_thr

    csds_mask = cs_idx * ds_idx
    csdi_mask = cs_idx * di_idx
    cgds_mask = cg_idx * ds_idx
    cgdi_mask = cg_idx * di_idx
    return cs_idx, ds_idx, csds_mask, csdi_mask, cgds_mask, cgdi_mask

def extract_gradients(algo, x, y, style):
    feature_maps = algo.network.get_feature(x).detach()
    feature_maps.requires_grad_(True)
    outputs = algo.network.domain_classifier(feature_maps)
    targets = [ClassifierOutputTarget(style[i].item()) for i in range(style.size(0))]
    loss = sum([target(output) for target, output in zip(targets, outputs)])

    algo.domain_optimizer.zero_grad()
    loss.backward(retain_graph=True)
    algo.domain_gradient = feature_maps.grad

    outputs1 = algo.predict(x)
    targets1 = [ClassifierOutputTarget(y[i].item()) for i in range(y.size(0))]
    loss1 = sum([target(output) for target, output in zip(targets1, outputs1)])
    algo.optimizer.zero_grad()
    loss1.backward(retain_graph=True)

    domain_gradient = algo.domain_gradient.clone()
    class_gradient = algo.class_gradient.clone()
    
    return feature_maps, class_gradient, domain_gradient

In [None]:
def get_logits(algo, feature):
    f = algo.network.network.avgpool(feature)
    f = torch.flatten(f, 1)
    f = algo.network.network.fc(f)
    return f

def get_four_features(feature_maps, class_gradients, domain_gradients):
    cs_maps = torch.zeros(feature_maps.size())
    ci_maps = torch.zeros(feature_maps.size())
    ds_maps = torch.zeros(feature_maps.size())
    di_maps = torch.zeros(feature_maps.size())
    
    for index in range(feature_maps.size(0)):
        cs_idx, ds_idx, csds_mask, csdi_mask, cgds_mask, cgdi_mask = extract_four_mask(feature_maps[index], class_gradients[index], domain_gradients[index])
        cs_maps[index] = feature_maps[index] * cs_idx
        ci_maps[index] = feature_maps[index] * (~cs_idx)
        ds_maps[index] = feature_maps[index] * ds_idx
        di_maps[index] = feature_maps[index] * (~ds_idx)
    
    return cs_maps, ci_maps, ds_maps, di_maps

def get_detail_four_features(feature_maps, class_gradients, domain_gradients):
    csds_maps = torch.zeros(feature_maps.size())
    csdi_maps = torch.zeros(feature_maps.size())
    cgds_maps = torch.zeros(feature_maps.size())
    cgdi_maps = torch.zeros(feature_maps.size())
    
    for index in range(feature_maps.size(0)):
        cs_idx, ds_idx, csds_mask, csdi_mask, cgds_mask, cgdi_mask = extract_four_mask(feature_maps[index], class_gradients[index], domain_gradients[index])
        csds_maps[index] = feature_maps[index] * csds_mask
        csdi_maps[index] = feature_maps[index] * csdi_mask
        cgds_maps[index] = feature_maps[index] * cgds_mask
        cgdi_maps[index] = feature_maps[index] * cgdi_mask
    
    return csds_maps, csdi_maps, cgds_maps, cgdi_maps

In [None]:
algorithm.to('cuda')
fc_network = nn.Sequential(nn.Linear(2048, 2), nn.Linear(2, 7)).to('cuda')

optimizer = torch.optim.SGD(fc_network.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss()

for e in range(10):
    c_loss_total = 0
    count = 0
    C_data = torch.utils.data.DataLoader(dataset[1], batch_size=32, shuffle=True)
    P_data = torch.utils.data.DataLoader(dataset[2], batch_size=32, shuffle=True)
    S_data = torch.utils.data.DataLoader(dataset[3], batch_size=32, shuffle=True)
    for data_1, data_2, data_3 in zip(C_data, P_data, S_data):
        all_x = torch.cat((data_1[0], data_2[0], data_3[0])).to('cuda')
        all_y = torch.cat((data_1[1], data_2[1], data_3[1])).to('cuda')

        features = algorithm.network.get_feature(all_x).to('cuda')
        f = pool(features)
        f = torch.flatten(f, 1)
        y_hat = fc_network(f)
        c_loss = loss(y_hat, all_y)
        optimizer.zero_grad()
        c_loss.backward()
        optimizer.step()
        c_loss_total += c_loss.item()
        count += 1
    if e % 10 == 0:
        print(e, c_loss_total/count)
        
print(e, c_loss_total/count)

In [None]:

C_data = torch.utils.data.DataLoader(dataset[1], batch_size=128, shuffle=True)
P_data = torch.utils.data.DataLoader(dataset[2], batch_size=128, shuffle=True)
S_data = torch.utils.data.DataLoader(dataset[3], batch_size=128, shuffle=True)

for data_1, data_2, data_3 in zip(C_data, P_data, S_data):
    all_x = torch.cat((data_1[0], data_2[0], data_3[0]))
    all_y = torch.cat((data_1[1], data_2[1], data_3[1]))
    
    len1 = data_1[0].size(0)
    len2 = data_2[0].size(0)
    len3 = data_3[0].size(0) 
    break

values, counts = np.unique(all_y, return_counts=True)

In [None]:
y_label = []
for i in all_y:
    if i == 0:
        y_label.append('Dog')
    elif i == 1:
        y_label.append('Elephant')
    elif i == 2:
        y_label.append('Giraffe')
    elif i == 3:
        y_label.append('Guitar')
    elif i == 4:
        y_label.append('Horse')
    elif i == 5:
        y_label.append('House')
    elif i == 6:
        y_label.append('Person')

In [None]:
domain = torch.LongTensor(len1+len2+len3)
domain[:len1] = 0
domain[len1:len1+len2] = 1
domain[len1+len2:] = 2

In [None]:
domain_label = []
for i in range(len1+len2+len3):
    if i < len1:
        domain_label.append('Cartoon')
    elif i < len1+len2:
        domain_label.append('Photo')
    else:
        domain_label.append('Sketch')

In [None]:
algorithm.to('cpu')

feature_maps, class_gradients, domain_gradients = extract_gradients(algorithm, all_x, all_y, domain)
cs, ci, ds, di = get_four_features(feature_maps, class_gradients, domain_gradients)
csds, csdi, cgds, cgdi = get_detail_four_features(feature_maps, class_gradients, domain_gradients)

In [None]:
fc_network.to('cpu')

feature_maps_reduce = pool(feature_maps)
feature_maps_reduce = torch.flatten(feature_maps_reduce, 1)
feature_maps_reduce = fc_network[0](feature_maps_reduce)
feature_maps_reduce = feature_maps_reduce.to('cpu').detach().numpy()

cs_reduce = pool(cs)
cs_reduce = torch.flatten(cs_reduce, 1)
cs_reduce = fc_network[0](cs_reduce)
cs_reduce = cs_reduce.to("cpu").detach().numpy()

ci_reduce = pool(ci)
ci_reduce = torch.flatten(ci_reduce, 1)
ci_reduce = fc_network[0](ci_reduce)
ci_reduce = ci_reduce.to("cpu").detach().numpy()

vis_data = pd.DataFrame({'x': feature_maps_reduce[:, 0], 
                         'y': feature_maps_reduce[:, 1], 
                         'cs_1': cs_reduce[:, 0], 
                         'cs_2': cs_reduce[:, 1],
                         'ci_1': ci_reduce[:, 0], 
                         'ci_2': ci_reduce[:, 1],
                         'class': y_label,
                         'domain': domain_label
                        })

import matplotlib as mpl

mpl.rcParams['font.size'] = 14

fig, axes = plt.subplots(1, 2, figsize=(9,4.5))

custom_palette = ['#9b5fe0', '#16a4d8', '#60dbe8', '#8bd346', '#efdf48', '#f9a52c', '#d64e12']
# Set the custom color palette
sns.set_palette('tab10')

sns.scatterplot(data=vis_data, x="cs_1", y="cs_2", hue="class", legend=False, ax=axes[0])
sns.despine(left=False, bottom=False, right=False, top=False)
axes[0].set(xlabel=None, ylabel=None, yticklabels=[], xticklabels=[])
axes[0].set_title('Class-specific feature')


sns.scatterplot(data=vis_data, x="ci_1", y="ci_2", hue="class", legend='full', ax=axes[1])
sns.despine(left=False, bottom=False, right=False, top=False)
axes[1].set(xlabel=None, ylabel=None, yticklabels=[], xticklabels=[])
axes[1].set_title('Class-generic feature')
plt.legend(title='Class', fontsize=12, bbox_to_anchor=(0.8, -0.05), ncol=4)

fig.savefig('class_feature.png', bbox_inches="tight", dpi=300)



In [None]:
# train a new linear classifier to reduce dimension
algorithm.to('cuda')
fc_network = nn.Sequential(nn.Linear(2048, 2), nn.Linear(2, 3)).to('cuda')

optimizer = torch.optim.SGD(fc_network.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss()

for e in range(10):
    c_loss_total = 0
    count = 0
    C_data = torch.utils.data.DataLoader(dataset[1], batch_size=32, shuffle=True)
    P_data = torch.utils.data.DataLoader(dataset[2], batch_size=32, shuffle=True)
    S_data = torch.utils.data.DataLoader(dataset[3], batch_size=32, shuffle=True)
    for data_1, data_2, data_3 in zip(C_data, P_data, S_data):
        all_x = torch.cat((data_1[0], data_2[0], data_3[0])).to('cuda')
        
        domain = torch.LongTensor(32+32+32)
        domain[:32] = 0
        domain[32:32+32] = 1
        domain[32+32:] = 2
        all_y = domain.to('cuda')

        features = algorithm.network.get_feature(all_x).to('cuda')
        f = pool(features)
        f = torch.flatten(f, 1)
        y_hat = fc_network(f)
        c_loss = loss(y_hat, all_y)
        optimizer.zero_grad()
        c_loss.backward()
        optimizer.step()
        c_loss_total += c_loss.item()
        count += 1
    if e % 10 == 0:
        print(e, c_loss_total/count)
        
print(e, c_loss_total/count)

In [None]:
fc_network.to('cpu')

feature_maps_reduce = pool(feature_maps)
feature_maps_reduce = torch.flatten(feature_maps_reduce, 1)
feature_maps_reduce = fc_network[0](feature_maps_reduce)
feature_maps_reduce = feature_maps_reduce.to('cpu').detach().numpy()

ds_reduce = pool(ds)
ds_reduce = torch.flatten(ds_reduce, 1)
ds_reduce = fc_network[0](ds_reduce)
ds_reduce = ds_reduce.to("cpu").detach().numpy()

di_reduce = pool(di)
di_reduce = torch.flatten(di_reduce, 1)
di_reduce = fc_network[0](di_reduce)
di_reduce = di_reduce.to("cpu").detach().numpy()

vis_data = pd.DataFrame({'x': feature_maps_reduce[:, 0], 
                         'y': feature_maps_reduce[:, 1], 
                         'ds_1': ds_reduce[:, 0], 
                         'ds_2': ds_reduce[:, 1],
                         'di_1': di_reduce[:, 0], 
                         'di_2': di_reduce[:, 1],
                         'class': y_label,
                         'domain': domain_label
                        })

import matplotlib as mpl

mpl.rcParams['font.size'] = 14

sns.set_palette('dark')
fig, axes = plt.subplots(1, 2, figsize=(9,4.5))


sns.scatterplot(data=vis_data, x="ds_1", y="ds_2", hue="domain", legend=False, ax=axes[0])
sns.despine(left=False, bottom=False, right=False, top=False)
axes[0].set(xlabel=None, ylabel=None, yticklabels=[], xticklabels=[])
axes[0].set_title('Domain-specific feature')


sns.scatterplot(data=vis_data, x="di_1", y="di_2", hue="domain", legend='full', ax=axes[1])
sns.despine(left=False, bottom=False, right=False, top=False)
axes[1].set(xlabel=None, ylabel=None, yticklabels=[], xticklabels=[])
axes[1].set_title('Domain-generic feature')
plt.legend(title='Domain', fontsize=12, bbox_to_anchor=(0.6, -0.05), ncol=4)


fig.savefig('domain_feature.png', bbox_inches="tight", dpi=300)