In [36]:
#Imports
import os
import pickle
import random
import numpy as np
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from argparse import Namespace
import torch.nn as nn
import torch
from torch import linalg as LA

#global vars
DATA_DIR = '/common/home/vk405/Projects/Crossmdl/Data/Recipe/'
EMB_TRN = DATA_DIR+'embeddings_train1.pkl'
EMB_VAL = DATA_DIR+'embeddings_val1.pkl'
ING_TRN = DATA_DIR+'ingredients_embeddings_train.pkl'
ING_VAL = DATA_DIR+'ingredients_embeddings_val.pkl'
#os.listdir(DATA_DIR)

In [5]:

# with open(EMB_TRN, 'rb') as files:
#     emb_trn = pickle.load(files)
# with open(EMB_VAL, 'rb') as files:
#     emb_val = pickle.load(files)

# with open(ING_TRN, 'rb') as files:
#     ing_trn = pickle.load(files)
# with open(ING_VAL, 'rb') as files:
#     ing_val = pickle.load(files)



In [2]:
class RecipeDset(Dataset):
    def __init__(self,data_dir= '/common/home/vk405/Projects/Crossmdl/Data/Recipe/'\
        ,split='train',txt_emb_type='total'):
        self.DATA_DIR = data_dir
        self.init_data_locs()
        self.txt_emb_type = txt_emb_type
        self.split = split
        if self.txt_emb_type == 'total':
            if self.split == 'train':
                with open(self.EMB_TRN, 'rb') as files:
                    self.emb_vid,self.emb_txt,self.ids = pickle.load(files)
            elif self.split == 'valid':
                with open(self.EMB_VAL, 'rb') as files:
                    self.emb_vid,self.emb_txt,self.ids = pickle.load(files)
            elif self.split == 'test':
                with open(self.EMB_TST, 'rb') as files:
                    self.emb_vid,self.emb_txt,self.ids = pickle.load(files)


    def __getitem__(self,idx):
        rind = np.random.choice(len(self.emb_txt))
        while rind != idx:
            rind = np.random.choice(len(self.emb_txt))
        return self.emb_vid[idx],self.emb_txt[idx],self.emb_vid[rind]

    def __len__(self):
        return len(self.emb_txt)

    def init_data_locs(self):
        #TOTAL EMBEDDINGS
        self.EMB_TRN = self.DATA_DIR+'embeddings_train1.pkl'
        self.EMB_VAL = self.DATA_DIR+'embeddings_val1.pkl'
        self.EMB_TST = self.DATA_DIR+'embeddings_test1.pkl'
        #INGRIDIENTS Embeddings
        self.ING_TRN = self.DATA_DIR+'ingredients_embeddings_train.pkl'
        self.ING_VAL = self.DATA_DIR+'ingredients_embeddings_val.pkl'
        self.ING_TST = self.DATA_DIR + 'ingredients_embeddings_test.pkl'

        #TITLE EMBEDDINGS
        self.TIT_TRN = self.DATA_DIR+'title_embeddings_train.pkl'
        self.TIT_VAL = self.DATA_DIR+'title_embeddings_val.pkl'
        self.TIT_TST = self.DATA_DIR + 'title_embeddings_test.pkl'

        #Instructions
        self.INS_TRN = self.DATA_DIR+'instructions_embeddings_train.pkl'
        self.INS_VAL = self.DATA_DIR+'instructions_embeddings_val.pkl'
        self.INS_TST = self.DATA_DIR+'instructions_embeddings_test.pkl'


        

In [3]:
trn_data = RecipeDset()

In [50]:
class EmbModel(nn.Module):
    def __init__(self,params):
        super().__init__()
        self.params = params
        self.lyrs = []
        dim = self.params['input_dim']
        for i in range(self.params['lyrs']):
            lyr = nn.Linear(dim,dim)
            if self.params['act'] == 'relu':
                non_lin = nn.ReLU()
            else:
                raise NotImplementedError
            self.lyrs.append(lyr)
            self.lyrs.append(non_lin)
        self.feedforward = nn.Sequential(*self.lyrs)
    def forward(self,x):
        return self.feedforward(x)





