# Network per channel

As we dive more precisely into the topic, we now create a more robust dataset:

* each input is an image of size `(C, W, H)`. On each channel, there is a grayscale sliced image. So, 
  * the first `C/3` channels are sliced images along x axis
  * the following `C/3` channels are sliced images along x axis
  * the last `C/3` channels are sliced images along x axis
* each output is list of fabric descriptors

If we take `C=3`, we can use a pretrained VGG model. Indeed, this model is trained on RGB images, which have 3 channels.

# Importing the dataframe

Firstly, we initialize wandb. It is a tool that allows to store the losses and retrieve the deframe. Otherwise, you can directly access locally the dataframe on your computer.

In [None]:
!pip install wandb --upgrade

We import all the useful packages.

In [None]:
import sys
from pathlib import Path

IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules
if IS_KAGGLE:
    repo_path = Path("../input/microstructure-reconstruction")
elif IS_COLAB:
    from google.colab import drive

    drive.mount("/content/gdrive")
    repo_path = Path("/content/gdrive/MyDrive/microstructure-reconstruction")
else:
    repo_path = Path("/home/matias/microstructure-reconstruction")
sys.path.append(str(repo_path))

from copy import deepcopy
from importlib import reload

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import pandas as pd
import pytorch_lightning as pl
import torch
from typing import Union, List
import torch.nn as nn
import torch.optim as optim
import torchmetrics
import torchvision.models as pretrained_models
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from tqdm import tqdm

import wandb
from custom_datasets import dataset
from custom_models import models
from tools import dataframe_reformat, inspect_code, plotting, training, wandb_api

log_wandb = True

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 2, "pin_memory": True} if use_cuda else {"num_workers": 4}
print(f"[INFO]: Computation device: {device}")


We initialize a wandb run, that will save our metrics

In [None]:
if log_wandb:
    import wandb

    wandb_api.login()
    run = wandb.init(
        project="microstructure-reconstruction",
        group="NWidth Network",
        job_type="test",
    )


Parameters of our run:

In [None]:
if log_wandb:
    config = wandb.config
else:
    config = {}

config["job_type"] = run.job_type
config["train_val_split"] = 0.7
config["seed"] = 42
config["batch_size"] = 64
config["learning_rate"] = 0.0001
config["device"] = device
config["momentum"] = 0.9
config["architecture"] = "pretrained VGG"
config["input_width"] = 64
config["weight_decay"] = 0.00005
config["epochs"] = 0
config["frac_sample"] = 1
config["frac_noise"] = 0
config["nb_image_per_axis"] = 3
# config["total_layers"] = 24
# config["fixed_layers"] = 0
config["log_wandb"] = log_wandb
torch.manual_seed(config["seed"])
pl.seed_everything(config["seed"])

