Skip to content

Commit

Permalink
Everything implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingTil committed Oct 30, 2023
1 parent b35e466 commit cbf71d9
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 60 deletions.
1 change: 1 addition & 0 deletions data/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
lol_dataset/
intermediate_images/
pixel_dataset.ds

87 changes: 35 additions & 52 deletions eiuie/consolidate_dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
from typing import Generator, Dict, Literal
import os

import numpy as np
import pandas as pd
import cv2
import glob

import pixel_dataset as pxds


def prepare_dataset() -> None:
generator = __consolidate_data()
if os.path.exists(pxds.FILE):
os.remove(pxds.FILE)
with open(pxds.FILE, "wb") as file:
for data in generator:
combined = np.hstack((data["original"], data["unsharp"], data["homomorphic"], data["retinex"], data["ground_truth"]))
for row in combined:
file.write(bytes(row))


def consolidate_data(image_files, source_path) -> pd.DataFrame:
"""
consolidate data
"""
def __consolidate_data() -> Generator[Dict[Literal["original", "retinex", "unsharp", "homomorphic", "ground_truth"], np.ndarray], None, None]:
# Path to intermediate images
path_retinex = source_path + "intermediate_images/retinex/"
path_unsharp = source_path + "intermediate_images/unsharp_masking/"
path_homomorphic = source_path + "intermediate_images/homomorphic_filtering/"
path_retinex = "data/intermediate_images/retinex/"
path_unsharp = "data/intermediate_images/unsharp_masking/"
path_homomorphic = "data/intermediate_images/homomorphic_filtering/"

list_of_dicts = []
for image in image_files:
files = glob.glob("data/lol_dataset/*/low/*.png")

for image in files:
# read original image
image_original = cv2.imread(image)

# image ground truth
image_ground_truth = cv2.imread(image.replace("low", "high"))

# extract image id
i = image.split("/")[-1].split(".")[0]

Expand All @@ -27,52 +43,19 @@ def consolidate_data(image_files, source_path) -> pd.DataFrame:
image_homomorphic = cv2.imread(path_homomorphic + str(i) + ".png")

# reshape image to 2D array
image2D_original = image_original.reshape((image_original.shape[0]*image_original.shape[1], 3))
image2D_retinex = image_retinex.reshape((image_retinex.shape[0]*image_retinex.shape[1], 3))
image2D_unsharp = image_unsharp.reshape((image_unsharp.shape[0]*image_unsharp.shape[1], 3))
image2D_homomorphic = image_homomorphic.reshape((image_homomorphic.shape[0]*image_homomorphic.shape[1], 3))
image2D_original = image_original.reshape(-1, 3)
image2D_retinex = image_retinex.reshape(-1, 3)
image2D_unsharp = image_unsharp.reshape(-1, 3)
image2D_homomorphic = image_homomorphic.reshape(-1, 3)
image2D_ground_truth = image_ground_truth.reshape(-1, 3)

