In [8]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
from sklearn.preprocessing import MinMaxScaler
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch
import matplotlib.pyplot as plt

In [9]:
# gpu 설정
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cpu


In [10]:
# data 셋 생성
class SequenceDataset(Dataset):
    def __init__(self, data, seq_len):
        self.data = data
        self.seq_len = seq_len

    def __len__(self):
        return len(self.data) - self.seq_len + 1

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.seq_len]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(x, dtype=torch.float32)

In [11]:
# Load & preprocess
def data_load(path):
    df = pd.read_csv(path)
    df['dates'] = pd.to_datetime(df['dates'])
    df.sort_values('dates', inplace=True)
    df.reset_index(drop=True, inplace=True)
    
    return df

def data_sclaer(df,col):

    df_scale = df.copy()

    sclaer = MinMaxScaler()
    
    df_scale[col] = sclaer.fit_transform(df_scale[[col]])


    return df_scale, sclaer

# 학습용 데이터 shape 맞추기
def data_shape(data,time_steps,col_list):
    '''
    data,time_steps,col_list
    
    return data(shape(row,time_steps,feature))
    '''
    data_list = []

    for i in tqdm(range(len(data)-time_steps+1)):
        data_list.append(data.loc[i:(i+time_steps-1),col_list].values)

    data_list = np.array(data_list)

    print(data_list.shape)

    return data_list

In [None]:
# MTad-GAN 모델 정의

# Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim=3, latent_dim=24, hidden_dim=16):
        super(Encoder, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, latent_dim)
    
    def forward(self, x):
        x, (h, _) = self.lstm(x)
        z = self.fc(h[-1])
        return z

# Generator
class Generator(nn.Module):
    def __init__(self, latent_dim=24, time_steps=120, output_dim=3, hidden_dim=16):
        super(Generator, self).__init__()
        self.repeat = time_steps
        self.lstm = nn.LSTM(latent_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, z):
        z = z.unsqueeze(1).repeat(1, self.repeat, 1)
        x, _ = self.lstm(x)
        out = self.fc(x)
        return out

# Critic
class CriticX(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=16):
        super(CriticX, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        x, (h, _) = self.lstm(x)
        out = self.fc(h[-1])
        return out

class CriticZ(nn.Module):
    def __init__(self, latent_dim=24, hidden_dim=16):
        super(CriticZ, self).__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)
        self.leakyrelu = nn.LeakyReLU(0.2)
    
    def forward(self, z):
        x = self.leakyrelu(self.fc1(z))
        out = self.fc2(x)
        return out

In [None]:
train_col = ['rn', 'vl', 'wl']
time_steps = 120
latent_dim = 24
hidden_dim = 16
num_features = 3
batch_size = 64
epochs = 50
critic_iter = 5
lambda_z = 0.1

In [None]:
# 데이터 불러오기
gwangjoo_train_df = data_load('../new_data_set/train_2920010001045020.csv')
changwon_train_df = data_load('../new_data_set/train_4812110001018020.csv')

# 데이터 sclaer

gwangjoo_train_data,g_rn_sclaer = data_sclaer(gwangjoo_train_df,'rn')
gwangjoo_train_data,g_vl_sclaer = data_sclaer(gwangjoo_train_data,'vl')
gwangjoo_train_data,g_wl_sclaer = data_sclaer(gwangjoo_train_data,'wl')
changwon_train_data,c_rn_sclaer = data_sclaer(changwon_train_df,'rn')
changwon_train_data,c_vl_sclaer = data_sclaer(changwon_train_data,'vl')
changwon_train_data,c_wl_sclaer = data_sclaer(changwon_train_data,'wl')

# # 데이터 shape
gwangjoo_train_data = data_shape(gwangjoo_train_data,time_steps,train_col)
changwon_train_data = data_shape(changwon_train_data,time_steps,train_col)

