In [None]:
import os
import os.path as osp
import numpy as np
import cv2
from torchvision import transforms
import torch
from skimage import io
from PIL import Image

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset


In [None]:
folder_dir = osp.join('Aerial Data', 'Collection') #path to full files

image_hsi_rad = io.imread(osp.join(folder_dir, 'image_hsi_radiance.tif'))
image_hsi_rad = np.transpose(image_hsi_rad, [1,2,0])[53:,7:,:]
image_hsi_rad = np.clip(image_hsi_rad, 0, 2**14)/2**14

print(image_hsi_rad.shape)

In [None]:
get_labels=np.asarray(
                [
                        [255, 0, 0],
                        [0, 255, 0],
                        [0, 0, 255],
                        [0, 255, 255],
                        [255, 127, 80],
                        [153, 0, 0],
                        ]
                )

def encode_segmap( mask):
        mask = mask.astype(int)
        label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
        for ii, label in enumerate(get_labels):
            label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
        label_mask = label_mask.astype(int)
        return label_mask

In [None]:
labels = io.imread(osp.join(folder_dir, 'image_labels.tif'))[53:,7:,:]
labels = encode_segmap(labels)

In [None]:
torch.manual_seed(0)
np.random.seed(0)

In [None]:
total_pixels = 1920 * 3968

mask = image_hsi_rad[..., 0] != 0  
nonzero_indices = np.argwhere(mask)  

chosen_indices = nonzero_indices[np.random.choice(len(nonzero_indices), 100000, replace=False)]
sampled_hsi_rad = image_hsi_rad[chosen_indices[:, 0], chosen_indices[:, 1]]  # shape: (1000, 51)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Oper1D(nn.Module):
    def __init__(self, bands=51):
        super(Oper1D, self).__init__()
        self.bands = bands
        self.tanh = nn.Tanh()
        self.all_layers = nn.ModuleList([ 
            nn.Conv1d(in_channels=1, out_channels=51,
                      kernel_size=3, padding=1)
            for i in range(3)
        ])

    def forward(self, x):
        # x: (batch, channels=n_bands, length=1), e.g. (B, 51, 1)

        out1 = self.all_layers[0](x)
        out2 = self.all_layers[1](x**2)
        out3 = self.all_layers[2](x**3) 

        out = out1+out2+out3
            
        out = self.tanh(out)

        mask = torch.ones_like(out) - torch.eye(51, device=out.device).unsqueeze(0)
        out_no_diag = out * mask
            
        return out_no_diag

class SLRol(nn.Module):
    def __init__(self, n_bands, q=None):
        super(SLRol, self).__init__()
        self.n_bands = n_bands
        self.Oper1D = Oper1D(n_bands)

    def forward(self, x):

        x_0 = self.Oper1D(x) 

        l1_penalty = 0.01 * torch.sum(torch.abs(x_0)) 

        y = torch.matmul(x,x_0)

        return y, l1_penalty


In [None]:
## Non-zero indices only
sampled_hsi_rad.shape

In [None]:
# 하이퍼파라미터 설정
batch_size = 10
epochs = 10
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# 예제 데이터 (10000, 51) 형태로 가정
x_tensor = torch.tensor(sampled_hsi_rad, dtype=torch.float32).unsqueeze(1).to(device)  # (10000, 51) -> (10000, 1, 51)
dataset = TensorDataset(x_tensor)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 모델 초기화
model = SLRol(n_bands=51).to(device)  
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# 훈련 루프
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for xb in loader:
        xb = xb[0].to(device)  
        
        optimizer.zero_grad()
        preds,l1_penalty = model(xb)

        loss_recon = criterion(preds, xb.squeeze(-1)) 
        loss = loss_recon+ l1_penalty
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss / len(loader):.6f}")


In [None]:
model.eval()

with torch.no_grad():
    A = model.Oper1D.all_layers[1](x_tensor)
A = torch.abs(A)
A = torch.mean(A, dim = 0)
A = torch.sum(A, dim = 0)

A = A.cpu().detach().numpy()
indices = np.argsort(A)

In [None]:
number_of_sampled_bands = 5
indices[-number_of_sampled_bands::]