## Pre Process

In [20]:
import pandas as pd, numpy as np, torch, argparse
from PIL import Image
from tqdm.notebook import tqdm
from torchvision import transforms
from torch.nn.utils.rnn import pad_sequence,pad_packed_sequence
from torch.utils.data import DataLoader
import torch, os, pickle
from scripts.pre_process import *
import os, shutil
import matplotlib.pyplot as plt

In [2]:
class arg_inputs():
    def __init__(self,data_dir=None, output_dir=None, meta_data=None,
                 dataset_size=None, batch_size=100, num_epochs=None,
                 project_path=None):
        
        self.data_dir=data_dir
        self.output_dir=output_dir
        self.dataset_size=dataset_size
        self.batch_size=batch_size
        self.meta_data = meta_data
        self.num_epochs = num_epochs
        self.project_path = project_path
   

### Set Args

In [3]:
data_dir = '../../image_captioning_pytorch/data/toy_dataset/'
output_dir = 'smallest_test/'
batch_size = 20
meta_data = '../../image_captioning_pytorch/data/toy_dataset/toy_dataset_label.csv'
dataset_size = 100


In [4]:
args = arg_inputs(data_dir=data_dir,output_dir=output_dir,meta_data=meta_data, batch_size=batch_size, 
                  dataset_size=dataset_size)


In [5]:
def reset_everything(path):
    try:
        folder = path
        for filename in tqdm(os.listdir(folder)):
            file_path = os.path.join(folder, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print('Failed to delete %s. Reason: %s' % (file_path, e))
                
    except Exception as e:
        print('I do not think this folder exists: {}'.format(e))



In [6]:
#reset_everything(output_dir)

In [7]:
#df = pd.read_csv('../../image_captioning_pytorch/data/toy_dataset/toy_dataset_label.csv', sep='\t')

In [8]:
def resize_images(args, output, pretrained):

    df = pd.read_csv(args.meta_data, sep='\t')

    if os.path.exists(args.output_dir) == False:
        os.mkdir(args.output_dir)

        image_folder = '{}images/'.format(args.output_dir)
        os.mkdir(image_folder)


    image_folder = '{}images/'.format(args.output_dir)
    
    pre_processed_images = []
    processed_images = []
    captions = []
    image_paths = []
    failed_path = []
    
    print('Processing Artwork\n')

    for num, i in tqdm(enumerate(zip(df['FILE'], df['TITLE'])), total=args.dataset_size):
        
        if num == args.dataset_size:
            break
            
        try:
            
            
                
            image = Image.open('{}{}'.format(args.data_dir, i[0]))
            
            resized_image = pre_process(image, pretrained)
            
            pre_processed_images.append(image)
            processed_images.append(resized_image)
            
            captions.append(i[1])
            
            if args.output_dir and output == True:
                
                img2 = reconstruct_image(resized_image)
                img2.save('{}{}'.format(image_folder, i[0]))
                image_paths.append('{}{}'.format(image_folder,i[0]))

            else:

                image_paths.append('{}{}'.format(args.data_dir,i[0]))

        except Exception as e:
            print('something Wrong: {}'.format(e))
            failed_path.append('{}{}'.format(args.data_dir,i[0]))

            continue
        
    print('\n{} artworks added to dataset'.format(len(processed_images)))
    print('{} failed to load\n'.format(len(failed_path)))

    return processed_images,captions, image_paths, pre_processed_images

### Resize images for Resnet pretrained 

In [9]:
images, captions, image_paths, preprocessed_images = resize_images(args, output=True, pretrained=True)

Processing Artwork



HBox(children=(FloatProgress(value=0.0), HTML(value='')))

something Wrong: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
something Wrong: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
something Wrong: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
something Wrong: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
something Wrong: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
something Wrong: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
something Wrong: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
something Wrong: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
something Wrong: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]
something Wrong: output with shape [1, 224, 224] doesn't match the broadcast shape [3, 224, 224]


90 artworks added to dataset

### Tokenizing Titles 

In [10]:
tokenized_titles = [tokenize(i) for i in captions]
vocab = gen_vocab(tokenized_titles) 
print('\nOutputting vocab object to {}'.format(pickle_data(vocab, args.output_dir, 'vocab')))
encoded_titles, title_lengths = encode(tokenized_titles, vocab)

Generating Vocab



HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))



173 tokens in vocab

Outputting vocab object to smallest_test/vocab/vocab.pkl


### Outputting DataLoader

