In [None]:
# Data Handlers
import numpy as np

# Pytorch
import torch
import torch.nn as nn
import torch.optim as optim

# Other
from pathlib import Path
from os import listdir
from time import time
from sklearn.preprocessing import OneHotEncoder
# from fastaniso import anisodiff

# Graphics
from matplotlib import pyplot as plt
import seaborn as sns

# Additional modules
from dataset_creator import generate_csv
from assistive_funcs import filtering_image, check_ssim, check_gmsd, get_dataset_name
from csv_dataloader import get_train_test_small_data

In [None]:
# Paths
p_main_data = Path("../data")
p_models = Path("../models")

p_scv_folder = p_main_data / "csv_files" # datasets_path
p_img = p_main_data / "images"

p_noised_imgs = p_main_data / "FC_imgs_with_noise"

p_filtered_images = p_main_data / "FC_filtered_images"

p_gray_images = p_main_data / "gray_images"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# Dataset
win_size = 3
step = 5
create_dataset = 1
if create_dataset:
    generate_csv(win_size=win_size, dump_to_file=10000, step=step, force_create_dataset=1, classification=True)

In [None]:
dataset_name = get_dataset_name(win_size, step, p_scv_folder, classification=True) #r"W5_S1_L3696640.csv"
dataset_name

In [None]:
# from google.colab import drive
# drive.mount('/content/gdrive/')
# !unzip -q /content/gdrive/MyDrive/NIR/data/FC_data/W5_S5_L146410.zip


In [None]:
class FCBlock(nn.Module):
    def __init__(self, in_len, out_len, p_dropout=False) -> None:
        super().__init__()
        if p_dropout:
            self.fc_block = nn.Sequential(
                nn.Linear(in_len, out_len),
                nn.Dropout(p_dropout),
                nn.ReLU(),
            )
        else:
            self.fc_block = nn.Sequential(
                nn.Linear(in_len, out_len),
                nn.BatchNorm1d(out_len),
                nn.ReLU(),
            )
    
    def forward(self, x):
        return self.fc_block(x)

