# Import modules

In [5]:
# Data Handlers
import pandas as pd
import numpy as np
from PIL import Image
from PIL import ImageOps

# Pytorch
import torch
import torch.nn as nn  # NN; networks (CNN, RNN, losses)
import torch.optim as optim  # Optimizers (Adam, Adadelta, Adagrad)
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset  # Dataset manager

# Other
from tqdm import tqdm
import os
from pathlib import Path

# Graphics
from matplotlib import pyplot as plt
import seaborn as sns

# Additional modules
from dataset_creator import generate_csv
from assistive_funcs import filtering_image, check_ssim, convert_to_grayscale, get_dataset_name
from csv_dataloader import get_train_test_data
from math import floor


sns.set()

# Define constants

In [6]:
# Paths
main_data_path = Path("../data")
models_path = Path("../models")

scv_folder = main_data_path / "csv_files" # datasets_path
img_path = main_data_path / "images"

# Generate Dataset

In [None]:
create_dataset = 0
if create_dataset:
    generate_csv(win_size=9, dump_to_file=50000, step=1)

# NN Model

In [7]:
class DefaultModel(nn.Module):
    def __init__(self, in_len, out_len) -> None:
        super().__init__()
        self.in_len = in_len
        self.out_len = out_len
        self.hid_n = 400
        
        self.fcs = nn.Sequential(
            nn.Linear(self.in_len, self.hid_n),
            nn.BatchNorm1d(self.hid_n),
            nn.ReLU(),
            nn.Linear(self.hid_n, self.hid_n * 2),
            # nn.BatchNorm1d(self.hid_n * 2),
            nn.ReLU(),
            nn.Linear(self.hid_n * 2, self.hid_n * 3),
            nn.BatchNorm1d(self.hid_n * 3),
            nn.ReLU(),
            
            nn.Linear(self.hid_n * 3, self.hid_n * 3),
            # nn.BatchNorm1d(self.hid_n * 3),
            nn.ReLU(),
            nn.Linear(self.hid_n * 3, self.hid_n * 3),
            nn.BatchNorm1d(self.hid_n * 3),
            nn.ReLU(),
            
            nn.Linear(self.hid_n * 3, self.hid_n * 2),
            # nn.BatchNorm1d(self.hid_n * 2),
            nn.ReLU(),
            nn.Linear(self.hid_n * 2, self.hid_n),
            nn.BatchNorm1d(self.hid_n),
            nn.ReLU(),
            nn.Linear(self.hid_n, self.out_len),
        )

    def forward(self, x):
        x = self.fcs(x)
        return x

## Define NN's constants

In [8]:
learning_rate = 0.001
num_epoches = 2
batch_size = 256

win_size = 9
_step = 1
dataset_name = get_dataset_name(win_size, _step, scv_folder) #r"W5_S1_L3696640.csv"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{dataset_name = }")
print(f"{device = }")

dataset_name = 'W9_S1_L3696640.csv'
device = device(type='cuda')


## Initialize model

In [None]:
model = DefaultModel(in_len=(win_size ** 2), out_len=1).to(device=device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
losses = []
losses_append = losses.append

valid_losses = []
valid_losses_append = valid_losses.append

for epoch in range(num_epoches):
    model.train()
    train_loader, test_loader = get_train_test_data(scv_folder=scv_folder, dataset_name=dataset_name, batch_size=batch_size, train_size=0.9)
    for batch_ind, (data, targets) in tqdm(enumerate(train_loader)):
        # Data on cuda
        data = data.to(device=device)
        targets = targets.to(device=device)
        
        # Forward
        scores = model(data) # Equal to model.forward(data)
        loss = criterion(scores, targets)
        if batch_ind % 3 == 0:
            losses_append(loss.item())
        # Backprop
        loss.backward()

        # Gradient descent or adam step
        optimizer.step()
        optimizer.zero_grad()
    model.eval()
    with torch.no_grad():
        for batch_ind, (data, targets) in tqdm(enumerate(test_loader)):
            data = data.to(device=device)
            targets = targets.to(device=device)
            
            # Forward
            scores = model(data) # Equal to model.forward(data)
            loss = criterion(scores, targets)
            valid_losses_append(loss.item())



In [None]:
sns.set(rc={"figure.figsize": (7, 6)})

In [None]:
plt.plot(losses);

In [None]:
plt.plot(valid_losses);

In [None]:
path_to_model = models_path / dataset_name
torch.save(model, path_to_model)

# Check NN works

In [None]:
path_to_noised_imgs = main_data_path / "imgs_with_noise" #r"D:\Projects\PythonProjects\NIR\data\imgs_with_noise"

In [None]:
for i in range(1, 11):   
    filtering_image(model, path_to_noised_imgs, f"{i}.jpg", win_size, device)

In [None]:
filtered_images = r"D:\Projects\PythonProjects\NIR\data\filtered_imgs"
# convert_to_grayscale(genuine_images)


In [None]:
genuine_images = r"D:\Projects\PythonProjects\NIR\data\gray_images"

check_ssim(filtered_images, genuine_images)

In [9]:
model_name = "also_good_modelW9_S1.csv"
load_model = True
if load_model:
    model = torch.load(models_path / model_name)
    

In [10]:
from os import listdir


path_real_imgs = main_data_path / "real_images"
path_to_image = path_real_imgs / "raw"
out_path = path_real_imgs / "filtered"

list_images = listdir(path_to_image)
for img_name in list_images:
    filtering_image(model, out_path, path_to_image, img_name, win_size, device)

100%|██████████| 450/450 [00:06<00:00, 70.83it/s] 
100%|██████████| 162/162 [00:00<00:00, 211.61it/s]
100%|██████████| 216/216 [00:00<00:00, 294.84it/s]
100%|██████████| 224/224 [00:00<00:00, 297.25it/s]
100%|██████████| 216/216 [00:00<00:00, 293.65it/s]