In [None]:
train_dataset = SequenceDataset(gwangjoo_train_data)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(num_features, latent_dim, hidden_dim).to(device)
generator = Generator(latent_dim, time_steps, num_features, hidden_dim).to(device)
critic_x = CriticX(num_features, hidden_dim).to(device)
critic_z = CriticZ(latent_dim, hidden_dim).to(device)

In [None]:
optimizer_EG = optim.Adam(list(encoder.parameters()) + list(generator.parameters()), lr=0.0002, betas=(0.5, 0.9))
optimizer_Cx = optim.Adam(critic_x.parameters(), lr=0.0002, betas=(0.5, 0.9))
optimizer_Cz = optim.Adam(critic_z.parameters(), lr=0.0002, betas=(0.5, 0.9))

In [None]:
# 학습 준비
critic_losses_x = []
critic_losses_z = []
generator_losses = []

best_loss = float('inf')
best_epoch = -1

In [None]:
for epoch in range(epochs):
    for real_series in train_loader:
        real_series = real_series.to(device)

        # Critic_x 학습
        for _ in range(critic_iter):
            z = encoder(real_series)
            generated_series = generator(z)

            critic_real_x = critic_x(real_series)
            critic_fake_x = critic_x(generated_series.detach())
            loss_cx = -(torch.mean(critic_real_x) - torch.mean(critic_fake_x))

            optimizer_Cx.zero_grad()
            loss_cx.backward()
            optimizer_Cx.step()

        # Critic_z 학습
        for _ in range(critic_iter):
            z_real = torch.randn(real_series.size(0), latent_dim).to(device)
            z_encoded = encoder(real_series).detach()

            critic_real_z = critic_z(z_real)
            critic_fake_z = critic_z(z_encoded)
            loss_cz = -(torch.mean(critic_real_z) - torch.mean(critic_fake_z))

            optimizer_Cz.zero_grad()
            loss_cz.backward()
            optimizer_Cz.step()

        # Generator & Encoder 학습
        z_encoded = encoder(real_series)
        generated_series = generator(z_encoded)
        critic_fake_x = critic_x(generated_series)
        critic_fake_z = critic_z(z_encoded)

        loss_eg = -torch.mean(critic_fake_x) - lambda_z * torch.mean(critic_fake_z)

        optimizer_EG.zero_grad()
        loss_eg.backward()
        optimizer_EG.step()
    
    critic_losses_x.append(loss_cx.item())
    critic_losses_z.append(loss_cz.item())
    generator_losses.append(loss_eg.item())

    # ✅ Best 모델 저장
    if loss_eg.item() < best_loss:
        best_loss = loss_eg.item()
        best_epoch = epoch
        torch.save({
            'encoder_state_dict': encoder.state_dict(),
            'generator_state_dict': generator.state_dict(),
            'critic_x_state_dict': critic_x.state_dict(),
            'critic_z_state_dict': critic_z.state_dict(),
            'epoch': epoch,
            'loss': best_loss
        }, "best_tadgan_criticxz_model.pt")
        print(f"[Saved best model at epoch {epoch} | Loss: {best_loss:.4f}]")

    if epoch % 5 == 0:
        print(f"[{epoch}] Critic_x: {loss_cx.item():.4f}, Critic_z: {loss_cz.item():.4f}, EG: {loss_eg.item():.4f}")

In [None]:
plt.figure(figsize=(10,6))
plt.plot(range(epochs), critic_losses_x, label="Critic_x Loss")
plt.plot(range(epochs), critic_losses_z, label="Critic_z Loss")
plt.plot(range(epochs), generator_losses, label="Generator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Gwangjoo_TadGAN Loss Curves (Critic_x + Critic_z)")
plt.legend()
plt.grid(True)
plt.savefig(f'mtad_gan_{time_steps}_loss_gwangjoo.png')
plt.show()

In [None]:
train_dataset = SequenceDataset(changwon_train_data, seq_len=time_steps)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [None]:
critic_losses_x_finetune = []
critic_losses_z_finetune = []
generator_losses_finetune = []