We retrieve the dataframe containing the descriptors. This can locally be done on your computer.

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        config,
        repo_path,
        train_df=None,
        test_df=None,
        train_dataset=None,
        validation_dataset=None,
    ):
        super().__init__()
        self.config = config
        self.repo_path = repo_path
        self.train_df = train_df.convert_dtypes() if train_df is not None else None
        self.test_df = test_df.convert_dtypes() if test_df is not None else None
        self.train_dataset = train_dataset
        self.validation_dataset = validation_dataset

        if self.config["log_wandb"]:
            if self.train_df is None:
                self.training_data_at = wandb.Api().artifact(
                    f"matiasetcheverry/microstructure-reconstruction/train_df:{self.config['nb_image_per_axis']}_images_invariants"
                )
            if self.test_df is None:
                self.test_data_at = wandb.Api().artifact(
                    f"matiasetcheverry/microstructure-reconstruction/test_df:{self.config['nb_image_per_axis']}_images_invariants"
                )

        self.transform = transforms.Compose(
            [
                transforms.GaussianBlur(kernel_size=3, sigma=0.5),
            ]
        )

    def prepare_data(self):
        if self.config["log_wandb"]:
            if self.train_df is None:
                self.training_data_at.download()
            if self.test_df is None:
                self.test_data_at.download()

    def _init_df_wandb(self):
        if self.train_df is None:
            self.train_df = wandb_api.convert_table_to_dataframe(
                self.training_data_at.get("fabrics")
            )
            self.train_df["photos"] = self.train_df["photos"].apply(
                func=lambda photo_paths: [
                    str(self.repo_path / Path(x)) for x in photo_paths
                ]
            )
        if self.test_df is None:
            self.test_df = wandb_api.convert_table_to_dataframe(
                self.test_data_at.get("fabrics")
            )
            self.test_df["photos"] = self.test_df["photos"].apply(
                func=lambda photo_paths: [
                    str(self.repo_path / Path(x)) for x in photo_paths
                ]
            )

    def _init_df_local(self):
        fabrics_df = pd.read_csv(self.repo_path / "REV1_600/fabrics.txt")
        path_to_slices = self.repo_path / "REV1_600/REV1_600Slices"
        fabrics_df["photos"] = fabrics_df["id"].apply(
            func=dataframe_reformat.associate_rev_id_to_its_images,
            args=(path_to_slices, self.config["nb_image_per_axis"]),
        )
        fabrics_df = fabrics_df[fabrics_df.photos.str.len().gt(0)]
        fabrics_df["photos"] = fabrics_df["photos"].apply(func=lambda x: sorted(x))
        train_df, test_df = train_test_split(
            fabrics_df,
            train_size=config["train_val_split"],
            random_state=config["seed"],
            shuffle=True,
        )
        if self.train_df is None:
            self.train_df = train_df.reset_index(drop=True)
        if self.test_df is None:
            self.test_df = test_df.reset_index(drop=True)

    def init_df(self):
        if self.config["log_wandb"]:
            self._init_df_wandb()
        else:
            self._init_df_local()
    
    def setup(self, stage):
        if self.train_dataset is None or self.validation_dataset is None:
            self.init_df()
            
            self.scaler = MinMaxScaler(feature_range=(0, 1))
            self.scaler.partial_fit(self.train_df.iloc[:, 1:-1])
            self.scaler.partial_fit(self.test_df.iloc[:, 1:-1])
            
            normalized_train_df = deepcopy(self.train_df)
            normalized_train_df.iloc[:, 1:-1] = self.scaler.transform(
                self.train_df.iloc[:, 1:-1]
            )
            normalized_test_df = deepcopy(self.test_df)
            normalized_test_df.iloc[:, 1:-1] = self.scaler.transform(
                self.test_df.iloc[:, 1:-1]
            )

            if self.train_dataset is None:
                self.train_dataset = dataset.NWidthStackedPhotosDataset(
                    normalized_train_df,
                    nb_image_per_axis=config["nb_image_per_axis"],
                    transform=self.transform,
                    noise=0,
                )
            if self.validation_dataset is None:
                self.validation_dataset = dataset.NWidthStackedPhotosDataset(
                    normalized_test_df,
                    nb_image_per_axis=config["nb_image_per_axis"],
                    transform=self.transform,
                    noise=0,
                )
            self.targets = self.test_df.iloc[:, 1:-1].to_numpy()

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config["batch_size"],
            shuffle=True,
            **kwargs,
        )

    def val_dataloader(self):
        return DataLoader(
            self.validation_dataset,
            batch_size=self.config["batch_size"],
            shuffle=False,
            **kwargs,
        )

    def test_dataloader(self):
        return self.val_dataloader()

    def predict_dataloader(self):
        return DataLoader(
            [image for image, _ in self.validation_dataset],
            batch_size=self.config["batch_size"],
            shuffle=False,
            **kwargs,
        )

dm = DataModule(config, repo_path)

dm.prepare_data()
dm.setup(stage="fit")
first_batch = next(iter(dm.train_dataloader()))
print("Nb of descriptors:", len(first_batch[1][0]))
print("Nb batch in dataset:", len(dm.train_dataloader()))
print("Size of a batch:", len(first_batch[0]))
print("Approx size of the dataset:", len(first_batch[0])*len(dm.train_dataloader()))
images, labels = first_batch[0], first_batch[1]
print("Image shape:", images[0].shape)
grid = utils.make_grid(images)
fig = plt.figure(figsize=(40, 10))
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.show()

# Network definition

The next step is to define our model. This model is inspired by VGG11:

* we define several convulational blocks.
* each of this block is sequence of:
  * convulational layer with `kernel_size=3, padding=1`
  * activation function, here it is the `ReLU`
  * max pooling layer with `kernel_size=2, stride=2` which aims at reducing the size of the convolutional layers

In [None]:
class PreTrainedVGG(models.BaseModel):
    def __init__(self, config, scaler=None):
        super().__init__(config)

        self.config = config
        self.config["model_type"] = type(self)
        self.scaler = scaler

        self.configure_model()
        self.configure_criterion()
        self.configure_metrics()

    def configure_model(self):
        assert self.config["total_layers"] >= self.config["fixed_layers"]
        vgg = pretrained_models.vgg16_bn(pretrained=True)
        self.layers = nn.Sequential(
            *(list(vgg.features.children())[: self.config["total_layers"]])
        )
        for idx, child in enumerate(self.layers.children()):
            if idx < self.config["fixed_layers"] and isinstance(child, nn.Conv2d):
                for param in child.parameters():
                    param.requires_grad = False
