## Trial One

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        print("Mask shape (in attention): ", mask.shape)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        print("Att. scores shape:", attn_scores.shape)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        print("In combined Heads:", x.shape)
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        
        print("Query shape before", Q.shape)
        print("Key shape before", K.shape)
        print("Value shape before", V.shape)

        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        print("Query shape after", Q.shape)
        print("Key shape after", K.shape)
        print("Value shape after", V.shape)
        print("------------ LAYER FINISHED --------------")
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output
    
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
    

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]
    

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x
    

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x
    

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        print(src_mask.shape)
        print(tgt_mask.shape)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        print("+++++++++++++++++++++++++++++++++++++++++ ENCODER FINISHED +++++++++++++++++++++++++++++++++++++++++++++")
        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

In [6]:
src_vocab_size = 5000
tgt_vocab_size = 10000
d_model = 512
num_heads = 8
num_layers = 4
d_ff = 2048
max_seq_length_src = 25
max_seq_length_tgt = 50
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, 2*max_seq_length_src, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (4, max_seq_length_src))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (4, max_seq_length_tgt))  # (batch_size, seq_length)
print(src_data.shape)
print(tgt_data.shape)

torch.Size([4, 25])
torch.Size([4, 50])


In [7]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data)
    print(output.shape)
    print(output.contiguous().view(-1, tgt_vocab_size).shape)
    print(tgt_data.contiguous().view(-1).shape)
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data.contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
    break

torch.Size([4, 1, 1, 25])
torch.Size([4, 1, 50, 50])
Query shape before torch.Size([4, 25, 512])
Key shape before torch.Size([4, 25, 512])
Value shape before torch.Size([4, 25, 512])
Query shape after torch.Size([4, 8, 25, 64])
Key shape after torch.Size([4, 8, 25, 64])
Value shape after torch.Size([4, 8, 25, 64])
------------ LAYER FINISHED --------------
Mask shape (in attention):  torch.Size([4, 1, 1, 25])
Att. scores shape: torch.Size([4, 8, 25, 25])
In combined Heads: torch.Size([4, 8, 25, 64])
Query shape before torch.Size([4, 25, 512])
Key shape before torch.Size([4, 25, 512])
Value shape before torch.Size([4, 25, 512])
Query shape after torch.Size([4, 8, 25, 64])
Key shape after torch.Size([4, 8, 25, 64])
Value shape after torch.Size([4, 8, 25, 64])
------------ LAYER FINISHED --------------
Mask shape (in attention):  torch.Size([4, 1, 1, 25])
Att. scores shape: torch.Size([4, 8, 25, 25])
In combined Heads: torch.Size([4, 8, 25, 64])
Query shape before torch.Size([4, 25, 512])

## Trial Two

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim

import math
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

In [38]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        # Info
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

In [91]:
class Transformer(nn.Module):
    """
    Model from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/p/c80afbc9ffb1/
    """
    # Constructor
    def __init__(self,num_tokens,dim_model,num_heads,num_encoder_layers,num_decoder_layers,dropout_p):
        super().__init__()

        # INFO
        self.model_type = "Transformer"
        self.dim_model = dim_model

        # LAYERS
        self.positional_encoder = PositionalEncoding(dim_model=dim_model, dropout_p=dropout_p, max_len=5000)
        self.embedding = nn.Embedding(num_tokens, dim_model)
        self.transformer = nn.Transformer(d_model=dim_model,nhead=num_heads,num_encoder_layers=num_encoder_layers,num_decoder_layers=num_decoder_layers,dropout=dropout_p)
        self.out = nn.Linear(dim_model, num_tokens)

    def forward(self,src,tgt,tgt_mask):
        # Src size must be (batch_size, src sequence length)
        # Tgt size must be (batch_size, tgt sequence length)

        # Embedding + positional encoding - Out size = (batch_size, sequence length, dim_model)
        src = self.embedding(src)*math.sqrt(self.dim_model)
        tgt = self.embedding(tgt)*math.sqrt(self.dim_model)
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)

        # we permute to obtain size (sequence length, batch_size, dim_model),
        src = src.permute(1, 0, 2)
        tgt = tgt.permute(1, 0, 2)

        # Transformer blocks - Out size = (sequence length, batch_size, num_tokens)
        transformer_out = self.transformer(src, tgt, src_mask=None, tgt_mask=tgt_mask)
        out = self.out(transformer_out)

        return out
    
    def get_tgt_mask(self, size) -> torch.tensor:
        # Generates a squeare matrix where the each row allows one word more to be seen
        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
        
        # EX for size=5:
        # [[0., -inf, -inf, -inf, -inf],
        #  [0.,   0., -inf, -inf, -inf],
        #  [0.,   0.,   0., -inf, -inf],
        #  [0.,   0.,   0.,   0., -inf],
        #  [0.,   0.,   0.,   0.,   0.]]
        
        return mask
    
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        return (matrix == pad_token)

