# Downloadin' the Dataset from Kaggle

In [1]:
import os
import sys
from tempfile import NamedTemporaryFile
from urllib.request import urlopen
from urllib.parse import unquote, urlparse
from urllib.error import HTTPError
from zipfile import ZipFile
import tarfile
import shutil

CHUNK_SIZE = 40960
DATA_SOURCE_MAPPING = 'home-office-dataset:https%3A%2F%2Fstorage.googleapis.com%2Fkaggle-data-sets%2F2609958%2F4458388%2Fbundle%2Farchive.zip%3FX-Goog-Algorithm%3DGOOG4-RSA-SHA256%26X-Goog-Credential%3Dgcp-kaggle-com%2540kaggle-161607.iam.gserviceaccount.com%252F20240929%252Fauto%252Fstorage%252Fgoog4_request%26X-Goog-Date%3D20240929T182135Z%26X-Goog-Expires%3D259200%26X-Goog-SignedHeaders%3Dhost%26X-Goog-Signature%3D6335ea9becd6a8fb3ab0ee0f2a593cf71061fe87419776bd67aded05e4a070c929bd28c43a5baa78975f162afc2e1962d2957365a93b663e4983d784677fb7d2c002a4749470c319251673796792bc26d5f46400067981cce63a149b1ff4e1fdecba4c1442d65f04177410e97fbc0fbe2dd89217eeda0cd63a60446a971513b346b08829770702afc0bb01f817fb9fe3a06cf13ba9f1422478d90e1df6f4a44494a013a581eb375f419753f3c006cd2538214b4b65ab5fae37f0f7a0a649eba61920fe96573f68eb29d36102f3a194508fd9b8d52c69be90697ed6630e7233b8994ab4761c899cf23a0598b9156c7f5949e6421f297bf3d2888b8a5629a4882b'

KAGGLE_INPUT_PATH='/kaggle/input'
KAGGLE_WORKING_PATH='/kaggle/working'
KAGGLE_SYMLINK='kaggle'

!umount /kaggle/input/ 2> /dev/null
shutil.rmtree('/kaggle/input', ignore_errors=True)
os.makedirs(KAGGLE_INPUT_PATH, 0o777, exist_ok=True)
os.makedirs(KAGGLE_WORKING_PATH, 0o777, exist_ok=True)

try:
  os.symlink(KAGGLE_INPUT_PATH, os.path.join("..", 'input'), target_is_directory=True)
except FileExistsError:
  pass
try:
  os.symlink(KAGGLE_WORKING_PATH, os.path.join("..", 'working'), target_is_directory=True)
except FileExistsError:
  pass

for data_source_mapping in DATA_SOURCE_MAPPING.split(','):
    directory, download_url_encoded = data_source_mapping.split(':')
    download_url = unquote(download_url_encoded)
    filename = urlparse(download_url).path
    destination_path = os.path.join(KAGGLE_INPUT_PATH, directory)
    try:
        with urlopen(download_url) as fileres, NamedTemporaryFile() as tfile:
            total_length = fileres.headers['content-length']
            print(f'Downloading {directory}, {total_length} bytes compressed')
            dl = 0
            data = fileres.read(CHUNK_SIZE)
            while len(data) > 0:
                dl += len(data)
                tfile.write(data)
                done = int(50 * dl / int(total_length))
                sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {dl} bytes downloaded")
                sys.stdout.flush()
                data = fileres.read(CHUNK_SIZE)
            if filename.endswith('.zip'):
              with ZipFile(tfile) as zfile:
                zfile.extractall(destination_path)
            else:
              with tarfile.open(tfile.name) as tarfile:
                tarfile.extractall(destination_path)
            print(f'\nDownloaded and uncompressed: {directory}')
    except HTTPError as e:
        print(f'Failed to load (likely expired) {download_url} to path {destination_path}')
        continue
    except OSError as e:
        print(f'Failed to load {download_url} to path {destination_path}')
        continue

print('Data source import complete.')


Downloading home-office-dataset, 1030224236 bytes compressed
Downloaded and uncompressed: home-office-dataset
Data source import complete.


