In [1]:
import torch
import encoderDecoder
import torchvision.transforms as transforms
from dataset import get_loader

-----Loading vocabulary from vocab.pkl file!----
----Vocabulary successfully loaded from vocab.pkl file!---
-----Loading vocabulary from vocab.pkl file!----
----Vocabulary successfully loaded from vocab.pkl file!---


In [2]:
transform = transforms.Compose(
        [
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

val_dataset, val_loader = get_loader(
        root_folder="../data/flickr8k/Images",
        captions_file="../data/flickr8k/captions_val.txt",
        batch_size = 16,
        transform=transform,
        num_workers=1,
        split_type='val'
    )


-----Loading vocabulary from vocab.pkl file!----
----Vocabulary successfully loaded from vocab.pkl file!---


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


-----Loading vocabulary from vocab.pkl file!----
----Vocabulary successfully loaded from vocab.pkl file!---


In [3]:
device = "cuda" if torch.cuda.is_available() else \
         ("mps" if torch.backends.mps.is_available() else "cpu" ) 
#device = 'cpu'
print(device) 

cuda


In [4]:
def indices_to_words(indices, vocab, referenceCorpus):
    decoded_sentences = []
    count = 0 ; 
    for idx_row in indices:
        count += 1 
        sentence = []
    
        for idx in idx_row:
            word = vocab.get_word(idx.item())
            if(word == '<start>'):
                continue 
            if word == '<end>':  # Stop decoding if '<end>' token is encountered
                break
            sentence.append(word)
        if(referenceCorpus):
            corpus = []
            corpus.append(sentence)
            decoded_sentences.append(corpus)
        else: 
            decoded_sentences.append(sentence)
    # print("the loop runs for" , count)
    return decoded_sentences

def generate_caption(imgs, caption, model):
    pred, alpha = model(imgs, caption)
    pred = pred.argmax(2)
    return indices_to_words(pred, val_dataset.vocab, False)
    
model = encoderDecoder.EncoderDecoderAttention(256, 256, len(val_dataset.vocab))
print(model.load_state_dict(torch.load("../model/model_epoch10.pth")))
model = model.to(device)

<All keys matched successfully>


In [12]:
predictions = []
ground_truth = []
topn = -1
log_interval = 50
with torch.no_grad():
    for batch_idx, (imgs, captions, all_captions) in enumerate(val_loader):
        if batch_idx % log_interval == 0 and batch_idx > 0:
            print('Batch:',batch_idx)
        imgs, captions = imgs.to(device), captions.to(device)
        preds = generate_caption(imgs, captions, model)
        truth = indices_to_words(captions, val_dataset.vocab, False)
        predictions.extend(preds)
        ground_truth.extend(truth)
        if topn != -1 and batch_idx == topn:
            break

Batch: 50
Batch: 100
Batch: 150
Batch: 200
Batch: 250
Batch: 300
Batch: 350


In [13]:
pred_captions = [' '.join(p) for p in predictions]
true_captions = [' '.join(g) for g in ground_truth]

In [14]:
pred_captions

['a little is standing on a',
 'a is a woman in . a red . <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>',
 'a people are playing on a skateboard . . the . . . <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>',
 'a people are in',
 'a man man in a white on a field . a . .',
 'a dogs playing in a the field . . a . .',
 'a dog is a snow . a green . . the mouth .',
 'a little in in a in a the red . .',
 'a young in a and in in running on the field .',
 'a dog and is in ball . a . .',
 'a man is in a and white . . a mouth .',
 'a people is a in a .',
 'a black dog running a black dog . a mouth . a . the grass . .',
 'a people are',
 'a man is a and standing a trick . . a a a . . .',
 'a young in a red shirt in standing a man . a red . .',
 'a man is in a a and in',
 'a young is a boy in a red girl in 

In [15]:
true_captions

['a toddler is sitting down .',
 'man carrying a tool box on a sidewalk',
 'two men are sitting on a bench looking at a water fountain',
 'two people watch tv from over a white ledge .',
 'a young woman crosses a bridge through a field of tall grass .',
 'two youn boys lay in a dog house with their dog .',
 'a dog in the snow carrying a tree limb in its mouth .',
 'a girl sits and laughs while in a kiddie pool .',
 'a boy wearing black swimming trunks is standing in a fountain .',
 'the brown dog pulls a skier across snow-covered ground .',
 'the owner tries to hand a deflated ball to his dog .',
 'three climbers on rocks near waterfall .',
 'a black dog carries a white toy in his mouth and walks on the snowy field .',
 'four children .',
 'a man in jeans is using a yard work tool just above some pavement steps .',
 'a child in a gray t-shirt is bouncing a basketball on a basketball court .',
 'a dog races while wearing number 6 .',
 'a woman holding a ball chases a small boy running i

In [17]:
len(true_captions), len(pred_captions)

(6065, 6065)

In [18]:
combined = ['true|predicted']
combined.extend([true_captions[i] + '|' + pred_captions[i] for i in range(len(true_captions))])

In [19]:
combined

['true|predicted',
 'a toddler is sitting down .|a little is standing on a',
 'man carrying a tool box on a sidewalk|a is a woman in . a red . <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>',
 'two men are sitting on a bench looking at a water fountain|a people are playing on a skateboard . . the . . . <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>',
 'two people watch tv from over a white ledge .|a people are in',
 'a young woman crosses a bridge through a field of tall grass .|a man man in a white on a field . a . .',
 'two youn boys lay in a dog house with their dog .|a dogs playing in a the field . . a . .',
 'a dog in the snow carrying a tree limb in its mouth .|a dog is a snow . a green . . the mouth .',
 'a girl sits and laughs while in a kiddie pool .|a little in in a in a the red . .',


In [20]:
with open('predictions.txt','w') as f:
    f.writelines(combined)

In [None]:
for batch_idx, (imgs, captions, all_captions) in enumerate(val_loader):
    imgs, captions = imgs.to(device), captions.to(device)
    pred, alpha = model(imgs, captions)
    break

In [None]:
imgs.shape, captions.shape, all_captions

In [None]:
pred = pred.argmax(2)

In [None]:
def indices_to_words(indices, vocab, referenceCorpus):
    decoded_sentences = []
    count = 0 ; 
    for idx_row in indices:
        count += 1 
        sentence = []
    
        for idx in idx_row:
            word = vocab.get_word(idx.item())
            if(word == '<start>'):
                continue 
            if word == '<end>':  # Stop decoding if '<end>' token is encountered
                break
            sentence.append(word)
        if(referenceCorpus):
            corpus = []
            corpus.append(sentence)
            decoded_sentences.append(corpus)
        else: 
            decoded_sentences.append(sentence)
    # print("the loop runs for" , count)
    return decoded_sentences

In [None]:
indices_to_words(pred, val_dataset.vocab, False)

In [None]:


model = encoderDecoder.EncoderDecoderAttention(256, 256, len(val_dataset.vocab))
model.load_state_dict(torch.load("../model/model_epoch10.pth"))

In [None]:
pred, alpha = model(imgs, captions)
pred = pred.argmax(1)
indices_to_words(pred, val_dataset.vocab, False)

In [None]:
def generate_from_img(img_path, caption, transform, model, caption):
    img = Image.open(img_location).convert("RGB")
    img = transform(img)
    img.to(device)



In [None]:
predictions = []
for batch_idx, (imgs, captions, all_captions) in enumerate(val_loader):
    imgs, captions = imgs.to(device), captions.to(device)
    preds = generate_caption(imgs, captions, model)
    predictions.extend(preds)

In [None]:
a = torch.rand((1,3,6))
b = a.squeeze(0)

In [None]:
a

In [None]:
b

In [None]:
from PIL import Image

def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

def generate_caption_visualization(encoder, decoder, img_path, vocab, beam_size=3, smooth=True):
    
    
    img = pil_loader(img_path)
    img = transform(img)
    img = torch.FloatTensor(img)
    img = img.unsqueeze(0)
    img = img.to(device)
    img_features = encoder(img)
    
    print(img_features.device)


    #img = img.to('cpu')
    #img_features = img_features.to('cpu')
    
    #print(decoder.weight.device)
    img_features = img_features.expand(beam_size, img_features.size(1), img_features.size(2))
    sentence, alpha = decoder.caption(img_features, beam_size)

    token_dict = vocab.idx2word
    sentence_tokens = []
    for word_idx in sentence:
        sentence_tokens.append(token_dict[word_idx])
        if word_idx == word_dict['<end>']:
            break

    img = Image.open(img_path)
    w, h = img.size
    if w > h:
        w = w * 256 / h
        h = 256
    else:
        h = h * 256 / w
        w = 256
    left = (w - 224) / 2
    top = (h - 224) / 2
    resized_img = img.resize((int(w), int(h)), Image.BICUBIC).crop((left, top, left + 224, top + 224))
    img = np.array(resized_img.convert('RGB').getdata()).reshape(224, 224, 3)
    img = img.astype('float32') / 255

    num_words = len(sentence_tokens)
    w = np.round(np.sqrt(num_words))
    h = np.ceil(np.float32(num_words) / w)
    alpha = torch.tensor(alpha)

    plot_height = ceil((num_words + 3) / 4.0)
    ax1 = plt.subplot(4, plot_height, 1)
    plt.imshow(img)
    plt.axis('off')
    for idx in range(num_words):
        ax2 = plt.subplot(4, plot_height, idx + 2)
        label = sentence_tokens[idx]
        plt.text(0, 1, label, backgroundcolor='white', fontsize=13)
        plt.text(0, 1, label, color='black', fontsize=13)
        plt.imshow(img)

        if encoder.network == 'vgg19':
            shape_size = 14
        else:
            shape_size = 7

        if smooth:
            alpha_img = skimage.transform.pyramid_expand(alpha[idx, :].reshape(shape_size, shape_size), upscale=16, sigma=20)
        else:
            alpha_img = skimage.transform.resize(alpha[idx, :].reshape(shape_size,shape_size), [img.shape[0], img.shape[1]])
        plt.imshow(alpha_img, alpha=0.8)
        plt.set_cmap(cm.Greys_r)
        plt.axis('off')
    plt.show()

In [None]:
with torch.no_grad():
    encoder = model.encoder.to(device)
    decoder = model.decoder.to(device)
    
    generate_caption_visualization(encoder, decoder, "../data/flickr8k/Images/997722733_0cb5439472.jpg", val_dataset.vocab)

In [None]:
# used for calculating bleu scores
references = []
hypotheses = []
with torch.no_grad():
    for batch_idx, (imgs, captions, all_captions) in enumerate(val_loader):
        # imgs, captions = Variable(imgs).cuda(), Variable(captions).cuda()
        imgs, captions = imgs.to(device), captions.to(device)
        preds, alphas = model(imgs, captions)
        targets = captions[:, 1:]

        targets = pack_padded_sequence(targets, [len(tar) - 1 for tar in targets], batch_first=True)[0]
        packed_preds = pack_padded_sequence(preds, [len(pred) - 1 for pred in preds], batch_first=True)[0]

        att_regularization = alpha_c * ((1 - alphas.sum(1))**2).mean()

        loss = cross_entropy_loss(packed_preds, targets)
        loss += att_regularization

        total_caption_length = calculate_caption_lengths(vocab, captions)
        acc1 = accuracy(packed_preds, targets, 1)
        acc5 = accuracy(packed_preds, targets, 5)
        losses.update(loss.item(), total_caption_length)
        top1.update(acc1, total_caption_length)
        top5.update(acc5, total_caption_length)

        start_token = vocab.get_index('<start>')
        pad_token = vocab.get_index('<pad>')
        for cap_set in all_captions.tolist():
            caps = []
            for caption in cap_set:
                cap = [word_idx for word_idx in caption
                                if word_idx != start_token and word_idx != pad_token]
                caps.append(cap)
            references.append(caps)

        word_idxs = torch.max(preds, dim=2)[1]
        for idxs in word_idxs.tolist():
            hypotheses.append([idx for idx in idxs
                                   if idx != start_token and idx != pad_token])

        if batch_idx % log_interval == 0:
            print('Validation Batch: [{0}/{1}]\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Top 1 Accuracy {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Top 5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                      batch_idx, len(data_loader), loss=losses, top1=top1, top5=top5))
        print('References',references)
        print('Hypothesis')
