#### Installing the required libraries

In [None]:
# !pip install torchinfo

#### Loading in the required libraries

In [1]:
import torch
import torch.nn as nn
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision.models import resnet50
from torchvision.models import ResNet50_Weights
from torchvision.datasets import MNIST, USPS, SVHN
from torchinfo import summary
from tqdm import tqdm
import torch.optim as optim
from sklearn.manifold import TSNE
import numpy as np
import random
from torch.autograd import Function
import time

#### Setting the device

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("The device is: ", device)

The device is:  cpu


## Loading in the Datasets

In [3]:
# Function to create a subset
def create_subset(dataset, subset_size, seed=42):
    """
    Creates a random subset of the dataset.

    Args:
        dataset (Dataset): The original dataset.
        subset_size (int): The number of samples in the subset.
        seed (int): Random seed for reproducibility.

    Returns:
        Subset: A PyTorch Subset object.
    """
    random.seed(seed)
    indices = random.sample(range(len(dataset)), subset_size)
    return Subset(dataset, indices)

#### Office-31

In [4]:
data_path = 'OFFICE31'
amazon_path = os.path.join(data_path, 'amazon')
webcam_path = os.path.join(data_path, 'webcam')
dslr_path = os.path.join(data_path, 'dslr')

print('amazon_path:', amazon_path)
print('webcam_path:', webcam_path)
print('dslr_path:', dslr_path)

def load_data(root_path, domain, batch_size, phase):
    transform_dict = {
        'src': transforms.Compose(
        [transforms.RandomResizedCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ]),
        'tar': transforms.Compose(
        [transforms.Resize(224),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
         ])}
    data = datasets.ImageFolder(root=os.path.join(root_path, domain), transform=transform_dict[phase])
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=phase=='src', drop_last=phase=='tar', num_workers=8)
    return data_loader

amazon_loader = load_data(data_path, 'amazon', 64, 'src')
webcam_loader = load_data(data_path, 'webcam', 64, 'tar')
dslr_loader = load_data(data_path, 'dslr', 64, 'tar')

# Checking the size of these data loaders
print('amazon_loader size:', len(amazon_loader))
print('webcam_loader size:', len(webcam_loader))
print('dslr_loader size:', len(dslr_loader))

# # Check the size of the first batch
# amazon_data = next(iter(amazon_loader))
# webcam_data = next(iter(webcam_loader))
# dslr_data = next(iter(dslr_loader))

# print('amazon_data size:', amazon_data[0].size())
# print('webcam_data size:', webcam_data[0].size())
# print('dslr_data size:', dslr_data[0].size())


amazon_path: OFFICE31/amazon
webcam_path: OFFICE31/webcam
dslr_path: OFFICE31/dslr
amazon_loader size: 45
webcam_loader size: 12
dslr_loader size: 7


#### Digits Datasets

In [8]:
# Define transformations for the datasets
transform_mnist_usps = transforms.Compose([
    transforms.Resize((224, 224)),                  # Resize to 224x224
    transforms.Grayscale(num_output_channels=3),  # Ensure grayscale images (for USPS/MNIST) and convert to 3 channels
    transforms.ToTensor(),                        # Convert to Tensor
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    # transforms.Normalize((0.5,), (0.5,))          # Normalize to [-1, 1]
])

transform_svhn = transforms.Compose([
    transforms.Resize((224, 224)),                  # Resize to 224x224
    transforms.ToTensor(),                        # Convert to Tensor
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    # transforms.Normalize((0.5,), (0.5,))          # Normalize to [-1, 1]
])

batch_size = 64

# Loading in MNIST dataset
mnist_train = MNIST(root='./data', train=True, download=True, transform=transform_mnist_usps)
mnist_test = MNIST(root='./data', train=False, download=True, transform=transform_mnist_usps)

# MNIST dataloaders
mnist_train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
mnist_test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

