Ce notebook a comme source principale [Carlucci et al. (2019)](https://arxiv.org/pdf/1903.06864.pdf).
Nous commençons par expliquer les méthodes et les procédures en détail avant de les appliquer à un ensemble de données différent de celui utilisé dans l'article (PACS).

# Méthodes et procédures utilisées pour la généralisation de domaine en résolvant des casse-têtes/puzzles

## Généralisation de domaine

La généralisation de domaine fait référence à la capacité d'un modèle de machine learning à généraliser ses capacités sur des domaines non vus ou des données hors distribution. Cela s'oppose à l'apprentissage supervisé traditionnel, qui suppose que les données d'entraînement et de test proviennent du même domaine ou de la même distribution. [Wang et al. (2022)]

Le modèle est entraîné sur différents domaines, les données sources (si l'on considère un style d'image, cela pourrait être un dessin, une peinture, un dessin animé, etc.). Nous voulons que le modèle soit capable de prédire avec précision la classe dans un domaine non vu, les données cibles (par exemple, des photos).

*Définition 1 (Domaine).*
Soit X un espace d'entrée non vide (input) et Y un espace de sortie (output). Un domaine est composé de données échantillonnées à partir d'une distribution. Nous le notons comme $\mathcal{S} = {(x_i, y_i)}^n_{i=1}∼P_{XY}$ , où $x \in \mathcal{X} \subset \mathbb{R}^d$, $y \in \mathcal{Y} \subset \mathbb{R}$ représente le label, et $P_{XY}$ représente la distribution conjointe de l'échantillon d'entrée et du label de sortie. X et Y désignent les variables aléatoires correspondantes. [Wang et al. (2022)]

*Définition 2 (Généralisation de domaine).*
En généralisation de domaine, on pose M domaines d'entraînement (source) $S_{train} = {S^i | i = 1, ..., M }$, où $S^i = {(x^i_j , y^i_j )}^{n_i}{j=1}$ désigne le i-ème domaine. Les distributions conjointes entre chaque paire de domaines sont différentes : $P^i{XY} \neq P^j_{XY}$, $1 \leq i \neq j \leq M$. L'objectif de la généralisation de domaine est d'apprendre une fonction prédictive robuste et généralisable $h : X → Y$ à partir des M domaines d'entraînement pour obtenir une erreur de prédiction minimale sur un domaine de test non vu $S_{test}$ (c'est-à-dire, $S_{test}$ n'est pas accessible lors de l'entraînement et $P^{test}{XY} \neq P^i{XY}$ pour i ∈ {1, ..., M}):
$$ min_h \mathcal{E}(x,y) \in S_{test} [\mathcal{l}(h(x), y)] $$
où $\mathcal{l}(·, ·)$ est la fonction de perte. [Wang et al. (2022)]

## JiGENDG
L'algorithme repose sur l'idée d'utiliser des casse-têtes/puzzles pour entraîner un modèle à être invariant sur différents domaines. Le réseau apprend simultanément à résoudre les casse-têtes et à classifier les images. [Carlucci et al. (2019)]


### Dataset
Les données d'entrée sont un ensemble de N images provenant de S domaines. Dans chaque domaine i, nous avons $N_i$ observations labellisées. Nous écrivons $\left{ x^i_j, y^i_j \right} _{j=1}^{N_i}$, ce qui signifie que pour la j-ème image du i-ème domaine $x^i_j$, le label associé est $y^i_j$.

Nous avons $x^i_j \in \mathbb{R}^{n_p \times n_p}$ où $n_p \times n_p$ est la taille des images en pixels, en supposant que les images sont des carrés. Nous avons $y^i_j \in \mathbb{R}^{C}$ où C est le nombre de classes, car l'étiquette $y^i_j$ est encodée en one-hot.

En termes de dimensions, $\left{x^i_j, y^i_j\right}_{j=1}^{N_i} \in \left( \left( \mathbb{R}^{n_p \times n_p} \times \mathbb{R}^{C} \right) ^ {N_i} \right) ^ {S}$ où $N_i \times S \leq N$ car le nombre d'images étiquetées $N_i \times S$ ne dépasse pas le nombre total d'images $N$.


### Dataset permuté
À partir de l'ensemble de données non permuté, nous créons un nouveau jeu de données utilisé pour la tâche de résolution de casse-têtes. Nous considérons des permutations sur une grille $n \times n$ (dans l'article et notre travail, nous fixons $n=3$).

Bien que nous ayons un total de $n^2!$ permutations possibles, nous n'en considérons que P. Nous les choisissons en fonction de la distance de Hamming, ce qui signifie que nous ne conservons que celles avec le moins de différences de position. Cela permet de simplifier un peu la tâche et également de réduire le temps d'inférence (utiliser les $n^2!=362 880$ possibilités serait beaucoup plus chronophage que d'utiliser $P=30$ permutations).

La non-permutation est toujours incluse dans le sous-ensemble des permutations P.

Chaque permutation possible est associée à un indice qui permet de traiter le problème comme une tâche de classification où l'étiquette est un vecteur encodé en one-hot des indices des permutations.

Nous notons $\left\{z^i_k, p^i_k\right\}_{k=1}^{K_i} \in \left( \left( \mathbb{R}^{n_p \times n_p} \times \mathbb{R}^{P} \right) ^ {K_i} \right) ^ {S}$ où $z^i_k$ est l'image permutée, $p^i_k$ est l'indice de la permutation utilisée sur l'image associée, $K_i$ est le nombre d'instances étiquetées et $P$ est le nombre de permutations considérées.



###  Fonction de perte
Rappelons brièvement comment un réseau de neurones est entraîné :
- Le modèle traite un batch de $b$ échantillons d'entrée. Chaque échantillon passe à travers le réseau, et la sortie est calculée.
- La fonction de perte est appliquée à la sortie prédite et aux valeurs cibles pour le lot. Cette perte représente la dissimilarité entre les valeurs prédites et réelles.
- Les gradients cumulés de la perte sont calculés par rapport à chaque paramètre.
- Les paramètres du modèle sont mis à jour en fonction des gradients calculés. Le taux d'apprentissage $\eta$ contrôle dans quelle mesure les paramètres du modèle changent dans la direction qui minimise la perte.

Ce processus est répété pendant $E$ époques. Chaque époque implique le traitement de l'ensemble du jeu de données.

Les batches sont composés d'un mélange d'images ordonnées et mélangées. Le ratio est défini par $\beta$ : pour $\beta=0.75$, 75 % du lot est composé d'images ordonnées et le reste d'images mélangées. Si nous avons un batch de taille $b=128$, cela signifierait que nous avons $N_i=0.75\times128=96$ et $K_i=(1-0.75)\times128=32$.

Dans JiGen, la fonction de perte prend une forme particulière car deux tâches sont apprises.


#### Cas supervisé
Nous cherchons à optimiser les paramètres avec le problème de minimisation suivant :

$$ argmin_{\theta_f, \theta_p, \theta_c} \sum_{i=1}^{S} \sum_{j=1}^{N_i} \mathcal{L}_c \left( h(x^i_j|\theta_f, \theta_c), y^i_j\right) + \sum_{k=1}^{K_i} \alpha \mathcal{L}_p \left( h(z^i_k|\theta_f, \theta_p), p^i_k\right) $$

- $\mathcal{L}_c$ est une perte (la cross-entropy) pour la tâche de classification d'image. Nous rappelons que $\mathcal{L}c \left( h(x^i_j|\theta_f, \theta_c), y^i_j\right) = - \sum{c \in C} y^i_j \log(\mathbb{P}(h(x^i_j|\theta_f, \theta_c)=c))$ ;
- $\mathcal{L}_p$ est une perte (la cross-entropy) pour la tâche de résolution de casse-têtes ;
- $\alpha$ est le poids de la perte pour le casse-tête (l'importance que nous accordons à la tâche de résolution de casse-tête par rapport à celle de la tâche de classification) ;
- $h$ est la fonction d'activation du modèle profond (deep model), elle prédit le label ;
- $\theta_f$ est l'ensemble des paramètres (poids et biais) pour la couche entièrement connectée (fully connected layer);
- $\theta_p$ est l'ensemble des paramètres pour la dernière couche entièrement connectée dédiée à la reconnaissance de permutation ;
- $\theta_c$ est l'ensemble des paramètres pour la couche de convolution.

La perte du casse-tête $\mathcal{L}_p$ est calculée sur l'image ordonnée, mais la perte de classification $\mathcal{L}_c$ n'est pas calculée sur les images mélangées car cela rendrait la reconnaissance d'objets plus difficile.


#### Unsupervised case
JiGen a été conçu dans le but de la généralisation de domaine non supervisée. La seule différence avec JiGen dans le cas supervisé réside dans la perte pour la tâche de classification d'image :

$$ argmin_{\theta_f, \theta_p, \theta_c} \sum_{i=1}^{S} \mathcal{L}_E (x^i) + \sum_{k=1}^{K_i} \alpha \mathcal{L}_p \left( h(z^i_k|\theta_f, \theta_p), p^i_k\right) $$

avec $\mathcal{L}E (x^i) = \sum{y \in \mathcal{Y}} h(x^i|\theta_f, \theta_c) \log(h(x^i|\theta_f, \theta_c))$, la cross-entropy empirique.

Remarque : la somme $\sum_{j=1}^{N_i}$ disparaît car nous considérons toutes les images et non pas seulement celles étiquetées.


### Test
Pour tester le modèle, nous ne considérons que la partie classification du réseau : nous n'utilisons pas la couche entièrement connectée finale qui sert à la résolution de casse-tête. Cela revient à fixer $\alpha=0$.


### Parameters
Pour toutes les expériences, nous préciserons clairement les valeurs des caractéristiques de l'ensemble de données : les tailles des images $n_b$, le nombre d'images $N$, et le nombre de classes $C$.

Pour tous les paramètres, nous considérerons les mêmes paramètres pour le casse-tête : la taille de la grille $n$, le nombre de permutations considérées $P$, et le biais des données $\beta$. Les auteurs choisissent ces paramètres avec une validation croisée sur 10 % de l'ensemble de données, pour chaque expérience.

Nous fixerons les paramètres d'expérience : la taille des lots $b$, le nombre d'époques $E$, le taux d'apprentissage $\eta$, et le poids du casse-tête $\alpha$. (Les auteurs ont fixé ...)

Les paramètres du modèle optimisés par rétropropagation et non choisis par l'utilisateur sont $\theta_f$, $\theta_p$, et $\theta_c$.

## References

**[Carlucci et al. (2019)]** Carlucci, F. M., D'Innocente, A., Bucci, S., Caputo, B., & Tommasi, T. (2019). Domain Generalization by Solving Jigsaw Puzzles. arXiv preprint arXiv:1903.06864. [URL](https://arxiv.org/pdf/1903.06864.pdf)

**[Wang et al. (2022)]** Wang, J., Lan, C., Liu, C., Ouyang, Y., Qin, T., Lu, W., Chen, Y., Zeng, W., & Yu, P. S. (2022). Generalizing to Unseen Domains: A Survey on Domain Generalization. arXiv preprint arXiv:2103.03097. [URL](https://arxiv.org/pdf/2103.03097.pdf)

# Using JiGen on PACS (as in the article)

In [1]:
# Stocker le dossier PACS au même endroit que ce notebook

In [2]:
import torch
from IPython.core.debugger import set_trace
from torch import nn
from torch.nn import functional as F
import torch.utils.model_zoo as model_zoo
from torch.autograd import Function
from torchvision.models.resnet import BasicBlock,Bottleneck
import torch.utils.data as data
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision import datasets
from torch import optim
import torchvision

import numpy as np
import os
from os.path import join, dirname
from collections import OrderedDict
from itertools import chain
from PIL import Image
from random import sample, random
import bisect
import warnings
import tensorflow as tf
import scipy.misc 
try:
    from StringIO import StringIO  # Python 2.7
except ImportError:
    from io import BytesIO         # Python 3.x
from time import time


2023-12-01 14:41:37.040169: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [21]:
# Commentaires sur le choix du réseau :
# resnet18 fonctionne, nécessite image_size=222
# resnet50 fonctionne pour image_size=222
# alexnet ne fonctionne pas, a un argument en trop avec jigsaw_classes (??)
# caffenet ne fonctionne pas, car a besoin du réseau pré-entrainé
# lenet ne fonctionne pas, on doit surment trouver la valeur de image_size qui convient


class Args:
    source = ['photo','cartoon','sketch']
    target = 'art_painting'
    batch_size = 64
    image_size = 222              # 222 si resnet18
    
    min_scale = 0.8               # Minimum scale percent
    max_scale = 1.0               # Maximum scale percent
    random_horiz_flip = 0.0       # Chance of random horizontal flip
    jitter = 0.0                  # Color jitter amount
    tile_random_grayscale = 0.1   # Chance of randomly greyscaling a tile
    
    limit_source = None     # If set, it will limit the number of training samples
    limit_target = None     # If set, it will limit the number of testing samples
    
    learning_rate = 0.01
    epochs = 5
    n_classes = 7              # Number of classes for object prediction
    jigsaw_n_classes = 31       # Number of permutation classes for the puzzle
    network = "resnet50"        # To choose from : 'caffenet', 'alexnet', 'resnet18', 'resnet50', 'lenet'
    jig_weight = 0.7            # Weight for the jigsaw puzzle compared to the classification
    ooo_weight = 0              # Weight for odd one out task
    tf_logger = True            # If True will save tensorboard compatible logs
    val_size = 0.1              # Validation size (between 0 and 1)
    folder_name = "Test"        # Used by the logger to save logs
    bias_whole_image = 0.9      # If set, will bias the training procedure to show more often the whole image
    TTA = False                 # Activate test time data augmentation
    classify_only_sane = False  # If true, the network will only try to classify the non scrambled images
    train_all = True            # If true, all network weights will be trained
    suffix = ""                 # Suffix for the logger
    nesterov = False            # Use nesterov
    


#### Fichiers de /model

In [4]:
# Common to all networks definition
class Id(nn.Module):
    def __init__(self):
        super(Id, self).__init__()

    def forward(self, x):
        return x

In [5]:
# model_utils.py

class GradientKillerLayer(Function):
    @staticmethod
    def forward(ctx, x, **kwargs):
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return None, None


class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, lambda_val):
        ctx.lambda_val = lambda_val

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.lambda_val

        return output, None

In [6]:
# caffenet


class AlexNetCaffe(nn.Module):
    def __init__(self, jigsaw_classes=1000, n_classes=100, domains=3, dropout=True):
        super(AlexNetCaffe, self).__init__()
        print("Using Caffe AlexNet")
        self.features = nn.Sequential(OrderedDict([
            ("conv1", nn.Conv2d(3, 96, kernel_size=11, stride=4)),
            ("relu1", nn.ReLU(inplace=True)),
            ("pool1", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)),
            ("norm1", nn.LocalResponseNorm(5, 1.e-4, 0.75)),
            ("conv2", nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2)),
            ("relu2", nn.ReLU(inplace=True)),
            ("pool2", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)),
            ("norm2", nn.LocalResponseNorm(5, 1.e-4, 0.75)),
            ("conv3", nn.Conv2d(256, 384, kernel_size=3, padding=1)),
            ("relu3", nn.ReLU(inplace=True)),
            ("conv4", nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2)),
            ("relu4", nn.ReLU(inplace=True)),
            ("conv5", nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2)),
            ("relu5", nn.ReLU(inplace=True)),
            ("pool5", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)),
        ]))
        self.classifier = nn.Sequential(OrderedDict([
            ("fc6", nn.Linear(256 * 6 * 6, 4096)),
            ("relu6", nn.ReLU(inplace=True)),
            ("drop6", nn.Dropout() if dropout else Id()),
            ("fc7", nn.Linear(4096, 4096)),
            ("relu7", nn.ReLU(inplace=True)),
            ("drop7", nn.Dropout() if dropout else Id())]))

        self.jigsaw_classifier = nn.Linear(4096, jigsaw_classes)
        self.class_classifier = nn.Linear(4096, n_classes)
        # self.domain_classifier = nn.Sequential(
        #     nn.Linear(256 * 6 * 6, 1024),
        #     nn.ReLU(),
        #     nn.Dropout(),
        #     nn.Linear(1024, 1024),
        #     nn.ReLU(),
        #     nn.Dropout(),
        #     nn.Linear(1024, domains))

    def get_params(self, base_lr):
        return [{"params": self.features.parameters(), "lr": 0.},
                {"params": chain(self.classifier.parameters(), self.jigsaw_classifier.parameters()
                                 , self.class_classifier.parameters()#, self.domain_classifier.parameters()
                                 ), "lr": base_lr}]

    def is_patch_based(self):
        return False

    def forward(self, x, lambda_val=0):
        x = self.features(x*57.6)  #57.6 is the magic number needed to bring torch data back to the range of caffe data, based on used std
        x = x.view(x.size(0), -1)
        #d = ReverseLayerF.apply(x, lambda_val)
        x = self.classifier(x)
        return self.jigsaw_classifier(x), self.class_classifier(x)#, self.domain_classifier(d)


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


