In [1]:
from fastai.text import *

In [2]:
PATH = Path('data')
POEMS = PATH/'poems.txt'
WIKI = PATH/'wikitext-103/wiki.train.tokens'

In [3]:
poems = POEMS.open().read().split('\n\n\n\n\n\n\n')
poems = [poem.strip().replace('\n','') for poem in poems]
wiki = WIKI.open().read().split('\n')
wiki = [sent.strip() for sent in wiki if len(sent) > 2000]

In [4]:
len(poems), poems[:3], len(wiki)

(746,
 ["O Captain! my Captain! our fearful trip is done;The ship has weather'd every rack, the prize we sought is won;The port is near, the bells I hear, the people all exulting,While follow eyes the steady keel, the vessel grim and daring:But O heart! heart! heart!O the bleeding drops of red,Where on the deck my Captain lies,Fallen cold and dead.O Captain! my Captain! rise up and hear the bells;Rise up--for you the flag is flung--for you the bugle trills; 10For you bouquets and ribbon'd wreaths--for you the shores a-crowding;For you they call, the swaying mass, their eager faces turning;Here Captain! dear father!This arm beneath your head;It is some dream that on the deck,You've fallen cold and dead.My Captain does not answer, his lips are pale and still;My father does not feel my arm, he has no pulse nor will;The ship is anchor'd safe and sound, its voyage closed and done;From fearful trip, the victor ship, comes in with object won; 20Exult, O shores, and ring, O bells!But I, with m

In [5]:
tok_poems = Tokenizer.proc_all(poems, 'en')
tok_wiki = Tokenizer.proc_all(wiki, 'en') 

In [6]:
tok_poems[:2]

[['o',
  'captain',
  '!',
  'my',
  'captain',
  '!',
  'our',
  'fearful',
  'trip',
  'is',
  'done;the',
  'ship',
  'has',
  "weather'd",
  'every',
  'rack',
  ',',
  'the',
  'prize',
  'we',
  'sought',
  'is',
  'won;the',
  'port',
  'is',
  'near',
  ',',
  'the',
  'bells',
  'i',
  'hear',
  ',',
  'the',
  'people',
  'all',
  'exulting',
  ',',
  'while',
  'follow',
  'eyes',
  'the',
  'steady',
  'keel',
  ',',
  'the',
  'vessel',
  'grim',
  'and',
  'daring',
  ':',
  'but',
  'o',
  'heart',
  '!',
  'heart',
  '!',
  'heart!o',
  'the',
  'bleeding',
  'drops',
  'of',
  'red',
  ',',
  'where',
  'on',
  'the',
  'deck',
  'my',
  'captain',
  'lies',
  ',',
  'fallen',
  'cold',
  'and',
  'dead.o',
  'captain',
  '!',
  'my',
  'captain',
  '!',
  'rise',
  'up',
  'and',
  'hear',
  'the',
  'bells;rise',
  'up',
  '--',
  'for',
  'you',
  'the',
  'flag',
  'is',
  'flung',
  '--',
  'for',
  'you',
  'the',
  'bugle',
  'trills',
  ';',
  '10for',
  'you',

In [3]:
trn_poems,val_poems = sklearn.model_selection.train_test_split(tok_poems, test_size=0.2)
trn_wiki,val_wiki = sklearn.model_selection.train_test_split(tok_wiki, test_size=0.2)

pickle.dump(trn_poems, open(PATH/'trn_poems_en.pkl', 'wb'))
pickle.dump(val_poems, open(PATH/'val_poems_en.pkl', 'wb'))

pickle.dump(trn_wiki, open(PATH/'trn_wiki_en.pkl', 'wb'))
pickle.dump(val_wiki, open(PATH/'val_wiki_en.pkl', 'wb'))

trn_poems,val_poems = pickle.load(open(PATH/'trn_poems_en.pkl', 'rb')), pickle.load(open(PATH/'val_poems_en.pkl', 'rb')) 
trn_wiki,val_wiki = pickle.load(open(PATH/'trn_wiki_en.pkl', 'rb')), pickle.load(open(PATH/'val_wiki_en.pkl', 'rb'))

In [4]:
len(trn_poems), len(val_poems), len(trn_wiki), len(val_wiki)

(596, 150, 3244, 812)

In [7]:
def chunks(l, n):
    """Yield successive n-sized chunks from l."""
    for i in range(0, len(l), n):
        yield l[i:i+n]

In [8]:
trn_poems_chunk = list(chunks([tok for poem in trn_poems for tok in poem],50))
val_poems_chunk = list(chunks([tok for poem in val_poems for tok in poem],50))

trn_wiki_chunk = list(chunks([tok for art in trn_wiki for tok in art],200))#[:len(trn_poems)*6]
val_wiki_chunk = list(chunks([tok for art in val_wiki for tok in art],200))#[:len(val_poems)*6]


In [9]:
len(trn_wiki_chunk), len(trn_poems_chunk), len(val_poems_chunk), len(val_wiki_chunk), len(trn_poems), len(val_poems), len(trn_wiki), len(val_wiki)

(7036, 3999, 771, 1741, 596, 150, 3244, 812)

In [10]:
import collections

def toks2ids(tok, pre):
    freq = collections.Counter(p for o in tok for p in o)
    itos = [o for o,c in freq.most_common(30000)]
    itos.insert(0, '_bos_')
    itos.insert(1, '_pad_')
    itos.insert(2, '_eos_')
    itos.insert(3, '_unk')
    stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(itos)})
    pickle.dump(itos, open(PATH/f'{pre}_itos.pkl', 'wb'))
    return itos,stoi

In [11]:
itos,stoi = toks2ids(trn_poems, 'en_2')
itos_d,stoi_d = toks2ids(trn_poems+trn_wiki, 'en_d')

In [9]:
itos = pickle.load(open(PATH/'en_2_itos.pkl', 'rb'))
stoi = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(itos)})
itos_d = pickle.load(open(PATH/'en_d_itos.pkl', 'rb'))
stoi_d = collections.defaultdict(lambda: 3, {v:k for k,v in enumerate(itos)})