In [101]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(num_tokens=30522, dim_model=512, num_heads=8, num_encoder_layers=4, num_decoder_layers=4, dropout_p=0.1).to(device)
opt = optim.AdamW(model.parameters(), lr=0.0001)
loss_fn = nn.CrossEntropyLoss()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
#optimizer = optim.AdamW(model.parameters(), lr=0.0001)
print("The trainable parameters of the model are: ", count_parameters(model))

The trainable parameters of the model are:  60712762


In [102]:
import pickle 

path_train_a = f'C:/Users/admitos/Desktop/ThesisUU/Phase_2/data_phase2/Combined/train_caps_1.pkl'
path_val_a = f'C:/Users/admitos/Desktop/ThesisUU/Phase_2/data_phase2/Combined/val_caps_1.pkl'
path_test_a = f'C:/Users/admitos/Desktop/ThesisUU/Phase_2/data_phase2/Combined/test_caps_1.pkl'

with open(path_train_a, 'rb') as f:
    tokenized_train_captions1 = pickle.load(f)

with open(path_val_a, 'rb') as f:
    tokenized_val_captions1 = pickle.load(f)
    
with open(path_test_a, 'rb') as f:
    tokenized_test_captions1 = pickle.load(f)

path_train_stories = f'C:/Users/admitos/Desktop/ThesisUU/Phase_2/data_phase2/Combined/train_stories.pkl'
path_val_stories = f'C:/Users/admitos/Desktop/ThesisUU/Phase_2/data_phase2/Combined/val_stories.pkl'
path_test_stories = f'C:/Users/admitos/Desktop/ThesisUU/Phase_2/data_phase2/Combined/test_stories.pkl'

with open(path_train_stories, 'rb') as f:
    tokenized_train_stories = pickle.load(f)

with open(path_val_stories, 'rb') as f:
    tokenized_val_stories = pickle.load(f)

with open(path_test_stories, 'rb') as f:
    tokenized_test_stories = pickle.load(f)

class CustomDataset(Dataset):
    def __init__(self, entries, references):
        self.entries = entries
        self.references = references

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        item = {}
        item['caption'] = self.entries[idx]
        item['story'] = self.references [idx] 
        return item

        
def my_collate_fn(batch):
    input_ids = [item['caption'].transpose(0,1) for item in batch]
    target_ids = [item['story'].transpose(0,1) for item in batch]
    
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=0)
    target_ids_padded = pad_sequence(target_ids, batch_first=True, padding_value=0)

    final_input_ids = [item.transpose(0,1) for item in input_ids_padded]
    final_labels = [item.transpose(0,1) for item in target_ids_padded]

    return {'caption_ids': final_input_ids, 'story_ids': final_labels}

In [103]:
b_s = 10
train_dataset1 = CustomDataset(tokenized_train_captions1, tokenized_train_stories)
train_loader1 = DataLoader(train_dataset1, batch_size=b_s, shuffle=False, collate_fn=my_collate_fn)

## Prepareing Validation Loaders 
val_dataset1 = CustomDataset(tokenized_val_captions1, tokenized_val_stories)
val_loader1 = DataLoader(val_dataset1, batch_size=b_s, shuffle=False, collate_fn=my_collate_fn)

