# Network per channel

As we dive more precisely into the topic, we now create a more robust dataset:

* each input is an image of size `(C, W, H)`. On each channel, there is a grayscale sliced image. So, 
  * the first `C/3` channels are sliced images along x axis
  * the following `C/3` channels are sliced images along x axis
  * the last `C/3` channels are sliced images along x axis
* each output is list of fabric descriptors

If we take `C=3`, we can use a pretrained VGG model. Indeed, this model is trained on RGB images, which have 3 channels.

# Importing the dataframe

Firstly, we initialize wandb. It is a tool that allows to store the losses and retrieve the deframe. Otherwise, you can directly access locally the dataframe on your computer.

In [None]:
!pip install wandb --upgrade

We import all the useful packages.

In [None]:
import sys
from pathlib import Path

IS_COLAB = "google.colab" in sys.modules
IS_KAGGLE = "kaggle_secrets" in sys.modules
if IS_KAGGLE:
    repo_path = Path("../input/microstructure-reconstruction")
elif IS_COLAB:
    from google.colab import drive

    drive.mount("/content/gdrive")
    repo_path = Path("/content/gdrive/MyDrive/microstructure-reconstruction")
else:
    repo_path = Path("/home/matias/microstructure-reconstruction")
sys.path.append(str(repo_path))

from copy import deepcopy
from importlib import reload

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import pandas as pd
import pytorch_lightning as pl
import torch
from typing import Union, List
import torch.nn as nn
import torch.optim as optim
import torchmetrics
import torchvision.models as pretrained_models
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from tqdm import tqdm

import wandb
from custom_datasets import dataset
from custom_models import models
from tools import dataframe_reformat, inspect_code, plotting, training, wandb_api

log_wandb = True

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 2, "pin_memory": True} if use_cuda else {"num_workers": 4}
print(f"[INFO]: Computation device: {device}")


We initialize a wandb run, that will save our metrics

In [None]:
if log_wandb:
    import wandb

    wandb_api.login()
    run = wandb.init(
        project="microstructure-reconstruction",
        group="Gans",
        job_type="test",
    )


Parameters of our run:

In [None]:
if log_wandb:
    config = wandb.config
else:
    config = wandb_api.Config()

config["job_type"] = run.job_type
config["train_val_split"] = 0.7
config["seed"] = 42
config["batch_size"] = 32
config["learning_rate_generator"] = 0.00001
config["learning_rate_discriminator"] = 0.0001
config["device"] = device
config["architecture"] = "GANS"
config["input_width"] = 64
config["epochs"] = 0
config["nb_image_per_axis"] = 1
config["log_wandb"] = True
config["beta1"] = 0.5
config["beta2"] = 0.9
config["n_critic"] = 1
torch.manual_seed(config["seed"])
pl.seed_everything(config["seed"])