# Necessary Imports

In [2]:
!pip install tensorflow



In [3]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

# Generator Block

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),  # Input: 3 channels (RGB), 64x64 -> 32x32
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),  # 32x32 -> 16x16
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),  # 16x16 -> 8x8
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),  # 8x8 -> 4x4
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, latent_dim)
        )

    def forward(self, x):
        return self.network(x)

# Discriminator Block

In [5]:
class Discriminator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.network(z)

# Autoencoder (Debo Bhai)

In [6]:
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=100):
        super(Autoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, latent_dim)
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256 * 4 * 4),
            nn.ReLU(),
            nn.Unflatten(1, (256, 4, 4)),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Tanh()  # Output between [-1, 1]
        )

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)

# Loss functions

In [7]:
adversarial_loss = nn.BCELoss()  # Binary Cross-Entropy for adversarial loss
reconstruction_loss = nn.MSELoss()  # MSE for reconstruction

# Model Init

In [8]:
latent_dim = 100
generator = Generator(latent_dim)
discriminator = Discriminator(latent_dim)
autoencoder = Autoencoder(latent_dim)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=1e-3)
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-3)
optimizer_AE = optim.Adam(autoencoder.parameters(), lr=1e-3)

# Helper Func (to compute accuracy)

In [9]:
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim

def compute_accuracy(preds, labels):
    predicted = (preds > 0.5).float()
    correct = (predicted == labels).float().sum()
    accuracy = correct / labels.size(0)
    return accuracy

# Domain Classifier

