In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai import *
from fastai.text import *

In [3]:
path = Path('../data')

In [4]:
data_lm = load_data(path, 'poems_ds')
data_lm.show_batch()

idx,text
0,"! heart ! \n o the bleeding drops of red , \n xxmaj where on the deck my xxmaj captain lies , \n xxmaj fallen cold and dead . \n \n o xxmaj captain ! my xxmaj captain ! rise up and hear the bells ; \n xxmaj rise up -- for you the flag is flung -- for you the bugle trills ; 10"
1,"ten thousand saw i at a glance , \n xxmaj tossing their heads in sprightly dance . \n \n \n xxmaj the waves beside them danced ; but they \n xxmaj out - did the sparkling waves in glee : \n a poet could not but be gay , \n xxmaj in such a jocund company : \n i gazed — and gazed —"
2,"xxmaj what did i know , what did i know \n of love 's austere and lonely offices ? xxbos \n \n xxup life , believe , is not a dream \n xxmaj so dark as sages say ; \n xxmaj oft a little morning rain \n xxmaj foretells a pleasant day . \n xxmaj sometimes there are clouds of gloom , \n"
3,nought can xxmaj deform the xxmaj human xxmaj race \n xxmaj like to the xxmaj armours iron brace \n xxmaj when xxmaj gold & xxmaj gems adorn the xxmaj plow \n xxmaj to peaceful xxmaj arts shall xxmaj envy xxmaj bow \n a xxmaj riddle or the xxmaj crickets xxmaj cry \n xxmaj is to xxmaj doubt a fit xxmaj reply \n xxmaj the
4,"xxmaj not everything that can be counted counts , \n and not everything that counts can be counted . \n \n \n xxmaj only one who devotes himself \n to a cause \n with his whole strength and soul \n can be a true master . \n xxmaj for this reason mastery \n demands all of a person . \n \n \n"


In [5]:
trn_dl = data_lm.train_dl
val_dl = data_lm.valid_dl

In [6]:
def lm_loss(input, target, kld_weight=0):
    sl, bs = target.size()
    sl_in,bs_in,nc = input.size()
    return F.cross_entropy(input.view(-1,nc), target.view(-1))

In [16]:
def bn_drop_lin(n_in, n_out, bn=True, initrange=0.01,p=0, bias=True, actn=nn.LeakyReLU(inplace=True)):
    layers = [nn.BatchNorm1d(n_in)] if bn else []
    if p != 0: layers.append(nn.Dropout(p))
    linear = nn.Linear(n_in, n_out, bias=bias)
    if initrange:linear.weight.data.uniform_(-initrange, initrange)
    if bias: linear.bias.data.zero_()
    layers.append(linear)
    if actn is not None: layers.append(actn)
    return layers

In [8]:
learn = language_model_learner(data_lm, arch=AWD_LSTM)
learn.load('poems_fine_tuned')

