In [None]:
from utils.dataset import FeatureTensorDataset, OfficeHomeDataset, _truncate_resnet_from
from torchvision.models import ResNet50_Weights 
from importlib import reload
import train
from train import train_workflow
import types
from utils.tsne import tsne_plot, get_features
from utils.evaluation import evaluate_ds
import matplotlib.pyplot as plt
import os
from architecture import resnet_classifier
import torch

In [None]:

root_dir = "../datasets/OfficeHomeDataset"
csv_file = f"{root_dir}/ImageInfo.csv"

transform = ResNet50_Weights.IMAGENET1K_V1.transforms()
classes = len(os.listdir(f"{root_dir}/Art"))

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def make_cache(domains, layer_name):
    model = resnet_classifier(len(classes))
    if isinstance(domains, list):
        caches = []
        for domain in domains:
            ds = OfficeHomeDataset(root_dir, csv_file, domain, transform)
            caches.append(FeatureTensorDataset(ds, model, layer_name, device=device, batch_size=64))
        return caches
    else:
        ds = OfficeHomeDataset(root_dir, csv_file, domains, transform)
        return FeatureTensorDataset(ds, model, layer_name, device=device, batch_size=64)


In [None]:
def analysis(sources, target, config, caches_source=None, cache_target=None):
    
    if caches_source is None:
        caches_source = make_cache(sources, 'layer3')
        

    if cache_target is None:
        cache_target = make_cache(target, 'layer3')

    model = resnet_classifier(num_classes=len(classes))
    truncated_model = _truncate_resnet_from(model, 'layer3')

    def _forward_impl(self, x):
        f4 = self.layer4(x)
        x = torch.flatten(self.avgpool(f4), 1)
        x = self.fc(x)
        return x

    import types
    truncated_model.forward = types.MethodType(_forward_impl, truncated_model)

    train_workflow(truncated_model, caches_source, cache_target, config, device)
    evaluate_ds(cache_target, truncated_model, device, target[0], None)


In [None]:
sources = ["Art", "Real World", "Product"]
target = "Clipart"

caches_source = make_cache(sources)
cache_target = make_cache(target)

In [None]:
config = {'epochs':2, 
          'lr':1e-3, 
          'batch_size':32, 
          'weight_decay':1e-5, 
          'phi':1.0,
          }
analysis(sources, target, config, caches_source, cache_target)