In [10]:
class DomainClassifier(nn.Module):
    def __init__(self, latent_dim=100):
        super(DomainClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.classifier(z)

# DA (Hank?) Trainer

In [11]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

data_dir_source = '/kaggle/input/home-office-dataset/OfficeHomeDataset_10072016/Art'
data_dir_target = '/kaggle/input/home-office-dataset/OfficeHomeDataset_10072016/Clipart'

dataset_source = datasets.ImageFolder(root=data_dir_source, transform=transform)
dataset_target = datasets.ImageFolder(root=data_dir_target, transform=transform)

# DataLoaders
dataloader_source = DataLoader(dataset_source, batch_size=64, shuffle=True)
dataloader_target = DataLoader(dataset_target, batch_size=64, shuffle=True)

In [None]:
def train_gan_domain_adaptation(dataloader_source, dataloader_target, n_epochs=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    generator.to(device)
    discriminator.to(device)
    autoencoder.to(device)
    domain_classifier = DomainClassifier(latent_dim).to(device)  # init domain classifier

    # optimizers
    optimizer_G = optim.Adam(generator.parameters(), lr=1e-3)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-3)
    optimizer_AE = optim.Adam(autoencoder.parameters(), lr=1e-3)
    optimizer_DC = optim.Adam(domain_classifier.parameters(), lr=1e-3)  # optimizer for domain classifier

    for epoch in range(n_epochs):
        total_dc_loss = 0.0
        total_d_loss = 0.0
        total_g_loss = 0.0
        total_ae_loss = 0.0
        total_dc_accuracy = 0.0
        total_batches = 0

        loop = tqdm(zip(dataloader_source, dataloader_target), total=min(len(dataloader_source), len(dataloader_target)))
        loop.set_description(f"Epoch [{epoch+1}/{n_epochs}]")

        for (source_images, _), (target_images, _) in loop:
            source_images = source_images.to(device)
            target_images = target_images.to(device)

            # for batch size consistency
            min_batch_size = min(source_images.size(0), target_images.size(0))
            source_images = source_images[:min_batch_size]
            target_images = target_images[:min_batch_size]

            # =======================
            # 1. train the discriminator

            optimizer_D.zero_grad()

            target_latent_fake = generator(target_images)
            source_latent_real = autoencoder.encode(source_images)  # real latent vectors from source

            real_labels = torch.ones(min_batch_size, 1).to(device)
            fake_labels = torch.zeros(min_batch_size, 1).to(device)

            # discriminator loss on real and fake
            d_loss_real = adversarial_loss(discriminator(source_latent_real), real_labels)
            d_loss_fake = adversarial_loss(discriminator(target_latent_fake.detach()), fake_labels)
            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            optimizer_D.step()

            # discriminator Accuracy (real vs fake)
            real_accuracy = compute_accuracy(discriminator(source_latent_real), real_labels)
            fake_accuracy = compute_accuracy(discriminator(target_latent_fake.detach()), fake_labels)
            d_accuracy = (real_accuracy + fake_accuracy) / 2

            # =======================
            # 2. train the Generator

            optimizer_G.zero_grad()

            g_loss = adversarial_loss(discriminator(target_latent_fake), real_labels)
            g_loss.backward()
            optimizer_G.step()

            # =======================
            # 3. train the autoencoder (for source domain reconstruction)

            optimizer_AE.zero_grad()

            source_reconstructed = autoencoder(source_images)  # reconstruct source images
            ae_loss = reconstruction_loss(source_reconstructed, source_images)  # reconstruction loss
            ae_loss.backward()
            optimizer_AE.step()

            # =======================
            # 4. train the Domain Classifier

            optimizer_DC.zero_grad()

            # Domain classification: 0 for source, 1 for target
            domain_labels_source = torch.zeros(min_batch_size, 1).to(device)
            domain_labels_target = torch.ones(min_batch_size, 1).to(device)

            # train domain classifier on both source and target latent vectors
            domain_pred_source = domain_classifier(source_latent_real.detach())
            domain_pred_target = domain_classifier(target_latent_fake.detach())

            # domain classifier loss (BCE)
            dc_loss_source = adversarial_loss(domain_pred_source, domain_labels_source)
            dc_loss_target = adversarial_loss(domain_pred_target, domain_labels_target)
            dc_loss = (dc_loss_source + dc_loss_target) / 2
            dc_loss.backward()
            optimizer_DC.step()

            # domain classifier accuracy
            dc_accuracy_source = compute_accuracy(domain_pred_source, domain_labels_source)
            dc_accuracy_target = compute_accuracy(domain_pred_target, domain_labels_target)
            dc_accuracy = (dc_accuracy_source + dc_accuracy_target) / 2

            # accumulation of loss and accuracy
            total_dc_loss += dc_loss.item()
            total_d_loss += d_loss.item()
            total_g_loss += g_loss.item()
            total_ae_loss += ae_loss.item()
            total_dc_accuracy += dc_accuracy.item()
            total_batches += 1

            loop.set_postfix(
                D_loss=total_d_loss/total_batches,
                G_loss=total_g_loss/total_batches,
                AE_loss=total_ae_loss/total_batches,
                DC_loss=total_dc_loss/total_batches,
                DC_acc=total_dc_accuracy/total_batches
            )

        print(f"Epoch [{epoch+1}/{n_epochs}] Summary:")
        print(f"  Discriminator Loss: {total_d_loss/total_batches:.4f}")
        print(f"  Generator Loss: {total_g_loss/total_batches:.4f}")
        print(f"  Autoencoder Loss: {total_ae_loss/total_batches:.4f}")
        print(f"  Domain Classifier Loss: {total_dc_loss/total_batches:.4f}")
        print(f"  Domain Classifier Accuracy: {total_dc_accuracy/total_batches:.4f}")

train_gan_domain_adaptation(dataloader_source, dataloader_target, n_epochs=10)

Epoch [1/10]: 100%|██████████| 38/38 [01:17<00:00,  2.04s/it, AE_loss=0.142, DC_acc=0.981, DC_loss=0.198, D_loss=0.0571, G_loss=3.51]


Epoch [1/10] Summary:
  Discriminator Loss: 0.0571
  Generator Loss: 3.5098
  Autoencoder Loss: 0.1420
  Domain Classifier Loss: 0.1980
  Domain Classifier Accuracy: 0.9811


Epoch [2/10]: 100%|██████████| 38/38 [01:19<00:00,  2.09s/it, AE_loss=0.0919, DC_acc=1, DC_loss=0.0223, D_loss=0.02, G_loss=4.15]


Epoch [2/10] Summary:
  Discriminator Loss: 0.0200
  Generator Loss: 4.1522
  Autoencoder Loss: 0.0919
  Domain Classifier Loss: 0.0223
  Domain Classifier Accuracy: 0.9996


Epoch [3/10]: 100%|██████████| 38/38 [01:25<00:00,  2.25s/it, AE_loss=0.0876, DC_acc=0.999, DC_loss=0.0323, D_loss=0.0326, G_loss=4]


Epoch [3/10] Summary:
  Discriminator Loss: 0.0326
  Generator Loss: 3.9963
  Autoencoder Loss: 0.0876
  Domain Classifier Loss: 0.0323
  Domain Classifier Accuracy: 0.9985


Epoch [4/10]: 100%|██████████| 38/38 [01:21<00:00,  2.14s/it, AE_loss=0.0851, DC_acc=0.999, DC_loss=0.0121, D_loss=0.0138, G_loss=4.61]


Epoch [4/10] Summary:
  Discriminator Loss: 0.0138
  Generator Loss: 4.6070
  Autoencoder Loss: 0.0851
  Domain Classifier Loss: 0.0121
  Domain Classifier Accuracy: 0.9994


Epoch [5/10]: 100%|██████████| 38/38 [01:18<00:00,  2.08s/it, AE_loss=0.0832, DC_acc=0.999, DC_loss=0.0121, D_loss=0.0146, G_loss=4.4]


Epoch [5/10] Summary:
  Discriminator Loss: 0.0146
  Generator Loss: 4.3976
  Autoencoder Loss: 0.0832
  Domain Classifier Loss: 0.0121
  Domain Classifier Accuracy: 0.9990


Epoch [6/10]: 100%|██████████| 38/38 [01:22<00:00,  2.17s/it, AE_loss=0.0809, DC_acc=0.999, DC_loss=0.0152, D_loss=0.0131, G_loss=4.37]


Epoch [6/10] Summary:
  Discriminator Loss: 0.0131
  Generator Loss: 4.3682
  Autoencoder Loss: 0.0809
  Domain Classifier Loss: 0.0152
  Domain Classifier Accuracy: 0.9992


Epoch [7/10]: 100%|██████████| 38/38 [01:23<00:00,  2.19s/it, AE_loss=0.0793, DC_acc=0.999, DC_loss=0.0143, D_loss=0.0122, G_loss=4.2]


Epoch [7/10] Summary:
  Discriminator Loss: 0.0122
  Generator Loss: 4.1980
  Autoencoder Loss: 0.0793
  Domain Classifier Loss: 0.0143
  Domain Classifier Accuracy: 0.9994


Epoch [8/10]: 100%|██████████| 38/38 [01:18<00:00,  2.07s/it, AE_loss=0.078, DC_acc=1, DC_loss=0.00685, D_loss=0.00653, G_loss=4.99]


Epoch [8/10] Summary:
  Discriminator Loss: 0.0065
  Generator Loss: 4.9872
  Autoencoder Loss: 0.0780
  Domain Classifier Loss: 0.0068
  Domain Classifier Accuracy: 0.9998


Epoch [9/10]: 100%|██████████| 38/38 [01:21<00:00,  2.14s/it, AE_loss=0.0768, DC_acc=0.999, DC_loss=0.00531, D_loss=0.00349, G_loss=5.5]


Epoch [9/10] Summary:
  Discriminator Loss: 0.0035
  Generator Loss: 5.4977
  Autoencoder Loss: 0.0768
  Domain Classifier Loss: 0.0053
  Domain Classifier Accuracy: 0.9994


Epoch [10/10]: 100%|██████████| 38/38 [01:17<00:00,  2.05s/it, AE_loss=0.0751, DC_acc=1, DC_loss=0.00451, D_loss=0.00206, G_loss=6.05]

Epoch [10/10] Summary:
  Discriminator Loss: 0.0021
  Generator Loss: 6.0477
  Autoencoder Loss: 0.0751
  Domain Classifier Loss: 0.0045
  Domain Classifier Accuracy: 0.9996





# V2.0 Starts

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim

class SourceClassifier(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10):
        super(SourceClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, z):
        return self.classifier(z)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

source_classifier = SourceClassifier(latent_dim=100, num_classes=10).to(device)
source_criterion = nn.CrossEntropyLoss()
optimizer_SC = optim.Adam(source_classifier.parameters(), lr=1e-3)

In [15]:
import torch

def train_source_classifier(dataloader_source, n_epochs=10):
    autoencoder.eval()
    for epoch in range(n_epochs):
        total_loss = 0
        for source_images, source_labels in dataloader_source:
            source_images = source_images.to(device)
            source_labels = source_labels.to(device)

            source_latent = autoencoder.encode(source_images)
            output = source_classifier(source_latent)
            source_labels = torch.clamp(source_labels, 0, source_classifier.classifier[-1].out_features - 1)

            loss = source_criterion(output, source_labels)
            optimizer_SC.zero_grad()
            loss.backward()
            optimizer_SC.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{n_epochs}], Loss: {total_loss/len(dataloader_source):.4f}")