#             else:
#                 reset_parameters = getattr(child, "reset_parameters", None)
#                 if callable(reset_parameters):
#                     child.reset_parameters()
#         self.max_pool = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)
        nb_channels, height, width = (
            self.layers(
                torch.rand(
                    (1, 3, self.config["input_width"], config["nb_image_per_axis"]*self.config["input_width"])
                )
            )
            .squeeze()
            .shape
        )
        input_fc = int(height * width * nb_channels)
        # fully connected linear layers
        self.linear_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=input_fc, out_features=512),
            nn.ReLU(),
            nn.Dropout2d(0.5),
            nn.Linear(in_features=512, out_features=512),
            nn.ReLU(),
            nn.Dropout2d(0.5),
            nn.Linear(in_features=512, out_features=28),
        )

    def forward(self, x):
        x = self.layers(x)
#         x = self.max_pool(x)
        x = self.linear_layers(x)
        return x
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        metrics = {name: metric(y, y_hat) for name, metric in self.metrics.items()}
        self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True)
        return metrics

    def configure_metrics(self):
        self.metrics = {
            "val_loss": self.criterion.to(self.config["device"]),
            "mae": torchmetrics.MeanAbsoluteError().to(self.config["device"]),
            "mape": torchmetrics.MeanAbsolutePercentageError().to(
                self.config["device"]
            ),
            "smape": torchmetrics.SymmetricMeanAbsolutePercentageError().to(
                self.config["device"]
            ),
            "r2_score": torchmetrics.R2Score(num_outputs=28).to(self.config["device"]),
            "cosine_similarity": torchmetrics.CosineSimilarity(reduction="mean").to(
                self.config["device"]
            ),
        }
    
    

config["total_layers"] = 44
config["fixed_layers"] = 0
model = PreTrainedVGG(config)
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"[INFO]: {total_params:,} total parameters.")
model(torch.rand((1, 3, config["input_width"], config["nb_image_per_axis"]*config["input_width"])))


# Checkpoint

We had 2 checkpoints to our training:

* one for saving our model every time we have a minimum in the validation loss 
* one for saving the model's and data module script

In [None]:
model_checkpoint = pl.callbacks.model_checkpoint.ModelCheckpoint(
    dirpath=run.dir if "run" in locals() else "tmp/",
    filename="{epoch}-{val_loss:.3f}",
    monitor="val_loss",
    mode="min",
    verbose=True,
    save_last=True,
)

script_checkpoint = training.ScriptCheckpoint(
    dirpath=run.dir if "run" in locals() else "tmp/",
)

callbacks = [script_checkpoint]
log = None
if config["job_type"] == "train" or True:
    callbacks.append(model_checkpoint)
    print(f"[INFO]: saving models.")
else:
    print(f"[INFO]: not saving models.")
if config["job_type"] == "debug":
    log = "all"


# Training

We then train our model.

In [None]:
if config["log_wandb"]:
    wandb_logger = pl.loggers.WandbLogger()
    wandb_logger.watch(model, log=log, log_graph=True)
else:
    wandb_logger = None
trainer = pl.Trainer(
    max_epochs=700,
    max_time={"hours": 11},
    callbacks=callbacks,
    logger=wandb_logger,
    devices="auto",
    accelerator="auto",
#     limit_train_batches=0.3,
#     limit_val_batches=1,
#     log_every_n_steps=1,
)
trainer.fit(
    model,
    datamodule=dm,
)


In [None]:
dm.prepare_data()
dm.setup("validate")
predictions = torch.cat(trainer.predict(model, dataloaders=dm.predict_dataloader()))
targets = torch.FloatTensor(dm.targets)

In [None]:
save_output = training.SaveOutput()
handle = model.layers[3].register_forward_hook(save_output)
image = images[0]
model(image.unsqueeze(0))
handle.remove()
outputs = save_output.outputs[0].permute(1, 0, 2, 3).detach().cpu()[:30]
grid_img = utils.make_grid(outputs, normalize=True, pad_value=1, padding=1)
plt.figure(figsize=(30, 30))
plt.imshow(grid_img.permute(1, 2, 0))

In [None]:
fig = plotting.plot_kde(
    [dm.scaler.transform(targets.cpu().numpy()),
    predictions.cpu().numpy()],
    nb_hist_per_line=6,
    columns=dm.train_df.columns[1:-1],
)

In [None]:
run.finish()