In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from IPython.core.display import display
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm, trange
from matplotlib import pyplot as plt

from utils_torch import *
from datasets.flickr8k import Flickr8kDataset

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
DATASET_BASE_PATH = 'data/flickr8k/'

In [None]:
train_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='train', device=device,
                            load_img_to_memory=False)
vocab, word2idx, idx2word, max_len = vocab_set = train_set.get_vocab()
val_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='val', vocab_set=vocab_set, device=device)
test_set = Flickr8kDataset(dataset_base_path=DATASET_BASE_PATH, dist='test', vocab_set=vocab_set, device=device)
len(train_set), len(val_set), len(test_set)

In [None]:
vocab_size = len(vocab)
vocab_size, max_len

In [None]:
samples_per_epoch = len(train_set)
samples_per_epoch

In [None]:
def train_model_new(train_loader, encoder, decoder, loss_fn, optimizer, vocab_size, acc_fn, desc=''):
    running_acc = 0.0
    running_loss = 0.0
    encoder.train()
    decoder.train()
    t = tqdm(iter(train_loader), desc=f'{desc}')
    for batch_idx, batch in enumerate(t):
        images, captions, lengths = batch

        optimizer.zero_grad()
        features = encoder(images)
        outputs = decoder(features, captions)

        loss = loss_fn(outputs.view(-1, vocab_size), captions.view(-1))
        loss.backward()
        optimizer.step()

        running_acc += acc_fn(torch.argmax(outputs.view(-1, vocab_size), dim=1), captions.view(-1))
        running_loss += loss.item()
        t.set_postfix({'loss': running_loss / (batch_idx + 1),
                       'acc': running_acc / (batch_idx + 1),
                       }, refresh=True)

    return running_loss / len(train_loader)

In [None]:
MODEL = "resnet50_monolstm"
EMBEDDING_DIM = 50
EMBEDDING = f"GLV{EMBEDDING_DIM}"
BATCH_SIZE = 16
LR = 1e-2
MODEL_NAME = f'saved_models/{MODEL}_b{BATCH_SIZE}_emd{EMBEDDING}'
NUM_EPOCHS = 2

In [None]:
from models.torch.resnet50_monolstm import Encoder

# encoder = Encoder(embed_size=300).to(device=device)

In [None]:
from models.torch.resnet50_monolstm import Decoder

encoder = Encoder(embed_size=EMBEDDING_DIM).to(device)
decoder = Decoder(EMBEDDING_DIM, 256, vocab_size, num_layers=2).to(device)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=train_set.pad_value).to(device)
acc_fn = accuracy_fn(ignore_value=train_set.pad_value)

# Specify the learnable parameters of the model
params = list(decoder.parameters()) + list(encoder.embed.parameters()) + list(encoder.bn.parameters())

# Define the optimizer
optimizer = torch.optim.Adam(params=params, lr=LR)

In [None]:
train_set.transformations = transforms.Compose([
    transforms.Resize(256),  # smaller edge of image resized to 256
    transforms.RandomCrop(224),  # get 224x224 crop from random location
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),  # convert the PIL Image to a tensor
    transforms.Normalize((0.485, 0.456, 0.406),  # normalize image for pre-trained model
                         (0.229, 0.224, 0.225))
])
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, sampler=None)
train_loss_min = 100
for epoch in range(NUM_EPOCHS):
    train_loss = train_model_new(desc=f'Epoch {epoch + 1}/{NUM_EPOCHS}', encoder=encoder, decoder=decoder,
                                 optimizer=optimizer, loss_fn=loss_fn, acc_fn=acc_fn,
                                 train_loader=train_loader, vocab_size=vocab_size)
#     state = {
#         'epoch': epoch + 1,
#         'state_dict': final_model.state_dict(),
#         'optimizer': optimizer.state_dict()
#     }
#     if (epoch + 1) % 2 == 0:
#         torch.save(state, f'{MODEL_NAME}_ep{epoch:02d}_weights.pt')
#     if train_loss < train_loss_min:
#         train_loss_min = train_loss
#         torch.save(state, f'{MODEL_NAME}''_best_train.pt')
# torch.save(final_model, f'{MODEL_NAME}_ep{5:02d}_weights.pt')
# final_model.eval()
encoder.eval()
decoder.eval()

In [None]:
t_i = 1003
feat = encoder(train_set[t_i][0].unsqueeze(0))
print(''.join([idx2word[idx.item()] + ' ' for idx in train_set[t_i][1]]))
print(''.join([idx2word[idx] + ' ' for idx in decoder.sample(feat.unsqueeze(1))]))

plt.imshow(train_set[t_i][0].detach().cpu().permute(1, 2, 0), interpolation="bicubic")

In [None]:
t_i = 2020
feat = encoder(val_set[t_i][0].unsqueeze(0))
print(''.join([idx2word[idx.item()] + ' ' for idx in val_set[t_i][1]]))
print(''.join([idx2word[idx] + ' ' for idx in decoder.sample(feat.unsqueeze(1))]))

plt.imshow(val_set[t_i][0].detach().cpu().permute(1, 2, 0), interpolation="bicubic")

In [None]:
t_i = 2020
feat = encoder(test_set[t_i][0].unsqueeze(0))
print(''.join([idx2word[idx.item()] + ' ' for idx in test_set[t_i][1]]))
print(''.join([idx2word[idx] + ' ' for idx in decoder.sample(feat.unsqueeze(1))]))

plt.imshow(val_set[t_i][0].detach().cpu().permute(1, 2, 0), interpolation="bicubic")
