# VGG_Preprocessing
DESCRIPTION: This notebook includes data preprocessing scripts for VGG model training

@author: Jian Zhong

In [None]:
## include modules 
import os

import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision.transforms import v2

In [None]:
## load data set

## convert image into torch.tensor
data_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32,scale = True),
])

## NOTE: The dataset_root_dir would need to be changed according to the desired data location in your computer    
dataset_root_dir = r"E:\Python\DataSet\TorchDataSet\CIFAR10"

## create training data set
train_data = torchvision.datasets.CIFAR10(
    root = dataset_root_dir,
    train = True,
    download = True,
    transform=data_transform,
)

## create test data set
test_data = torchvision.datasets.CIFAR10(
    root = dataset_root_dir,
    train = False,
    download = True,
    transform=data_transform,
)

print(f"train_data length: {len(train_data)}")
print(f"test_data lenght: {len(test_data)}")

In [None]:
## create data loader

train_batch_size = 128
test_batch_size = 128

train_dataloader = torch.utils.data.DataLoader(train_data, 
                                               batch_size = train_batch_size, 
                                               shuffle = False)
test_dataloader = torch.utils.data.DataLoader(test_data, 
                                              batch_size = test_batch_size, 
                                              shuffle = False)

In [None]:
## calculated the averaged channel values across the entire data set

input_dataloader = train_dataloader
nof_batchs = len(input_dataloader)
avg_ch_vals = [None for _ in range(nof_batchs)]

for i_batch, data in enumerate(input_dataloader):
    inputs, labels = data
    cur_avg_ch = torch.mean(inputs, dim = (-1,-2), keepdim = True)
    avg_ch_vals[i_batch] = cur_avg_ch

avg_ch_vals = torch.cat(avg_ch_vals, dim = 0)
avg_ch_val = torch.mean(avg_ch_vals, dim = 0, keepdim = False)

print("result size = ")
print(avg_ch_val.size())
print("result val = ")
print(repr(avg_ch_val))

In [None]:
## SVD decomposition for covariance matrix of image channels across all the pixels 

input_dataloader = train_dataloader
nof_batchs = len(input_dataloader)
ch_vecs = [None for _ in range(nof_batchs)]

for i_batch, data in enumerate(input_dataloader):
    inputs, labels = data
    # swap channel and batch axis and flatten the dimension of (batch, image height, image width)
    ch_vecs[i_batch] = torch.flatten(torch.swapaxes(inputs, 0, 1), start_dim = 1, end_dim = -1)

ch_vecs = torch.cat(ch_vecs, dim = -1)
ch_cov = torch.cov(ch_vecs)
ch_vecs = None

U, S, Vh = torch.linalg.svd(ch_cov, full_matrices = True)

## Each column of U is a channel PCA eigenvector
## S contains the corresponding to eigenvectors

print("U:")
print(repr(U))
print("S:")
print(S)
print("Vh:")
print(Vh)