This notebook refers to [Carlucci et al. (2019)](https://arxiv.org/pdf/1903.06864.pdf).
We begin by explaining the methods and procedures in details before applying it to a different dataset than the one used in the article (PACS).

# Methods and procedures used for Domain Generalization by Solving Jigsaw Puzzles

## Domain Generalization

Domain generalization refers to the ability of a machine learning model to generalize to unseen domains or out-of-distribution data. This is in contrast to traditional supervised learning which assumes that the training and test data come from the same domain or distribution. [Wang et al. (2022)]

The models is trained on different domains, the source data (if we consider a picture style it could be drawing, painting, cartoon, ...). We want the model to be able to predict accurately the class on an unseen domain, the target data (e.g. photos).

*Definition 1 (Domain).*
Let X denote a nonempty input space and Y an output space. A domain is composed of data that are sampled from a distribution. We denote it as $\mathcal{S} = {(x_i, y_i)}^n_{i=1}∼P_{XY}$ , where $x \in \mathcal{X} \subset \mathbb{R}^d$, $y \in \mathcal{Y} \subset \mathbb{R}$ denotes the label, and $P_{XY}$ denotes the joint distribution of the input sample and output label. X and Y denote the corresponding random variables. [Wang et al. (2022)]

*Definition 2 (Domain generalization).*
In domain generalization, we are given M training (source) domains $S_{train} = {S^i | i = 1, ..., M }$ where $S^i = {(x^i_j , y^i_j )}^{n_i}_{j=1}$ denotes the i-th domain. The joint distributions between each pair of domains are different: $P^i_{XY} \neq P^j_{XY}$, $1 \leq i \neq j \leq M$. The goal of domain generalization is to learn a robust and generalizable predictive function h : X → Y from the M training domains to achieve a minimum prediction error on an unseen test domain $S_{test}$ (i.e., $S_{test}$ cannot be accessed in training and $P^{test}_{XY} \neq P^i_{XY}$ for i ∈ {1, ..., M}):
$$ min_h \mathcal{E}(x,y) \in S_{test} [\mathcal{l}(h(x), y)] $$
where $\mathcal{l}(·, ·)$ is the loss function. [Wang et al. (2022)]

## JiGENDG
The algorithm is based on the idea of using jigsaw puzzles to train a model to be invariant to different domains. [Carlucci et al. (2019)]

## References

**[Carlucci et al. (2019)]** Carlucci, F. M., D'Innocente, A., Bucci, S., Caputo, B., & Tommasi, T. (2019). Domain Generalization by Solving Jigsaw Puzzles. arXiv preprint arXiv:1903.06864. [URL](https://arxiv.org/pdf/1903.06864.pdf)

**[Wang et al. (2022)]** Wang, J., Lan, C., Liu, C., Ouyang, Y., Qin, T., Lu, W., Chen, Y., Zeng, W., & Yu, P. S. (2022). Generalizing to Unseen Domains: A Survey on Domain Generalization. arXiv preprint arXiv:2103.03097. [URL](https://arxiv.org/pdf/2103.03097.pdf)

# Using JiGen on PACS (as in the article)

In [None]:
import torch
from IPython.core.debugger import set_trace
from torch import nn
from torch.nn import functional as F
import numpy as np

# Import plus utiles car seront définis dans le notebook
from data import data_helper
# from IPython.core.debugger import set_trace
from data.data_helper import available_datasets
from models import model_factory
from optimizer.optimizer_helper import get_optim_and_scheduler
from utils.Logger import Logger


In [None]:
class Args:
    source = ['photo','cartoon','sketch']
    target = ['art_painting']
    batch_size = 64
    image_size = 225   # 222 si resnet18
    
    min_scale = 0.8               #Minimum scale percent
    max_scale = 1.0               #Maximum scale percent
    random_horiz_flip = 0.0       #Chance of random horizontal flip
    jitter = 0.0                  #Color jitter amount
    tile_random_grayscale = 0.1   #Chance of randomly greyscaling a tile
    
    limit_source = None     #If set, it will limit the number of training samples
    limit_target = None     #If set, it will limit the number of testing samples
    
    learning_rate = 0.01
    epochs = 30
    n_classes = 31           #Number of classes for object prediction
    jigsaw_n_classes = 31    #Number of permutation classes for the puzzle
    network = "resnet18"     # To choose from : 'caffenet', 'alexnet', 'resnet18', 'resnet50', 'lenet'

    

    parser.add_argument("--source", choices=available_datasets, help="Source", nargs='+')
    parser.add_argument("--target", choices=available_datasets, help="Target")
    parser.add_argument("--batch_size", "-b", type=int, default=64, help="Batch size")
    parser.add_argument("--image_size", type=int, default=225, help="Image size")
    # data aug stuff
    parser.add_argument("--min_scale", default=0.8, type=float, help="Minimum scale percent")
    parser.add_argument("--max_scale", default=1.0, type=float, help="Maximum scale percent")
    parser.add_argument("--random_horiz_flip", default=0.0, type=float, help="Chance of random horizontal flip")
    parser.add_argument("--jitter", default=0.0, type=float, help="Color jitter amount")
    parser.add_argument("--tile_random_grayscale", default=0.1, type=float, help="Chance of randomly greyscaling a tile")
    #
    parser.add_argument("--limit_source", default=None, type=int, help="If set, it will limit the number of training samples")
    parser.add_argument("--limit_target", default=None, type=int, help="If set, it will limit the number of testing samples")

    parser.add_argument("--learning_rate", "-l", type=float, default=.01, help="Learning rate")
    parser.add_argument("--epochs", "-e", type=int, default=30, help="Number of epochs")
    parser.add_argument("--n_classes", "-c", type=int, default=31, help="Number of classes")
    parser.add_argument("--jigsaw_n_classes", "-jc", type=int, default=31, help="Number of classes for the jigsaw task")
    parser.add_argument("--network", choices=model_factory.nets_map.keys(), help="Which network to use", default="caffenet")
    parser.add_argument("--jig_weight", type=float, default=0.1, help="Weight for the jigsaw puzzle")
    parser.add_argument("--ooo_weight", type=float, default=0, help="Weight for odd one out task")
    parser.add_argument("--tf_logger", type=bool, default=True, help="If true will save tensorboard compatible logs")
    parser.add_argument("--val_size", type=float, default="0.1", help="Validation size (between 0 and 1)")
    parser.add_argument("--folder_name", default=None, help="Used by the logger to save logs")
    parser.add_argument("--bias_whole_image", default=None, type=float, help="If set, will bias the training procedure to show more often the whole image")
    parser.add_argument("--TTA", type=bool, help="Activate test time data augmentation")
    #parser.add_argument("--TTA", type=bool, action='store_true', help="Activate test time data augmentation")
    parser.add_argument("--classify_only_sane", type=bool,
                        help="If true, the network will only try to classify the non scrambled images")
    #parser.add_argument("--classify_only_sane", action='store_true', type=bool,
    #                    help="If true, the network will only try to classify the non scrambled images")
    parser.add_argument("--train_all", type=bool, help="If true, all network weights will be trained")
    #parser.add_argument("--train_all", action='store_true', type=bool, help="If true, all network weights will be trained")
    parser.add_argument("--suffix", default="", help="Suffix for the logger")
    parser.add_argument("--nesterov", type=bool, help="Use nesterov")
    #parser.add_argument("--nesterov", action='store_true', type=bool, help="Use nesterov")
    
    
    
args=Args()


In [1]:
! python train_jigsaw.py --batch_size 128 --n_classes 7 --learning_rate 0.001 --network resnet18 --val_size 0.1 --folder_name test --jigsaw_n_classes 30 --train_all True --TTA False --nesterov False --min_scale 0.8 --max_scale 1.0 --random_horiz_flip 0.5 --jitter 0.4 --tile_random_grayscale 0.1 --source photo cartoon sketch --target art_painting --jig_weight 0.7 --bias_whole_image 0.9 --image_size 222

Traceback (most recent call last):
  File "/home/bocquet-/Bureau/5A/Projet HDDL/JigenDG/train_jigsaw.py", line 10, in <module>
    from models import model_factory
  File "/home/bocquet-/Bureau/5A/Projet HDDL/JigenDG/models/model_factory.py", line 5, in <module>
    from models import resnet
  File "/home/bocquet-/Bureau/5A/Projet HDDL/JigenDG/models/resnet.py", line 3, in <module>
    from torchvision.models.resnet import BasicBlock, model_urls, Bottleneck
ImportError: cannot import name 'model_urls' from 'torchvision.models.resnet' (/home/bocquet-/.local/lib/python3.11/site-packages/torchvision/models/resnet.py)