We retrieve the dataframe containing the descriptors. This can locally be done on your computer.

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        config,
        repo_path,
        train_df=None,
        test_df=None,
        train_dataset=None,
        validation_dataset=None,
    ):
        super().__init__()
        self.config = config
        self.repo_path = repo_path
        self.train_df = train_df.convert_dtypes() if train_df is not None else None
        self.test_df = test_df.convert_dtypes() if test_df is not None else None
        self.train_dataset = train_dataset
        self.validation_dataset = validation_dataset

        if self.config["log_wandb"]:
            if self.train_df is None:
                self.training_data_at = wandb.Api().artifact(
                    f"matiasetcheverry/microstructure-reconstruction/train_df:{self.config['nb_image_per_axis']}_images_invariants"
                )
            if self.test_df is None:
                self.test_data_at = wandb.Api().artifact(
                    f"matiasetcheverry/microstructure-reconstruction/test_df:{self.config['nb_image_per_axis']}_images_invariants"
                )

        self.transform = transforms.Compose(
            [
                transforms.GaussianBlur(kernel_size=3, sigma=0.5),
            ]
        )

    def prepare_data(self):
        if self.config["log_wandb"]:
            if self.train_df is None:
                self.training_data_at.download()
            if self.test_df is None:
                self.test_data_at.download()

    def _init_df_wandb(self):
        if self.train_df is None:
            self.train_df = wandb_api.convert_table_to_dataframe(
                self.training_data_at.get("fabrics")
            )
            self.train_df["photos"] = self.train_df["photos"].apply(
                func=lambda photo_paths: [
                    str(self.repo_path / Path(x)) for x in photo_paths
                ]
            )
        if self.test_df is None:
            self.test_df = wandb_api.convert_table_to_dataframe(
                self.test_data_at.get("fabrics")
            )
            self.test_df["photos"] = self.test_df["photos"].apply(
                func=lambda photo_paths: [
                    str(self.repo_path / Path(x)) for x in photo_paths
                ]
            )

    def _init_df_local(self):
        fabrics_df = pd.read_csv(self.repo_path / "REV1_600/fabrics.txt")
        path_to_slices = self.repo_path / "REV1_600/REV1_600Slices"
        fabrics_df["photos"] = fabrics_df["id"].apply(
            func=dataframe_reformat.associate_rev_id_to_its_images,
            args=(path_to_slices, self.config["nb_image_per_axis"]),
        )
        fabrics_df = fabrics_df[fabrics_df.photos.str.len().gt(0)]
        fabrics_df["photos"] = fabrics_df["photos"].apply(func=lambda x: sorted(x))
        train_df, test_df = train_test_split(
            fabrics_df,
            train_size=config["train_val_split"],
            random_state=config["seed"],
            shuffle=True,
        )
        if self.train_df is None:
            self.train_df = train_df.reset_index(drop=True)
        if self.test_df is None:
            self.test_df = test_df.reset_index(drop=True)

    def init_df(self):
        if self.config["log_wandb"]:
            self._init_df_wandb()
        else:
            self._init_df_local()

    def setup(self, stage):
        if self.train_dataset is None or self.validation_dataset is None:
            self.init_df()

            self.scaler = MinMaxScaler(feature_range=(0, 1))
            self.scaler.partial_fit(self.train_df.iloc[:, 1:-1])
            self.scaler.partial_fit(self.test_df.iloc[:, 1:-1])

            normalized_train_df = deepcopy(self.train_df)
            normalized_train_df.iloc[:, 1:-1] = self.scaler.transform(
                self.train_df.iloc[:, 1:-1]
            )
            normalized_test_df = deepcopy(self.test_df)
            normalized_test_df.iloc[:, 1:-1] = self.scaler.transform(
                self.test_df.iloc[:, 1:-1]
            )

            if self.train_dataset is None:
                self.train_dataset = dataset.NWidthStackedPhotosDataset(
                    normalized_train_df,
                    width=self.config["input_width"],
                    nb_image_per_axis=config["nb_image_per_axis"],
                    transform=self.transform,
                    noise=0,
                )
            if self.validation_dataset is None:
                self.validation_dataset = dataset.NWidthStackedPhotosDataset(
                    normalized_test_df,
                    width=self.config["input_width"],
                    nb_image_per_axis=config["nb_image_per_axis"],
                    transform=self.transform,
                    noise=0,
                )
            self.targets = self.test_df.iloc[:, 1:-1].to_numpy()

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.config["batch_size"],
            shuffle=True,
            **kwargs,
        )

    def val_dataloader(self):
        return DataLoader(
            self.validation_dataset,
            batch_size=self.config["batch_size"],
            shuffle=False,
            **kwargs,
        )

    def test_dataloader(self):
        return self.val_dataloader()

    def predict_dataloader(self):
        return DataLoader(
            [image for image, _ in self.validation_dataset],
            batch_size=self.config["batch_size"],
            shuffle=False,
            **kwargs,
        )


dm = DataModule(config, repo_path)
dm.prepare_data()
dm.setup(stage="fit")
first_batch = next(iter(dm.train_dataloader()))
print("Nb of descriptors:", len(first_batch[1][0]))
print("Nb batch in dataset:", len(dm.train_dataloader()))
print("Size of a batch:", len(first_batch[0]))
images, labels = first_batch[0], first_batch[1]
print("Image shape:", images[0].shape)
grid = utils.make_grid(images)
fig = plt.figure(figsize=(90, 30))
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.show()


In [None]:
image = wandb.Image(
    first_batch[0],
    caption=f"First batch target",
)
wandb.log({"generated_images": image})

# Network definition

The next step is to define our model. This model is inspired by VGG11:

* we define several convulational blocks.
* each of this block is sequence of:
  * convulational layer with `kernel_size=3, padding=1`
  * activation function, here it is the `ReLU`
  * max pooling layer with `kernel_size=2, stride=2` which aims at reducing the size of the convolutional layers