train_source_classifier(dataloader_source, n_epochs=10)

Epoch [1/10], Loss: 2.2142
Epoch [2/10], Loss: 1.2422
Epoch [3/10], Loss: 0.8735
Epoch [4/10], Loss: 0.8630
Epoch [5/10], Loss: 0.8627
Epoch [6/10], Loss: 0.8629
Epoch [7/10], Loss: 0.8636
Epoch [8/10], Loss: 0.8623
Epoch [9/10], Loss: 0.8626
Epoch [10/10], Loss: 0.8624


In [18]:
import torch
import torch.nn as nn
import torch.optim as optim

class DomainClassifier(nn.Module):
    def __init__(self, latent_dim=100):
        super(DomainClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.classifier(z)

latent_dim = 100
domain_classifier = DomainClassifier(latent_dim).to(device)

adversarial_loss = nn.BCELoss()

In [19]:
def train_domain_adaptation(dataloader_source, dataloader_target, n_epochs=10):
    generator.to(device)
    domain_classifier.to(device)

    optimizer_G = optim.Adam(generator.parameters(), lr=1e-3)
    optimizer_DC = optim.Adam(domain_classifier.parameters(), lr=1e-3)

    for epoch in range(n_epochs):
        for (source_images, _), (target_images, _) in zip(dataloader_source, dataloader_target):
            source_images = source_images.to(device)
            target_images = target_images.to(device)

            min_batch_size = min(source_images.size(0), target_images.size(0))
            source_images = source_images[:min_batch_size]
            target_images = target_images[:min_batch_size]

            source_latent = autoencoder.encode(source_images)
            target_latent = generator(target_images)

            # =======================
            # 1. Train the Domain Classifier
            optimizer_DC.zero_grad()

            domain_labels_source = torch.zeros(min_batch_size, 1).to(device)  # Source: 0
            domain_labels_target = torch.ones(min_batch_size, 1).to(device)   # Target: 1

            domain_pred_source = domain_classifier(source_latent.detach())
            domain_pred_target = domain_classifier(target_latent.detach())

            dc_loss_source = adversarial_loss(domain_pred_source, domain_labels_source)
            dc_loss_target = adversarial_loss(domain_pred_target, domain_labels_target)
            dc_loss = (dc_loss_source + dc_loss_target) / 2

            dc_loss.backward()
            optimizer_DC.step()

            # =======================
            # 2. Train the Generator (to fool Domain Classifier)
            optimizer_G.zero_grad()

            domain_pred_target = domain_classifier(target_latent)
            g_loss = adversarial_loss(domain_pred_target, domain_labels_source)

            g_loss.backward()
            optimizer_G.step()

        print(f"Epoch [{epoch+1}/{n_epochs}], G_loss: {g_loss.item():.4f}, DC_loss: {dc_loss.item():.4f}")

# domain adaptation model
train_domain_adaptation(dataloader_source, dataloader_target, n_epochs=10)

Epoch [1/10], G_loss: 0.6652, DC_loss: 0.6984
Epoch [2/10], G_loss: 0.6735, DC_loss: 0.6996
Epoch [3/10], G_loss: 0.6886, DC_loss: 0.6942
Epoch [4/10], G_loss: 0.6953, DC_loss: 0.6872
Epoch [5/10], G_loss: 0.6918, DC_loss: 0.6905
Epoch [6/10], G_loss: 0.7172, DC_loss: 0.6833
Epoch [7/10], G_loss: 0.7374, DC_loss: 0.6781
Epoch [8/10], G_loss: 0.7051, DC_loss: 0.6814
Epoch [9/10], G_loss: 0.7988, DC_loss: 0.6148
Epoch [10/10], G_loss: 0.9068, DC_loss: 0.6001


In [20]:
def evaluate_on_target(dataloader_target):
    source_classifier.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for target_images, target_labels in dataloader_target:
            target_images = target_images.to(device)
            target_labels = target_labels.to(device)

            target_latent = generator(target_images)
            outputs = source_classifier(target_latent)
            _, predicted = torch.max(outputs, 1)

            total += target_labels.size(0)
            correct += (predicted == target_labels).sum().item()

    print(f"Accuracy on target domain: {100 * correct / total:.2f}%")

evaluate_on_target(dataloader_target)

Accuracy on target domain: 2.27%


In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np

class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, latent_dim)
        )

    def forward(self, x):
        return self.network(x)