# convert to single pandas dataframe
data = {
data: Dict[Literal["original", "retinex", "unsharp", "homomorphic", "ground_truth"], np.ndarray] = {
"original": image2D_original,
"retinex": image2D_retinex,
"unsharp": image2D_unsharp,
"homomorphic": image2D_homomorphic,
"ground_truth": image2D_ground_truth
}
list_of_dicts.append(data)
return list_of_dicts

def write_to_tsv(dataset, source_path):
"""
Write dataset to tsv file.
"""

# write to csv file
with open(source_path + "dataset.tsv", "w") as file:
for data in dataset:
# write data to tsv file in the following format: original, unsharp, homomorphic, retinex
for i in range(len(data["original"])):
line = [data['original'][i, 0], data['original'][i, 1], data['original'][i, 2],
data['unsharp'][i, 0], data['unsharp'][i, 1], data['unsharp'][i, 2],
data['homomorphic'][i, 0], data['homomorphic'][i, 1], data['homomorphic'][i, 2],
data['retinex'][i, 0], data['retinex'][i, 1], data['retinex'][i, 2]]

# write line to file
line_str = '\t'.join(map(str, line)) # Convert vector elements to strings and join with tabs
file.write(line_str + '\n') # Writing the vector as a single line
return 0


# source path
source_path = "../data/"

# consolidate dataset in pandas dataframe
glob_pattern = source_path + "lol_dataset/our485/low/*.png"
image_files = glob.glob(glob_pattern)
dataset = consolidate_data(image_files, source_path)

# write dataset to tsv file
write_to_tsv(dataset, source_path)



yield data

10 changes: 9 additions & 1 deletion eiuie/fusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def _get_latest_checkpoint(self) -> Optional[str]:
Optional[str]
Path to the latest checkpoint file or None if no checkpoint found.
"""
if not os.path.exists(CHECKPOINT_DIRECTORY):
return None
checkpoint_files = [
f for f in os.listdir(CHECKPOINT_DIRECTORY) if "checkpoint_epoch_" in f
]
Expand Down Expand Up @@ -240,11 +242,11 @@ def process_image(self, image: np.ndarray) -> np.ndarray:

def train_model(
self,
dataset: Dataset = pxds.PixelDataset(),
total_epochs=100,
patience=5,
train_ratio=0.8,
):
dataset = pxds.PixelDataset()
# Splitting dataset into training and validation subsets
train_size = int(train_ratio * len(dataset))
val_size = len(dataset) - train_size
Expand All @@ -261,6 +263,8 @@ def train_model(

self.net.train()
for epoch in range(self.start_epoch, total_epochs):
print()
print(f"Epoch {epoch+1}/{total_epochs}")
for inputs, targets in train_loader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
self.optimizer.zero_grad()
Expand All @@ -269,15 +273,19 @@ def train_model(
loss.backward()
self.optimizer.step()
# After training, check validation loss
print("Validating...")
val_loss = self.validate(val_loader)
print(f"Validation loss: {val_loss}")

print("Checking early stopping...")
early_stopping(val_loss, self.net)

if early_stopping.early_stop:
print("Early stopping")
break

# Save checkpoint after every epoch
print("Saving checkpoint...")
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch}.pth")

def validate(self, val_loader):
Expand Down
5 changes: 4 additions & 1 deletion eiuie/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import batch_process as bp
import base_model as bm
import consolidate_dataset as cd
import unsharp_masking
import retinex
import homomorphic_filtering
Expand All @@ -16,7 +17,7 @@ def main():
parser.add_argument(
"command",
type=str,
choices=["single", "batch_process", "train"],
choices=["single", "batch_process", "prepare_dataset", "train"],
help="Command to run",
)

Expand Down Expand Up @@ -60,6 +61,8 @@ def main():
cv2.waitKey()
case "batch_process":
bp.batch_process_dataset()
case "prepare_dataset":
cd.prepare_dataset()
case "train":
method = fusion_model.FusionModel()
method.train_model()
Expand Down
19 changes: 13 additions & 6 deletions eiuie/pixel_dataset.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Tuple

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import pandas as pd

TSV_FILE = "data/pixel_dataset.tsv"
FILE = "data/pixel_dataset.ds"


class PixelDataset(Dataset):
"""
PixelDataset class.
Attributes
----------
df: pd.DataFrame
Expand All @@ -20,8 +20,15 @@ class PixelDataset(Dataset):
df: pd.DataFrame

def __init__(self):
self.df = pd.read_table(TSV_FILE, header=None)
self.df = self.df.astype(float)
# Load binary data
with open(FILE, 'rb') as f:
raw_data = f.read()

# Convert binary data to a numpy array of shape (num_rows, 15)
data_array = np.frombuffer(raw_data, dtype=np.uint8).reshape(-1, 15)

# Convert numpy array to pandas dataframe
self.df = pd.DataFrame(data_array)

def __len__(self) -> int:
return len(self.df)
Expand Down

0 comments on commit cbf71d9

Please sign in to comment.