In [1]:
#Imports
import os
import pickle
import random
import numpy as np
import shutil
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from argparse import Namespace
import torch.nn as nn
import torch
from torch import linalg as LA
import wandb
import logging
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import os

logger = logging.getLogger(__name__)
wandb_logger = lambda dir, version: WandbLogger(
    name="wandb", save_dir=dir, version=version
)
csvlogger = lambda dir, version: CSVLogger(dir, name="csvlogs", version=version)
tblogger = lambda dir, version: TensorBoardLogger(dir, name="tblogs", version=version)

def get_loggers(dir,version,lis=["csv"]):
    lgrs = []
    if "wandb" in lis:
        lgrs.append(wandb_logger(dir, version))
    if "csv" in lis:
        lgrs.append(csvlogger(dir, version))
    if "tb" in lis:
        lgrs.append(tblogger(dir, version))
    return lgrs

#global vars
DATA_DIR = '/common/home/vk405/Projects/Crossmdl/Data/Recipe/'
EMB_TRN = DATA_DIR+'embeddings_train1.pkl'
EMB_VAL = DATA_DIR+'embeddings_val1.pkl'
ING_TRN = DATA_DIR+'ingredients_embeddings_train.pkl'
ING_VAL = DATA_DIR+'ingredients_embeddings_val.pkl'
#os.listdir(DATA_DIR)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# with open(EMB_TRN, 'rb') as files:
#     emb_trn = pickle.load(files)
# with open(EMB_VAL, 'rb') as files:
#     emb_val = pickle.load(files)

# with open(ING_TRN, 'rb') as files:
#     ing_trn = pickle.load(files)
# with open(ING_VAL, 'rb') as files:
#     ing_val = pickle.load(files)



In [3]:
class RecipeDset(Dataset):
    def __init__(self,data_dir= '/common/home/vk405/Projects/Crossmdl/Data/Recipe/'\
        ,split='train',txt_emb_type='total'):
        self.DATA_DIR = data_dir
        self.init_data_locs()
        self.txt_emb_type = txt_emb_type
        self.split = split
        if self.txt_emb_type == 'total':
            if self.split == 'train':
                with open(self.EMB_TRN, 'rb') as files:
                    self.emb_vid,self.emb_txt,self.ids = pickle.load(files)
            elif self.split == 'valid':
                with open(self.EMB_VAL, 'rb') as files:
                    self.emb_vid,self.emb_txt,self.ids = pickle.load(files)
            elif self.split == 'test':
                with open(self.EMB_TST, 'rb') as files:
                    self.emb_vid,self.emb_txt,self.ids = pickle.load(files)


    def __getitem__(self,idx):
        rind = idx+1
        if rind == len(self.emb_txt):
            rind = idx-1
        return self.emb_vid[idx],self.emb_txt[idx],self.emb_vid[rind]

    def __len__(self):
        return len(self.emb_txt)

    def init_data_locs(self):
        #TOTAL EMBEDDINGS
        self.EMB_TRN = self.DATA_DIR+'embeddings_train1.pkl'
        self.EMB_VAL = self.DATA_DIR+'embeddings_val1.pkl'
        self.EMB_TST = self.DATA_DIR+'embeddings_test1.pkl'
        #INGRIDIENTS Embeddings
        self.ING_TRN = self.DATA_DIR+'ingredients_embeddings_train.pkl'
        self.ING_VAL = self.DATA_DIR+'ingredients_embeddings_val.pkl'
        self.ING_TST = self.DATA_DIR + 'ingredients_embeddings_test.pkl'

        #TITLE EMBEDDINGS
        self.TIT_TRN = self.DATA_DIR+'title_embeddings_train.pkl'
        self.TIT_VAL = self.DATA_DIR+'title_embeddings_val.pkl'
        self.TIT_TST = self.DATA_DIR + 'title_embeddings_test.pkl'

        #Instructions
        self.INS_TRN = self.DATA_DIR+'instructions_embeddings_train.pkl'
        self.INS_VAL = self.DATA_DIR+'instructions_embeddings_val.pkl'
        self.INS_TST = self.DATA_DIR+'instructions_embeddings_test.pkl'


        

In [4]:
#trn_data = RecipeDset()

In [5]:
class EmbModel(nn.Module):
    def __init__(self,params):
        super().__init__()
        self.params = params
        self.lyrs = []
        dim = self.params['input_dim']
        for i in range(self.params['lyrs']):
            lyr = nn.Linear(dim,dim)
            if self.params['act'] == 'relu':
                non_lin = nn.ReLU()
            else:
                raise NotImplementedError
            self.lyrs.append(lyr)
            self.lyrs.append(non_lin)
        self.feedforward = nn.Sequential(*self.lyrs)
    def forward(self,x):
        return self.feedforward(x)





