In [1]:
import pandas as pd
from multimodal_model import MMLSTM, BenchmarkLSTM
device =  'cuda'
from tqdm import tqdm
import torch
import torchvision
import PIL


KeyboardInterrupt



In [None]:
!pwd

In [None]:
DATA_PATH = "./data/miami_with_tag/"

In [None]:
import os
# assign directory
directory = DATA_PATH + 'eng/'

def get_miami_data(directory, train_pct=0.8):
    eng_dfs = None
    # iterate over files in
    # that directory
    for filename in os.listdir(directory):
        f = os.path.join(directory, filename)
        # checking if it is a file
        if os.path.isfile(f):
            if eng_dfs is None:
                eng_dfs = pd.read_csv(f)
            else:
                next_df = pd.read_csv(f)
                eng_dfs = pd.concat([eng_dfs, next_df], ignore_index=True,axis=0)

    filtered_eng_dfs = eng_dfs[eng_dfs['sentence'].apply(lambda x: len(x.split()) > 3)].reset_index()

    import regex as reg
    filtered_eng_dfs.sentence = filtered_eng_dfs.sentence.apply(lambda x: x.replace(r'/', ''))

    all_eng = ' '.join(filtered_eng_dfs.sentence)
    from multi_bpe import MultiBPE
    multi_bpe = MultiBPE()

    all_tokens = []
    for i in range(len(filtered_eng_dfs.sentence)):
        all_tokens.extend(multi_bpe.encode(filtered_eng_dfs.sentence[i],
                                           padding=False,
                                           use_eos=False))

    train_tokens, val_tokens = all_tokens[:int(len(all_tokens)*train_pct)], all_tokens[int(len(all_tokens)*train_pct):]
    return train_tokens, val_tokens

In [None]:
eng_train_tokens, eng_val_tokens = get_miami_data(DATA_PATH + 'eng/')
spn_train_tokens, spn_val_tokens = get_miami_data(DATA_PATH + 'spa/')


In [None]:
all_train_tokens = [tok for tok in eng_train_tokens]
all_train_tokens.extend(spn_train_tokens)
all_val_tokens = [tok for tok in spn_val_tokens]
all_val_tokens.extend(spn_val_tokens)

In [None]:
len(all_train_tokens) == len(eng_train_tokens) + len(spn_train_tokens)