class Discriminator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.network(z)

def compute_accuracy(preds, labels):
    predicted = (preds > 0.5).float()
    correct = (predicted == labels).float().sum()
    accuracy = correct / labels.size(0)
    return accuracy.item()

def train_gan_domain_adaptation(dataloader_source, dataloader_target, n_epochs=10):
    generator.to(device)
    discriminator.to(device)

    optimizer_G = optim.Adam(generator.parameters(), lr=1e-3)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-3)
    adversarial_loss = nn.BCELoss()

    for epoch in range(n_epochs):
        total_d_loss = 0.0
        total_g_loss = 0.0
        total_accuracy = 0.0
        total_batches = 0

        loop = tqdm(zip(dataloader_source, dataloader_target), total=min(len(dataloader_source), len(dataloader_target)))
        loop.set_description(f"Epoch [{epoch+1}/{n_epochs}]")

        for (source_images, _), (target_images, _) in loop:
            source_images = source_images.to(device)
            target_images = target_images.to(device)

            min_batch_size = min(source_images.size(0), target_images.size(0))
            source_images = source_images[:min_batch_size]
            target_images = target_images[:min_batch_size]

            # =======================
            # 1. Train the Discriminator
            optimizer_D.zero_grad()

            # Generate latent vectors for the target domain using the Generator
            target_latent_fake = generator(target_images)
            source_latent_real = generator(source_images)

            real_labels = torch.ones(min_batch_size, 1).to(device)
            fake_labels = torch.zeros(min_batch_size, 1).to(device)

            d_loss_real = adversarial_loss(discriminator(source_latent_real), real_labels)
            d_loss_fake = adversarial_loss(discriminator(target_latent_fake.detach()), fake_labels)
            d_loss = (d_loss_real + d_loss_fake) / 2
            d_loss.backward()
            optimizer_D.step()

            real_accuracy = compute_accuracy(discriminator(source_latent_real), real_labels)
            fake_accuracy = compute_accuracy(discriminator(target_latent_fake.detach()), fake_labels)
            d_accuracy = (real_accuracy + fake_accuracy) / 2

            # =======================
            # 2. Train the Generator (to fool the Discriminator)
            optimizer_G.zero_grad()

            g_loss = adversarial_loss(discriminator(target_latent_fake), real_labels)
            g_loss.backward()
            optimizer_G.step()

            total_d_loss += d_loss.item()
            total_g_loss += g_loss.item()
            total_accuracy += d_accuracy
            total_batches += 1

            loop.set_postfix(
                D_loss=total_d_loss/total_batches,
                G_loss=total_g_loss/total_batches,
                D_accuracy=total_accuracy/total_batches
            )

        print(f"Epoch [{epoch+1}/{n_epochs}] Summary:")
        print(f"  Discriminator Loss: {total_d_loss/total_batches:.4f}")
        print(f"  Generator Loss: {total_g_loss/total_batches:.4f}")
        print(f"  Discriminator Accuracy: {total_accuracy/total_batches:.4f}")

