## Imports

In [None]:
!pip install "deeplake<4"

In [None]:
import deeplake
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.data import ConcatDataset
import torch
import numpy as np
import matplotlib.pyplot as plt
import random

In [None]:
from utils.dataset import MEAN, STD, DeepLakeWrapper

### Loading Dataset using DeepLake

In [None]:
ds = deeplake.load("hub://activeloop/office-home-domain-adaptation")

domain_map = {0: "RealWorld", 1: "Product", 2: "Art", 3: "Clipart"} # For reference

def filter_by_domain(ds, domain_id):
    """Return a Deep Lake subset for a specific domain."""
    return ds.filter(lambda x: int(x['domain_categories'].numpy()[0]) == domain_id)
    
art_ds = filter_by_domain(ds, 2)
clipart_ds = filter_by_domain(ds, 3)
product_ds = filter_by_domain(ds, 1)
real_ds = filter_by_domain(ds, 0)

In [None]:
art_torch = DeepLakeWrapper(art_ds, domain_label=2)
product_torch = DeepLakeWrapper(product_ds, domain_label=1)
real_torch = DeepLakeWrapper(real_ds, domain_label=0)
clipart_torch = DeepLakeWrapper(clipart_ds, domain_label=3)

In [None]:
source_ds = ConcatDataset([art_torch, product_torch, real_torch])
target_ds = clipart_torch

In [None]:
source_loader = DataLoader(source_ds, batch_size=32, shuffle=True, num_workers=2)
target_loader = DataLoader(target_ds, batch_size=32, shuffle=True, num_workers=2)

In [None]:
def denormalize_batch(imgs_tensor, mean=MEAN, std=STD):
    """Imgs: Tensor (N, C, H, W) normalized. Returns uint8 numpy (N, H, W, C)."""
    imgs = imgs_tensor.detach().cpu().clone()
    # imgs * std + mean  (per-channel)
    for t, m, s in zip(imgs, mean, std):
        t.mul_(s).add_(m)
    imgs = imgs.permute(0, 2, 3, 1).numpy()  # N,H,W,C with floats in [0,1] (roughly)
    imgs = (imgs * 255.0).clip(0, 255).astype(np.uint8)
    return imgs

In [None]:
def show_images(imgs, labels=None, domains=None, title="Batch", max_images=8, class_map=None, domain_map=None):
    imgs_np = denormalize_batch(imgs)  # returns floats [0,1]
    n = min(max_images, len(imgs_np))
    cols = min(8, n)
    rows = (n + cols - 1) // cols

    plt.figure(figsize=(cols * 3, rows * 3))
    for i in range(n):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(imgs_np[i])
        lbl = ""
        if labels is not None:
            lab = labels[i].item() if hasattr(labels[i], "item") else int(labels[i])
            if class_map is not None:
                lbl += f"{class_map.get(lab, lab)}"
            else:
                lbl += f"Class: {lab}"
        if domains is not None:
            d = domains[i].item() if hasattr(domains[i], "item") else int(domains[i])
            dname = domain_map.get(d, d) if domain_map is not None else d
            lbl += f"\nDomain: {dname}"
        plt.title(lbl, fontsize=9)
        plt.axis("off")
    plt.suptitle(title)
    plt.show()


In [None]:
imgs, labels, domains = next(iter(source_loader))
t_imgs, t_labels, t_domains = next(iter(target_loader))

# show source images
show_images(imgs, labels, domains, title="Source batch (Art + Product + Real)", max_images=8, domain_map=domain_map)
# show target images
show_images(t_imgs, t_labels, t_domains, title="Target Batch (Clipart)", max_images=8, domain_map=domain_map)

### Loading Repo

!git clone -b talib-1 http://github.com/Zapy67/ATML-PA-2

In [None]:
!git pull http://github.com/Zapy67/ATML-PA-2 talib-1

In [None]:
%cd ATML-PA-2

In [None]:
!ls

## DANN Training