def caffenet(jigsaw_classes, classes):
    model = AlexNetCaffe(jigsaw_classes, classes)
    for m in model.modules():
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight, .1)
            nn.init.constant_(m.bias, 0.)

    state_dict = torch.load(os.path.join(os.path.abspath(''), "pretrained/alexnet_caffe.pth.tar"))
    del state_dict["classifier.fc8.weight"]
    del state_dict["classifier.fc8.bias"]
    model.load_state_dict(state_dict, strict=False)

    return model


def caffenet_gap(jigsaw_classes, classes):
    model = AlexNetCaffe(jigsaw_classes, classes)
    state_dict = torch.load(os.path.join(os.path.abspath(''), "pretrained/alexnet_caffe.pth.tar"))
    del state_dict["classifier.fc6.weight"]
    del state_dict["classifier.fc6.bias"]
    del state_dict["classifier.fc7.weight"]
    del state_dict["classifier.fc7.bias"]
    del state_dict["classifier.fc8.weight"]
    del state_dict["classifier.fc8.bias"]
    model.load_state_dict(state_dict, strict=False)
    # weights are initialized in the constructor
    return model


In [7]:
# alexnet.py

model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, dropout=True):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout() if dropout else Id(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout() if dropout else Id(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 6 * 6)
        x = self.classifier(x)
        return x


def alexnet(classes, pretrained=False):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = AlexNet(classes, True)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))

    model.classifier[-1] = nn.Linear(4096, classes)
    nn.init.xavier_uniform_(model.classifier[-1].weight, .1)
    nn.init.constant_(model.classifier[-1].bias, 0.)
    return model