latent_dim = 100
generator = Generator(latent_dim=latent_dim)
discriminator = Discriminator(latent_dim=latent_dim)

train_gan_domain_adaptation(dataloader_source, dataloader_target, n_epochs=10)

Epoch [1/10]: 100%|██████████| 38/38 [00:55<00:00,  1.47s/it, D_accuracy=0.509, D_loss=2.89, G_loss=1.18]


Epoch [1/10] Summary:
  Discriminator Loss: 2.8905
  Generator Loss: 1.1805
  Discriminator Accuracy: 0.5093


Epoch [2/10]: 100%|██████████| 38/38 [00:56<00:00,  1.49s/it, D_accuracy=0.5, D_loss=0.693, G_loss=0.68]


Epoch [2/10] Summary:
  Discriminator Loss: 0.6933
  Generator Loss: 0.6803
  Discriminator Accuracy: 0.5000


Epoch [3/10]: 100%|██████████| 38/38 [00:55<00:00,  1.47s/it, D_accuracy=0.501, D_loss=0.693, G_loss=0.693]


Epoch [3/10] Summary:
  Discriminator Loss: 0.6931
  Generator Loss: 0.6931
  Discriminator Accuracy: 0.5015


Epoch [4/10]: 100%|██████████| 38/38 [00:55<00:00,  1.45s/it, D_accuracy=0.503, D_loss=0.693, G_loss=0.694]


