In [None]:
import sys

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torch.utils.data

import adin 
sys.path.insert(0, '../')
from domainbed import datasets, hparams_registry


dataset = 'PACS'
seed = 0
data_dir = '' # input data dir here 
test_envs = [0]
holdout_fraction = 0.2
trial_seed = 0

hparams = hparams_registry.default_hparams('ERM', dataset)
hparams['data_augmentation'] = False

dataset = vars(datasets)[dataset](data_dir, test_envs, hparams)

In [None]:
# download pretrained weights to pretrained_weights/ 
# link: https://github.com/naoto0804/pytorch-AdaIN

encoder = adin.vgg
encoder.load_state_dict(torch.load('pretrained_weights/vgg_normalised.pth'))
decoder = adin.decoder
decoder.load_state_dict(torch.load('pretrained_weights/decoder.pth'))
device = 'cpu'

In [None]:
# freeze the featurizer, train a class classifier and domain classifier for feature decomposition

network = adin.Net(encoder, decoder, 7, 3)

optimizer = torch.optim.SGD(network.class_classifier.parameters(), lr=5e-4)
loss = torch.nn.CrossEntropyLoss()

step = 0
for e in range(30):
    c_loss_total = 0
    count = 0
    A_data = torch.utils.data.DataLoader(dataset[0], batch_size=hparams['batch_size'], shuffle=True)
    C_data = torch.utils.data.DataLoader(dataset[1], batch_size=hparams['batch_size'], shuffle=True)
    P_data = torch.utils.data.DataLoader(dataset[2], batch_size=hparams['batch_size'], shuffle=True)
    for data_1, data_2, data_3 in zip(A_data, C_data, P_data):
        all_x = torch.cat((data_1[0], data_2[0], data_3[0]))
        all_y = torch.cat((data_1[1], data_2[1], data_3[1]))

        y_hat = network.forward_class_classifier(all_x)
        c_loss = loss(y_hat, all_y)
        optimizer.zero_grad()
        c_loss.backward()
        optimizer.step()
        c_loss_total += c_loss.item()
        count += 1
        step += 1
        if step % 100 == 0:
            print(c_loss_total/count)
        if step == 400:
            break
    if step == 400:
        break


domain_optimizer = torch.optim.SGD(network.domain_classifier.parameters(), lr=5e-4)
domain_loss = torch.nn.CrossEntropyLoss()

step = 0
for e in range(5):
    d_loss_total = 0
    count = 0
    A_data = torch.utils.data.DataLoader(dataset[0], batch_size=hparams['batch_size'], shuffle=True)
    C_data = torch.utils.data.DataLoader(dataset[1], batch_size=hparams['batch_size'], shuffle=True)
    P_data = torch.utils.data.DataLoader(dataset[2], batch_size=hparams['batch_size'], shuffle=True)
    for data_1, data_2, data_3 in zip(A_data, C_data, P_data):
        all_x = torch.cat((data_1[0], data_2[0], data_3[0]))
        
        domain = torch.LongTensor(all_x.size(0))
        domain[:data_1[0].size(0)] = 0
        domain[data_1[0].size(0):data_1[0].size(0)+data_2[0].size(0)] = 1
        domain[data_1[0].size(0)+data_2[0].size(0):] = 2

        y_hat = network.forward_domain_classifier(all_x)
        d_loss = domain_loss(y_hat, domain)
        domain_optimizer.zero_grad()
        d_loss.backward()
        domain_optimizer.step()
        d_loss_total += d_loss.item()
        count += 1
        step += 1
        if step % 100 == 0:
            print(d_loss_total/count)
        if step == 200:
            break
    if step == 200:
        break

In [None]:
network.to('cpu')

def show_image(image):
    result = []
    for img in image:
        img = img - np.min(img)
        img = img / (np.max(img)-np.min(img))
        result.append(img)
    result = np.float32(result)
    plt.imshow(np.transpose(result, (1, 2, 0))) 

In [None]:
def return_image_to_show(image):
    result = []
    for img in image:
        img = img - np.min(img)
        img = img / (np.max(img)-np.min(img))
        result.append(img)
    result = np.float32(result)
    return np.transpose(result, (1, 2, 0))

In [None]:
# get a batch of examples, then visualise how it looks like afer augmentation 
A_data = torch.utils.data.DataLoader(dataset[0], batch_size=32, shuffle=True)
C_data = torch.utils.data.DataLoader(dataset[1], batch_size=32, shuffle=True)
P_data = torch.utils.data.DataLoader(dataset[2], batch_size=32, shuffle=True)

for data_1, data_2, data_3 in zip(A_data, C_data, P_data):
    all_x = torch.cat((data_1[0], data_2[0], data_3[0]))
    all_y = torch.cat((data_1[1], data_2[1], data_3[1]))
    domain = torch.LongTensor(all_x.size(0))
    domain[:data_1[0].size(0)] = 0
    domain[data_1[0].size(0):data_1[0].size(0)+data_2[0].size(0)] = 1
    domain[data_1[0].size(0)+data_2[0].size(0):] = 2
    break

In [None]:
all_features = network.encode(all_x)
reconstruct = network.get_image(all_features)
fig, ax = plt.subplots(figsize=(20, 20))
show_image(torchvision.utils.make_grid(reconstruct,8, 12).numpy())