In [20]:
# resnet.py

class ResNet(nn.Module):
    def __init__(self, block, layers, jigsaw_classes=1000, classes=100, domains=3):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.jigsaw_classifier = nn.Linear(512 * block.expansion, jigsaw_classes)
        self.class_classifier = nn.Linear(512 * block.expansion, classes)
        #self.domain_classifier = nn.Linear(512 * block.expansion, domains)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def is_patch_based(self):
        return False

    def forward(self, x, **kwargs):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return self.jigsaw_classifier(x),self.class_classifier(x)


def resnet18(pretrained=True, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    print("Using ResNet-18")
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet18-5c106cde.pth'), strict=False)
        #model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
    return model

def resnet50(pretrained=True, **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    print("Using ResNet-50")
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth'), strict=False)
    return model

In [9]:
# mnist.py

# built as https://github.com/ricvolpi/generalize-unseen-domains/blob/master/model.py
class MnistModel(nn.Module):
    def __init__(self, jigsaw_classes=1000, n_classes=100):
        super().__init__()
        
        outfeats = 1024 
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 4 * 4, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1024),
            nn.ReLU(True),
        )
#         outfeats = 100
#         self.features = nn.Sequential(
#             nn.Conv2d(3, 32, 5),
#             nn.ReLU(True),
#             nn.MaxPool2d(2, 2),
#             nn.Conv2d(32, 48, 5),
#             nn.ReLU(True),
#             nn.MaxPool2d(2, 2)
#         )
#         self.classifier = nn.Sequential(
#             nn.Linear(48 * 4 * 4, 100),
#             nn.ReLU(True),
#             nn.Linear(100, outfeats),
#             nn.ReLU(True),
#         )
        print("Using LeNet (%d)" % outfeats)
        self.jigsaw_classifier = nn.Linear(outfeats, jigsaw_classes)
        self.class_classifier = nn.Linear(outfeats, n_classes)

    def get_params(self, base_lr):
        raise "No pretrained exists for LeNet - use train all"

    def is_patch_based(self):
        return False

    def forward(self, x, lambda_val=0):
        # print(x.shape)
        x = self.features(x)
        # print(x.shape)
        x = self.classifier(x.view(x.size(0), -1))
        return self.jigsaw_classifier(x), self.class_classifier(x)


