# Intrabatch permutation DRIT on the MIDOG 2021

**References**

- [Dataset location](https://www.google.com/url?q=https://zenodo.org/record/4643381%23.Y06NQ-zMJs-&sa=D&source=docs&ust=1666096180027331&usg=AOvVaw3FCguR599li_jd67uQMDVt)
- [DomainShiftQuantification](https://github.com/DeepPathology/MIDOG/tree/main/DomainShiftQuantification) code repo.

In [1]:
from collections import defaultdict, OrderedDict

import logging
import os
import openslide
import urllib
import requests
import shutil
from IPython.core.debugger import set_trace
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from enum import Enum, auto

import itertools
import cv2
import numpy as np
import matplotlib.pyplot as plt
import h5py

import torch
import torch.nn as nn
import pytorch_lightning as pl
from contextlib import ExitStack

In [2]:
%reload_ext autoreload
%autoreload 2

In [23]:
import torch
from skimage.transform import rotate
from torch.utils.data import Dataset
from torchvision.transforms import Compose
from torchvision.transforms import Normalize
from torch.utils.data import Dataset

# 6. Training module definition

In [51]:
def train_model(num_input_channels: int, 
                experiment_name: str, 
                train_hyperparams: Dict[str, Any],
                train_sample_list_by_domain_index: Dict[int, List[Sample]],
                val_sample_list_by_domain_index: Dict[int, List[Sample]],
               ):
    """Trains the cycleGAN for the scanner transform.
    
    Args:
        num_input_channels: The number of the input channels.
        experiment_name: The name of the experiment to run.
        train_hyperparams: The hyperparameters for the training.
        train_sample_list_by_domain_index: A dictionary of sample list, keyed by the index of the domain. One list for 1 domain.
        val_sample_list_by_domain_index: A dictionary of sample list, keyed by the index of the domain. One inner list for 1 domain.
    """
    logging.info("Training model...")
    #tensorboard_logger = get_tensorboard_logger(experiment_name)
    #logging.info(f"Tensorboard save dir = {tensorboard_logger.save_dir}, log dir = {tensorboard_logger.log_dir}")
    
    

In [53]:


def _adjust_minibatch_size(org_batch_size: int) -> int:
    """Makes sure that each minibatch size has an even number of samples."""
    return (org_batch_size // 2) * 2

def _loss_string_from_weight_dict(loss_weights_by_name: Dict[str, float]) -> int:
    return "_".join(f"{k}_{v}" for k, v in loss_weights_by_name.items())

#pretrain_model_path = "s3://imaging-team/tnguyen/breast_cancer/trained_models/Stain_separation_19888_bs_22_samples_w_real_fake_weight_1.0_recon_weight_20.0_content_consistency_weight_3.0_attr_consistency_weight_1.0_mode_seeking_loss_weight_1.0_patch_size_512_v1/every_n/periodic_epoch=139_val_encoders_generators_total_loss=1.532_train_encoders_generators_total_loss=1.604_train_disc_total_loss=1.224.ckpt" 
pretrain_model_path = None
batch_size_per_gpu = _adjust_minibatch_size(org_batch_size=14)  # 28 is the largest number that does not cause OOM.

loss_weights_by_name = {
    'real_fake_weight': 1.0,
    'recon_weight': 20.0,
    'content_consistency_weight': 3.0, 
    'attr_consistency_weight': 1.0,
    'mode_seeking_loss_weight': 1.0, 
    'content_channel_covariance_loss_weight': 1,
}

train_hyperparams = {
    'weight_decay': 0.001,
    'gen_learning_rate': 1e-3,
    'disc_learning_rate': 5e-2,
    'num_epochs': 10000,
    'batch_size': batch_size_per_gpu * NUM_GPUS,  
    'num_dataloaders': 20,
    'loss_weights_by_name': loss_weights_by_name,
    'pretrained_model_path': pretrain_model_path,
    'number_gen_optimization_steps_to_update_disc': 1,
    'number_of_steps_to_update_lr': 1,
    'periodically_save_training_results': False,
}

num_train_samples = _adjust_num_samples_so_that_each_gpu_has_an_even_number_of_samples(56, NUM_GPUS, batch_size_per_gpu)
num_val_samples = _adjust_num_samples_so_that_each_gpu_has_an_even_number_of_samples(28, NUM_GPUS, batch_size_per_gpu)


INFO:root:experiment_name = Stain_separation_32_stain_vectors_56_bs_14_samples_w_real_fake_weight_1.0_recon_weight_20.0_content_consistency_weight_3.0_attr_consistency_weight_1.0_mode_seeking_loss_weight_1.0_content_channel_covariance_loss_weight_1_MIDOG_v1.0


In [56]:
train_model(
        num_input_channels=3, 
        experiment_name=experiment_name, 
        train_hyperparams=train_hyperparams,
        train_sample_list_by_domain_index=train_sample_list_by_domain_index,
        val_sample_list_by_domain_index=val_sample_list_by_domain_index,
        )

INFO:root:Training model...


KeyboardInterrupt: 