In [2]:
import cv2
import matplotlib.pyplot as plt
import numpy as np 
from core.name_convention import *
import core.fruit_list as fruit_list
from core.datasets.hyperspectral_dataset import get_records
from core.datasets.hyperspectral_dataset import HyperspectralDataset, prepare_fruit


In [None]:
def main(classification_type, fruit, camera_type, data_path, type):
    train_records, val_records, test_records = get_records(fruit, camera_type, classification_type,
                                                           use_inter_ripeness_levels=True)

    train_records = np.concatenate([train_records, val_records])

    def load_fruit(r: FruitRecord):
        #_d =prepare_fruit(r,data_path,(128,128))
        # _, _d = r.load(data_path, True)
        #_d = _d.mean(axis=0).mean(axis=0)
        _header, _data = r.load(data_path, is_already_referenced=True)
        _data = cv2.resize(_data, dsize=(128,128), interpolation=cv2.INTER_CUBIC)
        _d= np.array(_data)

        if classification_type == ClassificationType.RIPENESS:
            return _d, ripeness2int(r.label.ripeness_state)
        if classification_type == ClassificationType.FIRMNESS:
            return _d, firmness2int(r.label.get_firmness_level())
        if classification_type == ClassificationType.SUGAR:
            return _d, sugar2int(r.label.get_sugar_level())

    def get_dataset(records):
        X = []
        Y = []
        for _r in records:
            x, y = load_fruit(_r)
            X.append(x)
            Y.append(y)

        X = np.stack(X)
        Y = np.stack(Y)

        return X, Y

    X_train, Y_train_a = get_dataset(train_records)
    X_test, Y_test_a= get_dataset(test_records)
    for i in range(len(X_train)):
      if type=="B":
        if Y_train_a[i] == 2:
          print(X_train[i].shape)
          spectral.imshow(X_train[i])
          print(Y_train_a[i])
    
    
    
    return X_train, X_test



In [None]:

train_V, test_V = main(ClassificationType.SUGAR, Fruit.KIWI, CameraType.VIS, "Data/", "A")
train_N, test_N = main(ClassificationType.SUGAR, Fruit.KIWI, CameraType.NIR, "Data/","B") 

# for i in range(train_V[0].shape[2]):
#       plt.imshow(train_V[0][:,:,i], cmap='gray')  # Assuming grayscale visualization
#       plt.title(f'Band {i+1}')
#       plt.colorbar()
#       plt.show()


num_bands = train_V[0].shape[2]
num_images_per_row = 10  # Adjust this according to your preference

num_rows = num_bands // num_images_per_row + (1 if num_bands % num_images_per_row != 0 else 0)

fig, axes = plt.subplots(num_rows, num_images_per_row, figsize=(20, 4*num_rows))

for i in range(num_bands):
    row = i // num_images_per_row
    col = i % num_images_per_row
    ax = axes[row, col]
    ax.imshow(train_V[0][:,:,i], cmap='gray')  # Assuming grayscale visualization
    ax.set_title(f'Band {i+1}')
    ax.axis('off')

# Hide empty subplots
for i in range(num_bands, num_rows*num_images_per_row):
    row = i // num_images_per_row
    col = i % num_images_per_row
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()




In [None]:
import spectral
import numpy as np
import matplotlib.pyplot as plt
import spectral.io.envi as envi
from mpl_toolkits.mplot3d import Axes3D
from core.util import display_hyper_spectral_data,display_all_bands,plot_3d_data

envi_header = envi.open("Data/Kiwi/NIR/day_07/kiwi_day_07_04_front.hdr", image="Data/Kiwi/NIR/day_07/kiwi_day_07_04_front.bin")
cube = envi_header.load()
display_hyper_spectral_data(cube)
print(cube.shape[2])

# Display the RGB visualization of the hyperspectral image
#rgb_image = envi_data.rgb()
spectral.imshow(cube)

# Assuming your hyperspectral cube is stored in the variable `cube`
# `cube` should be a 3D numpy array of shape (height, width, bands)

# Create meshgrid for x, y, z axes
x, y, z = np.meshgrid(np.arange(cube.shape[1]), np.arange(cube.shape[0]), np.arange(cube.shape[2]))

# Plot the hyperspectral cube
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111, projection='3d')

# Iterate over bands and plot each slice
for i in range(cube.shape[2]):
    ax.scatter(x[:, :, i].flatten(), y[:, :, i].flatten(), z[:, :, i].flatten(), c=cube[:, :, i].flatten(), cmap='viridis')
ax.set_axis_off()
# Set axis labels
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Band')


plt.show()

### Example output after running about 120 epochs

![Imgur](https://imgur.com/JfOx37U.png)