def lenet(jigsaw_classes, classes):
    model = MnistModel(jigsaw_classes, classes)
    return model

In [10]:
# model_factory.py

nets_map = {
    'caffenet': caffenet,
    'alexnet': alexnet,
    'resnet18': resnet18,
    'resnet50': resnet50,
    'lenet': lenet
}


def get_network(name):
    if name not in nets_map:
        raise ValueError('Name of network unknown %s' % name)

    def get_network_fn(**kwargs):
        return nets_map[name](**kwargs)

    return get_network_fn

#### Fichiers de /data

In [11]:
# StandardDataset.py

def get_dataset(path, mode, image_size):
    if mode == "train":
        img_transform = transforms.Compose([
            transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1/256., 1/256., 1/256.])  # std=[1/256., 1/256., 1/256.] #[0.229, 0.224, 0.225]
        ])
    else:
        img_transform = transforms.Compose([
            transforms.Resize(image_size),
            # transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], std=[1/256., 1/256., 1/256.])  # std=[1/256., 1/256., 1/256.]
        ])
    return datasets.ImageFolder(path, transform=img_transform)

In [12]:
# JigsawLoader.py



def get_random_subset(names, labels, percent):
    """

    :param names: list of names
    :param labels:  list of labels
    :param percent: 0 < float < 1
    :return:
    """
    samples = len(names)
    amount = int(samples * percent)
    random_index = sample(range(samples), amount)
    name_val = [names[k] for k in random_index]
    name_train = [v for k, v in enumerate(names) if k not in random_index]
    labels_val = [labels[k] for k in random_index]
    labels_train = [v for k, v in enumerate(labels) if k not in random_index]
    return name_train, name_val, labels_train, labels_val


