In [1]:
import torch
import torch.nn as nn

from models.lstm_generator import LSTMGenerator
from models.cnn_discriminator import CNNDiscriminator
from models.rollout import ROLLOUT

### генератор: работает или нет

In [2]:
emb_dim = 10
hidden_dim = 20
vocab_size = 100
max_seq_len = 15
padding_idx = 0
G = LSTMGenerator(embedding_dim=emb_dim,
                  hidden_dim=hidden_dim,
                  vocab_size=vocab_size,
                  max_seq_len=max_seq_len,
                  padding_idx=padding_idx)

In [3]:
G

LSTMGenerator(
  (embeddings): Embedding(100, 10, padding_idx=0)
  (lstm): LSTM(10, 20, batch_first=True)
  (lstm2out): Linear(in_features=20, out_features=100, bias=True)
  (logsoftmax): LogSoftmax(dim=-1)
)

In [4]:
batch_size = 32
seq_len = max_seq_len
batch = torch.randint(0, 100, size=(batch_size, seq_len))
h = torch.zeros(1, batch_size, hidden_dim)
c = torch.zeros(1, batch_size, hidden_dim)

res = G(batch, (h, c))

In [5]:
res.shape

torch.Size([32, 15, 100])

In [6]:
G.sample(5, 5, start_letter=3)

tensor([[78, 15, 79, 98, 96, 47,  6, 50, 34, 53, 35, 16, 67, 30, 59],
        [41, 68, 84, 70, 27, 63, 24, 67, 10, 89, 13, 57, 91, 48, 89],
        [70, 31,  8, 85, 58, 10, 99, 10, 52, 24, 77, 87, 43, 29,  2],
        [62, 15,  6, 58, 11, 22, 72, 44, 63, 15, 36, 95, 34, 79,  1],
        [32, 94, 57, 21, 80, 86, 63,  8, 32,  4, 74, 33,  8, 21,  7]])

### дискриминатор: работает или нет

In [7]:
inp = G.sample(batch_size, batch_size, start_letter=3)

In [8]:
D = CNNDiscriminator(embed_dim=emb_dim,
                     vocab_size=vocab_size,
                     filter_sizes=[2, 3],
                     num_filters=[2, 2],
                     padding_idx=0)

In [9]:
D

CNNDiscriminator(
  (embeddings): Embedding(100, 10, padding_idx=0)
  (convs): ModuleList(
    (0): Conv2d(1, 2, kernel_size=(2, 10), stride=(1, 1))
    (1): Conv2d(1, 2, kernel_size=(3, 10), stride=(1, 1))
  )
  (highway): Linear(in_features=4, out_features=4, bias=True)
  (feature2out): Linear(in_features=4, out_features=2, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [10]:
out = D(inp)

In [11]:
out

tensor([[-0.6857, -0.1100],
        [-0.8835, -0.1034],
        [-0.5736, -0.0756],
        [-0.0322,  0.0009],
        [-0.0240, -0.1292],
        [-0.5415, -0.0896],
        [-0.0364, -0.0650],
        [-0.9004, -0.1183],
        [-0.0573, -0.1141],
        [-0.1450, -0.1378],
        [-0.7085, -0.0803],
        [-0.6632, -0.1365],
        [-0.7624, -0.1216],
        [-0.4052, -0.1949],
        [-0.1462, -0.1602],
        [-0.3522, -0.0019],
        [-0.7602, -0.0924],
        [-0.3490, -0.0620],
        [-0.5350, -0.1046],
        [-0.2492, -0.0875],
        [-0.8708, -0.1234],
        [-0.7513, -0.1440],
        [-0.1119, -0.0687],
        [-0.6054, -0.1338],
        [-0.8201, -0.1144],
        [-0.4562, -0.0895],
        [-0.1772,  0.0027],
        [-0.6325, -0.1095],
        [-0.7966, -0.0822],
        [-0.5749,  0.0305],
        [-0.5983, -0.0959],
        [-0.3857, -0.1477]], grad_fn=<AddmmBackward>)

In [12]:
out.shape

torch.Size([32, 2])

### forward + backward

In [13]:
gen_lr = 0.01
discr_lr = 0.01

gen_opt = torch.optim.Adam(G.parameters(),
                           lr=gen_lr)
gen_adv_opt = torch.optim.Adam(G.parameters(),
                               lr=gen_lr)
discr_opt = torch.optim.Adam(D.parameters(),
                             lr=discr_lr)

mle_criterion = nn.NLLLoss()
dis_criterion = nn.CrossEntropyLoss()

**pretrain generator**

In [14]:
epochs = 10

batch_size = 32
seq_len = max_seq_len
batch = torch.randint(0, 100, size=(batch_size, seq_len))
h = torch.zeros(1, batch_size, hidden_dim)
c = torch.zeros(1, batch_size, hidden_dim)

target = torch.ones(batch_size, seq_len, dtype=torch.long)

pred = G(batch, (h, c))  # (batch * seq) * vocab
loss = mle_criterion(pred.permute(0, 2, 1), target)
gen_opt.zero_grad()
loss.backward()
gen_opt.step()

**pretrain discriminator**

In [15]:
pos_samples = torch.randint(0, 100, size=(batch_size, seq_len))
neg_samples = G.sample(num_samples=batch_size,
                       batch_size=batch_size, start_letter=0)
inp = torch.cat((pos_samples, neg_samples), dim=0).long().detach()
targets = torch.tensor([1]*batch_size + [0]*batch_size,
                       dtype=torch.long)
pred = D(inp)
loss = dis_criterion(pred, targets)
discr_opt.zero_grad()
loss.backward()
discr_opt.step()

**adversarial training**

In [19]:
# train generator
rollout_num = 20

rollout_func = ROLLOUT(G)
gen_inp = torch.randint(0, 100, size=(batch_size, seq_len))
gen_target = torch.randint(0, 100, size=(batch_size, seq_len))
rewards = rollout_func.get_reward(gen_target, rollout_num, D)
print(rewards.shape)

adv_loss = G.batchPGLoss(gen_inp, gen_target, rewards)
gen_adv_opt.zero_grad()
adv_loss.backward()
gen_adv_opt.step()

torch.Size([32, 15])


In [20]:
# train discriminator
pos_samples = torch.randint(0, 100, size=(batch_size, seq_len))
neg_samples = G.sample(num_samples=batch_size,
                       batch_size=batch_size, start_letter=0)
inp = torch.cat((pos_samples, neg_samples), dim=0).long().detach()
targets = torch.tensor([1]*batch_size + [0]*batch_size,
                       dtype=torch.long)
pred = D(inp)
loss = dis_criterion(pred, targets)
discr_opt.zero_grad()
loss.backward()
discr_opt.step()