필요한 라이브러리 및 모듈 임포트

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd.variable import Variable
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import DataLoader
import random

Generator 클래스 정의

In [None]:
class Generator(nn.Module):
    """
    Generoator 역할: 실제 data를 기반으로 실제와 비슷한 data 생성
    input (Tensor): 실제 data
    output (Tensor): generator로 생성한 data
    """
    def __init__(self, pretrained_model='gpt2'):
        super(Generator, self).__init__()
        self.generator = GPT2LMHeadModel.from_pretrained(pretrained_model)

    def forward(self, input_ids, attention_mask=None, labels=None):
        outputs = self.generator(input_ids, attention_mask=attention_mask, labels=labels)
        return outputs

Discriminator 클래스 정의

In [None]:
class Discriminator(nn.Module):
    """
    Discriminator 역할: data가 real 인지 fake인지 예측
    input (Tensor): real data or fake data
    output (Tensor): real or fake
    """
    def __init__(self, input_size):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

Gan 클래스 정의

In [None]:
class GAN(nn.Module):
    """
    Generator + Discriminator

    input (Tensor): real data or fake data
    output (Tensor): real or fake
    """
    def __init__(self, generator, discriminator):
        super(GAN, self).__init__()
        self.generator = generator
        self.discriminator = discriminator

    def forward(self, x):
        generated_data = self.generator(x)
        output = self.discriminator(generated_data.float())
        return output

real & fake data 생성

In [None]:
def generate_real_and_fake_data(batch_data, generator, tokenizer, max_length=50):
    """
    실제 문장이 들어오면 모델에서 사용할 real data 형태로 바꾸고, fake data를 생성시키는 함수

    input:
    - batch_data (list): 배치로 구성된 실제 문장 리스트
    - generator (Generator): 가짜 데이터 생성을 담당하는 생성자 모델
    - tokenizer (GPT2Tokenizer): 토크나이저
    - max_length (int): 토큰의 최대 길이 (기본값: 50)

    반환:
    - real_indexes (Tensor): real_data의 index로 구성된 텐서 size = (batch, max_length)
    - fake_indexes (Tensor): 생성된 fake_data의 index로 구성된 텐서 size = (batch, max_length)
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Add padding token to the tokenizer
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    real_indexes, fake_indexes = [], []

    for real_data in batch_data:
        # Process real data
        input_ids = tokenizer.encode(real_data, return_tensors='pt', truncation=True, padding=True, max_length=max_length)
        input_ids = input_ids.to(device)

        # Real data
        real_input_idx = pad_sequence(input_ids[0], max_length, tokenizer.pad_token_id)
        real_indexes.append(real_input_idx)

        # Generate fake data
        output_ids = generator(input_ids).logits[0].argmax(dim=-1)

        # Fake data
        fake_output_idx = pad_sequence(output_ids, max_length, tokenizer.pad_token_id)
        fake_indexes.append(fake_output_idx)

    real_indexes = torch.stack(real_indexes)
    fake_indexes = torch.stack(fake_indexes)

    return real_indexes, fake_indexes

In [None]:
def pad_sequence(sequence, max_length, pad_token_id):
    """
    주어진 sequence를 max_length로 패딩하는 함수

    input:
    - sequence (Tensor): input sequence
    - max_length (int): 최대로 패딩할 길이
    - pad_token_id (int): 패딩에 사용할 토큰의 인덱스

    output:
    - padded_sequence (Tensor): 패딩된 sequence
    """
    padding_length = max_length - len(sequence)

    if padding_length > 0:
        padded_sequence = F.pad(sequence, (0, padding_length), value=pad_token_id)
    else:
        padded_sequence = sequence

    return padded_sequence

Train

In [None]:
def train_gan(generator, discriminator, gan, num_epochs, batch_size, learning_rate, tokenizer, data):
    criterion = nn.BCELoss()

    optimizer_generator = optim.Adam(generator.parameters(), lr=learning_rate)
    optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=learning_rate)

    #===========================================
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    discriminator = discriminator.to(device)
    gan = gan.to(device)
    criterion = criterion.to(device)

    real_data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True)
    len_data = len(real_data_loader)
    for epoch in range(num_epochs):
        # Discriminator 학습
        for real_data in real_data_loader:
            real_data, fake_data = generate_real_and_fake_data(real_data, generator, tokenizer)

            real_data = real_data.to(device).float()
            fake_data = fake_data.to(device).float()

            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Real data loss
            output_real = discriminator(real_data)
            loss_real = criterion(output_real, real_labels)
            loss_real.backward()

            # Fake data loss
            output_fake = discriminator(fake_data.detach())
            loss_fake = criterion(output_fake, fake_labels)
            loss_fake.backward()

            # Optimize
            optimizer_discriminator.step()
            discriminator.zero_grad()

        # Generator 학습
        ranint = random.randint(0,len_data-1)
        input_data = real_data[ranint].to(device).float()
        labels = torch.ones(batch_size, 1).to(device)

        gan_data = gan(input_data).float()
        loss_generator = criterion(gan_data, labels)
        loss_generator.backward()

        optimizer_generator.step()
        generator.zero_grad()

        if (epoch + 1) % 10 == 0:
          torch.save(generator, '/change_your_dir/generator.pt')
          torch.save(discriminator, '/change_your_dir/discriminator.pt')

          print(f'Epoch [{epoch+1}/{num_epochs}],'
              f'Generator Loss: {loss_generator.item():.4f}, '
              f'Discriminator Loss: {0.5 * (loss_real + loss_fake).item():.4f}')

In [None]:
pretrained_model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model_name)
generator = Generator(pretrained_model=pretrained_model_name)
discriminator = Discriminator(50) #length_size
gan = GAN(generator, discriminator)

num_epochs = 100
batch_size = 64
learning_rate = 0.0002

data = list(pd.read_csv('your_data').dropna(axis=0, how='any')) #list 형식의 자연어 data

# GAN 학습
train_gan(generator, discriminator, gan, num_epochs, batch_size, learning_rate, tokenizer, data)