class RecipeModel(pl.LightningModule):
    def __init__(self,hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.txt_emb = EmbModel(hparams.txt_model)
        self.img_emb = EmbModel(hparams.img_model)
        self.shared = nn.Linear(hparams.txt_model['fin_dim'],\
            hparams.shared_emb_dim)
    def forward(self,x):
        # Ignores anchor embedding
        img,txt = x
        img_emb = self.img_emb(img)
        txt_emb = self.txt_emb(txt)
        #anch_img_emb = self.img_emb(anch_img)

        img_fin_emb = self.shared(img_emb)
        txt_fin_emb = self.shared(txt_emb)
        #anch_img_fin_emb = self.shared(anch_img_emb)
        return img_fin_emb,txt_fin_emb

    def training_step(self,batch,batch_idx):
        img,txt,anch_img = batch
        anch_img_emb = self.img_emb(anch_img)
        anch_img_fin_emb = self.shared(anch_img_emb)
        img_fin_emb,txt_fin_emb = self((img,txt))
        loss,log_losses = self.get_loss(img_fin_emb,txt_fin_emb,anch_img_fin_emb)
        self.log("train_loss",loss,on_step=True)
        self.log("cos_sim_n",log_losses[0],on_step=True)
        self.log("cos_sim_p",log_losses[-1],on_step=True)

        return loss

    def validation_step(self,batch,batch_idx):
        img,txt,anch_img = batch
        anch_img_emb = self.img_emb(anch_img)
        anch_img_fin_emb = self.shared(anch_img_emb)
        img_fin_emb,txt_fin_emb = self((img,txt))
        loss,log_losses = self.get_loss(img_fin_emb,txt_fin_emb,anch_img_fin_emb)
        self.log("val_loss",loss,on_step=False, on_epoch=True)
        self.log("val_cos_sim_n",log_losses[0],on_step=False, on_epoch=True)
        self.log("val_cos_sim_p",log_losses[-1],on_step=False, on_epoch=True)
        return loss

    def get_loss(self,img,txt,anch,reduce='mean'):
        #𝐿𝑐𝑜𝑠(𝒂,𝒑,𝒏)=max[𝑑(𝒂,𝒏)−𝑑(𝒂,𝒑)+𝜖,0]
        eps = self.hparams.eps if 'eps' in self.hparams else 1e-8
        
        im_norm,txt_norm,anch_norm = LA.norm(img,dim=-1).reshape(img.shape[0],1),\
        LA.norm(txt,dim=-1).reshape(txt.shape[0],1),LA.norm(anch,dim=-1).reshape(anch.shape[0],1)
        normd_img = img/im_norm
        normd_txt = txt/txt_norm
        normd_anch = anch/anch_norm

        cos_sim_p = torch.sum(normd_img*normd_txt,dim=-1)
        cos_sim_n = torch.sum(normd_anch*normd_txt,dim=-1)

        unclipped_loss = cos_sim_n-cos_sim_p+eps
        clipped_loss = torch.relu(unclipped_loss)
        if reduce == 'mean':
            return torch.mean(clipped_loss),(torch.mean(cos_sim_n),torch.mean(cos_sim_p))

    def configure_optimizers(self):
        lr = self.hparams.lr if 'lr' in self.hparams else 1e-3
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        return optimizer




        

In [6]:

# cfg = Namespace(
#     version = 'temp',
#     artifacts_loc = "/common/home/vk405/Projects/Crossmdl/nbs/Recipe/",
#     data_dir = "/common/home/vk405/Projects/Crossmdl/Data/Recipe/",
#     mode = 'train',
#     txt_model = {'input_dim':1024,'lyrs':2,'fin_dim':1024,'act':'relu'},
#     img_model = {'input_dim':1024,'lyrs':2,'fin_dim':1024,'act':'relu'},
#     shared_emb_dim = 1024,
#     txt_emb_type = 'total',
#     learning_rate = 1e-4,
#     loggers = ["csv","wandb"],
#     seed = 0,
#     cbs = ["checkpoint"],
#     trainer = {'log_every_n_steps': 1,
#     'max_epochs': 10},
#     checkpoint = {"every_n_epochs": 1,
#     "monitor": "train_loss"},


#     use_precomp_emb = True,
#     edim = 100,
#     attnhdim = 50,
#     nheads = 10,
#     wrdim = 768,
#     vidim = 512,
#     hdim = 30,
#     dropoutp=0.0,
#     seqlen=26,
#     framecnt=499,
#     batch_size=512


# )

In [7]:
# Sanity check -> OVERFIT ON SMALL DATA

# trainer = pl.Trainer(overfit_batches=10)
# trn_loader = DataLoader(trn_data,50,shuffle=True)
# net = RecipeModel(cfg)

# trainer.fit(net,trn_loader)


In [8]:
#batch = next(iter(trn_loader))

In [9]:
def run(cfg):
    pl.seed_everything(cfg.seed)
    dir = cfg.artifacts_loc
    version = str(cfg.version)
    logger_list = get_loggers(dir, version,cfg.loggers)
    cbs = []
    if "early_stop" in cfg.cbs:
        #? does'nt really work atm
        params = cfg.early_stop
        earlystopcb = EarlyStopping(**params, min_delta=0.00, verbose=False)
        cbs.append(earlystopcb)
    if "checkpoint" in cfg.cbs:
        store_path = dir + "ckpts/" + str(cfg.version) + "/"
        isExist = os.path.exists(store_path)
        # first remove
        if isExist and os.path.isdir(store_path):
            shutil.rmtree(store_path)
        # then create fresh
        if not isExist:
            os.makedirs(store_path)
        fname = "{epoch}-{val_loss:.2f}"
        params = cfg.checkpoint
        checkptcb = ModelCheckpoint(**params, dirpath=store_path, filename=fname)
        cbs.append(checkptcb)

    if 'wandb' in cfg.loggers:
        wandb.init(project="RecipeRetrieval", config=cfg)
        
    if cfg.mode == 'train':
        recipedata_trn = RecipeDset(data_dir=cfg.data_dir,split='train',\
            txt_emb_type = cfg.txt_emb_type)
        recipedata_vld = RecipeDset(data_dir=cfg.data_dir,split='valid',\
            txt_emb_type = cfg.txt_emb_type)

        train_loader = DataLoader(recipedata_trn,batch_size=cfg.batch_size,shuffle=True,\
            num_workers=4,pin_memory=True)    

        valid_loader = DataLoader(recipedata_vld,batch_size=cfg.batch_size,shuffle=False)
        net = RecipeModel(cfg)

        trainer = pl.Trainer(
            logger=logger_list,callbacks=cbs, gpus=1,deterministic=True, **cfg.trainer
        )
        trainer.fit(net, train_loader,valid_loader)
        return trainer
        #trainer.tune(net,train_loader)
            
    else:
        pass
    

In [10]:
#eps = eps = 0.1
cfg = Namespace(
    seed = 0,
    version = 'temp',
    artifacts_loc = "/common/home/vk405/Projects/Crossmdl/nbs/Recipe/",
    data_dir = "/common/home/vk405/Projects/Crossmdl/Data/Recipe/",
    mode = 'train',
    txt_model = {'input_dim':1024,'lyrs':2,'fin_dim':1024,'act':'relu'},
    img_model = {'input_dim':1024,'lyrs':2,'fin_dim':1024,'act':'relu'},
    shared_emb_dim = 1024,
    txt_emb_type = 'total',
    lr = 1e-4,
    eps = 0.1,
    loggers = ["csv","wandb"],
    cbs = ["checkpoint","early_stop"],
    trainer = {'log_every_n_steps': 50,
    'max_epochs': 10},
    checkpoint = {"every_n_epochs": 1,
    "monitor": "val_loss"},
    early_stop = {"monitor":"val_loss","patience":2,"mode":'min'},
    batch_size=512


)

In [11]:
trained_model = run(cfg)

Global seed set to 0
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvin136[0m (use `wandb login --relogin` to force relogin)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
  rank_zero_warn(
  rank_zero_warn(

  | Name    | Type     | Params
-------------------------------------
0 | txt_emb | EmbModel | 2.1 M 
1 | img_emb | EmbModel | 2.1 M 
2 | shared  | Linear   | 1.0 M 
-------------------------------------
5.2 M     Trainable params
0         Non-trainable params
5.2 M     Total params
20.992    Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


                                                              

  rank_zero_warn(
Global seed set to 0


Epoch 6: 100%|██████████| 669/669 [00:09<00:00, 70.53it/s, loss=0.000164, v_num=temp]


In [12]:
import pandas as pd
met = '/common/home/vk405/Projects/Crossmdl/nbs/Recipe/csvlogs/temp/metrics.csv'
met = pd.read_csv(met)

In [13]:
met['train_loss'].dropna()

0     0.001964
1     0.001240
2     0.001516
3     0.002234
4     0.001287
        ...   
78    0.000196
79    0.000360
80    0.000340
81    0.000098
82    0.000325
Name: train_loss, Length: 77, dtype: float64

In [14]:
# m = RecipeModel(cfg)
# dl = DataLoader(trn_data,2,shuffle=False)
# batch = next(iter(dl))
# out = m((batch[0],batch[1]))
# t_img,t_txt = out[0].detach(),out[1].detach()
# img,txt,anch = t_img,t_txt,batch[0]