def _dataset_info(txt_labels):
    with open(txt_labels, 'r') as f:
        images_list = f.readlines()

    file_names = []
    labels = []
    for row in images_list:
        row = row.split(' ')
        file_names.append(row[0])
        labels.append(int(row[1]))

    return file_names, labels


def get_split_dataset_info(txt_list, val_percentage):
    names, labels = _dataset_info(txt_list)
    return get_random_subset(names, labels, val_percentage)


class JigsawDataset(data.Dataset):
    def __init__(self, names, labels, jig_classes=100, img_transformer=None, tile_transformer=None, patches=True, bias_whole_image=None):
        self.data_path = ""
        self.names = names
        self.labels = labels

        self.N = len(self.names)
        self.permutations = self.__retrieve_permutations(jig_classes)
        self.grid_size = 3
        self.bias_whole_image = bias_whole_image
        if patches:
            self.patch_size = 64
        self._image_transformer = img_transformer
        self._augment_tile = tile_transformer
        if patches:
            self.returnFunc = lambda x: x
        else:
            def make_grid(x):
                return torchvision.utils.make_grid(x, self.grid_size, padding=0)
            self.returnFunc = make_grid

    def get_tile(self, img, n):
        w = float(img.size[0]) / self.grid_size
        y = int(n / self.grid_size)
        x = n % self.grid_size
        tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w])
        tile = self._augment_tile(tile)
        return tile
    
    def get_image(self, index):
        framename = self.data_path + self.names[index]
        img = Image.open(framename).convert('RGB')
        return self._image_transformer(img)
        
    def __getitem__(self, index):
        img = self.get_image(index)
        n_grids = self.grid_size ** 2
        tiles = [None] * n_grids
        for n in range(n_grids):
            tiles[n] = self.get_tile(img, n)

        order = np.random.randint(len(self.permutations) + 1)  # added 1 for class 0: unsorted
        if self.bias_whole_image:
            if self.bias_whole_image > random():
                order = 0
        if order == 0:
            data = tiles
        else:
            data = [tiles[self.permutations[order - 1][t]] for t in range(n_grids)]
            
        data = torch.stack(data, 0)
        return self.returnFunc(data), int(order), int(self.labels[index])

    def __len__(self):
        return len(self.names)

    def __retrieve_permutations(self, classes):
        all_perm = np.load('permutations_%d.npy' % (classes))
        # from range [1,9] to [0,8]
        if all_perm.min() == 1:
            all_perm = all_perm - 1

        return all_perm


class JigsawTestDataset(JigsawDataset):
    def __init__(self, *args, **xargs):
        super().__init__(*args, **xargs)

    def __getitem__(self, index):
        framename = self.data_path + self.names[index]
        img = Image.open(framename).convert('RGB')
        return self._image_transformer(img), 0, int(self.labels[index])