# Loading in the USPS dataset
usps_train = USPS(root='./data', train=True, download=True, transform=transform_mnist_usps)
usps_test = USPS(root='./data', train=False, download=True, transform=transform_mnist_usps)

# USPS dataloaders
usps_train_loader = DataLoader(usps_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
usps_test_loader = DataLoader(usps_test, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

# Loading in the SVHN dataset
svhn_train = SVHN(root='./data', split='train', download=True, transform=transform_svhn)
svhn_test = SVHN(root='./data', split='test', download=True, transform=transform_svhn)

# SVHN dataloaders
svhn_train_loader = DataLoader(svhn_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
svhn_test_loader = DataLoader(svhn_test, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

# Print dataset sizes for verification
print(f"MNIST Train Size: {len(mnist_train)}")
print(f"MNIST Test Size: {len(mnist_test)}")
print(f"USPS Train Size: {len(usps_train)}")
print(f"USPS Test Size: {len(usps_test)}")
print(f"SVHN Train Size: {len(svhn_train)}")
print(f"SVHN Test Size: {len(svhn_test)}")

Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat
MNIST Train Size: 60000
MNIST Test Size: 10000
USPS Train Size: 7291
USPS Test Size: 2007
SVHN Train Size: 73257
SVHN Test Size: 26032


In [9]:
# Function to create a subset
def create_subset(dataset, subset_size, seed=42):
    """
    Creates a random subset of the dataset.

    Args:
        dataset (Dataset): The original dataset.
        subset_size (int): The number of samples in the subset.
        seed (int): Random seed for reproducibility.

    Returns:
        Subset: A PyTorch Subset object.
    """
    random.seed(seed)
    indices = random.sample(range(len(dataset)), subset_size)
    return Subset(dataset, indices)

# Create subsets of the train and test datasets
mnist_train_subset = create_subset(mnist_train, 10000)
mnist_test_subset = create_subset(mnist_test, 5000)
usps_train_subset = create_subset(usps_train, 5000)
usps_test_subset = create_subset(usps_test, 1000)
svhn_train_subset = create_subset(svhn_train, 10000)
svhn_test_subset = create_subset(svhn_test, 5000)

# Creating the subset dataloaders
mnist_train_subset_loader = DataLoader(mnist_train_subset, batch_size=batch_size, shuffle=True)
mnist_test_subset_loader = DataLoader(mnist_test_subset, batch_size=batch_size, shuffle=False)
usps_train_subset_loader = DataLoader(usps_train_subset, batch_size=batch_size, shuffle=True)
usps_test_subset_loader = DataLoader(usps_test_subset, batch_size=batch_size, shuffle=False)
svhn_train_subset_loader = DataLoader(svhn_train_subset, batch_size=batch_size, shuffle=True)
svhn_test_subset_loader = DataLoader(svhn_test_subset, batch_size=batch_size, shuffle=False)


# Check the size of the subsets
print(f"MNIST Train Subset Size: {len(mnist_train_subset)}")
print(f"MNIST Test Subset Size: {len(mnist_test_subset)}")
print(f"USPS Train Subset Size: {len(usps_train_subset)}")
print(f"USPS Test Subset Size: {len(usps_test_subset)}")
print(f"SVHN Train Subset Size: {len(svhn_train_subset)}")
print(f"SVHN Test Subset Size: {len(svhn_test_subset)}")


MNIST Train Subset Size: 10000
MNIST Test Subset Size: 5000
USPS Train Subset Size: 5000
USPS Test Subset Size: 1000
SVHN Train Subset Size: 10000
SVHN Test Subset Size: 5000


## DANN Model

#### Util functions

In [5]:
class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None

def optimizer_scheduler(optimizer, p):
    """
    Adjust the learning rate of optimizer
    :param optimizer: optimizer for updating parameters
    :param p: a variable for adjusting learning rate
    :return: optimizer
    """
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.01 / (1. + 10 * p) ** 0.75

    return optimizer

#### DANN Class

In [6]:
# class DANN(nn.Module):
#     def __init__(self, num_classes, alpha=1.0):
#         super(DANN, self).__init__()
#         self.alpha = alpha
#         self.feature_extractor = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
#         num_features = self.feature_extractor.fc.in_features
#         self.feature_extractor.fc = nn.Identity()
#         self.class_classifier = nn.Sequential(
#             nn.Linear(num_features, 512),
#             nn.ReLU(),
#             nn.Linear(512, 256),
#             nn.ReLU(),
#             nn.Linear(256, num_classes)
#         )
#         self.domain_classifier = nn.Sequential(
#             nn.Linear(num_features, 512),
#             nn.ReLU(),
#             nn.Linear(512, 256),
#             nn.ReLU(),
#             nn.Linear(256, 1)
#         )

#     def forward(self, x, alpha=None):
#         if alpha is None:
#             alpha = self.alpha
#         features = self.feature_extractor(x)
#         reverse_features = ReverseLayerF.apply(features, alpha)
#         class_output = self.class_classifier(features)
#         domain_output = self.domain_classifier(reverse_features)
#         return class_output, domain_output


# # Create the model
# num_classes = 31
# model = DANN(num_classes, 1.0).to(device)
# summary(model, input_size=(64, 3, 224, 224))

class Extractor(nn.Module):
    def __init__(self):
        super(Extractor, self).__init__()
        self.feature_extractor = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        num_features = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Identity()

    def forward(self, x):
        features = self.feature_extractor(x)
        return features

class Classifier(nn.Module):
    def __init__(self, num_classes):
        super(Classifier, self).__init__()
        self.class_classifier = nn.Sequential(
            nn.Linear(2048, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, num_classes)
        )

    def forward(self, x):
        class_output = self.class_classifier(x)
        return class_output

class DomainClassifier(nn.Module):
    def __init__(self):
        super(DomainClassifier, self).__init__()
        self.domain_classifier = nn.Sequential(
            nn.Linear(2048, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 2)
        )

    def forward(self, x, alpha):
        reversed_input = ReverseLayerF.apply(x, alpha)
        domain_output = self.domain_classifier(reversed_input)
        return domain_output

# Testing the flow of the model
extractor = Extractor().to(device)
classifier = Classifier(31).to(device)
domain_classifier = DomainClassifier().to(device)

# Testing the flow of the model
x = torch.randn(64, 3, 224, 224).to(device)
features = extractor(x)
print("Feature shape:", features.shape)

class_output = classifier(features)
print("Class output shape:", class_output.shape)

domain_output = domain_classifier(features, 1.0)
print("Domain output shape:", domain_output.shape)

Feature shape: torch.Size([64, 2048])
Class output shape: torch.Size([64, 31])
Domain output shape: torch.Size([64, 2])


#### Defining the training and testing functions

In [None]:
# def train_dann_model(model, source_loader, target_loader, num_epochs, loss_class, loss_domain, optimizer, source_name, target_name, device):
#     model.train()
#     for epoch in range(num_epochs):
#         total_loss = 0.0
#         correct_class = 0
#         correct_domain = 0
#         total_samples = 0

#         source_iter = iter(source_loader)
#         target_iter = iter(target_loader)
#         num_batches = min(len(source_iter), len(target_iter))

#         with tqdm(total=num_batches, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as pbar:
#             for _ in range(num_batches):
#                 source_data, source_labels = next(source_iter)
#                 target_data, _ = next(target_iter)

#                 source_data, source_labels = source_data.to(device), source_labels.to(device)
#                 target_data = target_data.to(device)

#                 optimizer.zero_grad()

#                 # Forward pass
#                 class_output, domain_output = model(source_data)
#                 _, target_domain_output = model(target_data)

#                 # Compute losses
#                 loss_s_label = loss_class(class_output, source_labels)
#                 loss_s_domain = loss_domain(domain_output, torch.zeros_like(domain_output))
#                 loss_t_domain = loss_domain(target_domain_output, torch.ones_like(target_domain_output))

#                 loss = loss_s_label + loss_s_domain + loss_t_domain
#                 total_loss += loss.item()

#                 # Backward pass and optimization
#                 loss.backward()
#                 optimizer.step()

#                 # Compute accuracies
#                 _, predicted_class = torch.max(class_output, 1)
#                 correct_class += (predicted_class == source_labels).sum().item()

#                 _, predicted_domain = torch.max(domain_output, 1)
#                 correct_domain += (predicted_domain == torch.zeros_like(domain_output)).sum().item()

#                 _, predicted_target_domain = torch.max(target_domain_output, 1)
#                 correct_domain += (predicted_target_domain == torch.ones_like(target_domain_output)).sum().item()

#                 total_samples += source_labels.size(0)

#                 pbar.set_postfix(loss=total_loss / total_samples, class_acc=correct_class / total_samples, domain_acc=correct_domain / (2 * total_samples))
#                 pbar.update(1)

#         print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/total_samples:.4f}, Class Accuracy: {correct_class/total_samples:.4f}, Domain Accuracy: {correct_domain/(2*total_samples):.4f}")

def tester(encoder, classifier, discriminator, source_test_loader, target_test_loader, training_mode, device):
    encoder.to(device)
    classifier.to(device)
    
    # Set models to eval mode
    encoder.eval()
    classifier.eval()
    # discriminator.eval()

    if training_mode == 'DANN':
        discriminator.to(device)
        discriminator.eval()
        domain_correct = 0

    source_correct = 0
    target_correct = 0

    for batch_idx, (source_data, target_data) in enumerate(zip(source_test_loader, target_test_loader)):
        p = float(batch_idx) / len(source_test_loader)
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        # Process source and target data
        source_image, source_label = source_image.to(device), source_label.to(device)
        target_image, target_label = target_image.to(device), target_label.to(device)

        # Compute source and target predictions
        source_pred = compute_output(encoder, classifier, source_image, alpha=None)
        target_pred = compute_output(encoder, classifier, target_image, alpha=None)

        # Update correct counts
        source_correct += source_pred.eq(source_label.data.view_as(source_pred)).sum().item()
        target_correct += target_pred.eq(target_label.data.view_as(target_pred)).sum().item()

        if training_mode == 'DANN':
            # Process combined images for domain classification
            combined_image = torch.cat((source_image, target_image), 0)
            domain_labels = torch.cat((torch.zeros(source_label.size(0), dtype=torch.long),
                                       torch.ones(target_label.size(0), dtype=torch.long)), 0).cuda()

            # Compute domain predictions
            domain_pred = compute_output(encoder, discriminator, combined_image, alpha=alpha)
            domain_correct += domain_pred.eq(domain_labels.data.view_as(domain_pred)).sum().item()

    source_dataset_len = len(source_test_loader.dataset)
    target_dataset_len = len(target_test_loader.dataset)

    accuracies = {
        "Source": {
            "correct": source_correct,
            "total": source_dataset_len,
            "accuracy": calculate_accuracy(source_correct, source_dataset_len)
        },
        "Target": {
            "correct": target_correct,
            "total": target_dataset_len,
            "accuracy": calculate_accuracy(target_correct, target_dataset_len)
        }
    }

    if training_mode == 'DANN':
        accuracies["Domain"] = {
            "correct": domain_correct,
            "total": source_dataset_len + target_dataset_len,
            "accuracy": calculate_accuracy(domain_correct, source_dataset_len + target_dataset_len)
        }

    print_accuracy(training_mode, accuracies)

def compute_output(encoder, classifier, images, alpha=None):
    features = encoder(images)
    if isinstance(classifier, DomainClassifier):
        outputs = classifier(features, alpha)  # Domain classifier
    else:
        outputs = classifier(features)  # Category classifier
    preds = outputs.data.max(1, keepdim=True)[1]
    return preds


def calculate_accuracy(correct, total):
    return 100. * correct / total


def print_accuracy(training_mode, accuracies):
    print(f"Test Results on {training_mode}:")
    for key, value in accuracies.items():
        print(f"{key} Accuracy: {value['correct']}/{value['total']} ({value['accuracy']:.2f}%)")

def train_dann(encoder, classifier, discriminator, source_train_loader, target_train_loader, epochs, device):
    print("Training with the DANN adaptation method")

    classifier_criterion = nn.CrossEntropyLoss().to(device)
    discriminator_criterion = nn.CrossEntropyLoss().to(device)

    optimizer = optim.SGD(
        list(encoder.parameters()) +
        list(classifier.parameters()) +
        list(discriminator.parameters()),
        lr=0.01,
        momentum=0.9)

    for epoch in range(epochs):
        print(f"Epoch: {epoch}")
        # Setting the models to train mode
        encoder.train()
        classifier.train()
        discriminator.train()

        start_steps = epoch * len(source_train_loader)
        total_steps = epochs * len(target_train_loader)

        for batch_idx, (source_data, target_data) in enumerate(zip(source_train_loader, target_train_loader)):

            source_image, source_label = source_data
            target_image, target_label = target_data

            p = float(batch_idx + start_steps) / total_steps
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            # source_image = torch.cat((source_image, source_image, source_image), 1)

            source_image, source_label = source_image.to(device), source_label.to(device)
            target_image, target_label = target_image.to(device), target_label.to(device)

            print(f"Source image shape: {source_image.shape}")
            print(f"Target image shape: {target_image.shape}")

            combined_image = torch.cat((source_image, target_image), 0)

            optimizer = optimizer_scheduler(optimizer=optimizer, p=p)
            optimizer.zero_grad()

            combined_feature = encoder(combined_image)
            source_feature = encoder(source_image)

            # 1.Classification loss
            class_pred = classifier(source_feature)
            class_loss = classifier_criterion(class_pred, source_label)

            # 2. Domain loss
            domain_pred = discriminator(combined_feature, alpha)

            domain_source_labels = torch.zeros(source_label.shape[0]).type(torch.LongTensor)
            domain_target_labels = torch.ones(target_label.shape[0]).type(torch.LongTensor)
            domain_combined_label = torch.cat((domain_source_labels, domain_target_labels), 0).cuda()
            domain_loss = discriminator_criterion(domain_pred, domain_combined_label)

            total_loss = class_loss + domain_loss
            total_loss.backward()
            optimizer.step()

            if (batch_idx + 1) % 100 == 0:
                print('[{}/{} ({:.0f}%)]\tTotal Loss: {:.4f}\tClassification Loss: {:.4f}\tDomain Loss: {:.4f}'.format(
                    batch_idx * len(target_image), len(target_train_loader.dataset), 100. * batch_idx / len(target_train_loader), total_loss.item(), class_loss.item(), domain_loss.item()))

In [18]:
amazon_loader = load_data(data_path, 'amazon', 64, 'src')
webcam_loader = load_data(data_path, 'webcam', 64, 'tar')

print("Number of labels in Amazon dataset:", len(amazon_loader.dataset.classes))

# Create the model
encoder = Extractor().to(device)
classifier = Classifier(31).to(device)
discriminator = DomainClassifier().to(device)


# Train the model
# train_dann_model(model, amazon_loader, webcam_loader, 10, loss_class, loss_domain, optimizer, "amazon", "webcam", device)
# train_dann_model(model, amazon_loader, webcam_loader, 5, loss_class, loss_domain, optimizer, "amazon", "webcam", device)
train_dann(encoder, classifier, discriminator, amazon_loader, webcam_loader, 5, device)

Number of labels in Amazon dataset: 31
Training with the DANN adaptation method
Epoch: 0
Source image shape: torch.Size([64, 3, 224, 224])
Target image shape: torch.Size([64, 3, 224, 224])


: 