In [11]:
dataset = CustomDataset(images=images, captions=encoded_titles, image_paths=image_paths)
train_dataloader = DataLoader(dataset, batch_size = args.batch_size)
print('\nOutputting dataloader object to {}'.format(pickle_data(train_dataloader, args.output_dir, 'dataloader_different_preprocessing')))
             



Outputting dataloader object to smallest_test/dataloader_different_preprocessing/dataloader_different_preprocessing.pkl


### Sample Dataset

In [27]:
sampler = iter(train_dataloader)

In [28]:
def sample_dataset(data_sampler):
    
    image, caption, i = next(data_sampler)
    image.shape
    plt.imshow(image.squeeze(0).permute(1,2,0))
    plt.title('Sample Image')
    plt.show()
    
    original_caption = decode_text(caption.squeeze(0).numpy())
    


In [26]:
sample_dataset(sampler)

StopIteration: 

In [None]:
test_dataloader = DataLoader(dataset, batch_size = 1 ,shuffle=True)


## Training

In [None]:
import pickle, argparse
from scripts.models import EncoderCNN, DecoderRNN
import torch, math
from tqdm.notebook import tqdm
import torch.nn as nn
from scripts.train import *
import matplotlib.pyplot as plt
import torchvision

In [None]:
def train(train_dataloader, args, vocab, num_epochs):

    encoder_model = EncoderCNN(300)
    decoder_model = DecoderRNN(embed_size=300, hidden_size=512, vocab_size=len(vocab))

    device = 'cpu'#torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder_model.to(device)
    decoder_model.to(device)
    criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available() else nn.CrossEntropyLoss()

    params = list(decoder_model.parameters()) + list(encoder_model.embed.parameters())

    total_step = math.ceil(len(train_dataloader.dataset.caption_lengths) / train_dataloader.batch_sampler.batch_size)

    optimizer = torch.optim.Adam(encoder_model.parameters(), lr=0.01)

    encoder_model.train()
    decoder_model.train()
    vocab_size = len(vocab)
    num = 1

    for epoch in tqdm(range(1, num_epochs+1),total=num_epochs):
        
        
        
        for i in tqdm(train_dataloader):
            
            
            try:
                image = i[0].to(device)
                caption = i[1].to(device)

                decoder_model.zero_grad()
                encoder_model.zero_grad()


                features = encoder_model(image)
                outputs = decoder_model(features, caption)

                loss = criterion(outputs.view(-1, vocab_size), caption.view(-1))
                loss.backward()
                optimizer.step()

                num+=1 
            except Exception as e:
                failed_batch += 1
                continue
                
                
        print('Loss after epoch {}: {}'.format(epoch, loss))

    return encoder_model, decoder_model

### Training Args

In [None]:
train_args=arg_inputs(num_epochs=50, project_path='medium_dataset/')

In [None]:
#vocab, dataloader = load_objects(train_args.project_path)

In [None]:
trained_encoder, trained_decoder = train(train_dataloader, train_args, vocab, 50)



In [None]:
trained_encoder = trained_encoder.save()

In [None]:
def get_test_obj(dataset_tuple):

    test_image = dataset_tuple[0]
    test_caption = decode_text(dataset_tuple[1].numpy())
    
    return test_image, test_caption

In [None]:
def decode_text(list_nums):
    decode_vocab = {num: word for word, num in vocab.items()}
    return ' '.join([decode_vocab[i] for i in list_nums])

In [None]:
titles = []

test_encoder.cuda()
test_decoder.cuda()

for i in tqdm(dataset[:20]):
    
    image, caption = get_test_obj(i)
    
    image = image.cuda()
    features = trained_encoder(image.unsqueeze(0))
    output = trained_decoder.sample(features.unsqueeze(1))
    
    cleaned_text = decode_text(output)
    
    titles.append((cleaned_text, caption))
    
    
    
    

In [None]:
    plt.imshow(dataset[3][0].permute(1,2,0))


In [None]:
it = iter(test_dataloader)



In [None]:
def get_prediction(train_dataloader, title):
    
    
    orig_image, caption, path = next(train_dataloader)
    image = orig_image.cuda()
    plt.imshow(orig_image.squeeze(0).permute(1,2,0))
    plt.title('Sample Image')
    #plt.show()
    features = trained_encoder(image.cpu()).unsqueeze(0)
    output = trained_decoder.sample(features)    
    sentence = decode_text(output)
    original_caption = decode_text(caption.squeeze(0).numpy())
    
    title.append((sentence,original_caption))
   


In [None]:
title = []

for i in tqdm(range(0,5000),total=5000):
    get_prediction(it, title)

In [None]:
set([i[0] for i in title])