library

In [23]:
import argparse
import os
from PIL import Image
from tqdm.notebook import tqdm
import pickle

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

parameter & utils

In [302]:
parser = argparse.ArgumentParser()

parser.add_argument('--sampling_frequency',type=int, default=20000)
parser.add_argument('--lr_length',type=int, default=1000, help='length of low-resolution signal')
parser.add_argument('--hr_length',type=int, default=10000)
parser.add_argument('--train_x_img_dir',type=str, default='./data/train/x')
parser.add_argument('--train_y_img_dir',type=str, default='./data/train/y')
parser.add_argument('--test_x_img_dir',type=str, default='./data/test/x')
parser.add_argument('--test_y_img_dir',type=str, default='./data/test/y')
parser.add_argument('--valid_x_img_dir',type=str, default='./data/valid/x')
parser.add_argument('--valid_y_img_dir',type=str, default='./data/valid/y')
parser.add_argument('--n_samples',type=int, default=100)
parser.add_argument('--patch_size',type=int, default=128)
parser.add_argument('--stride',type=int, default=96)

parser.add_argument('--epochs',type=int, default=50)
parser.add_argument('--lr',type=float, default=1e-4)
parser.add_argument('--early_stop',type=int, default=20, help='early stop_patience')
parser.add_argument('--batch_size',type=int, default=16)
#parser.add_argument('--train_random_seed',type=int, default=42)
#parser.add_argument('--test_random_seed',type=int, default=43)

opt = parser.parse_args('')

data preprocssing

In [303]:
class FFTProcessor():
    
    '''signal to FFT image'''
    
    def __init__(self, sampling_frequency, img_x_dir, img_y_dir):
        self.fs = sampling_frequency
        self.img_x_dir = img_x_dir
        self.img_y_dir = img_y_dir
        #self.random_seed = random_seed
        
    def generate_random_signal(self, length):
        t = np.linspace(0, length, int(self.fs), endpoint=False)
        x = np.zeros_like(t)
        
        num_components = np.random.randint(20, 50)  # 2~4
        for _ in range(num_components):
            A = np.random.uniform(0.5, 3.0)            # 진폭
            f = np.random.uniform(1, 20)               # 주파수 (Hz)
            phi = np.random.uniform(0, 2*np.pi)        # 위상
            x += A * np.sin(2 * np.pi * f * t + phi)
        return x

    def process_batch(self, n_samples, patch_size, stride):
        for sample_idx in range(n_samples):
            #np.random.seed(self.random_seed)
            x = self.generate_random_signal(opt.lr_length)  # 이렇게 opt. 쓰는 것?? 별론가
            y = self.generate_random_signal(opt.hr_length)

            X = np.fft.fft(x)
            Y = np.fft.fft(y)
            freqs = np.fft.fftfreq(len(X), d=1/self.fs)
            mag_X = np.abs(X)* 2 / len(X)
            mag_Y = np.abs(Y)* 2 / len(Y)

        
            half = int(len(X)/2)
            freq_mag_X = np.column_stack((freqs[:half], mag_X[:half]))
            freq_mag_Y = np.column_stack((freqs[:half], mag_Y[:half]))
            
            save_img(freq_mag_X, self.img_x_dir, sample_idx)
            save_img(freq_mag_Y, self.img_y_dir, sample_idx)
            
            save_patch(f"{self.img_x_dir}/fft_{sample_idx:03d}.png", self.img_x_dir, sample_idx, patch_size, stride)
            save_patch(f"{self.img_y_dir}/fft_{sample_idx:03d}.png", self.img_y_dir, sample_idx, patch_size, stride)


def save_img(x, img_dir, idx):
    
    plt.plot(x[:, 0], x[:, 1])
    plt.axis('off')     # 축 숨기기 (x, y 모두)
    plt.gca().spines['top'].set_visible(False)    # 테두리(스파인) 숨기기
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['bottom'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)
    plt.savefig(f"{img_dir}/fft_{idx:03d}.png", bbox_inches = 'tight', transparent = True)
    plt.close()
    
def save_patch(x, img_dir, sample_idx, patch_size, stride):
    patch_idx = 0
    img = Image.open(x)
    img = np.array(img)
    for i in range(0, img.shape[0] - patch_size + 1, stride):
        for j in range(0, img.shape[1] - patch_size + 1, stride):
            patch = img[i:i+patch_size, j:j+patch_size]
            plt.imsave(f"{img_dir}/fft_{sample_idx:03d}_{patch_idx:02d}.png", patch, cmap='gray')
            patch_idx += 1

In [304]:
FFTProcessor(opt.sampling_frequency, opt.train_x_img_dir, opt.train_y_img_dir).process_batch(opt.n_samples, opt.patch_size, opt.stride)
FFTProcessor(opt.sampling_frequency, opt.test_x_img_dir, opt.test_y_img_dir).process_batch(opt.n_samples, opt.patch_size, opt.stride)
FFTProcessor(opt.sampling_frequency, opt.valid_x_img_dir, opt.valid_y_img_dir).process_batch(opt.n_samples, opt.patch_size, opt.stride)

data load

