# Implementação do paper **Adversarial Discriminative Domain Adaptation**

Esse paper propoe um framework de treinamento adversario, em que a rede de extração de features tenta confundir o discriminador, o qual tenta prever se determinado batch de imagens pertence ao source ou ao target do dominio. Os resultados mostraram a eficiencia em aplicar treinamento adversario em adaptação de dominio.

*Paper disponivel em: https://arxiv.org/abs/1702.05464

In [1]:
from torch import nn
import torch
from torchvision import datasets, transforms
from torchvision.datasets import SVHN, MNIST
import torch.nn.functional as F

True

## 1. Download do Dataset

Para essa implementação, utilizaremos o dataset MNIST e o SVHN. Além disso aplicaremos uma transformação simples, normalizando-o com 0.5 na media e variancia e redimensionando a imagem para o tamanho 28x28 com 3 canais.

A depender de seu hardware, será necessario diminuir o batch_size para poupar recursos no treinamento em GPU/CPU

In [2]:
transform_ = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((28, 28)),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    ]
)

mnist_dataset = datasets.MNIST(
    root="./data/",
    train=True,
    transform=transform_,
    download=True
)

mnist_data_loader = torch.utils.data.DataLoader(
    dataset=mnist_dataset,
    batch_size=128,
    shuffle=True,drop_last = True,
    pin_memory=True, num_workers=7
)

mnist_dataset = datasets.MNIST(
    root="./data/",
    train=False,
    transform=transform_,
    download=True
)

mnist_data_loader_test = torch.utils.data.DataLoader(
    dataset=mnist_dataset,
    batch_size=128,
    shuffle=False,drop_last = True,
    pin_memory=True, num_workers=7
)

In [4]:
usps_dataset = datasets.SVHN(
    root="./data/",
    transform=transforms.Compose([transforms.ToTensor(),transforms.Resize((28, 28))]),
    download=True
)

usps_data_loader_train = torch.utils.data.DataLoader(
    dataset=usps_dataset,
    batch_size=128,
    shuffle=True,drop_last = True,
    pin_memory=True, num_workers=7
)

Using downloaded and verified file: ./data/train_32x32.mat


## 2. Criação dos Modelos

Conforme descrito no paper de Eric Tzeng,, devemos criar tres modelos: discriminador, classificador e extrator de features.

* Extrator de Features: Essa rede é uma sequencia de convoluções que tem como output uma vetor de features de tamanho 100.

* Classificador: A rede de classificação tem como input um vetor de features de tamanho 100 e retorna a probabilidade entre as 10 classes do MNIST (0, 1, 2, ..., 9).

* Discriminador: Para o discriminador, a rede terá como input o vetor de features de tamanho 100 e retornará uma classificação entre duas classes (0 e 1 ou "Source" e "Target").


In [40]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fully_conec = nn.Linear(100, 10)

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 10, kernel_size=3),
            nn.BatchNorm2d(10),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(10, 25, kernel_size=3),
            nn.BatchNorm2d(25),
            nn.MaxPool2d(kernel_size=4),
            nn.ReLU(),
        )
    
    def forward(self, input):
        output = self.encoder(input).view(-1, 100)
        return self.fully_conec(output)


In [42]:
class Classificador(nn.Module):
    def __init__(self):
        super(Classificador, self).__init__()
        self.fully_connected = nn.Sequential(
            nn.Linear(50, 10)
        )
    
    def forward(self, input):
        out = self.fully_connected(input)
        return out


In [44]:
class CNN(nn.Module):
    def __init__(self, target=False):
        super(CNN, self).__init__()
        self.encoder = Encoder()
        self.classificador = Classificador()
        
        if target:
            for param in self.classificador.parameters():
                param.requires_grad = False
    
    def forward(self, input):
        return self.classificador(self.encoder(input))

In [None]:
discriminator = nn.Sequential(
    nn.Linear(50, 500),
    nn.LeakyReLU(0.2),
    nn.Linear(500, 500),
    nn.LeakyReLU(0.2),
    nn.Linear(500, 2),
).cuda()

