In [1]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mharadai[0m ([33menergy_project_uab[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
from models import LSTMModel , LSTMModel_seeOnce
from dataset import ImageCaptionDataset
import pickle
from torch.utils.data import DataLoader
import wandb
import yaml
from typing import Dict
import gensim.downloader as api
import tensorflow_hub as hub
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

import torch
import torch.nn as nn

from tqdm import tqdm
from icecream import ic

import numpy as np

import os

os.environ["TOKENIZERS_PARALLELISM"] = "False"


###SCRIPT CONFIG!####
device = "mps" #r u running cuda my boy? or mps? :D
num_epochs = 80
batch_size = 10
COSINE_SIM_IMPORTANCE = 0.8
print("Have you runned wandb login?? OK. go aheadd...")
#####################

#convert to dictionary a yaml 
def nested_dict(original_dict):
    nested_dict = {}
    for key, value in original_dict.items():
        parts = key.split(".")
        d = nested_dict
        for part in parts[:-1]:
            if part not in d:
                d[part] = {}
            d = d[part]
        d[parts[-1]] = value
    return nested_dict

#load datasets
with open('/Users/josepsmachine/Documents/UNI/DL/dlnn-project_ia-group_10/dataset/train_dataset.pkl', 'rb') as inp:
    train_dataset = pickle.load(inp)
with open('/Users/josepsmachine/Documents/UNI/DL/dlnn-project_ia-group_10/dataset/val_dataset.pkl', 'rb') as inp:
    val_dataset = pickle.load(inp)
#with open('/Users/josepsmachine/Documents/UNI/DL/dlnn-project_ia-group_10/dataset/debug_dataset.pkl', 'rb') as inp:
#    debug_dataset = pickle.load(inp)

#create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
#debug_dataloader = DataLoader(debug_dataset, batch_size=1, shuffle=False)

#load hyperparamaters to do grid search on
#setup wandb stuff
with open('hyperparams.yaml', 'r') as stream:
    try:
        sweep_config = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

#create sweep
sweep_id = wandb.sweep(sweep_config, project="energy_project_uab")

#load word2vec pretrained embedding layer
word2vec_emb = api.load('word2vec-google-news-300')
word2vec_emb = torch.FloatTensor(word2vec_emb.vectors)

#import vocabulary to word2vec indexes that we know
with open('vocabidx2word2vecidx.pkl', 'rb') as inp:
    vocabidx2word2vecidx = pickle.load(inp)

with open('vocabulary.pkl', 'rb') as inp:
    vocabulary = pickle.load(inp)

word2vec_emb = word2vec_emb[vocabidx2word2vecidx]

word2vec_emb = nn.Embedding.from_pretrained(word2vec_emb)
word2vec_emb.requires_grad_ = False #freeze word2vec embeddding layer

#load universal sentence encoder to define our own loss function
sntc_enc = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

#define cross entropy loss function
cross_entrop = nn.CrossEntropyLoss()

  from .autonotebook import tqdm as notebook_tqdm


Have you runned wandb login?? OK. go aheadd...
Create sweep with ID: trvl8xhn
Sweep URL: https://wandb.ai/energy_project_uab/energy_project_uab/sweeps/trvl8xhn


In [3]:
def count_parameters(model):
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_params = param.numel()
            print(f"{name}: {num_params}")
            total_params += num_params
    print(f"Total Parameters: {total_params}")

In [6]:
def train(config: Dict = None):
    with wandb.init(config, project="dl2023_imagecaptioning", entity = "dl2023team"):
        config = wandb.config
        config = nested_dict(config)
        
        if config["embedding_layer"] == "word2vec":
                emb_layer = word2vec_emb
        else:
            #the learnt embedding layer will have the same vocab size as word2vec for comparaison reasons.
            emb_layer = nn.Embedding(num_embeddings=word2vec_emb.weight.shape[0], embedding_dim=config["hidden_size"])

        if config["see_once"]: #if we want the model with residual at each step or not
            model = LSTMModel_seeOnce(input_dim=512,embedding_layer=emb_layer,hidden_dim=config["hidden_size"],n_layers=config['num_layers'])
        else:
            model = LSTMModel(input_dim=512,embedding_layer=emb_layer,hidden_dim=config["hidden_size"],n_layers=config['num_layers'])
        
        #print(config["see_once"])
        #count_parameters(model)
        
        model.init_weights()
        model.to(device)

        #define optimizer for this run
        optimizer_config = config["optimizer"]
        if optimizer_config["type"] == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr = optimizer_config['lr'])
        
        #define loss for this run
        if config["loss_funct"] == "crossentropy":
            def loss_funct(ref,pred): 
                loss = cross_entrop(pred,ref)
                return loss
        else:
            def loss_funct(ref,pred):
                loss1 = cross_entrop(pred,ref)

                pred_keys = torch.argmax(pred,axis=-1)
                
                pred_sntc = ""
                target_sntc = ""
                
                for batch_ref,batch_pred in zip(ref,pred_keys):

                    p_word =  vocabulary[batch_pred.item()]
                    t_worc =  vocabulary[batch_ref.item()]
                    pred_sntc += " " + vocabulary[batch_pred.item()]
                    target_sntc += " " + vocabulary[batch_ref.item()]

                #ic(caption)
                #ic(target_sntc)
                #ic(pred_sntc)

                embeddings = sntc_enc.encode([pred_sntc, target_sntc])
                similarity = cosine_similarity(embeddings[0].reshape(1, -1), embeddings[1].reshape(1, -1))
                loss = 1/(similarity[0][0])*COSINE_SIM_IMPORTANCE + loss1*(1-COSINE_SIM_IMPORTANCE)
                return loss

        ############
        #Train loop:
        ############

        best_loss = float('inf')
        for epoch in range(num_epochs):
            model.train()
            training_losses = [] # renamed from epoch_losses
            progress_bar = tqdm(enumerate(train_dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}")
            
            for batch,(X1,X2,caption) in progress_bar:
            #for batch,(X1,X2) in enumerate(train_dataloader):
                optimizer.zero_grad()
                
                X1 = X1.to(device) 
                X2 = X2.to(device)
                out,h,c = model(X1,X2)
                ref = X2
                #change ref to join batch and sequence dim
                ref = ref.view(ref.shape[0]*ref.shape[1])


                #same for the out logits
                #change ref to join batch and sequence dim
                out = out.view(out.shape[0]*out.shape[1],out.shape[2])
                ic(out.shape)
                
                loss = loss_funct(ref,out)
                loss.backward()
                optimizer.step()    
                training_losses.append(loss.item())
                progress_bar.set_postfix({'Batch Loss': loss.item()})

            average_training_loss = sum(training_losses) / len(training_losses) # renamed from avg_loss
            wandb.log({'Train_Epoch_Loss': average_training_loss})

            model.eval()  
            with torch.no_grad():  
                validation_losses = [] # renamed from val_losses
                for X1,X2,caption in tqdm(val_dataloader, desc='Validation'):
                    X1 = X1.to(device) 
                    X2 = X2.to(device)

                    out,h,c = model(X1,X2)
                    ref = X2.float()

                    ref = ref.view(ref.shape[0]*ref.shape[1])
                    out = out.view(out.shape[0]*out.shape[1],out.shape[2])
                    
                    loss = loss_funct(ref,out)

                    validation_losses.append(loss.item())

                average_validation_loss = sum(validation_losses) / len(validation_losses) # renamed from avg_val_loss
                wandb.log({'Validation_Epoch_Loss': average_validation_loss})

            if average_training_loss < best_loss:
                best_loss = average_training_loss
                torch.save(model.state_dict(), f'{wandb.run.id}_LSTM&resnet18.pt')
                wandb.save(f'{wandb.run.id}_LSTM&resnet18.pt')
                print(f"Model saved at {'{wandb.run.id}_LSTM&resnet18.pt'}")

        wandb.finish()

ic| pred_sntc: 

(' eos white the blue wearing the on and posing on and in launches launches '
                'launches launches launches launches launches launches launches launches '
                'launches launches launches launches launches launches launches launches '
                'launches launches launches launches launches eos of dog dog holding is from '
                'the and in launches launches launches launches launches launches launches '
                'launches launches launches launches launches launches launches launches '
                'launches launches launches launches launches launches launches launches '
                'launches launches eos white is is and people the in launches launches '
                'launches launches launches launches launches launches launches launches '
                'launches launches launches launches launches launches launches launches '
                'launches launches launches launches launches launches launches launches '
        

In [5]:
wandb.agent(sweep_id, function=train)

[34m[1mwandb[0m: Agent Starting Run: 81470bse with config:
[34m[1mwandb[0m: 	embedding_layer: word2vec
[34m[1mwandb[0m: 	hidden_size: 30
[34m[1mwandb[0m: 	loss_funct: CrssEntrop+UnivSntcEnc_cosDist
[34m[1mwandb[0m: 	num_layers: 1
[34m[1mwandb[0m: 	optimizer: {'lr': 0.007980127776484412, 'type': 'adam'}
[34m[1mwandb[0m: 	see_once: False
[34m[1mwandb[0m: Currently logged in as: [33mharadai[0m ([33mdl2023team[0m). Use [1m`wandb login --relogin`[0m to force relogin


False
imf2lstm_h.weight: 15360
imf2lstm_h.bias: 30
imf2lstm_c.weight: 15360
imf2lstm_c.bias: 30
lstm.weight_ih_l0: 36000
lstm.weight_hh_l0: 3600
lstm.bias_ih_l0: 120
lstm.bias_hh_l0: 120
LM_FC.weight: 252780
LM_FC.bias: 8426
Total Parameters: 331826


Epoch 1/80: 0it [00:00, ?it/s]ic| out.shape: torch.Size([350, 8426])
ic| caption: ('bos little girl runs along the beach toward man who admires the sky eos',
              'bos ragged man sleeping behind building eos',
              'bos the dog in vest leaps in the air and there is bird flying eos',
              'bos the surfer rides wave eos',
              'bos brown dog runs on the sand carrying stick eos',
              'bos man with camera aimed at singer in green jacket eos',
              'bos two dogs running through sand eos',
              'bos little black dog chasing little brown one eos',
              'bos dog shaking water off eos',
              'bos people dressed in all white are looking at some shaved lambs eos')
ic| target_sntc: (' eos grass to its plays on near tall of onto directly on toward in launches '
                  'launches launches launches launches launches launches launches launches '
                  'launches launches launches launches launches la