In [306]:
class make_dataset():
    def __init__(self, train_x, train_y, test_x, test_y, valid_x, valid_y):
        self.train_x = train_x
        self.train_y = train_y
        self.test_x = test_x
        self.test_y = test_y
        self.valid_x = valid_x
        self.valid_y = valid_y
        
    def main(self):
        self.to_tensor()
        return self.train_dl, self.test_dl, self.valid_dl
                
    def to_tensor(self):
        # To tensor
        train_x = torch.tensor(self.train_x)
        train_y = torch.tensor(self.train_y)
        test_x = torch.tensor(self.test_x)
        test_y = torch.tensor(self.test_y)
        valid_x = torch.tensor(self.valid_x)
        valid_y = torch.tensor(self.valid_y)

        train = MyDataset(train_x, train_y)
        test = MyDataset(test_x, test_y)
        valid = MyDataset(valid_x, valid_y)

        self.train_dl = DataLoader(train, opt.batch_size, shuffle=True)
        self.test_dl = DataLoader(test, opt.batch_size, shuffle=False)
        self.valid_dl = DataLoader(valid, opt.batch_size, shuffle=False)
    
class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = x.float()
        self.y = y.float()
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        x = self.x[idx]  # [H, W]
        y = self.y[idx]
        x = x.unsqueeze(0)  # → [1, H, W]
        y = y.unsqueeze(0)
        
        return x.to(device), y.to(device)

In [307]:
train_x = []
for filename in os.listdir('./data/train/x'):
    if filename.count('_') != 1:
    #if filename.count('_') == 1:
        filename = filename
        img_path = os.path.join('./data/train/x', filename)
        img = Image.open(img_path).convert('L')
        img_array = np.array(img)
        train_x.append(img_array)
    
train_y = []
for filename in os.listdir('./data/train/y'):
    if filename.count('_') != 1:
        filename = filename
        img_path = os.path.join('./data/train/y', filename)
        img = Image.open(img_path).convert('L')
        img_array = np.array(img)
        train_y.append(img_array)
    
test_x = []
for filename in os.listdir('./data/test/x'):
    if filename.count('_') != 1:
    #if filename.count('_') == 1:
        filename = filename
        img_path = os.path.join('./data/test/x', filename)
        img = Image.open(img_path).convert('L')
        img_array = np.array(img)
        test_x.append(img_array)
    
test_y = []
for filename in os.listdir('./data/test/y'):
    if filename.count('_') != 1:
        img_path = os.path.join('./data/test/y', filename)
        img = Image.open(img_path).convert('L')
        img_array = np.array(img)
        test_y.append(img_array)

valid_x = []
for filename in os.listdir('./data/valid/x'):
    if filename.count('_') != 1:
    #if filename.count('_') == 1:
        filename = filename
        img_path = os.path.join('./data/valid/x', filename)
        img = Image.open(img_path).convert('L')
        img_array = np.array(img)
        valid_x.append(img_array)
    
valid_y = []
for filename in os.listdir('./data/valid/y'):
    if filename.count('_') != 1:
        img_path = os.path.join('./data/valid/y', filename)
        img = Image.open(img_path).convert('L')
        img_array = np.array(img)
        valid_y.append(img_array)

In [308]:
train_dl, test_dl, valid_dl = make_dataset(train_x, train_y, test_x, test_y, valid_x, valid_y).main()

model

In [309]:
class SRCNN(nn.Module):
    def __init__(self, num_channels=1): #흑백
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding = 9//2) #(H, W) 
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding = 5//2)
        self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding = 5//2)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x=self.relu(self.conv1(x))
        x=self.relu(self.conv2(x))
        x=self.conv3(x)
        return x

train

In [310]:
def train(model, train_dl, valid_dl, optimizer, loss_fn, device, num_epochs):
    model = model.to(device)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        print(f"\nEpoch {epoch+1}/{num_epochs}")
        train_bar = tqdm(train_dl, desc="Training", leave=False)

        for inputs, targets in train_bar:
            inputs = inputs.to(device).float()
            targets = targets.to(device).float()

            # Forward
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            train_bar.set_postfix(loss=loss.item())

        avg_train_loss = total_loss / len(train_dl)
        print(f"Train Loss: {avg_train_loss:.4f}")

        # Validation
        if valid_dl:
            model.eval()
            val_loss = 0.0
            val_bar = tqdm(valid_dl, desc="Validating", leave=False)
            with torch.no_grad():
                for inputs, targets in val_bar:
                    inputs = inputs.to(device)
                    targets = targets.to(device)

                    outputs = model(inputs)
                    loss = loss_fn(outputs, targets)
                    val_loss += loss.item()
                    val_bar.set_postfix(loss=loss.item())

            avg_val_loss = val_loss / len(valid_dl)
            print(f"Valid Loss: {avg_val_loss:.4f}")

In [None]:
model = SRCNN()
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train(model, train_dl, valid_dl, optimizer, loss_fn, device, opt.epochs)
#3669/1947 -> 1947/1920 (epoch 50): 별 의미 x

test

In [335]:
def test(model, test_dl, device, loss_fn):
    model.eval()
    model.to(device)
    
    total_loss = 0
    n = 0
    with torch.no_grad():
        for inputs, targets in test_dl:
            inputs = inputs.to(device)
            targets = targets.to(device)

            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            total_loss += loss.item()
            n += inputs.size(0)
            
    print(f"total loss: {total_loss/n:.2f}")


In [334]:
model = SRCNN()
#model.load_state_dict(torch.load('best.pth', map_location='cpu'))

test(model, test_dl, device, loss_fn)


total loss: 3826.01