In [None]:
class Generator(nn.Module):
    def __init__(self, config, latent_dim=28):
        super(Generator, self).__init__()

        kernel_sizes = [3, 3, 3, 3, 4]
        strides = [1, 2, 2, 2, 2]
        paddings = [0, 0, 0, 0, 0]
        filters = [400, 512, 256, 64, 32, 3]

        self.linear_layer = nn.Sequential(
            nn.Linear(latent_dim, filters[0]),
            nn.ReLU(),
        )
        self.model = nn.Sequential(
            nn.ConvTranspose2d(
                filters[0],
                filters[0],
                stride=2,
                kernel_size=3,
                padding=paddings[0],
            ),
            nn.Conv2d(
                filters[0],
                filters[1],
                kernel_size=3, stride=1, padding=1,
            ),
            nn.BatchNorm2d(filters[1]),
            nn.ReLU(),
            
            nn.ConvTranspose2d(
                filters[1],
                filters[1],
                stride=2,
                kernel_size=3,
                padding=paddings[0],
            ),
            nn.Conv2d(
                filters[1],
                filters[2],
                kernel_size=3, stride=1, padding=1,
            ),
            nn.BatchNorm2d(filters[2]),
            nn.ReLU(),
            
            nn.ConvTranspose2d(
                filters[2],
                filters[2],
                stride=2,
                kernel_size=3,
                padding=paddings[0],
            ),
            nn.Conv2d(
                filters[2],
                filters[3],
                kernel_size=3, stride=1, padding=1,
            ),
            nn.BatchNorm2d(filters[3]),
            nn.ReLU(),
            
            nn.ConvTranspose2d(
                filters[3],
                filters[3],
                stride=2,
                kernel_size=4,
                padding=paddings[0],
            ),
            nn.Conv2d(
                filters[3],
                filters[4],
                kernel_size=3, stride=1, padding=1,
            ),
            nn.BatchNorm2d(filters[4]),
            
            nn.ConvTranspose2d(
                filters[4],
                filters[4],
                stride=2,
                kernel_size=4,
                padding=1,
            ),
            nn.Conv2d(
                filters[4],
                filters[5],
                kernel_size=3, stride=1, padding=1,
            ),
            nn.BatchNorm2d(filters[5]),
            nn.ReLU(),
            nn.Tanh(),
        )

    def forward(self, z):
        z = self.linear_layer(z)
        z = torch.unsqueeze(torch.unsqueeze(z, -1), -1)
        img = self.model(z)
        return img

class Discriminator(nn.Module):
    def __init__(
        self,
        config,
    ):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(
                3 * config["nb_image_per_axis"], 512, kernel_size=3, stride=1, padding=1
            ),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.Conv2d(512, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256, 126, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(126),
            nn.LeakyReLU(),
            nn.Conv2d(126, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Flatten(),
        )
        _, length = self.model(
            torch.rand(
                (
                    config["batch_size"],
                    3 * config["nb_image_per_axis"],
                    config["input_width"],
                    config["input_width"],
                )
            )
        ).shape
        self.output = nn.Sequential(
            nn.Linear(
                length,
                1,
            ),
            nn.Sigmoid()
        )

    def forward(self, img):
        x = self.model(img)
        return self.output(x)


class WGANGP(pl.LightningModule):
    def __init__(
        self,
        config,
    ):
        super().__init__()
        self.config = config

        self.generator = Generator(config, latent_dim=28)
        self.discriminator = Discriminator(config)

    def forward(self, z):
        return self.generator(z)

    def compute_gradient_penalty(self, real_samples, fake_samples):
        """Calculates the gradient penalty loss for WGAN GP"""
        # Random weight term for interpolation between real and fake samples
        alpha = torch.Tensor(np.random.random((real_samples.size(0), 1, 1, 1))).to(
            self.device
        )
        # Get random interpolation between real and fake samples
        interpolates = (
            alpha * real_samples + ((1 - alpha) * fake_samples)
        ).requires_grad_(True)
        interpolates = interpolates.to(self.device)
        d_interpolates = self.discriminator(interpolates)
        fake = torch.Tensor(real_samples.shape[0], 1).fill_(1.0).to(self.device)
        # Get gradient w.r.t. interpolates
        gradients = torch.autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(gradients.size(0), -1).to(self.device)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty
    
    def adversarial_loss(self, y_hat, y):
        return nn.BCELoss()(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, descriptors = batch

        lambda_gp = 10

        # train generator
        if optimizer_idx == 0:
            fake_imgs = self(descriptors)
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            g_loss = -torch.mean(self.discriminator(fake_imgs))
#             print(self.discriminator(descriptors).shape, valid.shape)
#             g_loss = self.adversarial_loss(self.discriminator(self(descriptors)), valid)
            metrics = {"g_loss": g_loss}
            self.log_dict(
                metrics,
                on_step=False,
                on_epoch=True,
            )
            return {"loss": g_loss}

        # train discriminator
        elif optimizer_idx == 1:
            fake_imgs = self(descriptors)
            real_validity = self.discriminator(imgs)
            fake_validity = self.discriminator(fake_imgs)
            
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)
            real_loss = self.adversarial_loss(real_validity, valid)
            
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs)
            fake_loss = self.adversarial_loss(fake_validity, fake)
            
            d_loss = (real_loss + fake_loss) / 2
