# Import modules

In [1]:
# Data Handlers
import pandas as pd
import numpy as np
from PIL import Image
from PIL import ImageOps

# Pytorch
import torch
import torch.nn as nn  # NN; networks (CNN, RNN, losses)
import torch.optim as optim  # Optimizers (Adam, Adadelta, Adagrad)
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, IterableDataset  # Dataset manager
from torch.autograd import Variable
from torchvision import datasets
from torch.utils.data.sampler import SubsetRandomSampler

# Other
import datetime
import os
from tqdm import tqdm
from pathlib import Path
from os import listdir
from fastaniso import anisodiff

# Graphics
from matplotlib import pyplot as plt
import seaborn as sns

# Additional modules
from dataset_creator import generate_csv
from assistive_funcs import filtering_image, check_ssim, convert_to_grayscale, get_dataset_name
from csv_dataloader import get_train_test_data
from math import floor

  from .autonotebook import tqdm as notebook_tqdm


# Generate Dataset

In [None]:
create_dataset = 0
if create_dataset:
    generate_csv(win_size=7, dump_to_file=50000, step=4)

## Define constants

In [2]:
# Paths
p_main_data = Path("../data")
p_models = Path("../models")

p_scv_folder = p_main_data / "csv_files" # datasets_path
p_img = p_main_data / "images"

p_noised_imgs = p_main_data / "imgs_with_noise" 

p_filtered_images = p_main_data / "filtered_imgs"

p_gray_images = p_main_data / "gray_images"

# Hyperparameters 
learning_rate = 0.001
num_epoches = 4
batch_size = 128

# Dataset
win_size = 7
_step = 4

dataset_name = get_dataset_name(win_size, _step, p_scv_folder) #r"W5_S1_L3696640.csv"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{dataset_name = }\n{device = }")

dataset_name = 'W7_S4_L231040.csv'
device = device(type='cpu')


# NN Model

In [None]:
class DefaultModel(nn.Module):
    def __init__(self, in_len, out_len) -> None:
        super().__init__()
        self.in_len = in_len
        self.out_len = out_len
        self.hid_n = 20
        
        self.fcs = nn.Sequential(
            nn.Linear(self.in_len, self.hid_n),
            nn.BatchNorm1d(self.hid_n),
            nn.Mish(),
            nn.Linear(self.hid_n, self.hid_n * 2),
            nn.BatchNorm1d(self.hid_n * 2),
            nn.Mish(),
            nn.Linear(self.hid_n * 2, self.hid_n),
            nn.BatchNorm1d(self.hid_n),
            nn.Mish(),
            nn.Linear(self.hid_n, self.out_len),
        )

    def forward(self, x):
        x = self.fcs(x)
        return x

## Initialize model

