# Network per channel

As we dive more precisely into the topic, we now create a more robust dataset:

* each input is an image of size `(C, W, H)`. On each channel, there is a grayscale sliced image. So, 
  * the first `C/3` channels are sliced images along x axis
  * the following `C/3` channels are sliced images along x axis
  * the last `C/3` channels are sliced images along x axis
* each output is list of fabric descriptors

If we take `C=3`, we can use a pretrained VGG model. Indeed, this model is trained on RGB images, which have 3 channels.

# Importing the dataframe

Firstly, we initialize wandb. It is a tool that allows to store the losses and retrieve the deframe. Otherwise, you can directly access locally the dataframe on your computer.

In [None]:
!pip install wandb --upgrade

We import all the useful packages.

In [None]:
import sys
from pathlib import Path

IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules
if IS_KAGGLE:
    repo_path = Path("../input/microstructure-reconstruction")
elif IS_COLAB:
    from google.colab import drive

    drive.mount("/content/gdrive")
    repo_path = Path("/content/gdrive/MyDrive/microstructure-reconstruction")
else:
    repo_path = Path("/home/matias/microstructure-reconstruction")
sys.path.append(str(repo_path))

from copy import deepcopy
from importlib import reload

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import pandas as pd
import pytorch_lightning as pl
import torch
from typing import Union, List
import torch.nn as nn
import torch.optim as optim
import torchmetrics
import torchvision.models as pretrained_models
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from tqdm import tqdm

import wandb
from custom_datasets import dataset
from custom_models import models
from tools import dataframe_reformat, inspect_code, plotting, training, wandb_api

log_wandb = True

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 2, "pin_memory": True} if use_cuda else {"num_workers": 4}
print(f"[INFO]: Computation device: {device}")


We initialize a wandb run, that will save our metrics

In [None]:
if log_wandb:
    import wandb

    wandb_api.login()
    run = wandb.init(
        project="microstructure-reconstruction",
        group="Autoencoders",
        job_type="test",
    )


Parameters of our run:

In [None]:
if log_wandb:
    config = wandb.config
else:
    config = {}

config["job_type"] = run.job_type if "run" in locals() else "test"
config["train_val_split"] = 0.7
config["seed"] = 42
config["batch_size"] = 64
config["learning_rate"] = 0.001
config["device"] = device
config["architecture"] = "Autoencoder"
config["sparse_term"] = 0.1
config["weight_decay"] = 0.0001
config["input_width"] = 88
config["epochs"] = 0
config["nb_image_per_axis"] = 10
config["latent_size"] = 1
config["log_wandb"] = True
torch.manual_seed(config["seed"])
pl.seed_everything(config["seed"])