## 3. Treinando o Classificador

Primeiramente, vamos treinar a rede de extração de features e o classificador, utilizando o dataset do SVHN para que posteriormente seja possivel adaptar o dominio para o dataset MNIST.

In [46]:
import torch.optim as optim
from tqdm import tqdm

CNN_source = CNN().cuda()

CNN_source.train()

optimizer = optim.Adam(CNN_source.parameters())
criterion = nn.CrossEntropyLoss()

epochs = 10

for epoch in tqdm(range(epochs)):
    for step, (images, ground_truth) in enumerate(usps_data_loader_train):
        images = images.cuda()
        ground_truth = ground_truth.cuda()

        prediction =CNN_source(images)
        
        loss = criterion(prediction, ground_truth)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

print(f" Loss = {loss}")


 24%|██████████████████▏                                                         | 12/50 [01:10<03:41,  5.84s/it]Exception in thread Thread-224:
Traceback (most recent call last):
  File "/home/andre/anaconda3/lib/python3.9/threading.py", line 973, in _bootstrap_inner
    self.run()
  File "/home/andre/anaconda3/lib/python3.9/threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "/home/andre/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 49, in _pin_memory_loop
    do_one_step()
  File "/home/andre/anaconda3/lib/python3.9/site-packages/torch/utils/data/_utils/pin_memory.py", line 26, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  File "/home/andre/anaconda3/lib/python3.9/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/andre/anaconda3/lib/python3.9/site-packages/torch/multiprocessing/reductions.py", line 305, in rebuild_storage_fd
    fd = df.detac

KeyboardInterrupt: 

Por fim, podemos salvar o modelo treinado.

In [47]:
import shutil

def save(log_dir, state_dict, is_best):
    checkpoint_path = os.path.join(log_dir, 'checkpoint.pt')
    torch.save(state_dict, checkpoint_path)
    if is_best:
        best_model_path = os.path.join(log_dir, 'best_model.pt')
        shutil.copyfile(checkpoint_path, best_model_path)

state_dict = {
    'model': CNN_source.state_dict(),
    'optimizer': optimizer.state_dict(),
    'epoch': 10,
    'val/acc': loss,
}
save("./", state_dict, True)

## 4. Calculando Acuracia para SVHN

In [48]:
from sklearn.metrics import accuracy_score

CNN_source.eval()

accuracy = 0
n = 0
for (images, ground_truth) in usps_data_loader_train:
    n+=1
    images = images.cuda()
    ground_truth = ground_truth.cuda()
    prediction = CNN_source(images)
    pred_max = prediction.data.max(1)[1]
    accuracy += accuracy_score(ground_truth.data.cpu(), pred_max.cpu())


accuracy/n

0.8438592657342657

## 5. Calculando Acuracia para MNIST

In [49]:
CNN_source.eval()

accuracy = 0

for (images, ground_truth) in mnist_data_loader_test:
    images = images.cuda()
    ground_truth = ground_truth.cuda()

    prediction = CNN_source(images)
    pred_max = prediction.data.max(1)[1]
    accuracy += pred_max.eq(ground_truth.data).cpu().sum()

accuracy = accuracy.item()/len(mnist_data_loader_test.dataset)

accuracy

0.4619

Como pudemos ver, para o dataset SVHN, tivemos uma acuracia de ~84%, porem, uma acuracia de 46% para o dataset MNIST.

Por serem datasets com o mesmo dominio, podemos aplicar as tecnicas descritas no paper para melhorar o resultado do MNIST.

## 6. Treinamento com Descriminador

In [53]:
CNN_target = CNN(target=True).cuda()

CNN_target.load_state_dict(CNN_source.state_dict())

discriminator.train()
CNN_target.encoder.train()
CNN_source.eval()


criterion = nn.CrossEntropyLoss()

