In [8]:
train_data = [
    (['Four older adults are sitting in the back yard, watching a toddler who is standing near them.', 
      'Young girl in pink hat eating a meal of bread, fruits, and vegetables.',
      'A girl is smiling while swinging on a swing',
      'An older man sitting with a young boy in a hat.',
      'The grandfather feeds his granddaughter as the rest of the family looks on.'],
     "The adults sat in their lawn chairs supervising the children. My niece had some of her favorite food to eat, including fruits. She went to play on her new swing set after eating. My nephew happily kept my uncle company. Today was my niece's birthday so we went over to celebrate."),
     (['A large building that has been restored sitting against a clear sky.',
   'A picture of buildings with Chinese architecture. The area seems to be relatively empty except for a few people in the distance, and the sky is clear.',
   'What a wonderful looking building, very different and unique.',
   'A column lined path leads to a temple',
   'Asian garden with fox statue, lantern, gates and trees'],
  'They used such distinctive architecture and paid so much attention to details. Our trip to Japan was amazing. We got to see how they built their temples. Even the steps leading up to the temples were cool. They were lined with statues that all had important meanings behind them.'),
 (['The man is wearing a red cap and is participating in a marathon.',
   'Runner number 1212 walks through the airport with his bag over his shoulder before the race.',
   'Two people are giving one another a high five.',
   'A man in a white t-shirt stands behind a group of people who have a glass ceiling above them.',
   'A group of runners taking a group photo.'],
  'There was a real atmosphere of eagerness. Many of the runners were excellent athletes. But regardless of skill everyone there had fun. All the runners were excited to help a good cause. Some group photos would forever commemorate the event.')
     ]

def prepare_data(data):
    inputs = []
    outputs = []
    for captions, story in data:
        input_text = ' [SEP] '.join(captions)
        inputs.append(input_text)
        outputs.append(story)
    return inputs, outputs

train_inputs1, train_outputs1 = prepare_data([train_data[0]])
train_inputs2, train_outputs2 = prepare_data([train_data[1]])
train_inputs3, train_outputs3 = prepare_data([train_data[2]])
all_train_inputs = [train_inputs1,train_inputs2,train_inputs3]
all_train_outputs = [train_outputs1,train_outputs2,train_outputs3]

In [9]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('t5-small')
train_encodings = [tokenizer(train_inputs, padding=True, truncation=True, return_tensors="pt") for train_inputs in all_train_inputs]
train_labels = [tokenizer(train_outputs, padding=True, truncation=True, return_tensors="pt").input_ids for train_outputs in all_train_outputs]

In [12]:
from torch.utils.data import DataLoader, Dataset
import torch

class StoryDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val.clone().detach() for key, val in self.encodings[idx].items()}
        item['labels'] = self.labels[idx].clone().detach()
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = StoryDataset(train_encodings, train_labels)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

for btch in train_loader:
    print(btch)

{'input_ids': tensor([[[   71,   508,   740,    24,    65,   118, 13216,  3823,   581,     3,
              9,   964,  5796,     5,   784,   134,  8569,   908,    71,  1554,
             13,  3950,    28,  2830,  4648,     5,    37,   616,  1330,    12,
             36,  4352,  6364,  3578,    21,     3,     9,   360,   151,    16,
              8,  2357,     6,    11,     8,  5796,    19,   964,     5,   784,
            134,  8569,   908,   363,     3,     9,  1627,   479,   740,     6,
            182,   315,    11,   775,     5,   784,   134,  8569,   908,    71,
           6710, 14372,  2071,  3433,    12,     3,     9,  7657,   784,   134,
           8569,   908,  6578,  2004,    28,     3, 20400, 12647,     6, 24167,
              6, 18975,    11,  3124,     1]]]), 'attention_mask': tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [None]:
# Initialize the model
model = T5ForConditionalGeneration.from_pretrained('t5-small')
model.train()
# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [33]:
# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        input_ids = input_ids.squeeze(1)
        attention_mask = attention_mask.squeeze(1)
        labels = labels.squeeze(1)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch} completed with loss: {loss.item()}')


Epoch 0 completed with loss: 0.018168499693274498
Epoch 1 completed with loss: 0.017402423545718193
Epoch 2 completed with loss: 0.021405942738056183
Epoch 3 completed with loss: 0.008466326631605625
Epoch 4 completed with loss: 0.019594982266426086
Epoch 5 completed with loss: 0.008092958480119705
Epoch 6 completed with loss: 0.01446730550378561
Epoch 7 completed with loss: 0.014023566618561745
Epoch 8 completed with loss: 0.007621606346219778
Epoch 9 completed with loss: 0.013131316751241684
Epoch 10 completed with loss: 0.007378798443824053
Epoch 11 completed with loss: 0.012347742915153503
Epoch 12 completed with loss: 0.013710255734622478
Epoch 13 completed with loss: 0.011672087013721466
Epoch 14 completed with loss: 0.006939155049622059
Epoch 15 completed with loss: 0.012357023544609547
Epoch 16 completed with loss: 0.011987525969743729
Epoch 17 completed with loss: 0.01163025014102459
Epoch 18 completed with loss: 0.010244892910122871
Epoch 19 completed with loss: 0.01094258762

In [34]:
model.eval()
new_captions = ["you cant get to the lake from here but it is a great view",
   'A field that is near the waters of the beach.',
   'A tree in the forest has fallen over on the grass.',
   'A large group of deciduous trees are behind a grassy field.',
   'People walking down a hill with a bunch of animals on it.']

input_text = ' [SEP] '.join(new_captions)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# Generate narrative story
with torch.no_grad():
    generated_ids = model.generate(input_ids, max_length=200)
story = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(story)


There is no way to get to the lake from here but it has a great view. My niece had some of her favorite things to do with the lake. She went to see how she got to see her. She went to see her niece. She went to see her niece. She went to see her niece niece.