We retrieve the dataframe containing the descriptors. This can locally be done on your computer.

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        config,
        repo_path,
        train_df=None,
        test_df=None,
        train_dataset=None,
        validation_dataset=None,
    ):
        super().__init__()
        self.config = config
        self.repo_path = repo_path
        self.train_df = train_df.convert_dtypes() if train_df is not None else None
        self.test_df = test_df.convert_dtypes() if test_df is not None else None
        self.train_dataset = train_dataset
        self.validation_dataset = validation_dataset

        if self.config["log_wandb"]:
            if self.train_df is None:
                self.training_data_at = wandb.Api().artifact(
                    f"matiasetcheverry/microstructure-reconstruction/train_df:{self.config['nb_image_per_axis']}_images_invariants"
                )
            if self.test_df is None:
                self.test_data_at = wandb.Api().artifact(
                    f"matiasetcheverry/microstructure-reconstruction/test_df:{self.config['nb_image_per_axis']}_images_invariants"
                )

        self.transform = transforms.Compose(
            [
                transforms.GaussianBlur(kernel_size=3, sigma=0.5),
            ]
        )

    def prepare_data(self):
        if self.config["log_wandb"]:
            if self.train_df is None:
                self.training_data_at.download()
            if self.test_df is None:
                self.test_data_at.download()

    def _init_df_wandb(self):
        if self.train_df is None:
            self.train_df = wandb_api.convert_table_to_dataframe(
                self.training_data_at.get("fabrics")
            )
            self.train_df["photos"] = self.train_df["photos"].apply(
                func=lambda photo_paths: [
                    str(self.repo_path / Path(x)) for x in photo_paths
                ]
            )
        if self.test_df is None:
            self.test_df = wandb_api.convert_table_to_dataframe(
                self.test_data_at.get("fabrics")
            )
            self.test_df["photos"] = self.test_df["photos"].apply(
                func=lambda photo_paths: [
                    str(self.repo_path / Path(x)) for x in photo_paths
                ]
            )

    def _init_df_local(self):
        fabrics_df = pd.read_csv(self.repo_path / "REV1_600/fabrics.txt")
        path_to_slices = self.repo_path / "REV1_600/REV1_600Slices"
        fabrics_df["photos"] = fabrics_df["id"].apply(
            func=dataframe_reformat.associate_rev_id_to_its_images,
            args=(path_to_slices, self.config["nb_image_per_axis"]),
        )
        fabrics_df = fabrics_df[fabrics_df.photos.str.len().gt(0)]
        fabrics_df["photos"] = fabrics_df["photos"].apply(func=lambda x: sorted(x))
        train_df, test_df = train_test_split(
            fabrics_df,
            train_size=config["train_val_split"],
            random_state=config["seed"],
            shuffle=True,
        )
        if self.train_df is None:
            self.train_df = train_df.reset_index(drop=True)
        if self.test_df is None:
            self.test_df = test_df.reset_index(drop=True)

    def init_df(self):
        if self.config["log_wandb"]:
            self._init_df_wandb()
        else:
            self._init_df_local()

    def setup(self, stage):
        if self.train_dataset is None or self.validation_dataset is None:
            self.init_df()

            self.scaler = MinMaxScaler(feature_range=(0, 1))
            self.scaler.partial_fit(self.train_df.iloc[:, 1:-1])
            self.scaler.partial_fit(self.test_df.iloc[:, 1:-1])

            normalized_train_df = deepcopy(self.train_df)
            normalized_train_df.iloc[:, 1:-1] = self.scaler.transform(
                self.train_df.iloc[:, 1:-1]
            )
            normalized_test_df = deepcopy(self.test_df)
            normalized_test_df.iloc[:, 1:-1] = self.scaler.transform(
                self.test_df.iloc[:, 1:-1]
            )

            if self.train_dataset is None:
                self.train_dataset = dataset.SinglePhotoDataset(
                    normalized_train_df,
                    width=self.config["input_width"],
                    # nb_image_per_axis=config["nb_image_per_axis"],
                    transform=self.transform,
                    noise=0,
                )
            if self.validation_dataset is None:
                self.validation_dataset = dataset.SinglePhotoDataset(
                    normalized_test_df,
                    width=self.config["input_width"],
                    # nb_image_per_axis=config["nb_image_per_axis"],
                    transform=self.transform,
                    noise=0,
                )
            self.targets = self.test_df.iloc[:, 1:-1].to_numpy()

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config["batch_size"],
            shuffle=True,
            **kwargs,
        )

    def val_dataloader(self):
        return DataLoader(
            self.validation_dataset,
            batch_size=self.config["batch_size"],
            shuffle=False,
            **kwargs,
        )

    def test_dataloader(self):
        return self.val_dataloader()

    def predict_dataloader(self):
        return DataLoader(
            [image for image, _ in self.validation_dataset],
            batch_size=self.config["batch_size"],
            shuffle=False,
            **kwargs,
        )


