In [1]:
!git clone https://github.com/Hannibal96/ImageCaptionProject.git

Cloning into 'ImageCaptionProject'...
remote: Enumerating objects: 202, done.[K
remote: Counting objects: 100% (202/202), done.[K
remote: Compressing objects: 100% (163/163), done.[K
remote: Total 202 (delta 116), reused 100 (delta 37), pack-reused 0[K
Receiving objects: 100% (202/202), 8.90 MiB | 23.07 MiB/s, done.
Resolving deltas: 100% (116/116), done.


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd ImageCaptionProject/

/content/ImageCaptionProject


In [4]:
!git status

On branch master
Your branch is up to date with 'origin/master'.

nothing to commit, working tree clean


In [5]:
from data import *

In [6]:
from models import *

In [16]:
from data import *
from models import *
import torchvision.transforms as T
from torchtext.vocab import GloVe # for pretrained model
import torch.optim as optim
from torch.utils.data import DataLoader
import pickle
from nltk.translate.bleu_score import sentence_bleu
from tqdm import tqdm

def save_model(model, num_epochs):
    path = "caption_model_E_"+str(num_epochs)+".torch"
    torch.save(model, path)


def evaluate(model, val_data_set):
    blew_score = 0
    total_val_loss = 0

    with torch.no_grad():
        for idx, (image, captions) in tqdm(enumerate(iter(val_data_set))):
            image, captions = image.to(device), captions.to(device)

            all_captions = val_dataset.get_last_captions()
            features = model.encoder(image[0:1].to(device))
            caps, alphas = model.decoder.generate_caption(features, vocab=vocab)
            hyp_caption = ' '.join(caps)
            curr_blew_score = sentence_bleu(references=all_captions, hypothesis=hyp_caption.split())
            blew_score += curr_blew_score

            outputs, attentions = model(image, captions)
            outputs = outputs.to(device)
            targets = captions[:, 1:]
            targets = targets.to(device)
            val_loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
            total_val_loss += val_loss.item()

    perplexity = total_val_loss / len(val_data_set)
    perplexity = np.exp(perplexity)
    return blew_score / len(val_data_set), total_val_loss/len(val_data_set), perplexity



In [None]:
!ls 

In [12]:


images_path = '/content/drive/MyDrive/ImageCaption/Images'
captions_file_path = "captions.txt"
karpathy_json_path = 'Karpathy_data.json'

# define the transforms to be applied which needed for the pretrained CNN
transforms = T.Compose([
    T.Resize(226),
    T.RandomCrop(224),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

#build vocab
vocab = build_vocab(captions_file_path=captions_file_path)

#build datasets
train, val, test = karpathy_split(captions_file_path, karpathy_json_path)
train_dataset = FlickrDataset(root_dir=images_path,vocab= vocab, captions_df=train,transform=transforms)
val_dataset = FlickrDataset(root_dir=images_path,vocab= vocab, captions_df=val,transform=transforms)
test_dataset = FlickrDataset(root_dir=images_path,vocab= vocab, captions_df=test,transform=transforms)
print("Finished building the Datasets.")

# Hyperparams
weights_matrix = None
# load pretrained embeddings (to train embeddings from scrach just set the embed_size)
#g = GloVe(name ='6B', dim=100)
embed_size = 300#g.dim
#weights_matrix = load_embedding_weights(vocab, g)

vocab_size = len(vocab)
attention_dim = 256
encoder_dim = 2048
decoder_dim = 512
learning_rate = 3e-4
BATCH_SIZE = 32

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# init model
model = EncoderDecoder(
    embed_size=embed_size,
    vocab_size=vocab_size,
    attention_dim=attention_dim,
    encoder_dim=encoder_dim,
    decoder_dim=decoder_dim,
    embedding_weights = weights_matrix
).to(device)

pad_idx = vocab.stoi["<PAD>"]
train_data = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, num_workers=0,
                        shuffle=True, collate_fn=CapsCollate(pad_idx=pad_idx, batch_first=True))
val_data = DataLoader(dataset=val_dataset, batch_size=1, num_workers=0,
                      shuffle=True, collate_fn=CapsCollate(pad_idx=pad_idx, batch_first=True))

criterion = nn.CrossEntropyLoss(ignore_index=vocab.stoi["<PAD>"]).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

num_epochs = 10
print_every = 1000


  0%|          | 0/40455 [00:00<?, ?it/s][A
  3%|▎         | 1083/40455 [00:00<00:03, 10812.38it/s][A
  6%|▌         | 2358/40455 [00:00<00:03, 11325.97it/s][A
  9%|▉         | 3781/40455 [00:00<00:03, 12063.65it/s][A
 13%|█▎        | 5418/40455 [00:00<00:02, 13093.77it/s][A
 18%|█▊        | 7096/40455 [00:00<00:02, 14016.82it/s][A
 22%|██▏       | 8963/40455 [00:00<00:02, 15123.82it/s][A
 27%|██▋       | 10868/40455 [00:00<00:01, 16120.49it/s][A
 31%|███       | 12637/40455 [00:00<00:01, 16559.19it/s][A
 35%|███▌      | 14340/40455 [00:00<00:01, 16697.16it/s][A
 40%|████      | 16331/40455 [00:01<00:01, 17545.20it/s][A
 45%|████▌     | 18317/40455 [00:01<00:01, 18175.96it/s][A
 50%|████▉     | 20142/40455 [00:01<00:01, 17680.06it/s][A
 55%|█████▍    | 22145/40455 [00:01<00:00, 18321.67it/s][A
 59%|█████▉    | 24016/40455 [00:01<00:00, 18435.04it/s][A
 64%|██████▍   | 26034/40455 [00:01<00:00, 18924.53it/s][A
 69%|██████▉   | 27936/40455 [00:01<00:00, 18807.46it/s][A


Finished building the Datasets.


In [20]:
def train_one_epoch(model, train_data):
  total_loss = 0
  for idx, (image, captions) in tqdm(enumerate(iter(train_data))):
      image, captions = image.to(device), captions.to(device)

      # Zero the gradients.
      optimizer.zero_grad()

      # Feed forward
      outputs, attentions = model(image, captions)
      outputs = outputs.to(device)

      # Calculate the batch loss.
      targets = captions[:, 1:]
      targets = targets.to(device)
      loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))

      # Backward pass.
      loss.backward()

      # Update the parameters in the optimizer.
      optimizer.step()

      total_loss += loss.item()
    
  return total_loss/len(train_data)

In [17]:
len(train_data)

938

In [None]:
loss_list = []
perplexity_list = []
bleu_list = []
total_loss = 0
for epoch in range(1, num_epochs + 1):
    model.train()
    avg_training_loss = train_one_epoch(model, train_data)
    print('average loss = {}'.format(avg_training_loss))
    save_model(model, epoch)

    model.eval()
    bleu, loss, perp = evaluate(model, val_data)
    print("Epoch: {} loss: {:.3f}, perplexity: {:.3f}, BLEU: {:.3f}".format(epoch, loss, perp, bleu))
    perplexity_list.append(perp)
    loss_list.append(loss)
    bleu_list.append(bleu)

    pickle.dump(perplexity_list, open('perplexity_list.p', 'wb'))
    pickle.dump(loss_list, open('loss_list.p', 'wb'))
    pickle.dump(bleu_list, open('blew_list.p', 'wb'))


plt.plot(loss_list)
plt.title('Loss')
plt.show()

plt.plot(perplexity_list)
plt.title('Perplexity')
plt.show()

plt.plot(bleu_list)
plt.title('BLEU')
plt.show()




0it [00:00, ?it/s][A[A[A


1it [00:08,  8.04s/it][A[A[A


2it [00:16,  8.19s/it][A[A[A


3it [00:25,  8.28s/it][A[A[A


4it [00:32,  8.17s/it][A[A[A


5it [00:40,  7.93s/it][A[A[A


6it [00:48,  7.86s/it][A[A[A


7it [00:55,  7.67s/it][A[A[A


8it [01:02,  7.46s/it][A[A[A


9it [01:09,  7.31s/it][A[A[A


10it [01:16,  7.39s/it][A[A[A


11it [01:24,  7.44s/it][A[A[A


12it [01:32,  7.54s/it][A[A[A


13it [01:39,  7.45s/it][A[A[A


14it [01:46,  7.25s/it][A[A[A


15it [01:52,  7.12s/it][A[A[A


16it [02:02,  7.74s/it][A[A[A


17it [02:09,  7.59s/it][A[A[A


18it [02:16,  7.40s/it][A[A[A


19it [02:23,  7.34s/it][A[A[A


20it [02:30,  7.25s/it][A[A[A


21it [02:38,  7.54s/it][A[A[A


22it [02:45,  7.42s/it][A[A[A


23it [02:52,  7.09s/it][A[A[A


24it [02:59,  7.15s/it][A[A[A


25it [03:05,  6.90s/it][A[A[A


26it [03:12,  6.73s/it][A[A[A


27it [03:19,  6.81s/it][A[A[A


28it [03:25,  6.76s/it][A[A[

average loss = 3.799947449393364


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
1433it [01:53, 17.36it/s][A[A[A


1436it [01:53, 18.41it/s][A[A[A


1438it [01:53, 18.58it/s][A[A[A


1441it [01:53, 19.26it/s][A[A[A


1443it [01:53, 14.03it/s][A[A[A


1445it [01:54, 14.91it/s][A[A[A


1447it [01:54, 15.84it/s][A[A[A


1449it [01:54, 16.48it/s][A[A[A


1451it [01:54, 17.16it/s][A[A[A


1453it [01:54, 17.89it/s][A[A[A


1455it [01:54, 13.27it/s][A[A[A


1457it [01:54, 14.45it/s][A[A[A


1459it [01:54, 15.45it/s][A[A[A


1461it [01:54, 16.47it/s][A[A[A


1463it [01:55, 16.95it/s][A[A[A


1465it [01:55, 17.39it/s][A[A[A


1467it [01:55, 17.98it/s][A[A[A


1469it [01:55, 18.43it/s][A[A[A


1471it [01:55, 18.59it/s][A[A[A


1473it [01:55, 18.72it/s][A[A[A


1475it [01:55, 18.52it/s][A[A[A


1477it [01:55, 18.37it/s][A[A[A


1479it [01:55, 18.35it/s][A[A[A


1481it [01:56, 18.09it/s][A[A[A


1483it [01:56, 17.86it/s][A[A[A


1485it [01

Epoch: 1 loss: 3.269, perplexity: 26.274, BLEU: 0.462





1it [00:00,  1.72it/s][A[A[A


2it [00:01,  1.73it/s][A[A[A


3it [00:01,  1.72it/s][A[A[A


4it [00:02,  1.72it/s][A[A[A


5it [00:02,  1.74it/s][A[A[A


6it [00:03,  1.73it/s][A[A[A


7it [00:04,  1.75it/s][A[A[A


8it [00:04,  1.74it/s][A[A[A


9it [00:05,  1.73it/s][A[A[A


10it [00:05,  1.72it/s][A[A[A


11it [00:06,  1.74it/s][A[A[A


12it [00:06,  1.74it/s][A[A[A


13it [00:07,  1.77it/s][A[A[A


14it [00:08,  1.74it/s][A[A[A


15it [00:08,  1.75it/s][A[A[A


16it [00:09,  1.74it/s][A[A[A


17it [00:09,  1.73it/s][A[A[A


18it [00:10,  1.74it/s][A[A[A


19it [00:10,  1.74it/s][A[A[A


20it [00:11,  1.73it/s][A[A[A


21it [00:12,  1.74it/s][A[A[A


22it [00:12,  1.74it/s][A[A[A


23it [00:13,  1.75it/s][A[A[A


24it [00:13,  1.76it/s][A[A[A


25it [00:14,  1.77it/s][A[A[A


26it [00:14,  1.75it/s][A[A[A


27it [00:15,  1.76it/s][A[A[A


28it [00:16,  1.78it/s][A[A[A


29it [00:16,  1.76it/s][A

average loss = 3.09358877286728


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
1575it [01:24, 18.26it/s][A[A[A


1577it [01:24, 17.88it/s][A[A[A


1579it [01:24, 18.18it/s][A[A[A


1581it [01:24, 18.07it/s][A[A[A


1583it [01:24, 18.17it/s][A[A[A


1585it [01:24, 17.73it/s][A[A[A


1588it [01:24, 18.60it/s][A[A[A


1590it [01:24, 18.58it/s][A[A[A


1592it [01:25, 18.56it/s][A[A[A


1594it [01:25, 18.83it/s][A[A[A


1596it [01:25, 19.11it/s][A[A[A


1598it [01:25, 18.31it/s][A[A[A


1600it [01:25, 17.81it/s][A[A[A


1602it [01:25, 17.93it/s][A[A[A


1604it [01:25, 17.32it/s][A[A[A


1606it [01:25, 17.38it/s][A[A[A


1608it [01:25, 17.42it/s][A[A[A


1610it [01:26, 17.76it/s][A[A[A


1612it [01:26, 17.81it/s][A[A[A


1614it [01:26, 17.40it/s][A[A[A


1616it [01:26, 17.40it/s][A[A[A


1618it [01:26, 17.98it/s][A[A[A


1620it [01:26, 18.07it/s][A[A[A


1622it [01:26, 18.22it/s][A[A[A


1624it [01:26, 18.43it/s][A[A[A


1626it [01

Epoch: 2 loss: 2.989, perplexity: 19.872, BLEU: 0.454





1it [00:00,  1.55it/s][A[A[A


2it [00:01,  1.59it/s][A[A[A


3it [00:01,  1.64it/s][A[A[A


4it [00:02,  1.65it/s][A[A[A


5it [00:02,  1.68it/s][A[A[A


6it [00:03,  1.70it/s][A[A[A


7it [00:04,  1.71it/s][A[A[A


8it [00:04,  1.72it/s][A[A[A


9it [00:05,  1.73it/s][A[A[A


10it [00:05,  1.73it/s][A[A[A


11it [00:06,  1.71it/s][A[A[A


12it [00:07,  1.71it/s][A[A[A


13it [00:07,  1.72it/s][A[A[A


14it [00:08,  1.74it/s][A[A[A


15it [00:08,  1.73it/s][A[A[A


16it [00:09,  1.75it/s][A[A[A


17it [00:09,  1.74it/s][A[A[A


18it [00:10,  1.73it/s][A[A[A


19it [00:11,  1.74it/s][A[A[A


20it [00:11,  1.70it/s][A[A[A


21it [00:12,  1.70it/s][A[A[A


22it [00:12,  1.73it/s][A[A[A


23it [00:13,  1.75it/s][A[A[A


24it [00:13,  1.75it/s][A[A[A


25it [00:14,  1.75it/s][A[A[A


26it [00:15,  1.75it/s][A[A[A


27it [00:15,  1.76it/s][A[A[A


28it [00:16,  1.75it/s][A[A[A


29it [00:16,  1.73it/s][A

average loss = 2.847757714135306


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
1598it [01:26, 19.25it/s][A[A[A


1600it [01:26, 18.26it/s][A[A[A


1602it [01:26, 17.78it/s][A[A[A


1604it [01:26, 17.95it/s][A[A[A


1606it [01:26, 17.62it/s][A[A[A


1608it [01:27, 17.80it/s][A[A[A


1610it [01:27, 17.81it/s][A[A[A


1612it [01:27, 18.07it/s][A[A[A


1614it [01:27, 18.25it/s][A[A[A


1616it [01:27, 17.98it/s][A[A[A


1618it [01:27, 18.04it/s][A[A[A


1620it [01:27, 17.85it/s][A[A[A


1622it [01:27, 17.19it/s][A[A[A


1624it [01:27, 17.52it/s][A[A[A


1626it [01:28, 18.10it/s][A[A[A


1628it [01:28, 18.47it/s][A[A[A


1630it [01:28, 18.22it/s][A[A[A


1632it [01:28, 18.34it/s][A[A[A


1634it [01:28, 18.44it/s][A[A[A


1636it [01:28, 18.72it/s][A[A[A


1638it [01:28, 18.90it/s][A[A[A


1640it [01:28, 18.85it/s][A[A[A


1642it [01:28, 18.81it/s][A[A[A


1644it [01:29, 19.13it/s][A[A[A


1647it [01:29, 19.83it/s][A[A[A


1649it [01

Epoch: 3 loss: 2.858, perplexity: 17.418, BLEU: 0.473





1it [00:00,  1.75it/s][A[A[A


2it [00:01,  1.70it/s][A[A[A


3it [00:01,  1.70it/s][A[A[A


4it [00:02,  1.71it/s][A[A[A


5it [00:02,  1.70it/s][A[A[A


6it [00:03,  1.70it/s][A[A[A


7it [00:04,  1.72it/s][A[A[A


8it [00:04,  1.69it/s][A[A[A


9it [00:05,  1.71it/s][A[A[A


10it [00:05,  1.69it/s][A[A[A


11it [00:06,  1.73it/s][A[A[A


12it [00:07,  1.72it/s][A[A[A


13it [00:07,  1.72it/s][A[A[A


14it [00:08,  1.73it/s][A[A[A


15it [00:08,  1.72it/s][A[A[A


16it [00:09,  1.71it/s][A[A[A


17it [00:09,  1.70it/s][A[A[A


18it [00:10,  1.68it/s][A[A[A


19it [00:11,  1.68it/s][A[A[A


20it [00:11,  1.71it/s][A[A[A


21it [00:12,  1.73it/s][A[A[A


22it [00:12,  1.72it/s][A[A[A


23it [00:13,  1.73it/s][A[A[A


24it [00:14,  1.72it/s][A[A[A


25it [00:14,  1.72it/s][A[A[A


26it [00:15,  1.72it/s][A[A[A


27it [00:15,  1.72it/s][A[A[A


28it [00:16,  1.70it/s][A[A[A


29it [00:16,  1.72it/s][A

average loss = 2.692328441371796





0it [00:00, ?it/s][A[A[A


2it [00:00, 14.19it/s][A[A[A


4it [00:00, 14.78it/s][A[A[A


6it [00:00, 15.61it/s][A[A[A


8it [00:00, 16.20it/s][A[A[A


10it [00:00, 16.51it/s][A[A[A


12it [00:00, 16.34it/s][A[A[A


14it [00:00, 17.11it/s][A[A[A


16it [00:00, 17.25it/s][A[A[A


18it [00:01, 17.07it/s][A[A[A


20it [00:01, 17.60it/s][A[A[A


22it [00:01, 17.32it/s][A[A[A


24it [00:01, 17.56it/s][A[A[A


26it [00:01, 17.69it/s][A[A[A


28it [00:01, 17.62it/s][A[A[A


30it [00:01, 17.83it/s][A[A[A


32it [00:01, 18.07it/s][A[A[A


34it [00:01, 18.12it/s][A[A[A


36it [00:02, 17.64it/s][A[A[A


38it [00:02, 17.39it/s][A[A[A


40it [00:02, 17.48it/s][A[A[A


42it [00:02, 17.42it/s][A[A[A


45it [00:02, 18.56it/s][A[A[A


47it [00:02, 18.72it/s][A[A[A


49it [00:02, 19.04it/s][A[A[A


51it [00:02, 18.62it/s][A[A[A


53it [00:02, 18.71it/s][A[A[A


55it [00:03, 18.55it/s][A[A[A


57it [00:03, 17.88it/s][A