In [12]:
len(stoi_d), len(stoi)

(30004, 26125)

In [13]:
trn_p_ids = np.array([[0] + [stoi[o] if o in stoi else 3 for o in p] + [2] for p in trn_poems_chunk])
val_p_ids = np.array([[0] + [stoi[o] if o in stoi else 3 for o in p] + [2] for p in val_poems_chunk])
trn_w_ids = np.array([[0] + [stoi[o] if o in stoi else 3 for o in p] + [2] for p in trn_wiki_chunk])
val_w_ids = np.array([[0] + [stoi[o] if o in stoi else 3 for o in p] + [2] for p in val_wiki_chunk])


In [12]:
sum(len(i) for i in val_p_ids)/len(val_p_ids)

201.765503875969

In [14]:
import io

def load_vectors(fname):
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    header = fin.readline().split()
    n, d = header[0], header[1]
    data = {}
    #partitions = [fin[:]]
    for line in fin:
        tokens = line.rstrip().split(' ')
        data[tokens[0]] = np.array(tokens[1:], dtype=float)
    return data, n, d

In [15]:
en_vecs, vs, dim_en_vec = load_vectors(str(PATH/'wiki-news-300d-1M.vec'))

In [16]:
class GANDataset(Dataset):
    def __init__(self, x): 
        self.x = x
    def __getitem__(self, idx):
        seq_len = len(self.x[idx])
        return A(self.x[idx][:seq_len//2]), A(self.x[idx][seq_len//2:])
    def __len__(self): return len(self.x)
    
class DiscDataset(Dataset):
    def __init__(self, real, fake): 
        self.real = real
        self.fake = fake
    def __getitem__(self, idx):
        return A(self.real[idx]), A(self.fake[idx])
    def __len__(self): return len(self.real)

In [17]:
trn_seq2seq_ds = GANDataset(trn_p_ids)
val_seq2seq_ds = GANDataset(val_p_ids)

trn_gan_ds = GANDataset(trn_p_ids)
val_gan_ds = GANDataset(val_p_ids)

trn_disc_ds = DiscDataset(trn_p_ids, trn_w_ids[:len(trn_p_ids)])
val_disc_ds = DiscDataset(val_p_ids, val_w_ids[:len(trn_p_ids)])

In [18]:
bs=64

trn_samp = SortishSampler(trn_p_ids, lambda x: len(trn_p_ids[x]), bs)
val_samp = SortSampler(val_p_ids, lambda x: len(val_p_ids[x]))

trn_seq2seq_dl = DataLoader(trn_seq2seq_ds,batch_size=bs,pad_idx=1,num_workers=1,pre_pad=False, transpose_y=True, transpose=True, sampler=trn_samp)
val_seq2seq_dl = DataLoader(val_seq2seq_ds,batch_size=bs,pad_idx=1,num_workers=1,pre_pad=False, transpose_y=True, transpose=True, sampler=val_samp)

trn_gan_dl = DataLoader(trn_gan_ds,batch_size=bs,pad_idx=1,num_workers=1,pre_pad=False, transpose_y=True, transpose=True, sampler=trn_samp)
val_gan_dl = DataLoader(val_gan_ds,batch_size=bs,pad_idx=1,num_workers=1,pre_pad=False, transpose_y=True, transpose=True, sampler=val_samp)

trn_disc_dl = DataLoader(trn_disc_ds,batch_size=bs//2,pad_idx=1,num_workers=1,pre_pad=False, transpose_y=True, transpose=True)
val_disc_dl = DataLoader(val_disc_ds,batch_size=bs//2,pad_idx=1,num_workers=1,pre_pad=False, transpose_y=True, transpose=True)

In [19]:
x, y = next(iter(trn_disc_dl))
x.size(), y.size()

(torch.Size([52, 32]), torch.Size([202, 32]))

In [20]:
def rand_t(*sz): return torch.randn(sz)/math.sqrt(sz[0])
def rand_p(*sz): return nn.Parameter(rand_t(*sz))

In [21]:
class Seq2SeqAttention(nn.Module):
    def __init__(self, vecs, itos, em_sz, nh, nl=2, dropf=1.0):
        super().__init__()
        self.emb_enc = create_emb(vecs, itos, em_sz)
        self.nl,self.nh = nl,nh
        self.lstm_enc = nn.LSTM(em_sz, nh, num_layers=nl, dropout=0.25*dropf, bidirectional=True)
        self.out_enc = nn.Linear(nh*2, em_sz, bias=False)
        self.drop_enc = nn.Dropout(0.25*dropf)
        self.emb_dec = create_emb(vecs, itos, em_sz)
        self.lstm_dec = nn.LSTM(em_sz, em_sz, num_layers=nl, dropout=0.1)
        self.emb_enc_drop = nn.Dropout(0.15*dropf)
        self.out_drop = nn.Dropout(0.35*dropf)
        self.out = nn.Linear(em_sz, len(itos))
        self.out.weight.data = self.emb_dec.weight.data

        self.W1 = rand_p(nh*2, em_sz)
        self.l2 = nn.Linear(em_sz, em_sz)
        self.l3 = nn.Linear(em_sz+nh*2, em_sz)
        self.V = rand_p(em_sz)

    def forward(self, inp, y=None,sampling=False):
        sl,bs = inp.size()
        h_n, c_n = self.initHidden(bs)
        emb = self.emb_enc_drop(self.emb_enc(inp))
        enc_out, (h_n, c_n) = self.lstm_enc(emb, (h_n, c_n))
        h_n = h_n.view(2,2,bs,-1).permute(0,2,1,3).contiguous().view(self.nl,bs,-1)
        h_n = self.out_enc(self.drop_enc(h_n))
        c_n = c_n.view(2,2,bs,-1).permute(0,2,1,3).contiguous().view(self.nl,bs,-1)
        c_n = self.out_enc(self.drop_enc(c_n))
        pg_loss=0
        dec_inp = V(torch.zeros(bs).long())
        res,attns = [],[]
        samples = V(torch.zeros(sl, bs).long())
        w1e = enc_out @ self.W1
        for i in range(sl):
            w2h = self.l2(h_n[-1])
            u = torch.tanh(w1e + w2h)
            a = F.softmax(u @ self.V, 0)
            attns.append(a)
            Xa = (a.unsqueeze(2) * enc_out).sum(0)
            emb = self.emb_dec(dec_inp)
            wgt_enc = self.l3(torch.cat([emb, Xa], 1))
            outp, (h_n, c_n) = self.lstm_dec(wgt_enc.unsqueeze(0), (h_n,c_n))
            outp = self.out(self.out_drop(outp[0]))
            res.append(outp)
            if not sampling:
                dec_inp = V(outp.data.max(1)[1])
            else:
                outp = F.log_softmax(outp, dim=1)
                outp = torch.multinomial(torch.exp(outp), 1)
                samples[i, :] = outp.view(-1).data
                dec_inp = V(outp.view(-1))
            if (dec_inp==1).all(): break
            if (y is not None):
                if i>=len(y): break
                dec_inp = y[i]
        h_n = repackage_hidden(h_n)
        c_n = repackage_hidden(c_n)
        
        return torch.stack(res), samples
    
    
    def initHidden(self, bs): return V(torch.zeros(self.nl*2, bs, self.nh)), V(torch.zeros(self.nl*2, bs, self.nh))

In [22]:
def repackage_hidden(h):
    return h.detach() if isinstance(h, torch.Tensor) else tuple(repackage_hidden(v) for v in h)

In [23]:
def seq2seq_loss(input, target, kld_weight=0):
    decoded = input
    sl, bs = target.size()
    sl_in,bs_in,nc = decoded.size()
    if sl>sl_in: decoded = F.pad(decoded, (0,0,0,0,0,sl-sl_in))
    decoded = decoded[:sl]
    return F.cross_entropy(decoded.view(-1,nc), target.view(-1))

In [24]:
class TextDicriminator(nn.Module):
    def __init__(self,vecs,itos,em_sz,nh,nl=2, dropf=1):
        super().__init__()
        #encoder
        self.emb_enc = create_emb(vecs,itos,em_sz)
        self.emb_enc_drop = nn.Dropout(0.15*dropf)
        self.nl,self.nh, self.vs = nl,nh,len(itos)
        self.lstm_enc = nn.LSTM(em_sz, nh, num_layers=nl, dropout=0.25*dropf, bidirectional=True)
        self.out_drop = nn.Dropout(0.35*dropf)
        #attention
        self.W1 = rand_p(nh*2, em_sz)
        self.l2 = nn.Linear(nh*2, em_sz)
        self.V = rand_p(em_sz)
        
        #classifier
        #self.out_enc = nn.Linear(nh*3*2, em_sz)
        #self.nonlin = nn.Tanh()
        self.out = nn.Linear(nh*2, 1)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, inp, y=None):
        #encode
        sl,bs = inp.size()
        hidden = self.initHidden(bs) # nl*bs*nh
        emb = self.emb_enc_drop(self.emb_enc(inp)) #sl*bs*em_sz
        enc_out, (h_n, c_n) = self.lstm_enc(emb,hidden) #enc_out: sl*bs*nh, hidden: nl*bs*nh
        #classify
        h_n = h_n.view(2,2,bs,-1).permute(0,2,1,3).contiguous().view(self.nl,bs,-1)
        #hidden = torch.cat([enc_out.max(0)[0],enc_out.sum(0),h_n[-1]],1)
        #c = self.nonlin(self.out_enc(hidden))
        #c = self.sigmoid(self.out(self.out_drop(c)))
        out = []
        w1e = enc_out @ self.W1
        #for i in range(sl):
        w2h = self.l2(h_n[-1])
        u = torch.tanh(w1e + w2h)
        a = F.softmax(u @ self.V, 0)
        #out.append(a)
        Xa = (a.unsqueeze(2) * enc_out).sum(0)
            #emb = self.emb_enc(inp[i])
        #wgt_enc = self.l3(torch.cat([h_n[-1], Xa], 1))
            #c = self.nonlin(self.out_enc(hidden))
        c = self.sigmoid(self.out(self.out_drop(Xa)))
            #out.append(self.sigmoid(wgt_enc))
            #enc = self.nonlin(self.out_enc(torch.cat([enc_out.max(0)[0],enc_out.sum(0),enc_out[i]],1)))
            #out.append(self.sigmoid(self.out(self.out_drop(enc))))
        #h_n = repackage_hidden(h_n)
        #c_n = repackage_hidden(c_n)
        #return out
        return  c, a

    def initHidden(self, bs): return V(torch.zeros(self.nl*2, bs, self.nh)), V(torch.zeros(self.nl*2, bs, self.nh))

In [25]:
def create_emb(vecs, itos, em_sz):
    emb = nn.Embedding(len(itos), em_sz, padding_idx=1)
    wgts = emb.weight.data
    miss = []
    for i,w in enumerate(itos):
        try: wgts[i] = torch.from_numpy(vecs[w])
        except: miss.append(w)
    print(len(miss),miss[5:10])
    return emb

In [26]:
disc = TextDicriminator(en_vecs, itos_d, dim_en_vec, 128)
gen = Seq2SeqAttention(en_vecs, itos, dim_en_vec, 128)
gen.cuda(); disc.cuda();

4442 ['<unk>', 't_up', '@.@', '@,@', ',--']
12516 [',--', '’s', 'tk_rep', ';--', 'keewis']
12516 [',--', '’s', 'tk_rep', ';--', 'keewis']


In [27]:
from torch import optim

optimizerD = optim.Adam(disc.parameters(), lr = 1e-3)
optimizerG = optim.Adam(gen.parameters(), lr = 1e-3, betas=(0.7, 0.8))

In [28]:
def train_gen(gen, epochs, trn_dl, val_dl, crit, opt):
    n = epochs
    for epoch in range(epochs):
        with tqdm(total=len(trn_dl)) as pbar:
            gen.train()
            for i, ds in enumerate(trn_dl):
                x, y = ds
                sl, bs = x.size()
                fake = gen(x,y)
                gen.zero_grad()
                gen_loss = creative_loss = crit(fake[0],y)
                gen_loss.backward()
                opt.step()
                torch.nn.utils.clip_grad_norm_(gen.parameters(), 8.)
                pbar.update()
        print(f'Epoch {epoch}:')
        print(f'Loss_G {creative_loss.data.item()} Ppx {torch.exp(creative_loss)}')
        gen.eval()
        for i, ds in enumerate(val_dl):
            x, y = ds
            sl, bs = x.size()
            fake = gen(x)
            gen_loss = creative_loss = crit(fake[0],y)
        print(f'Val Loss_G {creative_loss.data.item()} Ppx {torch.exp(creative_loss)}')

In [162]:
train_gen(gen,20, trn_seq2seq_dl, val_seq2seq_dl,seq2seq_loss, optimizerG)

100%|██████████| 6246/6246 [17:44<00:00,  5.67it/s]
Epoch 0:
Loss_G 4.81662130355835 Ppx 123.54695892333984
Val Loss_G 7.573636054992676 Ppx 1946.203857421875
100%|██████████| 6246/6246 [18:28<00:00,  5.66it/s]
Epoch 1:
Loss_G 5.159106254577637 Ppx 174.00885009765625
Val Loss_G 7.47801399230957 Ppx 1768.724609375
100%|██████████| 6246/6246 [17:42<00:00,  5.97it/s]
Epoch 2:
Loss_G 4.431586742401123 Ppx 84.06470489501953
Val Loss_G 8.16218090057373 Ppx 3505.823974609375
100%|██████████| 6246/6246 [17:00<00:00,  6.18it/s]
Epoch 3:
Loss_G 4.358253002166748 Ppx 78.12053680419922
Val Loss_G 8.25850772857666 Ppx 3860.3291015625
 31%|███       | 1922/6246 [05:20<11:56,  6.04it/s]


KeyboardInterrupt: 

In [163]:
torch.save(gen.state_dict(), PATH/'models/seq2seq_en_poems2.h5')

In [49]:
gen.load_state_dict(torch.load(PATH/'models/seq2seq_en_poems2.h5', map_location=lambda storage, loc: storage))

In [30]:
def produce_out(val_dl, model, use_cuda=True):
    model.eval()
    out = []
    loss = 0
    for i, ds in enumerate(val_dl):
        x, y = ds
        bptt, bs = x.size()
        y = y.view(bptt, bs)
        if use_cuda:
            x = x.cuda()
        probs = model(x)
        loss = seq2seq_loss(probs[0], y)
        preds = np.array(probs[0].topk(2)[1])
        for i in range(x.size(1)):
            seed = ' '.join([itos[o] for o in x[:,i] if o not in [0,1,2,3]])
            actual = ' '.join([itos[o] for o in y[:,i] if o not in [0,1,2,3]])
            fake = ' '.join([itos[o[0]] if o[0] not in [0,1,2,3] else itos[o[-1]] for o in preds[:,i]])
            out.append([seed,actual,fake])
    print(torch.exp(loss))
    return out

In [31]:
out_seq2seq = produce_out(val_gan_dl, gen)
pickle.dump(out_seq2seq, (PATH/'seq2seq_en.pkl').open('wb'))

tensor(64036.4219, device='cuda:0', grad_fn=<ExpBackward>)


In [34]:
out_seq2seq = pickle.load((PATH/'seq2seq_en.pkl').open('rb'))
out_seq2seq[1][0],out_seq2seq[1][1], out_seq2seq[1][2] 

('a on both sides you fall into love , and as you fall out the time while under its are like',
 'on the single does not cut how soon hath time , the subtle thief of youth , on his wing my three',
 'dianna , but the same was not to . the first of the end was the first of the most river , the <unk> @-@ <unk>')

In [32]:
out_seq2seq

[['i am hunting my wayward heart.i can feel that steady drum beat , following the ruby fingertips like , it to the',
  'continually the heart flees , running from this only ardor can all know it is only must armor ourselves is',
  'world , with the and , as the my soul , with joy and . . at the from of a , never could be ,'],
 ['a on both sides -as you fall into love , and as you fall out -all the time while under its are like',
  'on the single does not cut how soon hath time , the subtle thief of youth , on his wing my three',
  'the the to you and day you will never meet the way of the and of his the , and the bellman in tell you ,'],
 ['and year ! my days fly on full career , but my late spring no bud or blossom . perhaps my',
  'might deceive the truth , that i to manhood am arrived so near , and inward doth much less appear , that some',
  'love , and that is there ? no , but here is there to - day and be the same , that s not even no'],
 ['more - happy spirits . yet be it less o

In [33]:
def train_disc(disc, epochs, trn_dl, val_dl, opt):
    for epoch in range(epochs):
        with tqdm(total=len(trn_dl)) as pbar:
            disc.train()
            for p in disc.parameters(): p.data.clamp_(-0.01, 0.01)
            for i, ds in enumerate(trn_dl):
                x, y = ds
                disc.zero_grad()
                real_loss, _ = disc(x)
                fake_loss, _ = disc(y)
                disc_loss = (1 - real_loss + fake_loss).mean(0)
                disc_loss.backward()
                opt.step()
                pbar.update()
            print(f'Epoch {epoch}:')
            print(f'Loss_Real {real_loss.mean(0).data.item()}; Loss_Fake {fake_loss.mean(0).data.item()}; Loss_Disc {disc_loss.data.item()} ')
        disc.eval()
        for i, ds in enumerate(val_dl):
            x, y = ds
            real_loss, _ = disc(x)
            fake_loss, _ = disc(y)
            disc_loss = (1 - real_loss + fake_loss).mean(0)
        print(f'Val Loss_Real {real_loss.mean(0).data.item()}; Val Loss_Fake {fake_loss.mean(0).data.item()}; Loss_Disc {disc_loss.data.item()} ')

In [169]:
train_disc(disc, 5, trn_disc_dl, val_disc_dl, optimizerD)

100%|█████████▉| 12491/12492 [12:13<00:00, 15.27it/s]Epoch 0:
Loss_Real 1.0; Loss_Fake 1.3681117350117233e-10; Loss_Disc 8.358331626823201e-08 
100%|██████████| 12492/12492 [12:13<00:00, 17.03it/s]
Val Loss_Real 1.0; Val Loss_Fake 2.659799464010959e-11; Loss_Disc 2.659799464010959e-11 
  7%|▋         | 814/12492 [00:43<10:27, 18.61it/s]


KeyboardInterrupt: 

In [170]:
torch.save(disc.state_dict(), PATH/'models/disc_en_5.h5')

In [50]:
disc.load_state_dict(torch.load(PATH/'models/disc_en_5.h5', map_location=lambda storage, loc: storage))

In [35]:
def ppx(input, target):
    sl, bs = target.size()
    sl_in,bs_in,nc = input.size()
    if sl>sl_in: input = F.pad(input, (0,0,0,0,0,sl-sl_in))
    input = input[:sl]
    return torch.exp(F.cross_entropy(input.view(-1,nc), target.view(-1)))

In [51]:
def train(gen, disc, epochs, trn_dl, val_dl, optimizerD, optimizerG, first=True):
    gen_iterations = 0
    
    for epoch in range(epochs):
        gen.train(); disc.train()
        n = len(trn_dl)
        #train loop
        with tqdm(total=n) as pbar:
            for i, ds in enumerate(trn_dl):
                d_iters = 20 if (first and (epoch < 2)) else 5
                g_iters = 20 if (first and (epoch < 2)) else 3
                x, y = ds
                sl, bs = x.size()
                for j in range(d_iters):
                    for p in disc.parameters(): p.data.clamp_(-0.01, 0.01)
                    gen.eval()
                    disc.train()
                    _,fake = gen(x, sampling=True)
                    disc.zero_grad()
                    fake_loss, _ = disc(V(fake))
                    real_loss, _ = disc(V(y))
                    disc_loss = (1 - real_loss +fake_loss).mean(0)
                    disc_loss.backward()
                    optimizerD.step()
                
                disc.eval(), gen.train()
                fake_, fake_sample = gen(x,y,sampling=True)
                fake = F.log_softmax(fake_,dim=2)
                reward_t, _ = disc(y)
                reward_f, attn_f = disc(fake_sample)
                missing = 1 - y==fake_sample
                gen.zero_grad()
                gen_loss = 0
                for i in range(sl):
                    for j in range(bs):
                        gen_loss += -fake[i,j,fake_sample[i,j]] * reward_f.squeeze(1)[j] #* attn_f[i,j]
                gen_loss = V(gen_loss, requires_grad=True)/bs
                gen_loss.backward()
                optimizerG.step()
                gen_iterations += 1
                pbar.update()
        print(f'Epoch {epoch}:')
        print(f'Loss_D {disc_loss.data.item()}; Loss_G {gen_loss.data.item()} Ppx {ppx(fake_,y)}')
        print(f'D_real {real_loss.mean(0).view(1).data.item()}; Loss_D_fake {fake_loss.mean(0).view(1).data.item()}')
        # val loop
        gen.eval(), disc.eval()
        for i, ds in enumerate(val_dl):
            x, y = ds
            sl, bs = x.size()
            gen_loss=0
            fake, fake_sample = gen(x, sampling=True)
            real_loss, _ = disc(y)
            fake_loss, attn_f = disc(fake_sample)
            disc_loss = (1-real_loss+fake_loss).mean(0)
            for i in range(sl):
                for j in range(bs):
                    gen_loss += -fake[i,j,fake_sample[i,j]] * fake_loss.squeeze(1)[j] #* attn_f[i,j]
        print(f'Loss_D_val {disc_loss.data.item()}; Loss_G_val {gen_loss.mean(0).data.item()} Ppx {ppx(fake,y)}')
        print(f'D_real_val {real_loss.mean(0).view(1).data.item()}; Loss_D_fake_val {fake_loss.mean(0).view(1).data.item()}')

In [216]:
gen.load_state_dict(torch.load(PATH/'models/seq2seq_en.h5', map_location=lambda storage, loc: storage)) 
disc.load_state_dict(torch.load(PATH/'models/disc_en.h5', map_location=lambda storage, loc: storage)) 

In [41]:
def produce_out_sampling(val_dl, model, use_cuda=True):
    model.eval()
    out = []
    loss = 0
    for i, ds in enumerate(val_dl):
        x, y = ds
        bptt, bs = x.size()
        y = y.view(bptt, bs)
        #x = x.long()
        if use_cuda:
            x = x.cuda()
        probs, sample = model(x, sampling=True)
        p = ppx(probs, y)
        preds = np.array(sample)
        for i in range(x.size(1)):
            seed = ' '.join([itos[o] for o in x[:,i] if o not in [0,1,2,3]])
            actual = ' '.join([itos[o] for o in y[:,i] if o not in [0,1,2,3]])
            fake = ' '.join([itos[o] for o in preds[:,i] if o not in [0,1,2,3]])
            out.append([seed,actual,fake])
    print(p.data.item())
    return out

In [122]:
torch.save(gen.state_dict(), PATH/'models/netG.h5')
torch.save(disc.state_dict(), PATH/'models/netD.h5')

In [127]:
gen.load_state_dict(torch.load(PATH/'models/netG.h5', map_location=lambda storage, loc: storage))
disc.load_state_dict(torch.load(PATH/'models/netD.h5', map_location=lambda storage, loc: storage))

In [54]:
out_gan2 = produce_out_sampling(val_gan_dl, gen)
out_gan = produce_out(val_gan_dl, gen)
pickle.dump(out_gan, (PATH/'v_gan_en.pkl').open('wb'))

65609.6015625
tensor(64036.4219, device='cuda:0', grad_fn=<ExpBackward>)


In [57]:
out_gan

[['i am hunting my wayward heart.i can feel that steady drum beat , following the ruby fingertips like , it to the',
  'continually the heart flees , running from this only ardor can all know it is only must armor ourselves is',
  'world , with the and , as the my soul , with joy and . . at the from of a , never could be ,'],
 ['a on both sides -as you fall into love , and as you fall out -all the time while under its are like',
  'on the single does not cut how soon hath time , the subtle thief of youth , on his wing my three',
  'the the to you and day you will never meet the way of the and of his the , and the bellman in tell you ,'],
 ['and year ! my days fly on full career , but my late spring no bud or blossom . perhaps my',
  'might deceive the truth , that i to manhood am arrived so near , and inward doth much less appear , that some',
  'love , and that is there ? no , but here is there to - day and be the same , that s not even no'],
 ['more - happy spirits . yet be it less o

In [46]:
out_gan = pickle.load((PATH/'v_gan_en.pkl').open('rb'))

In [56]:
out_gan2

[['i am hunting my wayward heart.i can feel that steady drum beat , following the ruby fingertips like , it to the',
  'continually the heart flees , running from this only ardor can all know it is only must armor ourselves is',
  'comes , sweet she , about a unless they gentle some being man and a climb to make a above the each after . still'],
 ['a on both sides -as you fall into love , and as you fall out -all the time while under its are like',
  'on the single does not cut how soon hath time , the subtle thief of youth , on his wing my three',
  'when the wound is many time go home to - night with die or world , whom after nature is by thou hast . make'],
 ['and year ! my days fly on full career , but my late spring no bud or blossom . perhaps my',
  'might deceive the truth , that i to manhood am arrived so near , and inward doth much less appear , that some',
  'laugh at this whose seven alone above the summer , better one to about had hunt the beat and fair , save i all in'],
 