### Make Necessary Imports

In [1]:
import argparse

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor
from tqdm import tqdm

from data import MNISTM
from models import Net

If CUDA-enabled GPU isn't found, we run on CPU.

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Change the path of the saved checkpoint for source domain trained model, if necessary. Available models are RevGrad, ADDA & WDGRL.

In [3]:
# Uncomment the corresponding `MODEL_FILE` line for the model being tested on 

MODEL_FILE = "trained_models/revgrad.pt"    # For RevGrad
# MODEL_FILE = "trained_models/adda.pt"    # For ADDA
# MODEL_FILE = "trained_models/wdgrl.pt"    # For WDGRL
batch_size = 256

### Test the model

In [4]:
dataset = MNISTM(train=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                        drop_last=False, num_workers=1, pin_memory=True)

model = Net().to(device)
model.load_state_dict(torch.load(MODEL_FILE))
model.eval()

total_accuracy = 0
with torch.no_grad():
    for x, y_true in tqdm(dataloader, leave=False):
        x, y_true = x.to(device), y_true.to(device)
        y_pred = model(x)
        total_accuracy += (y_pred.max(1)[1] == y_true).float().mean().item()

mean_accuracy = total_accuracy / len(dataloader)
print(f'Accuracy on target data: {mean_accuracy:.4f}')

                                                                                                                                                    

Accuracy on target data: 0.8107