for epoch in range(epochs):
    for real_series in train_loader:
        real_series = real_series.to(device)

        # Critic_x 학습
        for _ in range(critic_iter):
            z = encoder(real_series)
            generated_series = generator(z)

            critic_real_x = critic_x(real_series)
            critic_fake_x = critic_x(generated_series.detach())
            loss_cx = -(torch.mean(critic_real_x) - torch.mean(critic_fake_x))

            optimizer_Cx.zero_grad()
            loss_cx.backward()
            optimizer_Cx.step()

        # Critic_z 학습
        for _ in range(critic_iter):
            z_real = torch.randn(real_series.size(0), latent_dim).to(device)
            z_encoded = encoder(real_series).detach()

            critic_real_z = critic_z(z_real)
            critic_fake_z = critic_z(z_encoded)
            loss_cz = -(torch.mean(critic_real_z) - torch.mean(critic_fake_z))

            optimizer_Cz.zero_grad()
            loss_cz.backward()
            optimizer_Cz.step()

        # Generator & Encoder 학습
        z_encoded = encoder(real_series)
        generated_series = generator(z_encoded)
        critic_fake_x = critic_x(generated_series)
        critic_fake_z = critic_z(z_encoded)

        loss_eg = -torch.mean(critic_fake_x) - lambda_z * torch.mean(critic_fake_z)

        optimizer_EG.zero_grad()
        loss_eg.backward()
        optimizer_EG.step()

    # 로그 저장
    critic_losses_x_finetune.append(loss_cx.item())
    critic_losses_z_finetune.append(loss_cz.item())
    generator_losses_finetune.append(loss_eg.item())

    # best 모델 저장
    if loss_eg.item() < best_loss:
        best_loss = loss_eg.item()
        best_epoch = epoch
        torch.save({
            'encoder_state_dict': encoder.state_dict(),
            'generator_state_dict': generator.state_dict(),
            'critic_x_state_dict': critic_x.state_dict(),
            'critic_z_state_dict': critic_z.state_dict(),
            'epoch': epoch,
            'loss': best_loss
        }, "best_tadgan_criticxz_model_finetuned.pt")
        print(f"[Saved fine-tuned best model at epoch {epoch} | Loss: {best_loss:.4f}]")

    if epoch % 5 == 0:
        print(f"[{epoch}] Critic_x: {loss_cx.item():.4f}, Critic_z: {loss_cz.item():.4f}, EG: {loss_eg.item():.4f}")

In [None]:
plt.figure(figsize=(10,6))
plt.plot(range(epochs), critic_losses_x_finetune, label="Critic_x Loss (fine-tune)")
plt.plot(range(epochs), critic_losses_z_finetune, label="Critic_z Loss (fine-tune)")
plt.plot(range(epochs), generator_losses_finetune, label="Generator Loss (fine-tune)")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Changwon TadGAN Loss Curves (Critic_x + Critic_z)")
plt.legend()
plt.grid(True)
plt.savefig(f'mtad_gan_{time_steps}_loss_changwon.png')
plt.show()

In [None]:
gc_critic_losses_x = critic_losses_x + critic_losses_x_finetune
gc_critic_losses_z = critic_losses_z + critic_losses_z_finetune
gc_generator_losses = generator_losses + generator_losses_finetune

plt.figure(figsize=(10,6))
plt.plot(range(len(gc_critic_losses_x)), gc_critic_losses_x, label="gc_Critic_x Loss")
plt.plot(range(len(gc_critic_losses_z)), gc_critic_losses_z, label="gc_Critic_z Loss")
plt.plot(range(len(gc_generator_losses)), gc_generator_losses, label="gc_Generator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("gc TadGAN Loss Curves")
plt.legend()
plt.grid(True)
plt.savefig(f'mtad_gan_{time_steps}_loss_gc.png')
plt.show()