In [None]:
class DefaultModel(nn.Module):
    def __init__(self, in_len, out_len) -> None:
        super().__init__()
        double_in_len = in_len * 2
        triple_in_len = in_len * 3

        self.structure = nn.Sequential(
            FCBlock(in_len, in_len,               p_dropout=0),
            FCBlock(in_len, double_in_len,        p_dropout=0),
            # FCBlock(double_in_len, double_in_len, p_dropout=0),
            FCBlock(double_in_len, double_in_len, p_dropout=0),
            FCBlock(double_in_len, triple_in_len, p_dropout=0),
            FCBlock(triple_in_len, triple_in_len, p_dropout=0),
            nn.Linear(triple_in_len, out_len),
            # nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.structure(x)
        return x

In [None]:
class FitModel():
    def __init__(self, model, criterion, optimizer, scheduler,
                 p_scv_folder, dataset_name,
                 batch_size, device, num_epoches, classification=True):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.num_epoches = num_epoches
        self.p_scv_folder = p_scv_folder
        self.dataset_name = dataset_name
        self.batch_size = batch_size
        self.train_losses = []
        self.valid_losses = []
        self.device = device  
        self.classification = classification
        self.images_filtered = False
        
        if self.classification:
            X = [[i] for i in range(0, 256)]
            self.enc = OneHotEncoder()
            self.enc.fit(X)

    def _train(self, current_epoch):
        total_loss = []
        start_time = time()
        self.model.train()
        for data, targets in self.train_loader:
            data = data.to(device=self.device)
            if self.classification:
                targets = torch.Tensor(self.enc.transform(targets).toarray())
            targets = targets.to(device=self.device)

            scores = self.model(data)
            loss = self.criterion(scores, targets)
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            total_loss.append(loss.item())
        mean_total_loss = np.mean(total_loss)
        
        self.train_losses.append(mean_total_loss)
        print(f"Epoch: {current_epoch}/{self.num_epoches}, time: {int(time() - start_time)} s\n\tTrain loss: {mean_total_loss:.2f}", end=" ")
        
        
    def _valid(self, current_epoch):
        total_loss = []

        self.model.eval()
        with torch.no_grad():
            for data, targets in self.test_loader:
                
                data = data.to(device=self.device)
                if self.classification:
                    targets = torch.Tensor(self.enc.transform(targets).toarray())
                targets = targets.to(device=self.device)
                
                # Forward
                scores = self.model(data)
                loss = self.criterion(scores, targets)
                total_loss.append(loss.item())
        mean_total_loss = np.mean(total_loss)
        self.valid_losses.append(mean_total_loss)
        print(f"Valid loss: {mean_total_loss:.2f}, lr = {self.scheduler.get_last_lr()}")
    
    def fit(self):
        for epoch in range(self.num_epoches):
            self.train_loader, self.test_loader = get_train_test_small_data(scv_folder=self.p_scv_folder, dataset_name=self.dataset_name,
                                                                      batch_size=self.batch_size, train_size=0.8)
            self._train(epoch + 1)
            self._valid(epoch + 1)
            if self.scheduler is not None:
                # print(self.scheduler.get_last_lr())
                self.scheduler.step()
                
    def plot_graph(self):
        sns.set()
        fig, (ax_train, ax_test) = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
        fig.suptitle('Loss')

        ax_train.set_title("Train loss")
        ax_test.set_title("Valid loss")

        ax_train.set_ylabel('Loss value')
        ax_test.set_ylabel('Loss value')

        ax_train.set_xlabel("Batch")
        ax_test.set_xlabel("Batch")

        sns.lineplot(data=self.train_losses, ax=ax_train)
        sns.lineplot(data=self.valid_losses, ax=ax_test)

        plt.show()
    
    def filtering_all_images(self):
        self.images_filtered = True
        images_names = listdir(p_noised_imgs)
        for name in images_names:
            filtering_image(self.model, p_filtered_images, p_noised_imgs, name, win_size, self.device, classification=self.classification)
        
    @staticmethod
    def _check_filtering(p_target_images, p_original_images):
        ssim_metric = []
        gmsd_metric = []
        images_names = listdir(p_target_images)
        for name in images_names:
            ssim_metric.append(check_ssim(p_target_images, p_original_images, name))
            gmsd_metric.append(check_gmsd(p_target_images, p_original_images, name))
        return ssim_metric, gmsd_metric
        # print(f"SSIM avg: {sum(ssim_metric) / len(ssim_metric)}")
        # print(f"GMSD avg: {sum(gmsd_metric) / len(gmsd_metric)}")
    
    def check_metrics(self):
        if not self.images_filtered:
            print("Warning: images weren't filtered")
        metrics_after_filtering = self._check_filtering(p_filtered_images, p_img)
        metrics_befor_filtering = self._check_filtering(p_noised_imgs, p_img)
        print(f"After filtering\n\tSSIM: {np.mean(metrics_after_filtering[0]):.3f}\n\tGMSD: {np.mean(metrics_after_filtering[1]):.3f}")
        
        print(f"Before filtering\n\tSSIM: {np.mean(metrics_befor_filtering[0]):.3f}\n\tGMSD: {np.mean(metrics_befor_filtering[1]):.3f}")
        

In [None]:
# Hyperparameters 
learning_rate = 0.1
num_epoches = 18
batch_size = 300

In [None]:
model = DefaultModel(in_len=(win_size ** 2), out_len=256).to(device=device)
criterion = nn.CrossEntropyLoss()
# Adagrad, RAdam, Adam, Adamax, NAdam - 0.89
optimizer = optim.Rprop(model.parameters(), lr=learning_rate)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[i for i in range(4, 8 + 1, 4)], gamma=0.1)
# scheduler = None

In [None]:
fit_model = FitModel(model, criterion, optimizer, scheduler, p_scv_folder, dataset_name, batch_size, device, num_epoches, classification=True)
fit_model.fit()

In [None]:
fit_model.plot_graph()

In [None]:
fit_model.filtering_all_images()

In [None]:
fit_model.check_metrics()