Epoch [4/10] Summary:
  Discriminator Loss: 0.6931
  Generator Loss: 0.6935
  Discriminator Accuracy: 0.5031


Epoch [5/10]: 100%|██████████| 38/38 [00:55<00:00,  1.46s/it, D_accuracy=0.501, D_loss=0.693, G_loss=0.694]


Epoch [5/10] Summary:
  Discriminator Loss: 0.6931
  Generator Loss: 0.6938
  Discriminator Accuracy: 0.5014


Epoch [6/10]: 100%|██████████| 38/38 [00:55<00:00,  1.46s/it, D_accuracy=0.5, D_loss=0.693, G_loss=0.694]


Epoch [6/10] Summary:
  Discriminator Loss: 0.6931
  Generator Loss: 0.6943
  Discriminator Accuracy: 0.5002


Epoch [7/10]: 100%|██████████| 38/38 [00:56<00:00,  1.48s/it, D_accuracy=0.503, D_loss=0.693, G_loss=0.694]


Epoch [7/10] Summary:
  Discriminator Loss: 0.6931
  Generator Loss: 0.6939
  Discriminator Accuracy: 0.5029


Epoch [8/10]: 100%|██████████| 38/38 [00:55<00:00,  1.47s/it, D_accuracy=0.5, D_loss=0.693, G_loss=0.694]


Epoch [8/10] Summary:
  Discriminator Loss: 0.6931
  Generator Loss: 0.6942
  Discriminator Accuracy: 0.5004


Epoch [9/10]: 100%|██████████| 38/38 [00:55<00:00,  1.47s/it, D_accuracy=0.501, D_loss=0.693, G_loss=0.694]


Epoch [9/10] Summary:
  Discriminator Loss: 0.6931
  Generator Loss: 0.6944
  Discriminator Accuracy: 0.5009


Epoch [10/10]: 100%|██████████| 38/38 [00:56<00:00,  1.49s/it, D_accuracy=0.5, D_loss=0.693, G_loss=0.695]

Epoch [10/10] Summary:
  Discriminator Loss: 0.6931
  Generator Loss: 0.6948
  Discriminator Accuracy: 0.5000