class JigsawTestDatasetMultiple(JigsawDataset):
    def __init__(self, *args, **xargs):
        super().__init__(*args, **xargs)
        self._image_transformer = transforms.Compose([
            transforms.Resize(255, Image.BILINEAR),
        ])
        self._image_transformer_full = transforms.Compose([
            transforms.Resize(225, Image.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        self._augment_tile = transforms.Compose([
            transforms.Resize((75, 75), Image.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    def __getitem__(self, index):
        framename = self.data_path + self.names[index]
        _img = Image.open(framename).convert('RGB')
        img = self._image_transformer(_img)

        w = float(img.size[0]) / self.grid_size
        n_grids = self.grid_size ** 2
        images = []
        jig_labels = []
        tiles = [None] * n_grids
        for n in range(n_grids):
            y = int(n / self.grid_size)
            x = n % self.grid_size
            tile = img.crop([x * w, y * w, (x + 1) * w, (y + 1) * w])
            tile = self._augment_tile(tile)
            tiles[n] = tile
        for order in range(0, len(self.permutations)+1, 3):
            if order==0:
                data = tiles
            else:
                data = [tiles[self.permutations[order-1][t]] for t in range(n_grids)]
            data = self.returnFunc(torch.stack(data, 0))
            images.append(data)
            jig_labels.append(order)
        images = torch.stack(images, 0)
        jig_labels = torch.LongTensor(jig_labels)
        return images, jig_labels, int(self.labels[index])

In [13]:
# concat_dataset.py


class ConcatDataset(Dataset):
    """
    Dataset to concatenate multiple datasets.
    Purpose: useful to assemble different existing datasets, possibly
    large-scale datasets as the concatenation operation is done in an
    on-the-fly manner.

    Arguments:
        datasets (sequence): List of datasets to be concatenated
    """

    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def isMulti(self):
        return isinstance(self.datasets[0], JigsawTestDatasetMultiple)

    def __init__(self, datasets):
        super(ConcatDataset, self).__init__()
        assert len(datasets) > 0, 'datasets should not be an empty iterable'
        self.datasets = list(datasets)
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        return self.datasets[dataset_idx][sample_idx], dataset_idx

    @property
    def cummulative_sizes(self):
        warnings.warn("cummulative_sizes attribute is renamed to "
                      "cumulative_sizes", DeprecationWarning, stacklevel=2)
        return self.cumulative_sizes

In [14]:
# data_helper.py

mnist = 'mnist'
mnist_m = 'mnist_m'
svhn = 'svhn'
synth = 'synth'
usps = 'usps'

vlcs_datasets = ["CALTECH", "LABELME", "PASCAL", "SUN"]
pacs_datasets = ["art_painting", "cartoon", "photo", "sketch"]
office_datasets = ["amazon", "dslr", "webcam"]
digits_datasets = [mnist, mnist, svhn, usps]
available_datasets = office_datasets + pacs_datasets + vlcs_datasets + digits_datasets
#office_paths = {dataset: "/home/enoon/data/images/office/%s" % dataset for dataset in office_datasets}
#pacs_paths = {dataset: "/home/enoon/data/images/PACS/kfold/%s" % dataset for dataset in pacs_datasets}
vlcs_paths = {dataset: "/home/goulmdata/images/VLCS/%s/test" % dataset for dataset in vlcs_datasets}
#paths = {**office_paths, **pacs_paths, **vlcs_paths}

dataset_std = {mnist: (0.30280363, 0.30280363, 0.30280363),
               mnist_m: (0.2384788, 0.22375608, 0.24496263),
               svhn: (0.1951134, 0.19804622, 0.19481073),
               synth: (0.29410212, 0.2939651, 0.29404707),
               usps: (0.25887518, 0.25887518, 0.25887518),
               }

dataset_mean = {mnist: (0.13909429, 0.13909429, 0.13909429),
                mnist_m: (0.45920207, 0.46326601, 0.41085603),
                svhn: (0.43744073, 0.4437959, 0.4733686),
                synth: (0.46332872, 0.46316052, 0.46327512),
                usps: (0.17025368, 0.17025368, 0.17025368),
                }


class Subset(torch.utils.data.Dataset):
    def __init__(self, dataset, limit):
        indices = torch.randperm(len(dataset))[:limit]
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)


def get_train_dataloader(args, patches):
    dataset_list = args.source
    assert isinstance(dataset_list, list)
    datasets = []
    val_datasets = []
    img_transformer, tile_transformer = get_train_transformers(args)
    limit = args.limit_source
    for dname in dataset_list:
        name_train, name_val, labels_train, labels_val = get_split_dataset_info(join(os.path.abspath(''), 'data/txt_lists', '%s_train.txt' % dname), args.val_size)
        train_dataset = JigsawDataset(name_train, labels_train, patches=patches, img_transformer=img_transformer,
                                      tile_transformer=tile_transformer, jig_classes=args.jigsaw_n_classes, bias_whole_image=args.bias_whole_image)
        if limit:
            train_dataset = Subset(train_dataset, limit)
        datasets.append(train_dataset)
        val_datasets.append(JigsawTestDataset(name_val, labels_val, img_transformer=get_val_transformer(args),
                              patches=patches, jig_classes=args.jigsaw_n_classes))
    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader, val_loader


def get_val_dataloader(args, patches=False):
    names, labels = _dataset_info(join(os.path.abspath(''), 'data/txt_lists', '%s_test.txt' % args.target))
    img_tr = get_val_transformer(args)
    val_dataset = JigsawTestDataset(names, labels, patches=patches, img_transformer=img_tr, jig_classes=args.jigsaw_n_classes)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader


def get_jigsaw_val_dataloader(args, patches=False):
    names, labels = _dataset_info(join(os.path.abspath(''), 'data/txt_lists', '%s_test.txt' % args.target))
    img_tr = [transforms.Resize((args.image_size, args.image_size))]
    tile_tr = [transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    img_transformer = transforms.Compose(img_tr)
    tile_transformer = transforms.Compose(tile_tr)
    val_dataset = JigsawDataset(names, labels, patches=patches, img_transformer=img_transformer,
                                      tile_transformer=tile_transformer, jig_classes=args.jigsaw_n_classes, bias_whole_image=args.bias_whole_image)
    if args.limit_target and len(val_dataset) > args.limit_target:
        val_dataset = Subset(val_dataset, args.limit_target)
        print("Using %d subset of val dataset" % args.limit_target)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader


def get_train_transformers(args):
    img_tr = [transforms.RandomResizedCrop((int(args.image_size), int(args.image_size)), (args.min_scale, args.max_scale))]
    if args.random_horiz_flip > 0.0:
        img_tr.append(transforms.RandomHorizontalFlip(args.random_horiz_flip))
    if args.jitter > 0.0:
        img_tr.append(transforms.ColorJitter(brightness=args.jitter, contrast=args.jitter, saturation=args.jitter, hue=min(0.5, args.jitter)))

    tile_tr = []
    if args.tile_random_grayscale:
        tile_tr.append(transforms.RandomGrayscale(args.tile_random_grayscale))
    tile_tr = tile_tr + [transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]

    return transforms.Compose(img_tr), transforms.Compose(tile_tr)


def get_val_transformer(args):
    img_tr = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    return transforms.Compose(img_tr)


def get_target_jigsaw_loader(args):
    img_transformer, tile_transformer = get_train_transformers(args)
    name_train, _, labels_train, _ = get_split_dataset_info(join(os.path.abspath(''), 'data/txt_lists', '%s_train.txt' % args.target), 0)
    dataset = JigsawDataset(name_train, labels_train, patches=False, img_transformer=img_transformer,tile_transformer=tile_transformer, jig_classes=args.jigsaw_n_classes, bias_whole_image=args.bias_whole_image)
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    return loader

#### Fichier de /optimizer

In [15]:
# optimizer_helper.py

def get_optim_and_scheduler(network, epochs, lr, train_all, nesterov=False):
    if train_all:
        params = network.parameters()
    else:
        params = network.get_params(lr)
    optimizer = optim.SGD(params, weight_decay=.0005, momentum=.9, nesterov=nesterov, lr=lr)
    #optimizer = optim.Adam(params, lr=lr)
    step_size = int(epochs * .8)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size)
    print("Step size: %d" % step_size)
    return optimizer, scheduler

#### Fichiers de /utils

In [16]:
# tf_logger.py

class TFLogger(object):
    
    def __init__(self, log_dir):
        """Create a summary writer logging to log_dir."""
        self.writer = tf.summary.create_file_writer(log_dir)
        #self.writer = tf.summary.FileWriter(log_dir)

    def scalar_summary(self, tag, value, step):
        """Log a scalar variable."""
        #summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
        #self.writer.add_summary(summary, step)
        with self.writer.as_default():
            tf.summary.scalar(tag, value, step=step)
            self.writer.flush()
            
    def image_summary(self, tag, images, step):
        """Log a list of images."""

        img_summaries = []
        for i, img in enumerate(images):
            # Write the image to a string
            try:
                s = StringIO()
            except:
                s = BytesIO()
            scipy.misc.toimage(img).save(s, format="png")

            # Create an Image object
            img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
                                       height=img.shape[0],
                                       width=img.shape[1])
            # Create a Summary value
            img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))

        # Create and write Summary
        summary = tf.Summary(value=img_summaries)
        self.writer.add_summary(summary, step)
        
    def histo_summary(self, tag, values, step, bins=1000):
        """Log a histogram of the tensor of values."""

        # Create a histogram using numpy
        counts, bin_edges = np.histogram(values, bins=bins)

        # Fill the fields of the histogram proto
        hist = tf.HistogramProto()
        hist.min = float(np.min(values))
        hist.max = float(np.max(values))
        hist.num = int(np.prod(values.shape))
        hist.sum = float(np.sum(values))
        hist.sum_squares = float(np.sum(values**2))

        # Drop the start of the first bin
        bin_edges = bin_edges[1:]

        # Add bin edges and counts
        for edge in bin_edges:
            hist.bucket_limit.append(edge)
        for c in counts:
            hist.bucket.append(c)

        # Create and write Summary
        summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
        self.writer.add_summary(summary, step)
        self.writer.flush()

In [17]:
# Logger.py


_log_path = join(os.path.abspath(''), '../logs')


# high level wrapper for tf_logger.TFLogger
class Logger():
    def __init__(self, args, update_frequency=10):
        self.current_epoch = 0
        self.max_epochs = args.epochs
        self.last_update = time()
        self.start_time = time()
        self._clean_epoch_stats()
        self.update_f = update_frequency
        folder, logname = self.get_name_from_args(args)
        log_path = join(_log_path, folder, logname)
        if args.tf_logger:
            self.tf_logger = TFLogger(log_path)
            print("Saving to %s" % log_path)
        else:
            self.tf_logger = None
        self.current_iter = 0

    def new_epoch(self, learning_rates):
        self.current_epoch += 1
        self.last_update = time()
        self.lrs = learning_rates
        print("New epoch - lr: %s" % ", ".join([str(lr) for lr in self.lrs]))
        self._clean_epoch_stats()
        if self.tf_logger:
            for n, v in enumerate(self.lrs):
                self.tf_logger.scalar_summary("aux/lr%d" % n, v, self.current_iter)

    def log(self, it, iters, losses, samples_right, total_samples):
        self.current_iter += 1
        loss_string = ", ".join(["%s : %.3f" % (k, v) for k, v in losses.items()])
        for k, v in samples_right.items():
            past = self.epoch_stats.get(k, 0.0)
            self.epoch_stats[k] = past + v
        self.total += total_samples
        acc_string = ", ".join(["%s : %.2f" % (k, 100 * (v / total_samples)) for k, v in samples_right.items()])
        if it % self.update_f == 0:
            print("%d/%d of epoch %d/%d %s - acc %s [bs:%d]" % (it, iters, self.current_epoch, self.max_epochs, loss_string,
                                                                acc_string, total_samples))
            # update tf log
            if self.tf_logger:
                for k, v in losses.items(): self.tf_logger.scalar_summary("train/loss_%s" % k, v, self.current_iter)

    def _clean_epoch_stats(self):
        self.epoch_stats = {}
        self.total = 0

    def log_test(self, phase, accuracies):
        print("Accuracies on %s: " % phase + ", ".join(["%s : %.2f" % (k, v * 100) for k, v in accuracies.items()]))
        if self.tf_logger:
            for k, v in accuracies.items(): self.tf_logger.scalar_summary("%s/acc_%s" % (phase, k), v, self.current_iter)

    def save_best(self, val_test, best_test):
        print("It took %g" % (time() - self.start_time))
        if self.tf_logger:
            for x in range(10):
                self.tf_logger.scalar_summary("best/from_val_test", val_test, x)
                self.tf_logger.scalar_summary("best/max_test", best_test, x)

    @staticmethod
    def get_name_from_args(args):
        folder_name = "%s_to_%s" % ("-".join(sorted(args.source)), args.target)
        if args.folder_name:
            folder_name = join(args.folder_name, folder_name)
        name = "eps%d_bs%d_lr%g_class%d_jigClass%d_jigWeight%g" % (args.epochs, args.batch_size, args.learning_rate, args.n_classes,
                                                                   args.jigsaw_n_classes, args.jig_weight)
        # if args.ooo_weight > 0:
        #     name += "_oooW%g" % args.ooo_weight
        if args.train_all:
            name += "_TAll"
        if args.bias_whole_image:
            name += "_bias%g" % args.bias_whole_image
        if args.classify_only_sane:
            name += "_classifyOnlySane"
        if args.TTA:
            name += "_TTA"
        try:
            name += "_entropy%g_jig_tW%g" % (args.entropy_weight, args.target_weight)
        except AttributeError:
            pass
        if args.suffix:
            name += "_%s" % args.suffix
        name += "_%d" % int(time() % 1000)
        return folder_name, name

#### Fichier principal train_jigsaw.py


In [18]:
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        model = get_network(args.network)(jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes)
        self.model = model.to(device)
        # print(self.model)
        self.source_loader, self.val_loader = get_train_dataloader(args, patches=model.is_patch_based())
        self.target_loader = get_val_dataloader(args, patches=model.is_patch_based())
        self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov)
        self.jig_weight = args.jig_weight
        self.only_non_scrambled = args.classify_only_sane
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader):
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(self.device), d_idx.to(self.device)
            # absolute_iter_count = it + self.current_epoch * self.len_dataloader
            # p = float(absolute_iter_count) / self.args.epochs / self.len_dataloader
            # lambda_val = 2. / (1. + np.exp(-10 * p)) - 1
            # if domain_error > 2.0:
            #     lambda_val  = 0
            # print("Shutting down LAMBDA to prevent implosion")

            self.optimizer.zero_grad()

            jigsaw_logit, class_logit = self.model(data)  # , lambda_val=lambda_val)
            jigsaw_loss = criterion(jigsaw_logit, jig_l)
            # domain_loss = criterion(domain_logit, d_idx)
            # domain_error = domain_loss.item()
            if self.only_non_scrambled:
                if self.target_id is not None:
                    idx = (jig_l == 0) & (d_idx != self.target_id)
                    class_loss = criterion(class_logit[idx], class_l[idx])
                else:
                    class_loss = criterion(class_logit[jig_l == 0], class_l[jig_l == 0])

            elif self.target_id:
                class_loss = criterion(class_logit[d_idx != self.target_id], class_l[d_idx != self.target_id])
            else:
                class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            # _, domain_pred = domain_logit.max(dim=1)
            loss = class_loss + jigsaw_loss * self.jig_weight  # + 0.1 * domain_loss

            loss.backward()
            self.optimizer.step()

            self.logger.log(it, len(self.source_loader),
                            {"jigsaw": jigsaw_loss.item(), "class": class_loss.item()  # , "domain": domain_loss.item()
                             },
                            # ,"lambda": lambda_val},
                            {"jigsaw": torch.sum(jig_pred == jig_l.data).item(),
                             "class": torch.sum(cls_pred == class_l.data).item(),
                             # "domain": torch.sum(domain_pred == d_idx.data).item()
                             },
                            data.shape[0])
            del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                if loader.dataset.isMulti():
                    jigsaw_correct, class_correct, single_acc = self.do_test_multi(loader)
                    print("Single vs multi: %g %g" % (float(single_acc) / total, float(class_correct) / total))
                else:
                    jigsaw_correct, class_correct = self.do_test(loader)
                jigsaw_acc = float(jigsaw_correct) / total
                class_acc = float(class_correct) / total
                self.logger.log_test(phase, {"jigsaw": jigsaw_acc, "class": class_acc})
                self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        domain_correct = 0
        for it, ((data, jig_l, class_l), _) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(self.device), class_l.to(self.device)
            jigsaw_logit, class_logit = self.model(data)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data)
        return jigsaw_correct, class_correct

    def do_test_multi(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        single_correct = 0
        for it, ((data, jig_l, class_l), d_idx) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(self.device), class_l.to(self.device)
            n_permutations = data.shape[1]
            class_logits = torch.zeros(n_permutations, data.shape[0], self.n_classes).to(self.device)
            for k in range(n_permutations):
                class_logits[k] = F.softmax(self.model(data[:, k])[1], dim=1)
            class_logits[0] *= 4 * n_permutations  # bias more the original image
            class_logit = class_logits.mean(0)
            _, cls_pred = class_logit.max(dim=1)
            jigsaw_logit, single_logit = self.model(data[:, 0])
            _, jig_pred = jigsaw_logit.max(dim=1)
            _, single_logit = single_logit.max(dim=1)
            single_correct += torch.sum(single_logit == class_l.data)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data[:, 0])
        return jigsaw_correct, class_correct, single_correct

    def do_training(self):
        self.logger = Logger(self.args, update_frequency=30)  # , "domain", "lambda"
        self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)}
        for self.current_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        #print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max()))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model