In [None]:
from domainbed.mixup_module import DomainClassMixAugmentation

def extract_four_mask(feature_map, class_gradient, domain_gradient):
    feature_map = torch.mean(feature_map, dim=(1,2), keepdim=True)
    
    class_gradient = class_gradient[:, None, None]
    domain_gradient = domain_gradient[:, None, None]
    
    cam = feature_map * class_gradient
    dam = feature_map * domain_gradient
    class_thr = DomainClassMixAugmentation.get_threshold(cam, 0.5)
    domain_thr = DomainClassMixAugmentation.get_threshold(dam, 0.5)

    cs_idx = cam >= class_thr
    cg_idx = cam < class_thr
    ds_idx = dam >= domain_thr
    di_idx = dam < domain_thr

    csds_mask = cs_idx * ds_idx
    csdi_mask = cs_idx * di_idx
    cgds_mask = cg_idx * ds_idx
    cgdi_mask = cg_idx * di_idx
    return cs_idx, ds_idx, csds_mask, csdi_mask, cgds_mask, cgdi_mask

def extract_gradients(network, y, style):
    class_gradient = network.class_classifier.weight[y]
    domain_gradient = network.domain_classifier.weight[style]
    
    return class_gradient, domain_gradient

In [None]:
# augmentation by us 
all_features = network.encode(all_x)

result = torch.zeros(all_features.size())
csds = torch.zeros(all_features.size())
cgds = torch.zeros(all_features.size())
csdi = torch.zeros(all_features.size())
cgdi = torch.zeros(all_features.size())

for i in range(all_x.size(0)):
    # perform augmentation 
    current_feature = all_features[i]
    cg, dg = extract_gradients(network, all_y[i], domain[i])
    cs_idx, ds_idx, csds_f, csdi_f, cgds_f, cgdi_f = extract_four_mask(current_feature, cg, dg)
    
    csds[i] = current_feature * csds_f.expand(512, 28, 28)
    csdi[i] = current_feature * csdi_f.expand(512, 28, 28)
    cgds[i] = current_feature * cgds_f.expand(512, 28, 28)
    cgdi[i] = current_feature * cgdi_f.expand(512, 28, 28)
    
for b in range(all_x.size(0)):
    diff_y = DomainClassMixAugmentation.sample_different_class_different_domain(i, all_y, domain, all_y[b], domain[b])
    same_y = DomainClassMixAugmentation.sample_same_class_different_domain(i, all_y, domain, all_y[b], domain[b])
    
    new_csds = 1 * csds[b] + 0 * csds[same_y]
    new_cgds = 0 * cgds[b] + 1 * cgds[diff_y]
    
    result[b] = new_csds + new_cgds + csdi[b] + cgdi[b]

In [None]:
decoded = network.get_image(result)

fig, ax = plt.subplots(figsize=(20, 20))
show_image(torchvision.utils.make_grid(decoded,8, 12).numpy())

In [None]:
import torch.nn as nn
class DistributionUncertainty(nn.Module):
    """
    Distribution Uncertainty Module
        Args:
        p   (float): probabilty of foward distribution uncertainty module, p in [0,1].

    """

    def __init__(self, p=0.5, eps=1e-6):
        super(DistributionUncertainty, self).__init__()
        self.eps = eps
        self.p = p
        self.factor = 1.0

    def _reparameterize(self, mu, std):
        epsilon = torch.randn_like(std) * self.factor
        return mu + epsilon * std

    def sqrtvar(self, x):
        t = (x.var(dim=0, keepdim=True) + self.eps).sqrt()
        t = t.repeat(x.shape[0], 1)
        return t

    def forward(self, x):
        if (not self.training) or (np.random.random()) > self.p:
            return x

        mean = x.mean(dim=[2, 3], keepdim=False)
        std = (x.var(dim=[2, 3], keepdim=False) + self.eps).sqrt()

        sqrtvar_mu = self.sqrtvar(mean)
        sqrtvar_std = self.sqrtvar(std)

        beta = self._reparameterize(mean, sqrtvar_mu)
        gamma = self._reparameterize(std, sqrtvar_std)

        x = (x - mean.reshape(x.shape[0], x.shape[1], 1, 1)) / std.reshape(x.shape[0], x.shape[1], 1, 1)
        x = x * gamma.reshape(x.shape[0], x.shape[1], 1, 1) + beta.reshape(x.shape[0], x.shape[1], 1, 1)

        return x
    
def dsu_out(x, network):
    mix_module1 = DistributionUncertainty(p=0.5)
    mix_module2 = DistributionUncertainty(p=0.5)
    mix_module3 = DistributionUncertainty(p=0.5)
    mix_module4 = DistributionUncertainty(p=0.5)
    
    out = x
    out = network.enc_1(out)
    out = mix_module1(out)

    out = network.enc_2(out)
    out = mix_module2(out)

    out = network.enc_3(out)
    out = mix_module3(out)
    out = network.enc_4(out)
    out = mix_module4(out)
    return out

In [None]:
decoded2 = network.get_image(dsu_out(all_x, network))
fig, ax = plt.subplots(figsize=(20, 20))
show_image(torchvision.utils.make_grid(decoded2,8, 12).numpy())