dm = DataModule(config, repo_path)
dm.prepare_data()
dm.setup(stage="fit")
first_batch = next(iter(dm.val_dataloader()))
print("Nb of descriptors:", len(first_batch[1][0]))
print("Nb batch in dataset:", len(dm.train_dataloader()))
print("Size of a batch:", len(first_batch[0]))
images, labels = first_batch[0], first_batch[1]
print("Image shape:", images[0].shape)
grid = utils.make_grid(images)
fig = plt.figure(figsize=(90, 30))
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.show()


# Network definition

The next step is to define our model. This model is inspired by VGG11:

* we define several convulational blocks.
* each of this block is sequence of:
  * convulational layer with `kernel_size=3, padding=1`
  * activation function, here it is the `ReLU`
  * max pooling layer with `kernel_size=2, stride=2` which aims at reducing the size of the convolutional layers

In [None]:
class Encoder(nn.Module):
    def __init__(
        self,
        config,
    ):
        super(Encoder, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=5, stride=2, padding=0),
            nn.MaxPool2d(2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=0),
            nn.MaxPool2d(2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=0),
            nn.MaxPool2d(2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(32),
            nn.Flatten(),
#             nn.Dropout(p=0.3),
        )
        _, length = self.model(
            torch.rand(
                (
                    config["batch_size"],
                    1,
                    config["input_width"],
                    config["input_width"],
                )
            )
        ).shape
        self.output = nn.Sequential(
            nn.Linear(
                length,
                config["latent_size"],
            ),
#             nn.Tanh(),
        )

    def forward(self, img):
        x = self.model(img)
        return self.output(x)


descriptors = first_batch[1]
images = first_batch[0]
d = Encoder(config)
print(d(images).shape)
total_params = sum(p.numel() for p in d.parameters())
print(f"[INFO]: {total_params:,} total parameters.")
# g = Generator(config)
# print(g(descriptors).shape)
# total_params = sum(p.numel() for p in g.parameters())
# print(f"[INFO]: {total_params:,} total parameters.")


In [None]:
class Decoder(nn.Module):
    def __init__(self, config, latent_dim=28):
        super(Decoder, self).__init__()
        self.linear = nn.Linear(config["latent_size"], 128)
        
        self.model = nn.Sequential(
            nn.ConvTranspose2d(
                32,
                32,
                kernel_size=3,
                stride=1,
                padding=0,
            ),
            nn.Upsample(scale_factor=(2, 2)),
            nn.LeakyReLU(),
            nn.BatchNorm2d(32),
            
            nn.ConvTranspose2d(
                32,
                64,
                kernel_size=3,
                stride=1,
                padding=0,
            ),
            nn.Upsample(scale_factor=(2, 2)),
            nn.LeakyReLU(),
            nn.BatchNorm2d(64),
            
            nn.ConvTranspose2d(
                64,
                128,
                kernel_size=5,
                stride=2,
                padding=0,
            ),
            nn.Upsample(scale_factor=(2, 2)),
            nn.LeakyReLU(),
            nn.BatchNorm2d(128),
            
            nn.ConvTranspose2d(
                128,
                1,
                kernel_size=3,
                stride=1,
                padding=0,
            ),
            
            nn.Sigmoid(),
        )

    def forward(self, z):
        z = self.linear(z)
        z = torch.reshape(z, (-1, 32, 2, 2))
        img = self.model(z)
        return img


descriptors = first_batch[1]
images = first_batch[0]
d = Decoder(config)
print(d(torch.rand((config["batch_size"], config["latent_size"]))).shape)
total_params = sum(p.numel() for p in d.parameters())
print(f"[INFO]: {total_params:,} total parameters.")