In [19]:
# main du code python

args = Args()

def main(args):

    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    trainer = Trainer(args, device)
    trainer.do_training()

main(args)

Dataset size: train 7150, val 793, test 2048
Step size: 4
Saving to /home/tiphaign/Documents/5A/HDDL/Projet_HDDL/JigenDG-master/../logs/Test/cartoon-photo-sketch_to_art_painting/eps5_bs64_lr0.01_class7_jigClass31_jigWeight0.7_TAll_bias0.9_113
New epoch - lr: 0.01


2023-12-01 14:41:53.003685: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


0/111 of epoch 1/5 jigsaw : 3.254, class : 1.974 - acc jigsaw : 10.94, class : 17.19 [bs:64]
30/111 of epoch 1/5 jigsaw : 0.440, class : 0.848 - acc jigsaw : 89.06, class : 67.19 [bs:64]
60/111 of epoch 1/5 jigsaw : 0.321, class : 0.378 - acc jigsaw : 92.19, class : 85.94 [bs:64]
90/111 of epoch 1/5 jigsaw : 0.427, class : 0.434 - acc jigsaw : 89.06, class : 85.94 [bs:64]
Accuracies on val: jigsaw : 100.00, class : 89.28
Accuracies on test: jigsaw : 100.00, class : 69.34
New epoch - lr: 0.01
0/111 of epoch 2/5 jigsaw : 0.140, class : 0.289 - acc jigsaw : 96.88, class : 92.19 [bs:64]
30/111 of epoch 2/5 jigsaw : 0.567, class : 0.213 - acc jigsaw : 84.38, class : 89.06 [bs:64]
60/111 of epoch 2/5 jigsaw : 0.226, class : 0.311 - acc jigsaw : 93.75, class : 92.19 [bs:64]
90/111 of epoch 2/5 jigsaw : 0.217, class : 0.329 - acc jigsaw : 95.31, class : 90.62 [bs:64]
Accuracies on val: jigsaw : 100.00, class : 90.67
Accuracies on test: jigsaw : 99.90, class : 72.61
New epoch - lr: 0.01
0/111 o