## Prepareing Test Loaders 
test_dataset1 = CustomDataset(tokenized_test_captions1, tokenized_test_stories)
test_loader1 = DataLoader(test_dataset1, batch_size=b_s, shuffle=False, collate_fn=my_collate_fn)


In [104]:
for batch in train_loader1:
    print("Captions shape:", batch['caption_ids'][0].shape, batch['caption_ids'][1].shape)
    print("Stories shape: ", batch['story_ids'][0].shape, batch['story_ids'][0].shape)
    break

Captions shape: torch.Size([1, 79]) torch.Size([1, 79])
Stories shape:  torch.Size([1, 81]) torch.Size([1, 81])


In [105]:
def transform_before_model(batch):
    batch_caption_ids = torch.stack(batch['caption_ids']).squeeze(1)
    batch_story_ids = torch.stack(batch['story_ids']).squeeze(1)
    return batch_caption_ids, batch_story_ids

In [106]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
    
optimizer = optim.Adam(model.parameters(), lr=0.0001)
print("The trainable parameters of the model are: ", count_parameters(model))

The trainable parameters of the model are:  60712762


In [131]:
def train_loop(model, opt, loss_fn, dataloader):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    
    model.train()
    total_loss = 0
    
    for idx,batch in enumerate(dataloader):
        caps, stories = transform_before_model(batch)
        # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
        y_input = stories[:,:-1]
        y_expected = stories[:,1:]
        
        # Get mask to mask out the next words
        sequence_length = y_input.size(1)
        tgt_mask = model.get_tgt_mask(sequence_length)

        # Standard training except we pass in y_input and tgt_mask
        pred = model(caps, y_input, tgt_mask)
        #pred = model(caps, y_input)

        # Permute pred to have batch size first again
        pred = pred.permute(1, 2, 0)      
        loss = loss_fn(pred, y_expected)
        print(f'Iteration: {idx}/{len(dataloader)}, with running loss: {loss.item()}')

        opt.zero_grad()
        loss.backward()
        opt.step()
    
        total_loss += loss.detach().item()
        
    return total_loss / len(dataloader)

In [132]:
def validation_loop(model, loss_fn, dataloader):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            caps, stories = transform_before_model(batch)
        
            # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
            y_input = stories[:,:-1]
            y_expected = stories[:,1:]
            
            # Get mask to mask out the next words
            sequence_length = y_input.size(1)
            tgt_mask = model.get_tgt_mask(sequence_length)

            # Standard training except we pass in y_input and src_mask
            pred = model(caps, y_input, tgt_mask)
            #pred = model(caps, y_input)
            
            # Permute pred to have batch size first again
            pred = pred.permute(1, 2, 0)      
            loss = loss_fn(pred, y_expected)
            print(f'Iteration: {idx}/{len(dataloader)}, with running loss: {loss.item()}')
            total_loss += loss.detach().item()

        
    return total_loss / len(dataloader)

In [133]:
def fit(model, opt, loss_fn, train_dataloader, val_dataloader, epochs):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    
    # Used for plotting later on
    train_loss_list, validation_loss_list = [], []
    
    print("Training and validating model")
    for epoch in range(epochs):
        print("-"*25, f"Epoch {epoch + 1}","-"*25)
        
        train_loss = train_loop(model, opt, loss_fn, train_dataloader)
        train_loss_list += [train_loss]
        
        validation_loss = validation_loop(model, loss_fn, val_dataloader)
        validation_loss_list += [validation_loss]
        
        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {validation_loss:.4f}")
        print()
        
    return train_loss_list, validation_loss_list
    
train_loss_list, validation_loss_list = fit(model, opt, loss_fn, train_loader1, val_loader1, 10)