In [None]:
class Autoencoder(pl.LightningModule):
    def __init__(
        self,
        config,
    ):
        super().__init__()
        self.config = config

        self.encoder = Encoder(config)
        self.decoder = Decoder(config)

    #         self.configure_criterion()
    #         self.configure_metrics()

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.config["learning_rate"],
            weight_decay=self.config["weight_decay"],
        )
        return optimizer

    def compute_metrics(self, img):
        encoding = self.encoder(img)
        fake_img = self.decoder(encoding)
        mae_encoding = torch.mean(torch.abs(encoding))
        metrics = {
            "loss": nn.MSELoss(reduction="mean")(
                fake_img, img
            ), #+ self.config["sparse_term"]*mae_encoding,
            "mse": nn.MSELoss(reduction="mean")(fake_img, img),
            "mae_encoding": mae_encoding,
            "bce": nn.BCELoss(reduction="mean")(fake_img, img),
        }
        return metrics

    def training_step(self, batch, batch_idx):
        img, _ = batch
        metrics = {
            "train_" + metric_name: metric_value
            for metric_name, metric_value in self.compute_metrics(img).items()
        }
        self.log_dict(
            metrics,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )
        return metrics["train_loss"]

    def validation_step(self, batch, batch_idx):
        img, _ = batch
        metrics = {
            "val_" + metric_name: metric_value
            for metric_name, metric_value in self.compute_metrics(img).items()
        }
        self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True)
        return metrics

    def training_epoch_end(self, outputs):
        self.config["epochs"] += 1


model = Autoencoder(config)
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"[INFO]: {total_params:,} total parameters.")
model.training_step(first_batch, 1)


# Checkpoint

We had 2 checkpoints to our training:

* one for saving our model every time we have a minimum in the validation loss 
* one for saving the model's and data module script

In [None]:
model_checkpoint = pl.callbacks.model_checkpoint.ModelCheckpoint(
    dirpath=run.dir if "run" in locals() else "tmp/",
    filename="{epoch}-{val_bce:.3f}",
    monitor="val_bce",
    mode="min",
    verbose=True,
    save_last=True,
)


script_checkpoint = training.ScriptCheckpoint(
    dirpath=run.dir if "run" in locals() else "tmp/",
)
images_callback = training.AutoencoderGeneratedImagesCallback(
    images=first_batch[0].to(device), log_every_n_epochs=1
)
callbacks = [script_checkpoint, images_callback]
log = None
if config["job_type"] == "train" or True:
    callbacks.append(model_checkpoint)
    print(f"[INFO]: saving models.")
else:
    print(f"[INFO]: not saving models.")
if config["job_type"] == "debug":
    log = "all"


# Training

We then train our model.

In [None]:
if config["log_wandb"]:
    wandb_logger = pl.loggers.WandbLogger()
    wandb_logger.watch(model, log=log, log_graph=True)
else:
    wandb_logger = None
trainer = pl.Trainer(
    max_epochs=400,
    callbacks=callbacks,
    logger=wandb_logger,
    devices="auto",
    accelerator="auto",
    #     limit_train_batches=0.3,
    #     limit_val_batches=1,
    #     log_every_n_steps=1,
)
trainer.fit(
    model,
    datamodule=dm,
)


In [None]:
run.finish()

In [None]:
targets = next(iter(dm.train_dataloader()))[0] #first_batch[0]
# outputs = save_output.outputs[0].permute(1, 0, 2, 3).detach().cpu()[:30]
grid_img = utils.make_grid(targets, normalize=False, pad_value=1, padding=1)
plt.figure(figsize=(30, 30))
plt.imshow(grid_img.cpu().numpy().transpose(1,2,0))
plt.title("Targets images from first validation batch", fontdict={'fontsize': 70})
plt.tight_layout()

In [None]:
fake_img = model(targets.to(device))
outputs = fake_img #(-fake_img +1 > 0.6).float()*1
print(outputs.shape)
# outputs = save_output.outputs[0].permute(1, 0, 2, 3).detach().cpu()[:30]
grid_img = utils.make_grid(outputs, normalize=False, pad_value=1, padding=1)
plt.figure(figsize=(30, 30))
plt.imshow(grid_img.cpu().numpy().transpose(1,2,0))
plt.title("Reconstructed images from first validation batch", fontdict={'fontsize': 70})
plt.tight_layout()