In [1]:
cd ATML-PA-2/DAN/

/kaggle/working/ATML-PA-2/DAN


In [91]:
import importlib
import torch
import utils
import architecture
import train
import types
importlib.reload(utils)
importlib.reload(architecture)
importlib.reload(train)


from utils.dataset import FeatureTensorDataset, OfficeHomeDataset, _truncate_resnet_from
from torchvision.models import ResNet50_Weights
from importlib import reload
import train
from train import train_workflow
import types
from utils.tsne import tsne_plot, get_features
import matplotlib.pyplot as plt
import os
from architecture import resnet_classifier

In [93]:
root_dir = "../datasets/OfficeHomeDataset"
csv_file = f"{root_dir}/ImageInfo.csv"

transform = ResNet50_Weights.IMAGENET1K_V1.transforms()
classes = os.listdir(f"{root_dir}/Art")

In [94]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [99]:
import torch
import numpy as np
from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    accuracy_score
)
import seaborn as sns

def evaluate_ds(ds, model, device, domain_name=None, class_names=None):
    model.eval()
    all_preds = []
    all_labels = []

    loader = torch.utils.data.DataLoader(ds, shuffle=False, batch_size=32)

    with torch.inference_mode():
        for X, Y in loader:
            X, Y = X.to(device), Y.to(device)
            logits = model(X)  
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(Y.cpu())

    y_true = torch.cat(all_labels).numpy()
    y_pred = torch.cat(all_preds).numpy()



    acc = accuracy_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred)
    report = classification_report(y_true, y_pred, target_names=class_names, digits=4)

    if domain_name:
        print(f"\n=== Evaluation for {domain_name} Domain ===")
    print(f"Accuracy: {acc*100:.2f}%")
    print("Confusion Matrix:\n")
    sns.heatmap(cm, cmap="Blues")
    plt.show()
    print("Classification Report:\n", report)

    return {
        'domain': domain_name,
        'accuracy': acc,
        'confusion_matrix': cm,
        'classification_report': report,
        'y_true': y_true,
        'y_pred': y_pred
    }


In [89]:
import matplotlib.pyplot as plt

def analysis(source, target, config, cache_source=None, cache_target=None):
    
    model = resnet_classifier(num_classes=len(classes))

    if cache_source is None:
        source_ds = OfficeHomeDataset(root_dir, csv_file, source, transform)
        cache_source = FeatureTensorDataset(source_ds, model, layer_name='layer3', device=device, batch_size=64)

    if cache_target is None:
        target_ds = OfficeHomeDataset(root_dir, csv_file, target, transform)
        cache_target = FeatureTensorDataset(target_ds, model, layer_name='layer3', device=device, batch_size=64)

    truncated_model = _truncate_resnet_from(model, 'layer3')

    def _forward_impl(self, x):
        f4 = self.layer4(x)
        x = torch.flatten(self.avgpool(f4), 1)
        x = self.fc(x)
        return x, [x, f4]

    truncated_model.forward = types.MethodType(_forward_impl, truncated_model)

    cache_source_train, cache_source_test = torch.utils.data.random_split(cache_source, [0.8,0.2])
    torch.manual_seed(100)
    cache_target_train, cache_target_test = torch.utils.data.random_split(cache_target, [0.8,0.2])

    
    train_workflow(truncated_model, cache_source_train, cache_target_train, cache_target_test, config, device)

    def _forward_impl_2(self, x):
        f4 = self.layer4(x)
        x = torch.flatten(self.avgpool(f4), 1)
        x = self.fc(x)
        return x
    truncated_model.forward = types.MethodType(_forward_impl_2, truncated_model)
    
    print(device)
    print("Evaluating on target domain test set")
    evaluate_ds(cache_target_test, truncated_model, device, target[0])

    print("Evaluating on source domain test set")
    evaluate_ds(cache_target_test, truncated_model, device, source)

    print("Evaluating on source domain train set")
    evaluate_ds(cache_source_train, truncated_model, device, source)

    latents_s, _ = get_features(truncated_model[:-1], cache_source_test, batchsize=32, device=device)
    latents_t, _ = get_features(truncated_model[:-1], cache_target_test, batchsize=32, device=device)

    if latents_s.dim() > 2:
        latents_s = latents_s.flatten(start_dim=1)
    if latents_t.dim() > 2:
        latents_t = latents_t.flatten(start_dim=1)

    latents = torch.cat([latents_s, latents_t], dim=0)
    domain_labels = torch.cat([
        torch.zeros(latents_s.shape[0], dtype=torch.long),   # 0 = source
        torch.ones(latents_t.shape[0], dtype=torch.long)     # 1 = target
    ])

    fig, ax = plt.subplots(figsize=(8, 6))
    tsne_plot(latents, domain_labels, classes=["Source", "Target"], ax=ax, fig=fig, perplexity=30)
    plt.title("t-SNE Domain Clustering: Source vs Target")
    plt.show()

