In [3]:
import os
from pathlib import Path
home = os.getcwd()
current = home
while 'data' not in os.listdir(current):
    current = Path(current).parent
DATA_FOLDER = os.path.join(current, 'data')

In [4]:
amazon_dir = os.path.join(DATA_FOLDER, 'amazon', 'amazon')
amazon_train = os.path.join(DATA_FOLDER, 'amazon', 'amazon_splitted', 'train')
amazon_val = os.path.join(DATA_FOLDER, 'amazon', 'amazon_splitted', 'val')

In [6]:
import torch
from torch import nn
from typing import Tuple, List, Union
from torch.utils.data import DataLoader
from copy import deepcopy
from mypt.distances.MMD import GaussianMMD
from transferable_alexnet import TransferAlexNet


In [None]:
def calculate_loss(source_logits: torch.Tensor, 
                   source_labels: torch.Tensor,
                   source_features: List[torch.Tensor], 
                   target_features: List[torch.Tensor], 
                   loss_coefficient: float,
                   reduction: str = 'mean', 
                   sigma: float = 0.5) -> Tuple[torch.tensor, torch.Tensor, List[torch.Tensor]]:
    
    if len(source_features) != len(target_features):
        raise ValueError(f"Please make sure the number of features is the same acoss target and source domains")

    # first part is the cross entropy loss as usual
    cls_loss = nn.CrossEntropyLoss(reduction=reduction).forward(source_logits, source_labels)
    
    # this is the main object used for backpropagation
    domain_confusion_loss = None    

    # a list to store and track the different similarities along with training
    distribution_losses = []

    for fs, ft in zip(source_features, target_features):
        loss_obj = GaussianMMD(sigma=sigma).forward(x=fs, y=ft)

        distribution_losses.append(loss_obj.detach().cpu())

        if domain_confusion_loss is None:
            domain_confusion_loss = deepcopy(loss_obj)
        else: 
            domain_confusion_loss += loss_obj

    return cls_loss + loss_coefficient * domain_confusion_loss, cls_loss, distribution_losses

In [None]:
from torchvision.datasets import ImageFolder
import torchvision.transforms as tr
import mypt.utilities.directories_and_files as dirf
from mypt.data.dataloaders.standard_dataloaders import initialize_train_dataloader

def get_dataloader(root: Union[str, Path],
                   image_transform:tr, 
                   batch_size: int,
                   seed: int, 
                   num_workers: int = 2) -> DataLoader:
    
    dirf.process_path(root, dir_ok=True, 
                      file_ok=False, 
                      condition=dirf.image_dataset_directory, 
                      error_message=dirf.image_dataset_directory_error(root))
    
    # initialize the dataloader
    ds = ImageFolder(root, transform=image_transform)
    return initialize_train_dataloader(ds, 
                                       seed=seed, 
                                       batch_size=batch_size,
                                       num_workers=num_workers)

In [None]:
def train_epoch(
                model: TransferAlexNet,
                source_train_dir: DataLoader[Tuple[torch.Tensor, torch.Tensor]], 
                target_dir: DataLoader[torch.Tensor], 
                image_transform: tr,
                optimizer: torch.optim.Optimizer, 
                lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
                loss_coefficient: float,
                device: str, 
                seed: int):

    dl_source_train = get_dataloader(root=source_train_dir, 
                                     image_transform=image_transform,
                                     batch_size=256, 
                                     seed=seed)

    dl_target = get_dataloader(root=target_dir, 
                                image_transform=image_transform,
                                batch_size=256, 
                                seed=seed)

    source_batch = next(dl_source_train, None)
    target_batch = next(dl_target, None)
    source_over, target_over = source_batch is None , target_batch is None

    while not source_over:
        optimizer.zero_grad()
        
        xs, ys = source_batch
        xs, ys = xs.to(device), ys.to(device)
        # forward pass: source
        model_output_source = model.forward(xs)
        source_features, logits = model_output_source[:-1], model_output_source[-1]

        # forward pass: target
        xt, _ = target_batch
        model_output_target = model.forward(xt)
        target_features, _ = model_output_target[:-1], model_output_target[-1]

        # the loss consists of
        final_loss, cls_loss, feature_losses = calculate_loss(source_logits=logits,
                                                              source_labels=ys, 
                                                              source_features=source_features,
                                                              target_features=target_features,
                                                              loss_coefficient=loss_coefficient, 
                                                              )
        
        # backpropagation
        final_loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # make sure to check the dataloaders
        source_batch = next(dl_source_train, None)
        target_batch = next(dl_target, None)
        source_over, target_over = source_batch is None, target_batch is None

        # if the target domain is already exhausted, then re-initialize it
        if target_over: 
            dl_target = get_dataloader(root=target_dir, 
                                       image_transform=image_transform, 
                                       batch_size=256, 
                                       seed=seed+1)
        