In [1]:
!git clone https://github.com/CornerSiow/zero-shot-image-captioning.git

Cloning into 'zero-shot-image-captioning'...
remote: Enumerating objects: 164, done.[K
remote: Counting objects: 100% (75/75), done.[K
remote: Compressing objects: 100% (75/75), done.[K
remote: Total 164 (delta 34), reused 0 (delta 0), pack-reused 89[K
Receiving objects: 100% (164/164), 76.89 MiB | 16.87 MiB/s, done.
Resolving deltas: 100% (73/73), done.


In [2]:
!cp "zero-shot-image-captioning/code/Vocabulary.py" "Vocabulary.py"
!cp "zero-shot-image-captioning/code/DecoderLSTM.py" "DecoderLSTM.py"

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pickle
import torchvision
import torchvision.transforms as transforms
from skimage import io
from Vocabulary import Vocabulary
from DecoderLSTM import DecoderLSTM
import random
import numpy as np
import nltk
nltk.download('punkt')
random.seed(10)
torch.manual_seed(10)
np.random.seed(10)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generate Dataset

In [5]:
vocab = Vocabulary()
vocab.loadFile("zero-shot-image-captioning/data/vocab.pickle")

In [6]:
trainData = [
    ("zero-shot-image-captioning/img_test/working_1.jpg","A person using a laptop in the office"),
    ("zero-shot-image-captioning/img_test/eating_1.jpg","A person eats a banana in front of a laptop"),
    ("zero-shot-image-captioning/img_test/washing_1.jpg","A person washes his face in the sink"),
    ("zero-shot-image-captioning/img_test/cycling_1.jpg","A person riding a bike on a clear sky"),
    ("zero-shot-image-captioning/img_test/bus_1.jpg","Someone is waiting at the bus stop")
]
testData = [
    ("zero-shot-image-captioning/img_test/working_2.jpg","A person using a laptop in the office"),
    ("zero-shot-image-captioning/img_test/eating_2.jpg","A person eats a banana in front of a laptop"),
    ("zero-shot-image-captioning/img_test/washing_2.jpg","A person washes his face in the sink"),
    ("zero-shot-image-captioning/img_test/cycling_2.jpg","A person riding a bike on a clear sky"),
    ("zero-shot-image-captioning/img_test/bus_2.jpg","Someone is waiting at the bus stop")
]
class ImageCaptionDataset(Dataset):
    def __init__(self,imagesList, vocab):
        self.vocab = vocab
        self.imagesList = imagesList
        self.transform = transforms.Compose([
            transforms.ToTensor(), 
            transforms.Resize(512), 
            transforms.CenterCrop(512),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        self.samples = []
        self.captions = []
        for v in imagesList:
          image = io.imread(v[0])
          sample = self.transform(image)
          self.samples.append(sample)
          self.captions.append(vocab.convertSentenceToToken(v[1]))

    def __len__(self):
        return len(self.imagesList)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.samples[idx]
        caption = self.captions[idx]

        return sample,caption


In [7]:
def collate_fn(data):
    x = []
    y = []
    for _x, _y in data:
        x.append(_x)
        y.append(_y)
        
    y = torch.nn.utils.rnn.pad_sequence(y, batch_first=True)
    return torch.stack(x), y

trainDataset = ImageCaptionDataset(trainData, vocab)
testDataset = ImageCaptionDataset(testData, vocab)
trainLoader = DataLoader(trainDataset, batch_size = 1, shuffle = True, collate_fn =collate_fn)
testLoader = DataLoader(testDataset, batch_size = 1, shuffle = True, collate_fn =collate_fn)


#Initialize Encoder and Decoder

In [8]:
class Encoder(nn.Module):
    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size

        resnet = torchvision.models.resnet101(pretrained=True)  # pretrained ImageNet ResNet-101

        # Remove linear for the feature extraction.
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

    def forward(self, images):
        out = self.resnet(images)  # (batch_size, 2048, 1, 1)
        return out.flatten(1)

In [9]:
vocab_size = len(vocab)
embed_size = 1*1*2048
hidden_size = 256

encoder = Encoder()
encoder.to(device)
decoder = DecoderLSTM(embed_size, hidden_size, vocab_size)
decoder.to(device)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth


  0%|          | 0.00/171M [00:00<?, ?B/s]

DecoderLSTM(
  (embedding): Embedding(30, 2048)
  (lstm): LSTM(2048, 256, bias=False, batch_first=True)
  (linear): Linear(in_features=256, out_features=30, bias=True)
)

In [10]:
params = list(encoder.parameters()) + list(decoder.parameters())
criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()
criterion.to(device)
optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.9,0.999), eps=1e-8)

# Start Training

In [11]:
encoder.train()
decoder.train()
bar = tqdm(range(1000))
for epoch in bar:
    totalLoss = 0
    for x, y in trainLoader:       
        encoder.zero_grad()
        decoder.zero_grad()
        features = encoder(x.to(device))
        
  
        outputs = decoder(features, y.to(device))
        
        loss = criterion(outputs.view(-1, vocab_size), y.view(-1).to(device))
        loss.backward()
        optimizer.step()
        totalLoss += loss.item()
    bar.set_description("Total Loss: {:.4f}".format(totalLoss))

Total Loss: 0.2628: 100%|██████████| 1000/1000 [09:45<00:00,  1.71it/s]


# Start Testing

In [14]:
encoder.eval()
decoder.eval()
with torch.no_grad():
  for x, y in trainLoader:   
    features = encoder(x.to(device))
    output = decoder.sample(features.unsqueeze(0)) 
    sentence = vocab.clean_sentence(output) 
    ground_truth = vocab.clean_sentence(y.numpy()[0]) 
    print(sentence, ground_truth)  
  for x, y in testLoader:   
    features = encoder(x.to(device))
    output = decoder.sample(features.unsqueeze(0)) 
    sentence = vocab.clean_sentence(output) 
    ground_truth = vocab.clean_sentence(y.numpy()[0]) 
    print(sentence, ground_truth)  

 a person washes his face in the sink  a person washes his face in the sink
 a person eats a banana in front of a laptop  a person eats a banana in front of a laptop
 a person eats a banana in front of a laptop  a person using a laptop in the office
 a person washes his face in the sink  someone is waiting at the bus stop
 someone is waiting at the bus stop  a person riding a bike on a clear sky
 someone is waiting at the bus stop  someone is waiting at the bus stop
 someone is waiting at the bus stop  a person eats a banana in front of a laptop
 a person washes his face in the sink  a person washes his face in the sink
 a person washes his face in the sink  a person riding a bike on a clear sky
 a person washes his face in the sink  a person using a laptop in the office


In [13]:
import os
with torch.no_grad():
  for file in os.listdir("zero-shot-image-captioning/img_test"):
      if file.endswith(".jpg"):
        img = "zero-shot-image-captioning/img_test/" + file
        image = io.imread(img)
        x = testDataset.transform(image).unsqueeze(0)
        features = encoder(x.to(device))
        output = decoder.sample(features.unsqueeze(0)) 
        sentence = vocab.clean_sentence(output) 
        print(file + "\t"+sentence)

bus_2.jpg	 someone is waiting at the bus stop
working_2.jpg	 a person washes his face in the sink
washing_1.jpg	 a person washes his face in the sink
eating_2.jpg	 someone is waiting at the bus stop
cycling_2.jpg	 a person washes his face in the sink
cycling_1.jpg	 someone is waiting at the bus stop
working_1.jpg	 a person eats a banana in front of a laptop
bus_1.jpg	 a person washes his face in the sink
eating_1.jpg	 a person eats a banana in front of a laptop
washing_2.jpg	 a person washes his face in the sink