LanguageLearner(data=TextLMDataBunch;

Train: LabelList (672 items)
x: LMTextList
xxbos o xxmaj captain ! my xxmaj captain ! our fearful trip is done ; 
  xxmaj the ship has weather'd every rack , the prize we sought is won ; 
  xxmaj the port is near , the bells i hear , the people all exulting , 
  xxmaj while follow eyes the steady keel , the vessel grim and daring : 
  xxmaj but o heart ! heart ! heart ! 
  o the bleeding drops of red , 
  xxmaj where on the deck my xxmaj captain lies , 
  xxmaj fallen cold and dead . 
 
  o xxmaj captain ! my xxmaj captain ! rise up and hear the bells ; 
  xxmaj rise up -- for you the flag is flung -- for you the bugle trills ; 10 
  xxmaj for you bouquets and ribbon'd wreaths -- for you the shores a - crowding ; 
  xxmaj for you they call , the swaying mass , their eager faces turning ; 
  xxmaj here xxmaj captain ! dear father ! 
  xxmaj this arm beneath your head ; 
  xxmaj it is some dream that on the deck , 
  xxmaj you 've fallen cold and de

In [9]:
encoder = deepcopy(learn.model[0])
encoder

AWD_LSTM(
  (encoder): Embedding(17096, 400, padding_idx=1)
  (encoder_dp): EmbeddingDropout(
    (emb): Embedding(17096, 400, padding_idx=1)
  )
  (rnns): ModuleList(
    (0): WeightDropout(
      (module): LSTM(400, 1150, batch_first=True)
    )
    (1): WeightDropout(
      (module): LSTM(1150, 1150, batch_first=True)
    )
    (2): WeightDropout(
      (module): LSTM(1150, 400, batch_first=True)
    )
  )
  (input_dp): RNNDropout()
  (hidden_dps): ModuleList(
    (0): RNNDropout()
    (1): RNNDropout()
    (2): RNNDropout()
  )
)

In [10]:
x, y = next(iter(trn_dl))
x.size(), y.size()

(torch.Size([64, 70]), torch.Size([64, 70]))

In [11]:
outs = encoder(x)

In [12]:
outs[-1][-1].size()

torch.Size([64, 70, 400])

In [13]:
generator = deepcopy(learn.model) 

In [14]:
generator.load_state_dict(learn.model.state_dict())

In [17]:
class TextDicriminator(nn.Module):
    def __init__(self,encoder, nh, bn_final=True):
        super().__init__()
        #encoder
        self.encoder = encoder
        #classifier
        layers = []
        layers+=bn_drop_lin(nh*3,nh,bias=False)
        layers += bn_drop_lin(nh,nh,p=0.25)
        layers+=bn_drop_lin(nh,1,p=0.15,actn=nn.Sigmoid())
        if bn_final: layers += [nn.BatchNorm1d(1)]
        self.layers = nn.Sequential(*layers)
    
    def pool(self, x, bs, is_max):
        f = F.adaptive_max_pool1d if is_max else F.adaptive_avg_pool1d
        return f(x.permute(0,2,1), (1,)).view(bs,-1)
    
    def forward(self, inp,y=None):
        raw_outputs, outputs = self.encoder(inp)
        output = outputs[-1]
        bs,sl,_ = output.size()
        avgpool = self.pool(output, bs, False)
        mxpool = self.pool(output, bs, True)
        x = torch.cat([output[:,-1], mxpool, avgpool], 1)
        out = self.layers(x)
        return out

In [18]:
disc = TextDicriminator(encoder,400).cuda()

In [19]:
out = disc(x)
out.size()

torch.Size([64, 1])

In [20]:
probs,raw_outputs, outputs = generator(x)

In [21]:
probs.size()

torch.Size([64, 70, 17096])

In [22]:
optimizerD = optim.Adam(disc.parameters(), lr = 3e-4)
optimizerG = optim.Adam(generator.parameters(), lr = 3e-3, betas=(0.7, 0.8))

In [20]:
samples = F.gumbel_softmax(torch.randn(10,100))
samples.size()

torch.Size([10, 100])

In [21]:
def stats(tensor):return torch.mean(tensor),torch.std(tensor)

In [22]:
stats(samples)

(tensor(0.0100), tensor(0.0591))

In [23]:
torch.multinomial(samples,1).squeeze(1)

tensor([36, 62, 99, 19, 86, 31, 11, 51, 12, 77])

In [23]:
def seq_gumbel_softmax(input):
    samples = []
    bs,sl,nc = input.size()
    for i in range(sl): 
        z = F.gumbel_softmax(input[:,i,:])
        samples.append(torch.multinomial(z,1))
    samples = torch.stack(samples).transpose(1,0).squeeze(2) 
    return samples

In [24]:
from tqdm import tqdm

In [25]:
def train(gen, disc, epochs, trn_dl, val_dl, optimizerD, optimizerG, crit=None,first=True):
    gen_iterations = 0
    
    for epoch in range(epochs):
        gen.train(); disc.train()
        n = len(trn_dl)
        #train loop
        with tqdm(total=n) as pbar:
            for i, ds in enumerate(trn_dl):
                x, y = ds
                bs,sl = x.size()
                disc.eval(), gen.train()
                fake,_,_ = gen(x)
                gen.zero_grad()
                fake_sample =seq_gumbel_softmax(fake)
                with torch.no_grad():
                    gen_loss = reward = disc(fake_sample)
                    if crit: gen_loss = crit(fake,fake_sample,reward.squeeze(1))
                    gen_loss = gen_loss.mean()
                gen_loss.requires_grad_(True)
                gen_loss.backward()
                optimizerG.step()
                gen_iterations += 1
                d_iters = 3
                for j in range(d_iters):
                    gen.eval()
                    disc.train()
                    with torch.no_grad():
                        fake,_,_ = gen(x)
                        fake_sample = seq_gumbel_softmax(fake)
                    disc.zero_grad()
                    fake_loss = disc(fake_sample)
                    #fake_loss.requires_grad=True
                    real_loss = disc(y.view(bs,sl))
                    #real_loss.requires_grad=True
                    disc_loss = (fake_loss-real_loss).mean(0)
                    disc_loss.backward()
                    optimizerD.step()
                pbar.update()
        print(f'Epoch {epoch}:')
        print('Train Loss:')
        print(f'Loss_D {disc_loss.data.item()}; Loss_G {gen_loss.data.item()} Ppx {torch.exp(lm_loss(fake,y))}')
        print(f'D_real {real_loss.mean(0).view(1).data.item()}; Loss_D_fake {fake_loss.mean(0).view(1).data.item()}')
        disc.eval(), gen.eval()
        with tqdm(total=len(val_dl)) as pbar:
            for i, ds in enumerate(val_dl):
                with torch.no_grad():
                    x, y = ds
                    bs,sl = x.size()
                    fake,_,_ = gen(x)
                    fake_sample =seq_gumbel_softmax(fake)
                    gen_loss = reward = disc(fake_sample)
                    if crit: gen_loss = crit(fake,fake_sample,reward.squeeze(1))
                    gen_loss = gen_loss.mean()
                    fake_sample = seq_gumbel_softmax(fake)
                    fake_loss = disc(fake_sample)
                    real_loss = disc(y.view(bs,sl))
                    disc_loss = (fake_loss-real_loss).mean(0)
                pbar.update()
        print('Valid Loss:')
        print(f'Loss_D {disc_loss.data.item()}; Loss_G {gen_loss.data.item()} Ppx {torch.exp(lm_loss(fake,y))}')
        print(f'D_real {real_loss.mean(0).view(1).data.item()}; Loss_D_fake {fake_loss.mean(0).view(1).data.item()}')

In [42]:
#gen.load_state_dict(torch.load(PATH/'models/seq2seq_en.h5', map_location=lambda storage, loc: storage)) 
#disc.load_state_dict(torch.load(PATH/'models/disc_en.h5', map_location=lambda storage, loc: storage)) 

In [26]:
disc.train()
generator.train();

In [27]:
train(generator, disc, 6, trn_dl, val_dl, optimizerD, optimizerG, first=False)

100%|██████████| 69/69 [01:06<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.33it/s]

Epoch 0:
Train Loss:
Loss_D -1.341104507446289e-07; Loss_G -0.086343914270401 Ppx 7.658246994018555
D_real 1.6763806343078613e-07; Loss_D_fake 2.9802322387695312e-08


100%|██████████| 6/6 [00:01<00:00,  4.02it/s]
  0%|          | 0/69 [00:00<?, ?it/s]

Valid Loss:
Loss_D -0.013350073248147964; Loss_G -0.28069180250167847 Ppx 56.721004486083984
D_real -0.23393051326274872; Loss_D_fake -0.24728058278560638


100%|██████████| 69/69 [01:07<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.35it/s]

Epoch 1:
Train Loss:
Loss_D 1.7043203115463257e-07; Loss_G -0.03880370408296585 Ppx 7.571606159210205
D_real -2.2724270820617676e-07; Loss_D_fake -5.960464477539063e-08


100%|██████████| 6/6 [00:01<00:00,  4.02it/s]
  0%|          | 0/69 [00:00<?, ?it/s]

Valid Loss:
Loss_D 0.0012347488664090633; Loss_G -0.24660125374794006 Ppx 56.881778717041016
D_real -0.2859431505203247; Loss_D_fake -0.2847083806991577


100%|██████████| 69/69 [01:07<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.35it/s]

Epoch 2:
Train Loss:
Loss_D -1.7881393432617188e-07; Loss_G 0.09737280756235123 Ppx 7.843847751617432
D_real 2.2351741790771484e-07; Loss_D_fake 5.029141902923584e-08


100%|██████████| 6/6 [00:01<00:00,  4.02it/s]
  0%|          | 0/69 [00:00<?, ?it/s]

Valid Loss:
Loss_D 0.04497455060482025; Loss_G 0.16908785700798035 Ppx 57.84066390991211
D_real 0.23502245545387268; Loss_D_fake 0.2799970209598541


100%|██████████| 69/69 [01:07<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.37it/s]

Epoch 3:
Train Loss:
Loss_D -3.6694109439849854e-07; Loss_G 0.07079493254423141 Ppx 8.204992294311523
D_real 2.682209014892578e-07; Loss_D_fake -1.043081283569336e-07


100%|██████████| 6/6 [00:01<00:00,  4.01it/s]
  0%|          | 0/69 [00:00<?, ?it/s]

Valid Loss:
Loss_D -0.024879172444343567; Loss_G 0.2076968550682068 Ppx 57.07601547241211
D_real 0.2681492567062378; Loss_D_fake 0.24327006936073303


100%|██████████| 69/69 [01:07<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.34it/s]

Epoch 4:
Train Loss:
Loss_D 2.123415470123291e-07; Loss_G 0.08917149156332016 Ppx 7.705118656158447
D_real -3.948807716369629e-07; Loss_D_fake -1.9371509552001953e-07


100%|██████████| 6/6 [00:01<00:00,  4.02it/s]
  0%|          | 0/69 [00:00<?, ?it/s]

Valid Loss:
Loss_D 0.0038774628192186356; Loss_G 0.17146974802017212 Ppx 55.272804260253906
D_real 0.21537131071090698; Loss_D_fake 0.21924875676631927


100%|██████████| 69/69 [01:07<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.35it/s]

Epoch 5:
Train Loss:
Loss_D -1.6391277313232422e-07; Loss_G 0.03497626259922981 Ppx 7.247567176818848
D_real 8.568167686462402e-08; Loss_D_fake -7.82310962677002e-08


100%|██████████| 6/6 [00:01<00:00,  4.02it/s]

Valid Loss:
Loss_D -0.025480497628450394; Loss_G 0.18148395419120789 Ppx 55.032745361328125
D_real 0.251420795917511; Loss_D_fake 0.2259402871131897





In [28]:
learn.model.load_state_dict(generator.state_dict())

In [29]:
learn.predict("O the bleeding drops of red",n_words=50)

"O the bleeding drops of red throbbing \n  There are some trampling feet allowed to reach \n  The heads of steed - chariots makes goblet mist \n  As though the departed ones ' bodies were wound . \n  Who knows what that age may impart to those \n  Who feel curious"

In [30]:
learn.save('poems_gan_gumbel')

In [32]:
def reinforce_loss(input,sample,reward):
    loss=0
    bs,sl = sample.size()
    for i in range(sl):
        loss += -input[:,i,sample[:,i]] * reward
    return loss/sl

In [33]:
learn.load('poems_fine_tuned')
encoder = deepcopy(learn.model[0])
disc = TextDicriminator(encoder,400).cuda()
generator = deepcopy(learn.model) 

In [34]:
optimizerD = optim.Adam(disc.parameters(), lr = 1e-4)
optimizerG = optim.Adam(generator.parameters(), lr = 3e-3, betas=(0.7, 0.8))

In [35]:
train(generator, disc, 6, trn_dl, val_dl, optimizerD, optimizerG, crit=reinforce_loss,first=False)

100%|██████████| 69/69 [01:06<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.45it/s]

Epoch 0:
Train Loss:
Loss_D -6.50063157081604e-07; Loss_G 0.10009565949440002 Ppx 7.42289924621582
D_real 5.21540641784668e-07; Loss_D_fake -1.2293457984924316e-07


100%|██████████| 6/6 [00:01<00:00,  4.10it/s]
  0%|          | 0/69 [00:00<?, ?it/s]

Valid Loss:
Loss_D -0.21008385717868805; Loss_G -2.8133182525634766 Ppx 53.0468864440918
D_real -0.2947290241718292; Loss_D_fake -0.5048128366470337


100%|██████████| 69/69 [01:07<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.50it/s]

Epoch 1:
Train Loss:
Loss_D -1.9744038581848145e-07; Loss_G -0.6046543717384338 Ppx 7.8499321937561035
D_real 6.07222318649292e-07; Loss_D_fake 4.0978193283081055e-07


100%|██████████| 6/6 [00:01<00:00,  4.10it/s]
  0%|          | 0/69 [00:00<?, ?it/s]

Valid Loss:
Loss_D -0.07184897363185883; Loss_G -0.6218725442886353 Ppx 52.736106872558594
D_real -0.1615934669971466; Loss_D_fake -0.23344242572784424


100%|██████████| 69/69 [01:07<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.47it/s]

Epoch 2:
Train Loss:
Loss_D -7.82310962677002e-08; Loss_G -2.000972032546997 Ppx 8.486431121826172
D_real -4.544854164123535e-07; Loss_D_fake -5.364418029785156e-07


100%|██████████| 6/6 [00:01<00:00,  4.07it/s]
  0%|          | 0/69 [00:00<?, ?it/s]

Valid Loss:
Loss_D -0.14020952582359314; Loss_G -1.3472065925598145 Ppx 51.095420837402344
D_real -0.23206868767738342; Loss_D_fake -0.37227821350097656


100%|██████████| 69/69 [01:07<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.49it/s]

Epoch 3:
Train Loss:
Loss_D 2.1141022443771362e-07; Loss_G -0.8980365991592407 Ppx 8.2464017868042
D_real -1.7508864402770996e-07; Loss_D_fake 3.725290298461914e-08


100%|██████████| 6/6 [00:01<00:00,  4.09it/s]
  0%|          | 0/69 [00:00<?, ?it/s]

Valid Loss:
Loss_D -0.12584710121154785; Loss_G -1.3681776523590088 Ppx 51.840694427490234
D_real -0.26572832465171814; Loss_D_fake -0.3915754556655884


100%|██████████| 69/69 [01:07<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.43it/s]

Epoch 4:
Train Loss:
Loss_D 2.523884177207947e-07; Loss_G -1.5117201805114746 Ppx 8.406465530395508
D_real -3.3527612686157227e-08; Loss_D_fake 2.1420419216156006e-07


100%|██████████| 6/6 [00:01<00:00,  4.08it/s]
  0%|          | 0/69 [00:00<?, ?it/s]

Valid Loss:
Loss_D -0.0779573991894722; Loss_G -1.1873434782028198 Ppx 55.788875579833984
D_real -0.23494872450828552; Loss_D_fake -0.3129061460494995


100%|██████████| 69/69 [01:07<00:00,  1.03it/s]
 17%|█▋        | 1/6 [00:00<00:00,  5.44it/s]

Epoch 5:
Train Loss:
Loss_D 6.705522537231445e-07; Loss_G -1.3681237697601318 Ppx 8.54354476928711
D_real -3.296881914138794e-07; Loss_D_fake 3.427267074584961e-07


100%|██████████| 6/6 [00:01<00:00,  4.07it/s]

Valid Loss:
Loss_D -0.04955030605196953; Loss_G -1.7284224033355713 Ppx 58.09437942504883
D_real -0.3295239806175232; Loss_D_fake -0.3790743052959442





In [36]:
learn.model.load_state_dict(generator.state_dict())

In [37]:
learn.predict("O the bleeding drops of red",n_words=50)

'O the bleeding drops of red \n  discoloured plant dead \n  are fallen over this river day and day \n  All day long their thin eyelids \n  droop and never flow . \n \n \n  Reflected upon the image \n  Walnut trees ; \n  Stroking their tender hands \n  With shut eyelids .'

In [38]:
learn.save('poems_gan_reinforce')