In [None]:
model = DefaultModel(in_len=(win_size ** 2), out_len=1).to(device=device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
losses = []
losses_append = losses.append

valid_losses = []
valid_losses_append = valid_losses.append

for epoch in range(num_epoches):
    model.train()
    train_loader, test_loader = get_train_test_data(scv_folder=p_scv_folder, dataset_name=dataset_name, batch_size=batch_size, train_size=0.9)
    for batch_ind, (data, targets) in tqdm(enumerate(train_loader)):
        # Data on cuda
        data = data.to(device=device)
        targets = targets.to(device=device)
        
        # Forward
        scores = model(data) # Equal to model.forward(data)
        loss = criterion(scores, targets)
        if batch_ind % 3 == 0:
            losses_append(loss.item())
        # Backprop
        loss.backward()

        # Gradient descent or adam step
        optimizer.step()
        optimizer.zero_grad()
    model.eval()
    with torch.no_grad():
        for batch_ind, (data, targets) in tqdm(enumerate(test_loader)):
            data = data.to(device=device)
            targets = targets.to(device=device)
            
            # Forward
            scores = model(data) # Equal to model.forward(data)
            loss = criterion(scores, targets)
            valid_losses_append(loss.item())



In [None]:
sns.set()
fig, (ax_train, ax_test) = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
fig.suptitle('Loss')

ax_train.set_title("Train loss")
ax_test.set_title("Test loss")

ax_train.set_ylabel('Loss value')
ax_test.set_ylabel('Loss value')

ax_train.set_xlabel("k * Batch")
ax_test.set_xlabel("k * Batch")

sns.lineplot(data=losses, ax=ax_train)
sns.lineplot(data=valid_losses, ax=ax_test)

plt.show()

In [None]:
path_to_model = p_models / dataset_name
torch.save(model, path_to_model)

# Check NN works

In [None]:
list_images = listdir(p_noised_imgs)
for img_name in list_images:
    filtering_image(model, p_filtered_images, p_noised_imgs, img_name, win_size, device)

In [None]:
p_real_imgs = p_main_data / "real_images"

p_raw_image = p_real_imgs / "raw"
p_out_imgs = p_real_imgs / "filtered"

list_images = listdir(p_raw_image)
for img_name in list_images:
    filtering_image(model, p_out_imgs, p_raw_image, img_name, win_size, device)

In [None]:
images_names = listdir(p_noised_imgs)
for name in images_names:
    check_ssim(p_filtered_images, p_gray_images, name)

# Anis diff

In [None]:
p_img = p_raw_image / list_images[2]
img_arr = np.array(ImageOps.grayscale(Image.open(p_img)))
niters = [1, 2, 5, 10, 20]
kappa = [1, 5, 10, 50, 100]
for i in niters:
    for k in kappa:
        img_filtered = anisodiff(img_arr, niter=i, kappa=k)
        img_filtered = img_filtered.astype(np.uint8)
        Image.fromarray(img_filtered).save(f"{p_img}_I{i}_K{k}.jpg")

# Load model

In [None]:
model_name = "also_good_modelW9_S1.csv"
load_model = False
if load_model:
    model = torch.load(p_models / model_name)

# CNN Model




In [3]:
class CNNFilter(nn.Module):
    def __init__(self, *, num_filter, num_channel) -> None:
        super().__init__()
        self.sequential = nn.Sequential(
            nn.Conv2d(1, num_filter, 3, padding=1),
            nn.BatchNorm2d(num_filter),
            nn.LeakyReLU(0.25, inplace=True),
            nn.Conv2d(num_filter, num_filter, 3, padding=1),
            nn.BatchNorm2d(num_filter),
            nn.LeakyReLU(0.25, inplace=True),
            nn.Conv2d(num_filter, num_filter, 3, padding=1),
            nn.BatchNorm2d(num_filter),
            nn.LeakyReLU(0.25, inplace=True),
            nn.Conv2d(num_filter, num_channel, 3, padding=1),
            nn.BatchNorm2d(1),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.sequential(x)

In [4]:
cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu')
w, h = 608, 608 
num_filter = 64
num_channel = 1

def load_data(BATCH_SIZE=32):
    
    transform = transforms.Compose([
        transforms.Resize(size=(w, h)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    Data = datasets.ImageFolder(p_main_data / "images_for_cnn", transform=transform)
   

    batch_size = BATCH_SIZE
    test_size = 0.3
    
    num_data = len(Data)
    indices_data = list(range(num_data))
    np.random.shuffle(indices_data)
    split_tt = int(np.floor(test_size * num_data))
    train_idx, test_idx = indices_data[split_tt:], indices_data[:split_tt]
    
   
    # define samplers for obtaining training and validation batches
    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(test_idx)
    
    trainloader = torch.utils.data.DataLoader(Data, sampler = train_sampler, batch_size=batch_size, num_workers=4)
    testloader  = torch.utils.data.DataLoader(Data, sampler = test_sampler,  batch_size=batch_size, num_workers=4)
    
    return trainloader, testloader

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    
def noise_input(images):
    result = (images + images * torch.rand(images.size()))
    result_max = Variable(torch.max(result))
    return result / result_max#(images + images * torch.rand(images.size())
    #return images + torch.rand(images.size())
    #return images * (1 - NOISE_RATIO) + torch.rand(images.size()) * NOISE_RATIO

def model_training(mycnn, train_loader, epoch):
    loss_metric = nn.MSELoss()
    optimizer = torch.optim.Adam(mycnn.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    mycnn.train()
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        images, _ = data
        images = Variable(images)
        noise_images = noise_input(images)      # add noise into image data
        if cuda: 
            images, noise_images = images.to(device), noise_images.to(device)
        noise = mycnn(noise_images)
        outputs = noise_images / (1 - noise)
        #outputs = noise_images - noise
        loss = loss_metric(outputs, images)
        loss.backward()
        optimizer.step()
        if (i + 1) % LOG_INTERVAL == 0:
            print('Epoch [{}/{}] - Iter[{}/{}], MSE loss:{:.4f}'.format(
                epoch + 1, EPOCHS, i + 1, len(train_loader.dataset) // BATCH_SIZE, loss.item()
            ))
   #     torch.cuda.empty_cache()

def evaluation(mycnn, test_loader):
    total_loss = 0
    loss_metric = nn.MSELoss()
    mycnn.eval()
    for i, data in enumerate(test_loader):
        images, _ = data
        images = Variable(images)
        if cuda: 
            images = images.to(device)
        outputs = mycnn(images)
        loss = loss_metric(outputs, images)
        total_loss += loss.item() * len(images)
    avg_loss = total_loss / len(test_loader.dataset)

    print('\nAverage MSE Loss on Test set: {:.4f}'.format(avg_loss))

    global BEST_VAL
    if TRAIN and avg_loss < BEST_VAL:
        BEST_VAL = avg_loss
        torch.save(mycnn.state_dict(), '/content/drive/My Drive/history/conv_filter.pt')
        print('Save Best Model in HISTORY\n')


In [7]:
if __name__ == '__main__':
    EPOCHS = 10
    BATCH_SIZE = 8
    LEARNING_RATE = 1e-3
    WEIGHT_DECAY = 1e-5
    LOG_INTERVAL = 10
    NOISE_RATIO = 0.5
    TRAIN = True       # whether to train a model from scratch
    BEST_VAL = float('inf')     # record the best val loss

   # train_loader, test_loader = data_utils.load_mnist(BATCH_SIZE)
    
    train_loader, test_loader = load_data(BATCH_SIZE)
    mycnn = CNNFilter(64, 1)
    if cuda: mycnn.to(device)

    if TRAIN:
        for epoch in range(EPOCHS):
            starttime = datetime.datetime.now()
            model_training(mycnn, train_loader, epoch)
            endtime = datetime.datetime.now()
            
            print(f'Train a epoch in {(endtime - starttime).seconds} seconds')
            # evaluate on test set and save best model
            evaluation(mycnn, test_loader)
        print('Trainig Complete with best validation loss {:.4f}'.format(BEST_VAL))
    else:
        mycnn.load_state_dict(torch.load('/content/drive/My Drive/history/conv_filter.pt'))
        evaluation(mycnn, test_loader)
        
        mycnn.cpu()
        dataiter = iter(train_loader)
        images, _ = next(dataiter)
        images = Variable(images[:1])

        noise_images = noise_input(images)
        noise = mycnn(noise_images)
        outputs = noise_images / (1-noise)
        #outputs = noise_images - noise

        # plot and save original and reconstruction images for comparisons
        plt.figure(figsize=(20, 15))
        plt.subplot(131)
        plt.title('Image')
        imshow(torchvision.utils.make_grid(images))
        plt.subplot(132)
        plt.title('Noise Image')
        imshow(torchvision.utils.make_grid(noise_images))
        plt.subplot(133)
        plt.title('Reconstruction')
        imshow(torchvision.utils.make_grid(
            outputs.view(images.size(0), num_channel, w, h).data
        ))
        plt.savefig('/content/drive/My Drive/images/filter_net_2.png')


TypeError: __init__() takes 1 positional argument but 3 were given