Training and validating model
------------------------- Epoch 1 -------------------------
Iteration: 0/2657, with running loss: 3.435696601867676
Iteration: 1/2657, with running loss: 3.595099925994873
Iteration: 2/2657, with running loss: 3.0748696327209473
Iteration: 3/2657, with running loss: 3.4868874549865723
Iteration: 4/2657, with running loss: 3.357516288757324
Iteration: 5/2657, with running loss: 2.9999752044677734
Iteration: 6/2657, with running loss: 3.655576467514038
Iteration: 7/2657, with running loss: 3.0454952716827393
Iteration: 8/2657, with running loss: 2.570434808731079
Iteration: 9/2657, with running loss: 2.8568203449249268
Iteration: 10/2657, with running loss: 3.0220389366149902
Iteration: 11/2657, with running loss: 2.7257168292999268
Iteration: 12/2657, with running loss: 3.416012763977051
Iteration: 13/2657, with running loss: 3.6815953254699707
Iteration: 14/2657, with running loss: 3.2250218391418457
Iteration: 15/2657, with running loss: 2.964179515838623

KeyboardInterrupt: 

In [134]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
start_token_id = tokenizer.cls_token_id
eos_token_id = tokenizer.sep_token_id

def predict(model, input_sequence, SOS_token, EOS_token, max_length=150):
    """
    Method from "A detailed guide to Pytorch's nn.Transformer() module.", by
    Daniel Melchor: https://medium.com/@danielmelchor/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
    """
    model.eval()
    y_input = torch.tensor([[SOS_token]], dtype=torch.long)

    num_tokens = len(input_sequence[0])

    for _ in range(max_length):
        # Get source mask
        tgt_mask = model.get_tgt_mask(y_input.size(1))
        pred = model(input_sequence, y_input, tgt_mask)
        
        next_item = pred.topk(1)[1].view(-1)[-1].item() # num with highest probability
        next_item = torch.tensor([[next_item]])

        # Concatenate previous input with predicted best word
        y_input = torch.cat((y_input, next_item), dim=1)

        # Stop if model predicts end of sentence
        if next_item.view(-1).item() == EOS_token:
            break

    return y_input.view(-1).tolist()

In [135]:
new_entry = (['A tall, barren tree by a flowing creek.',
   'A fallen tree trunk on a broken wood bridge.',
   'The iron gate of a small palace with shrubs on the side.',
   'Perpendicular plants are standing straight out of a muddy ground.',
   'An outdoor shot showing a river up around leafless trees.'],
  'The swamp had started as an upsetting stream of dead fishes, flies, and frogs. After having made it across the wobbly bridge. I finally got to the house just beyond the marsh. Even before that I walked through the suffocating swamp. I had to cross the creepy lake.')

new_captions = new_entry[0]
new_story = new_entry[1]
input_text = ' [SEP] '.join(new_captions)
print(input_text)
print()
print(new_story)
caption_ids = tokenizer(input_text, return_tensors="pt").input_ids
target_ids = tokenizer(new_story, return_tensors="pt").input_ids
print()
print(caption_ids.shape)
print(target_ids.shape)

A tall, barren tree by a flowing creek. [SEP] A fallen tree trunk on a broken wood bridge. [SEP] The iron gate of a small palace with shrubs on the side. [SEP] Perpendicular plants are standing straight out of a muddy ground. [SEP] An outdoor shot showing a river up around leafless trees.

The swamp had started as an upsetting stream of dead fishes, flies, and frogs. After having made it across the wobbly bridge. I finally got to the house just beyond the marsh. Even before that I walked through the suffocating swamp. I had to cross the creepy lake.

torch.Size([1, 62])
torch.Size([1, 61])


In [136]:
generated_ids = predict(model, caption_ids, start_token_id, eos_token_id)
print(generated_ids)
story = tokenizer.decode(generated_ids, skip_special_tokens=True)

[101, 1996, 3193, 1997, 1996, 2103, 2001, 3376, 1012, 1996, 2103, 2001, 3376, 1012, 1996, 3193, 2001, 3376, 1012, 1996, 2103, 2001, 3376, 1012, 1996, 2103, 2001, 3376, 1012, 1996, 2103, 2001, 3376, 1012, 102]


In [137]:
print(story)

the view of the city was beautiful. the city was beautiful. the view was beautiful. the city was beautiful. the city was beautiful. the city was beautiful.