In [None]:
def process_img(image_path): 
    transform = torchvision.transforms.Compose([
        # Resize image to 224 x 224 as required by most vision models
        torchvision.transforms.Resize(
            size=(224, 224)
        ),
        # Convert PIL image to tensor with image values in [0, 1]
        torchvision.transforms.ToTensor(),

        torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    im = PIL.Image.open(image_path)
    image = im.convert('RGB')
    image = transform(image)
    
    return image.view(1, image.size(0), image.size(1), image.size(2))


In [None]:
def finetuning(loaded_model, 
               lr=5e-4, 
               is_multimodal=False,
               train_visual_module=False,
               is_benchmark=False,
               herring_train_text=all_train_tokens,
               herring_test_text=all_val_tokens,
               max_length = 64,
               stride = 32,
               epochs = 5):
    
    herring_ids = herring_train_text
    
    herring_test_ids = herring_test_text
    
    if is_benchmark:
        mm_model = BenchmarkLSTM().to(device)
    else:
        mm_model = MMLSTM(is_multimodal=is_multimodal,
                     train_visual_module=train_visual_module).to(device)
    
    mm_model.load_state_dict(torch.load(f'./saved_models/{loaded_model}'))
    mm_model.train()

    inp = torch.tensor(herring_ids).view(1,-1)
    test_inp = torch.tensor(herring_test_ids).view(1,-1)

    white_img_path = './images/white_img.png'

    optimizer = torch.optim.Adam(mm_model.parameters(), lr=lr)
    loss_fct = torch.nn.CrossEntropyLoss()

    for epoch in range(epochs):
        nlls = []
        for i in tqdm(range(0, inp.size(1), stride)):
            begin_loc = max(i + stride - max_length, 0)
            end_loc = min(i + stride, inp.size(1))
            trg_len = end_loc - i  # may be different from stride on last loop
            input_ids = inp[:, begin_loc:end_loc].to(device)
            target_ids = input_ids.clone().to(device)
            target_ids[:, :-trg_len] = -100

            input_ids = input_ids[..., :-1].contiguous()
            target_ids = target_ids[..., 1:].contiguous()

            img = process_img(white_img_path).to(device)
            if is_benchmark:
                 output = mm_model(input_ids)
            else:
                output = mm_model.forward_text(input_ids)
            mm_model.zero_grad()
            loss = torch.nn.functional.cross_entropy(output.view(-1, output.size(-1)), target_ids.view(-1))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(mm_model.parameters(), 2.0)
            optimizer.step()
            neg_log_likelihood = loss * trg_len
            nlls.append(neg_log_likelihood)

            if i % 200 == 1:
                print("Current train loss:", epoch_train_loss/num_train_steps)

        ppl1 = torch.exp(torch.stack(nlls).sum() / end_loc)
        
        print(f'Training Perplexity for epoch {epoch}: {ppl1}')

    mm_model.eval()
    inp = test_inp
    nlls = []
    for i in tqdm(range(0, inp.size(1), stride)):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, inp.size(1))
        trg_len = end_loc - i  # may be different from stride on last loop
        input_ids = inp[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone().to(device)
        target_ids[:, :-trg_len] = -100

        input_ids = input_ids[..., :-1].contiguous()
        target_ids = target_ids[..., 1:].contiguous()

        img = process_img(white_img_path).to(device)
        with torch.no_grad():
            if is_benchmark:
                 output = mm_model(input_ids)
            else:
                output = mm_model.forward_text(input_ids)
            loss = torch.nn.functional.cross_entropy(output.view(-1, output.size(-1)), target_ids.view(-1))
            neg_log_likelihood = loss * trg_len
        nlls.append(neg_log_likelihood)

    ppl1 = torch.exp(torch.stack(nlls).sum() / end_loc)
    
    print("Validation Perplexity: ")
    print(ppl1)
    if is_multimodal:
        torch.save(mm_model.state_dict(), 
                       f'./saved_models/finetuned_multimodal_lstm')
    else:
        torch.save(mm_model.state_dict(), 
                       f'./saved_models/finetuned_benchmark_lstm')

In [None]:
def pretraining(loaded_model, 
               lr=5e-4, 
               is_multimodal=False,
               train_visual_module=False,
               is_benchmark=False,
               herring_train_text=all_train_tokens,
               herring_test_text=all_val_tokens,
               max_length = 64,
               stride = 32,
               epochs = 5):
    
    # herring_ids = herring_train_text
    
    herring_test_ids = herring_test_text
    
    if is_benchmark:
        mm_model = BenchmarkLSTM().to(device)
    else:
        mm_model = MMLSTM(is_multimodal=is_multimodal,
                     train_visual_module=train_visual_module).to(device)
    
    mm_model.load_state_dict(torch.load(f'./saved_models/{loaded_model}'))
    mm_model.eval()

    # inp = torch.tensor(herring_ids).view(1,-1)
    test_inp = torch.tensor(herring_test_ids).view(1,-1)

    white_img_path = './images/white_img.png'

#     optimizer = torch.optim.Adam(mm_model.parameters(), lr=lr)
#     loss_fct = torch.nn.CrossEntropyLoss()

#     for epoch in range(epochs):
#         nlls = []
#         for i in tqdm(range(0, inp.size(1), stride)):
#             begin_loc = max(i + stride - max_length, 0)
#             end_loc = min(i + stride, inp.size(1))
#             trg_len = end_loc - i  # may be different from stride on last loop
#             input_ids = inp[:, begin_loc:end_loc].to(device)
#             target_ids = input_ids.clone().to(device)
#             target_ids[:, :-trg_len] = -100

#             input_ids = input_ids[..., :-1].contiguous()
#             target_ids = target_ids[..., 1:].contiguous()

#             img = process_img(white_img_path).to(device)
#             if is_benchmark:
#                  output = mm_model(input_ids)
#             else:
#                 output = mm_model(input_ids, img)
#             mm_model.zero_grad()
#             loss = torch.nn.functional.cross_entropy(output.view(-1, output.size(-1)), target_ids.view(-1))
#             loss.backward()

#             torch.nn.utils.clip_grad_norm_(mm_model.parameters(), 1.0)
#             optimizer.step()
#             neg_log_likelihood = loss * trg_len
#             nlls.append(neg_log_likelihood)

#             if i % 200 == 1:
#                 print("Current train loss:", epoch_train_loss/num_train_steps)

#         ppl1 = torch.exp(torch.stack(nlls).sum() / end_loc)
        
#         print(f'Training Perplexity for epoch {epoch}: {ppl1}')


    inp = test_inp
    nlls = []
    for i in tqdm(range(0, inp.size(1), stride)):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, inp.size(1))
        trg_len = end_loc - i  # may be different from stride on last loop
        input_ids = inp[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone().to(device)
        target_ids[:, :-trg_len] = -100

        input_ids = input_ids[..., :-1].contiguous()
        target_ids = target_ids[..., 1:].contiguous()

        img = process_img(white_img_path).to(device)
        with torch.no_grad():
            if is_benchmark:
                 output = mm_model(input_ids)
            else:
                output = mm_model.forward_text(input_ids)
            loss = torch.nn.functional.cross_entropy(output.view(-1, output.size(-1)), target_ids.view(-1))
            neg_log_likelihood = loss * trg_len
        nlls.append(neg_log_likelihood)

    ppl1 = torch.exp(torch.stack(nlls).sum() / end_loc)
    
    print("Validation Perplexity: ")
    print(ppl1)

In [None]:
def generate_seq(model, seed, origin_size, size, temperature=1.0):
    """
    :param model: The complete RNN language model
    :param seed: The first few wordas of the sequence to start generating from
    :param size: The total size of the sequence to generate
    :param temperature: This controls how much we follow the probabilities provided by the network. For t=1.0 we just
        sample directly according to the probabilities. Lower temperatures make the high-probability words more likely
        (providing more likely, but slightly boring sentences) and higher temperatures make the lower probabilities more
        likely (resulting is weirder sentences). For temperature=0.0, the generation is _greedy_, i.e. the word with the
        highest probability is always chosen.
    :return: A list of integers representing a samples sentence
    """

    ls = seed.shape[0]

    tokens = seed.to(device)
    
    for i in range(origin_size+1, size):
        probs = model(tokens[None,:])

        # Extract the i-th probability vector and sample an index from it
        next_token = sample_logits(probs[0, i-1, :], temperature=temperature)
        
        tokens[i] = next_token

    return tokens

In [None]:
visual_model= 'multimodal_lstm_200000_1.0_15v6'
# text_model = 'monomodal_model_50000_0.00025_6'
benchmark_model = 'benchmark_model_200000_1.0_15v7'

## English Perplexity

In [None]:
finetuning(visual_model, 
           lr=5e-4, 
           is_multimodal=True,
           herring_train_text=eng_train_tokens,
           herring_test_text=eng_val_tokens,
           max_length = 32,
           stride = 16,
           epochs = 5)

In [None]:
# finetuning(text_model, 
#            lr=`.0, 
#            is_multimodal=False,
#            train_visual_module=False,
#            max_length = 32,
#            stride = 16,
#            epochs = 5)

In [None]:
finetuning(benchmark_model, 
           lr=5e-4, 
           is_multimodal=False,
           train_visual_module=False,
           herring_train_text=eng_train_tokens,
           herring_test_text=eng_val_tokens,
           is_benchmark=True,
           max_length = 32,
           stride = 16,
           epochs = 5)

## Spanish Perplexity

In [None]:
finetuning(visual_model, 
           lr=5e-4, 
           is_multimodal=True,
           herring_train_text=spn_train_tokens,
           herring_test_text=spn_val_tokens,
           max_length = 32,
           stride = 16,
           epochs = 5)

In [None]:
finetuning(benchmark_model, 
           lr=5e-4, 
           is_multimodal=False,
           train_visual_module=False,
           is_benchmark=True,
           herring_train_text=spn_train_tokens,
           herring_test_text=spn_val_tokens,
           max_length = 32,
           stride = 16,
           epochs = 5)

## English + Spanish Perplexity

In [None]:
finetuning(visual_model, 
           lr=5e-4, 
           is_multimodal=True,
           max_length = 32,
           stride = 16,
           epochs = 5)

In [None]:
finetuning(benchmark_model, 
           lr=5e-4, 
           is_multimodal=False,
           train_visual_module=False,
           is_benchmark=True,
           max_length = 32,
           stride = 16,
           epochs = 5)

## English + Spanish Zero-shot Perplexity

In [None]:
pretraining(benchmark_model, 
           lr=5e-4, 
           is_multimodal=False,
           train_visual_module=False,
           is_benchmark=True,
           max_length = 32,
           stride = 16,
           epochs = 5)

In [None]:
pretraining(visual_model, 
           lr=5e-4, 
           is_multimodal=True,
           max_length = 32,
           stride = 16,
           epochs = 5)

## Spanish Zero-shot Perplexity

In [None]:
pretraining(benchmark_model, 
           lr=5e-4, 
           is_multimodal=False,
           train_visual_module=False,
           herring_test_text=spn_val_tokens,
           is_benchmark=True,
           max_length = 32,
           stride = 16,
           epochs = 5)

In [None]:
pretraining(visual_model, 
           lr=5e-4, 
           is_multimodal=True,
           herring_test_text=spn_val_tokens,
           max_length = 32,
           stride = 16,
           epochs = 5)

## English Zero-shot Perplexity

In [None]:
pretraining(benchmark_model, 
           lr=5e-4, 
           is_multimodal=False,
           train_visual_module=False,
           herring_test_text=eng_val_tokens,
           is_benchmark=True,
           max_length = 32,
           stride = 16,
           epochs = 5)

In [None]:
pretraining(visual_model, 
           lr=5e-4, 
           is_multimodal=True,
           herring_test_text=eng_val_tokens,
           max_length = 32,
           stride = 16,
           epochs = 5)