optimizer_target_encoder = optim.Adam(
    CNN_target.encoder.parameters(),
    lr=3e-4,
    betas=(0.5, 0.999)
)
optimizer_discriminator = optim.Adam(
    discriminator.parameters(),
    lr=2e-4,
    betas=(0.5, 0.999)
)


epochs = 50


for epoch in range(epochs):
    discriminator.train()
    CNN_target.encoder.train()
    for step, ((images_target, _), (images_source, _)) in enumerate(zip(mnist_data_loader, usps_data_loader_train)):
        
        
        images_source = images_source.cuda()
        feature_source = CNN_source.encoder(images_source)
        prediction_source = discriminator(feature_source)
        
        images_target = images_target.cuda()
        feature_target = CNN_target.encoder(images_target)
        prediction_target = discriminator(feature_target)
        
        label_source_domain = torch.zeros(images_source.size(0)).long().cuda()
        label_target_domain = torch.ones(images_target.size(0)).long().cuda()
        
        
        # Treinamento do Discriminator
        descriminator_predictions = torch.cat([prediction_source, prediction_target], dim=0)
        descriminator_target_label = torch.cat([label_source_domain, label_target_domain], dim=0)
        descriminator_loss = criterion(descriminator_predictions, descriminator_target_label)
        optimizer_discriminator.zero_grad()
        descriminator_loss.backward()
        optimizer_discriminator.step()
        

        descriminator_target_output = CNN_target.encoder(images_target)
        descrminator_output = discriminator(descriminator_target_output)
        loss_target = criterion(descrminator_output, label_source_domain)
        optimizer_target_encoder.zero_grad()
        loss_target.backward()
        optimizer_target_encoder.step()
        
        
    accuracy = 0
    CNN_target.eval()

    accuracy = 0
    n = 0
    for (images, ground_truth) in mnist_data_loader:
        n+=1
        images = images.cuda()
        ground_truth = ground_truth.cuda()
        prediction = CNN_target(images)
        pred_max = prediction.data.max(1)[1]
        accuracy += accuracy_score(ground_truth.data.cpu(), pred_max.cpu())
        

    print(f"Epoch = {epoch}; Discriminator Loss= {descriminator_loss}; Target Encoder Loss = {loss_target}; Target ACC={accuracy/n}")

Epoch = 0; Discriminator Loss= 0.381192684173584; Target Encoder Loss = 1.6581473350524902; Target ACC=0.3797242254273504
Epoch = 1; Discriminator Loss= 0.4029332101345062; Target Encoder Loss = 2.7219181060791016; Target ACC=0.4188034188034188
Epoch = 2; Discriminator Loss= 0.43447619676589966; Target Encoder Loss = 2.0638458728790283; Target ACC=0.3385750534188034
Epoch = 3; Discriminator Loss= 0.4020487666130066; Target Encoder Loss = 2.1944973468780518; Target ACC=0.3480902777777778
Epoch = 4; Discriminator Loss= 0.4178672134876251; Target Encoder Loss = 2.327063798904419; Target ACC=0.37837206196581197
Epoch = 5; Discriminator Loss= 0.3326684236526489; Target Encoder Loss = 1.8382474184036255; Target ACC=0.34132946047008544
Epoch = 6; Discriminator Loss= 0.3535568416118622; Target Encoder Loss = 1.7790955305099487; Target ACC=0.3447516025641026
Epoch = 7; Discriminator Loss= 0.3385522961616516; Target Encoder Loss = 2.18906307220459; Target ACC=0.3482238247863248
Epoch = 8; Discri

## 7. Calculando Acuracia para MNIST pós adaptação

In [54]:
CNN_target.eval()

accuracy = 0
n = 0
for (images, ground_truth) in mnist_data_loader_test:
    n+=1
    images = images.cuda()
    ground_truth = ground_truth.cuda()

    prediction = CNN_target(images)
    pred_max = prediction.data.max(1)[1]
    accuracy += accuracy_score(ground_truth.data.cpu(), pred_max.cpu())

# accuracy = accuracy.item()/len(mnist_data_loader_test.dataset)

accuracy/n

0.4914863782051282