In [22]:
class SourceClassifier(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10):
        super(SourceClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, z):
        return self.classifier(z)


In [24]:
def train_source_classifier(dataloader_source, n_epochs=10):
    source_classifier.to(device)
    generator.eval()

    classification_loss = nn.CrossEntropyLoss()
    optimizer_SC = optim.Adam(source_classifier.parameters(), lr=1e-3)

    for epoch in range(n_epochs):
        total_loss = 0
        correct = 0
        total = 0

        loop = tqdm(dataloader_source, total=len(dataloader_source))
        loop.set_description(f"Train Source Classifier - Epoch [{epoch+1}/{n_epochs}]")

        for source_images, source_labels in loop:
            source_images = source_images.to(device)
            source_labels = source_labels.to(device)

            source_labels = torch.clamp(source_labels, 0, 9)

            source_latent = generator(source_images)

            outputs = source_classifier(source_latent)
            loss = classification_loss(outputs, source_labels)

            optimizer_SC.zero_grad()
            loss.backward()
            optimizer_SC.step()

            _, predicted = torch.max(outputs.data, 1)
            total += source_labels.size(0)
            correct += (predicted == source_labels).sum().item()

            total_loss += loss.item()

            loop.set_postfix(loss=total_loss/total, accuracy=100*correct/total)

    print(f"Final Accuracy on Source Domain: {100*correct/total:.2f}%")

source_classifier = SourceClassifier(latent_dim=100, num_classes=10).to(device)
train_source_classifier(dataloader_source, n_epochs=10)

Train Source Classifier - Epoch [1/10]: 100%|██████████| 38/38 [00:33<00:00,  1.12it/s, accuracy=77.6, loss=0.0206]
Train Source Classifier - Epoch [2/10]: 100%|██████████| 38/38 [00:34<00:00,  1.10it/s, accuracy=81.5, loss=0.0137]
Train Source Classifier - Epoch [3/10]: 100%|██████████| 38/38 [00:33<00:00,  1.13it/s, accuracy=81.5, loss=0.0136]
Train Source Classifier - Epoch [4/10]: 100%|██████████| 38/38 [00:33<00:00,  1.12it/s, accuracy=81.5, loss=0.0136]
Train Source Classifier - Epoch [5/10]: 100%|██████████| 38/38 [00:33<00:00,  1.13it/s, accuracy=81.5, loss=0.0136]
Train Source Classifier - Epoch [6/10]: 100%|██████████| 38/38 [00:34<00:00,  1.12it/s, accuracy=81.5, loss=0.0136]
Train Source Classifier - Epoch [7/10]: 100%|██████████| 38/38 [00:33<00:00,  1.13it/s, accuracy=81.5, loss=0.0136]
Train Source Classifier - Epoch [8/10]: 100%|██████████| 38/38 [00:34<00:00,  1.11it/s, accuracy=81.5, loss=0.0136]
Train Source Classifier - Epoch [9/10]: 100%|██████████| 38/38 [00:33<00

Final Accuracy on Source Domain: 81.50%





In [25]:
def evaluate_on_target(dataloader_target):
    source_classifier.eval()
    generator.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        loop = tqdm(dataloader_target, total=len(dataloader_target))
        loop.set_description(f"Evaluating on Target Domain")

        for target_images, target_labels in loop:
            target_images = target_images.to(device)
            target_labels = target_labels.to(device)

            target_latent = generator(target_images)

            outputs = source_classifier(target_latent)

            _, predicted = torch.max(outputs.data, 1)
            total += target_labels.size(0)
            correct += (predicted == target_labels).sum().item()

            loop.set_postfix(accuracy=100*correct/total)

    print(f"Final Accuracy on Target Domain: {100*correct/total:.2f}%")

evaluate_on_target(dataloader_target)

Evaluating on Target Domain: 100%|██████████| 69/69 [00:24<00:00,  2.84it/s, accuracy=2.27]

Final Accuracy on Target Domain: 2.27%