class RecipeModel(pl.LightningModule):
    def __init__(self,hparams):
        super().__init__()
        self.save_hyperparameters(hparams)
        self.txt_emb = EmbModel(hparams.txt_model)
        self.img_emb = EmbModel(hparams.img_model)
        self.shared = nn.Linear(hparams.txt_model['fin_dim'],\
            hparams.shared_emb_dim)
    def forward(self,x):
        # Ignores anchor embedding
        img,txt = x
        img_emb = self.img_emb(img)
        txt_emb = self.txt_emb(txt)
        #anch_img_emb = self.img_emb(anch_img)

        img_fin_emb = self.shared(img_emb)
        txt_fin_emb = self.shared(txt_emb)
        #anch_img_fin_emb = self.shared(anch_img_emb)
        return img_fin_emb,txt_fin_emb

    def training_step(self,batch,batch_idx):
        img,txt,anch_img = batch
        anch_img_emb = self.img_emb(anch_img)
        anch_img_fin_emb = self.shared(anch_img_emb)
        img_fin_emb,txt_fin_emb = self((img,txt))
        loss = self.get_loss(img_fin_emb,txt_fin_emb,anch_img_fin_emb)
        self.log("train_loss",loss,on_step=True)
        return loss

    def validation_step(self,batch,batch_idx):
        img,txt,anch_img = batch
        anch_img_emb = self.img_emb(anch_img)
        anch_img_fin_emb = self.shared(anch_img_emb)
        img_fin_emb,txt_fin_emb = self((img,txt))
        loss = self.get_loss(img_fin_emb,txt_fin_emb,anch_img_fin_emb)
        self.log("val_loss",loss,on_step=False, on_epoch=True)
        return loss

    def get_loss(self,img,txt,anch,reduce='mean'):
        #𝐿𝑐𝑜𝑠(𝒂,𝒑,𝒏)=max[𝑑(𝒂,𝒏)−𝑑(𝒂,𝒑)+𝜖,0]
        eps = self.hparams.eps if 'eps' in self.hparams else 1e-8
        
        im_norm,txt_norm,anch_norm = LA.norm(img,dim=-1).reshape(img.shape[0],1),\
        LA.norm(txt,dim=-1).reshape(txt.shape[0],1),LA.norm(anch,dim=-1).reshape(anch.shape[0],1)
        normd_img = img/im_norm
        normd_txt = txt/txt_norm
        normd_anch = anch/anch_norm

        cos_sim_p = torch.sum(normd_img*normd_txt,dim=-1)
        cos_sim_n = torch.sum(normd_anch*normd_txt,dim=-1)

        unclipped_loss = cos_sim_n-cos_sim_p+eps
        clipped_loss = torch.relu(unclipped_loss)
        if reduce == 'mean':
            return torch.mean(clipped_loss)

    def configure_optimizers(self):
        lr = self.hparams.lr if 'lr' in self.hparams else 1e-3
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        return optimizer




        

In [52]:

cfg = Namespace(
    version = 'trail',
    artifacts_loc = "/common/home/vk405/Projects/Crossmdl/nbs/Recipe",
    data_dir = "/common/home/vk405/Projects/Crossmdl/Data/Recipe/",
    mode = 'train',
    txt_model = {'input_dim':1024,'lyrs':2,'fin_dim':1024,'act':'relu'},
    img_model = {'input_dim':1024,'lyrs':2,'fin_dim':1024,'act':'relu'},
    shared_emb_dim = 1024,
    txt_emb_type = 'total',
    learning_rate = 1e-4,
    loggers = ["csv","wandb"],
    seed = 0,
    cbs = ["checkpoint"],
    trainer = {'log_every_n_steps': 1,
    'max_epochs': 10},
    checkpoint = {"every_n_epochs": 1,
    "monitor": "train_loss"},


    use_precomp_emb = True,
    edim = 100,
    attnhdim = 50,
    nheads = 10,
    wrdim = 768,
    vidim = 512,
    hdim = 30,
    dropoutp=0.0,
    seqlen=26,
    framecnt=499,
    batch_size=512


)

In [6]:
# m = RecipeModel(cfg)
# dl = DataLoader(trn_data,2,shuffle=False)
# batch = next(iter(dl))
# out = m((batch[0],batch[1]))
# t_img,t_txt = out[0].detach(),out[1].detach()
# img,txt,anch = t_img,t_txt,batch[0]

In [7]:
class EmbModel(nn.Module):
    def __init__(self,params):
        super().__init__()
        self.params = params
        self.lyrs = []
        dim = self.params['input_dim']
        for i in range(self.params['lyrs']):
            lyr = nn.Linear(dim,dim)
            if self.params['act'] == 'relu':
                non_lin = nn.ReLU()
            else:
                raise NotImplementedError
            self.lyrs.append(lyr)
            self.lyrs.append(non_lin)
        self.feedforward = nn.Sequential(*self.lyrs)
    def forward(self,x):
        return self.feedforward(x)


In [8]:
model = EmbModel({'input_dim':1024,'lyrs':2,'fin_dim':1024,'act':'relu'})

In [15]:
t = torch.randn(1,1024)
model(t).shape

torch.Size([1, 1024])

In [9]:

cfg = Namespace(
    version = 'trail',
    artifacts_loc = "/common/home/vk405/Projects/Crossmdl/nbs/Recipe",
    data_dir = "/common/home/vk405/Projects/Crossmdl/Data/Recipe/",
    mode = 'train',
    txt_model = {'input_dim':1024,'lyrs':2,'fin_dim':1024,'act':'relu'},
    img_model = {'input_dim':1024,'lyrs':2,'fin_dim':1024,'act':'relu'},
    shared_emb_dim = 1024,
    txt_emb_type = 'total',
    learning_rate = 1e-4,
    loggers = ["csv","wandb"],
    seed = 0,
    cbs = ["checkpoint"],
    trainer = {'log_every_n_steps': 1,
    'max_epochs': 10},
    checkpoint = {"every_n_epochs": 1,
    "monitor": "train_loss"},


    use_precomp_emb = True,
    edim = 100,
    attnhdim = 50,
    nheads = 10,
    wrdim = 768,
    vidim = 512,
    hdim = 30,
    dropoutp=0.0,
    seqlen=26,
    framecnt=499,
    batch_size=512


)