In [96]:
sources = ["Art", "Real World", "Product"]
target = ["Clipart"]
source_ds = OfficeHomeDataset(root_dir, csv_file, sources, transform)
target_ds = OfficeHomeDataset(root_dir, csv_file, target, transform)

In [97]:
model = resnet_classifier(num_classes=len(classes))

In [10]:
cache_source = FeatureTensorDataset(source_ds, model, layer_name='layer3', device=device, batch_size=32)
cache_target = FeatureTensorDataset(target_ds, model, layer_name='layer3', device=device, batch_size=32)

Precomputing up to layer3: 100%|██████████| 176/176 [03:57<00:00,  1.35s/it]
Precomputing up to layer3: 100%|██████████| 69/69 [00:55<00:00,  1.24it/s]


In [102]:
config = {'epochs': 20, 
          'lr':1e-3, 
          'batch_size':48, 
          'weight_decay':1e-4, 
          'sigmas':[2e-16,2e-12,2e-8,2e-4,2e+0,2e+4,2e+8,2e+12,2e+16], 
          'scale':128.0
          }
analysis(source=["Art", "Real World", "Product"], target=["Clipart"], config=config, cache_source=cache_source, cache_target=cache_target)

Training:   0%|          | 0/20 [00:00<?, ?it/s]


Epoch 1/20

Train loss: 0.79825 | Supervised: 0.79636 | MK-MMD: 0.00188 | Source train acc: 20.51% | Target test acc=42.27%



Training:   5%|▌         | 1/20 [00:09<03:00,  9.48s/it]


Epoch 2/20

Train loss: 0.31559 | Supervised: 0.31554 | MK-MMD: 0.00005 | Source train acc: 30.57% | Target test acc=43.30%



Training:  10%|█         | 2/20 [00:18<02:50,  9.50s/it]


Epoch 3/20

Train loss: 0.21284 | Supervised: 0.21279 | MK-MMD: 0.00006 | Source train acc: 32.95% | Target test acc=45.59%



Training:  15%|█▌        | 3/20 [00:28<02:41,  9.53s/it]


Epoch 4/20

Train loss: 0.17103 | Supervised: 0.16970 | MK-MMD: 0.00133 | Source train acc: 34.03% | Target test acc=42.50%



Training:  20%|██        | 4/20 [00:38<02:32,  9.52s/it]


Epoch 5/20

Train loss: 0.12661 | Supervised: 0.12658 | MK-MMD: 0.00003 | Source train acc: 35.34% | Target test acc=43.64%



Training:  25%|██▌       | 5/20 [00:47<02:22,  9.50s/it]


Epoch 6/20

Train loss: 0.09285 | Supervised: 0.09283 | MK-MMD: 0.00001 | Source train acc: 36.08% | Target test acc=44.56%



Training:  30%|███       | 6/20 [00:57<02:13,  9.51s/it]


Epoch 7/20

Train loss: 0.06996 | Supervised: 0.06996 | MK-MMD: 0.00000 | Source train acc: 36.83% | Target test acc=44.56%



Training:  35%|███▌      | 7/20 [01:06<02:03,  9.53s/it]


Epoch 8/20

Train loss: 0.05297 | Supervised: 0.05296 | MK-MMD: 0.00001 | Source train acc: 37.38% | Target test acc=46.16%



Training:  40%|████      | 8/20 [01:16<01:54,  9.51s/it]


Epoch 9/20

Train loss: 0.03866 | Supervised: 0.03865 | MK-MMD: 0.00001 | Source train acc: 37.67% | Target test acc=44.90%



Training:  45%|████▌     | 9/20 [01:25<01:44,  9.51s/it]


Epoch 10/20

Train loss: 0.03019 | Supervised: 0.03018 | MK-MMD: 0.00000 | Source train acc: 37.96% | Target test acc=45.25%



Training:  50%|█████     | 10/20 [01:35<01:35,  9.51s/it]


Epoch 11/20

Train loss: 0.02440 | Supervised: 0.02440 | MK-MMD: 0.00000 | Source train acc: 38.09% | Target test acc=45.93%



Training:  55%|█████▌    | 11/20 [01:44<01:25,  9.54s/it]


Epoch 12/20

Train loss: 0.02211 | Supervised: 0.02210 | MK-MMD: 0.00000 | Source train acc: 38.08% | Target test acc=46.39%



Training:  60%|██████    | 12/20 [01:54<01:16,  9.53s/it]


Epoch 13/20


Training:  60%|██████    | 12/20 [02:03<01:22, 10.29s/it]


KeyboardInterrupt: 

NameError: name 'cache_source_train' is not defined