In [None]:
# data split, model choise, GPU setting
from sklearn.model_selection import train_test_split
from compressai.zoo import cheng2020_attn
import torch
import os

# data 읽어오기
data1024_filename = 'data1024lst.txt'

f = open(data1024_filename,'r')
data1024copy = [line.strip() for line in f.readlines()]
print(len(data1024copy))

torch.cuda.empty_cache()

# train:val:test = 7:1:2
X_train, X_test = train_test_split(data1024copy, test_size=0.2, random_state=42)
X_train, X_val = train_test_split(X_train, test_size =1/8, random_state=42)

PATH = "/nas/userdata/kim_y/CME autodetection/LZ/CME_unpiped/"
MAX_INT16 = 32767

# GPU 사용
device = torch.device('cuda:1')

# quality=6 → 논문에서 사용한 lambda=0.0483에 해당
model = cheng2020_attn(pretrained=True, quality=6).train().to(device)

print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())


hello


In [None]:
# main
# <전제>
# 논문에서의 전처리는 16bit정수형 -> 8bit정수형 byte scale
# 모델도 8bit정수형으로 양자화 되어야 하는가? 해석의 모호함 존재
# -> 우선 전처리에서 byte scale은 빼고 진행
from torchvision import transforms
from compressai.losses import RateDistortionLoss
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import astropy.io.fits as fits
import numpy as np

log_max = 0
log_min = 0

def cut_saturation(img_arr, threshold=3):
    pixels = img_arr.flatten()
    pmean = np.mean(pixels)
    pstd = np.std(pixels)  
    pmax = threshold*pstd + pmean 
    modified_array = np.where(img_arr > pmax, pmax, img_arr)
    return modified_array

def minmax(img_arr):
    global log_max
    global log_min 
    log_max = torch.max(img_arr)
    log_min = torch.min(img_arr)
    byteimage = ((img_arr - log_min) / (log_max - log_min))
    return byteimage

def rescale(img_arr):
    global log_max
    global log_min 
    rescaled_img = img_arr * (log_max - log_min) + log_min
    return rescaled_img

transform = transforms.Compose([
    # ValueError: given numpy array has byte order different from the native byte order. 
    # Conversion between byte orders is currently not supported.
    # 위의 에러 해결 위해 numpy int16을 float32로 바꿔주는 코드 추가
    lambda img: img.astype(np.float32), # byte
    transforms.ToTensor(),
    transforms.Resize((512,512)), 
    # log conversion
    lambda img: torch.log1p(img), # [log_min, log_max]
    minmax, # [0,1] 
    lambda img: torch.Tensor.repeat(img,3,1,1).type(torch.float).to(device)
])

invtransform = transforms.Compose([
    lambda img: img.squeeze(0).permute(1, 2, 0),
    # numpy로 바꾸기 전에 max값으로 나누고 mean때리고 exmp1하는게 더 빠르지 않을까
    lambda img: img.mean(axis=2),
    # inverse byte scale
    lambda img: img/torch.max(img), # [0,1]
    rescale, # [log_min, log_max]
    # invers log conversion
    lambda img: torch.expm1(img)# [0, max]
])

# 사용자 정의 Dataset 클래스 생성
class ImageDataset(Dataset):
    def __init__(self, img_dir, data_name_lst, transform=None):
        self.img_dir = img_dir
        self.img_labels = data_name_lst 
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx])
        image = fits.open(img_path)[0].data
        if self.transform:
            image = self.transform(image)
        return image
    

# variable setting
batch_size = 4
learning_rate = 1e-4
patience = 15 # max patience
j = 0 # current patience
epoch = 0 
criterion = RateDistortionLoss(lmbda=0.0483)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
num_workers = 0

# dataset
image_trainset = ImageDataset(PATH, X_train,transform=transform)
train_loader = DataLoader(image_trainset, batch_size=batch_size, shuffle=True)

image_valset = ImageDataset(PATH, X_val,transform=transform)
val_loader = DataLoader(image_valset, batch_size=batch_size, shuffle=False)

image_testset = ImageDataset(PATH, X_test,transform=transform)
test_loader = DataLoader(image_testset, shuffle=False)

train_loss_lst = []
val_loss_lst = []

num_epochs = 100
for epoch in range(num_epochs):
    model.train()  # 학습 모드로 설정
    train_loss = 0.0
    for i, inputs in enumerate(train_loader):
        # 순전파 (Forward pass)
        outputs = model(inputs)
        loss = criterion(outputs, inputs)

        # 역전파 및 파라미터 업데이트
        optimizer.zero_grad()
        loss['loss'].backward()
        optimizer.step()
        train_loss += loss['loss'] * inputs.size(0)
        if i%10 == 0:
            print("epoch",epoch+1,"/ iteration:",i,", ratio:",i*batch_size/len(train_loader)*100,"%")
    train_loss /= len(X_train)
    train_loss_lst.append(train_loss)
    print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}')

    # 검증 루프
    model.eval()  # 평가 모드로 설정
    val_loss = 0.0
    with torch.no_grad():
        for inputs_val in val_loader:
            outputs_val = model(inputs_val)
            loss_val = criterion(outputs_val, inputs_val)
            val_loss += loss_val['loss'] * inputs_val.size(0)
        val_loss /= len(X_val)
        val_loss_lst.append(val_loss)
        print(f'Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}')

        # Early Stopping 체크
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            j = 0
            torch.save(model.state_dict(), 'best_model.pth') # Best 모델 저장
        else:
            j += 1
            if j >= patience:
                print(f'Early stopping triggered at epoch {epoch+1}')
                break