#             gradient_penalty = self.compute_gradient_penalty(imgs.data, fake_imgs.data)
#             d_loss = (
#                 -torch.mean(real_validity)
#                 + torch.mean(fake_validity)
#                 + lambda_gp * gradient_penalty
#             )
            metrics = {
                "real_discriminator": real_loss,
                "fake_discriminator": fake_loss,
#                 "penalty_discriminator": lambda_gp * gradient_penalty,
                "d_loss": d_loss,
            }
            self.log_dict(
                metrics,
                on_step=False,
                on_epoch=True,
            )
            return {"loss": d_loss}

    def configure_optimizers(self):
        opt_g = torch.optim.Adam(
            self.generator.parameters(),
            lr=self.config["learning_rate_generator"],
            betas=(self.config["beta1"], self.config["beta2"]),
        )
        opt_d = torch.optim.Adam(
            self.discriminator.parameters(),
            lr=self.config["learning_rate_discriminator"],
            betas=(self.config["beta1"], self.config["beta2"]),
        )
        return (
            {"optimizer": opt_g, "frequency": 1},
            {"optimizer": opt_d, "frequency": self.config["n_critic"]},
        )

model = WGANGP(config)
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f"[INFO]: {total_params:,} total parameters.")
model(torch.rand((1, 28))).shape


# Checkpoint

We had 2 checkpoints to our training:

* one for saving our model every time we have a minimum in the validation loss 
* one for saving the model's and data module script

In [None]:
# model_checkpoint = pl.callbacks.model_checkpoint.ModelCheckpoint(
#     dirpath=run.dir if "run" in locals() else "tmp/",
#     filename="{epoch}-{val_loss:.3f}",
#     monitor="val_loss",
#     mode="min",
#     verbose=True,
#     save_last=True,
# )

script_checkpoint = training.ScriptCheckpoint(
    dirpath=run.dir if "run" in locals() else "tmp/",
)
images_callback = training.GeneratedImagesCallback(
    descriptors=first_batch[1].to(device), log_every_n_epochs=10
)
callbacks = [script_checkpoint, images_callback]
log = None
if config["job_type"] == "train":
    callbacks.append(model_checkpoint)
    print(f"[INFO]: saving models.")
else:
    print(f"[INFO]: not saving models.")
if config["job_type"] == "debug":
    log = "all"



# Training

We then train our model.

In [None]:
if config["log_wandb"]:
    wandb_logger = pl.loggers.WandbLogger()
    wandb_logger.watch(model, log=log, log_graph=True)
else:
    wandb_logger = None
trainer = pl.Trainer(
    max_epochs=4000,
    callbacks=callbacks,
    logger=wandb_logger,
    devices="auto",
    accelerator="auto",
    #     limit_train_batches=0.3,
    #     limit_val_batches=1,
    #     log_every_n_steps=1,
)
trainer.fit(
    model,
    datamodule=dm,
)


In [None]:
outputs = model(dm.validation_dataset[:60][1].to(device))
print(outputs.shape)
grid_img = utils.make_grid(outputs, normalize=True, pad_value=1, padding=1)
plt.figure(figsize=(30, 30))
plt.imshow(grid_img.permute(1, 2, 0).cpu())

In [None]:
run.finish()

In [None]:
outputs = torch.permute(model(torch.unsqueeze(dm.validation_dataset[0][1], 0)), (0, 2, 3, 1))
plt.figure(figsize=(10, 10))
plt.imshow(np.squeeze(outputs.detach().numpy()))