In [1]:
#  pip install --upgrade torch torchvision torchaudio

In [2]:
# pip install --quiet --user "torch>=1.6, <1.9" "lightning-bolts" "pytorch-lightning>=1.3" "torchmetrics>=0.3" "torchvision" "torchtext"

In [3]:
from torch.utils.data import Dataset
import os
import pandas as pd
from torchvision.io import read_image
from torchvision import transforms
from torchvision.transforms import ToTensor, Lambda
from PIL import Image
import random
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from torch.nn.functional import cosine_similarity
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from torch.optim.lr_scheduler import OneCycleLR
import pytorch_lightning as pl

In [4]:
mypath = os.path.join('c:', os.sep, 'Users', "ojubh", "Desktop", "images")

In [5]:
seed_everything(7)
NUM_WORKERS = int(os.cpu_count() / 2)

Global seed set to 7


In [6]:
def create_splits(csv_file):
    df = pd.read_csv(csv_file)
    
    all_data = df.values.tolist()
    validation_size = int(0.2*len(all_data))
   
    validation_split = random.sample(all_data, validation_size)
    try:
        for x in validation_split:
            all_data.remove(x)
    except:
        pass
    
    validation_df = pd.DataFrame(validation_split, columns = ["index","article_id", "index_group_name", "detail_desc", "location"]).reset_index()
    train_df = pd.DataFrame(all_data, columns = ["index","article_id", "index_group_name", "detail_desc", "location"]).reset_index()
    
    return train_df, validation_df
    


In [7]:

class HMDataset(Dataset):
    def __init__(self, data, image_dir):
        self.df = data
        self.image_dir = image_dir
        self.preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        ref_data = self.df.iloc[idx]
#         ref_image = read_image(os.path.join(self.image_dir,*ref_data["location"].split("/")))
        ref_image = Image.open(os.path.join(self.image_dir,*ref_data["location"].split("/")))
        ref_image = self.preprocess(ref_image)
        ref_text = ref_data["detail_desc"]
        
        pos_data = pd.DataFrame(self.df[self.df["index_group_name"] == ref_data["index_group_name"]].sample(n=1))
        pos_data = pos_data.iloc[0]
        pos_image = Image.open(os.path.join(self.image_dir,*pos_data["location"].split("/")))
        pos_image = self.preprocess(pos_image)
        pos_text = pos_data["detail_desc"]
        
        neg_data = pd.DataFrame(self.df[self.df["index_group_name"] != ref_data["index_group_name"]].sample(n=1))
        neg_data = neg_data.iloc[0]
        neg_image = Image.open(os.path.join(self.image_dir,*neg_data["location"].split("/")))
        neg_image = self.preprocess(neg_image)
        neg_text = neg_data["detail_desc"]
        
        attributes = {
            "ref_image": ref_image,
            "ref_text": ref_text,
            "pos_image": pos_image,
            "pos_text": pos_text,
            "neg_image": neg_image,
            "neg_text": neg_text
        }
        
        
        
        return attributes

In [8]:
class HMDataModule(pl.LightningDataModule):
    def __init__(self, csv_file="", img_dir="", batch_size= 64):
        super().__init__()
        self.batch_size = batch_size
        self.img_dir = img_dir
        self.train_df, self.val_df = create_splits(csv_file)
        

    def setup(self, stage=None):
        self.train_data = HMDataset(data=self.train_df,image_dir=self.img_dir)
        self.val_data = HMDataset(data=self.val_df,image_dir=self.img_dir)
        

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=NUM_WORKERS)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=NUM_WORKERS)

In [9]:

class VisualEmbd(nn.Module):
    # resnet last 2 layers
    # dense layers
    # similarity
    def __init__(self):
        super().__init__()
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
        self.af = nn.ReLU()
        self.last_layer = nn.Linear(1000, 256)

        count=0
        for child in self.model.children():
            count+=1
            if count < 7:
                for param in child.parameters():
                    param.requires_grad = False
    
    

    def forward(self, x):
        x = self.model(x)
        x = self.last_layer(self.af(x))
        return x

In [10]:
class LitVisualEmbd(LightningModule):
    def __init__(self, lr=0.01, margin=1.0):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisualEmbd()
        self.margin = margin
        self.lr = lr
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        return self.evaluate(batch, batch_idx, stage="train")
    
    def evaluate(self, batch, batch_idx, stage=None):
        ref_batch = batch["ref_image"]
        pos_batch = batch["pos_image"]
        neg_batch = batch["neg_image"]
        
        ref_embd = self(ref_batch)
        pos_embd = self(pos_batch)
        neg_embd = self(neg_batch)
        
        pos_sim = cosine_similarity(ref_embd, pos_embd, dim=1)
        neg_sim = cosine_similarity(ref_embd, neg_embd, dim=1)
        
        loss = torch.relu(neg_sim - pos_sum + self.margin)
        
        if stage:
            self.log(f"{stage}_loss", loss.mean(), prog_bar=True)
#             self.log(f"{stage}_acc", acc, prog_bar=True)

        return loss.mean()

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, batch_idx, "val")
        
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams.lr,
            momentum=0.9,
            weight_decay=5e-4,
        )
#         steps_per_epoch = 40000 // BATCH_SIZE
        return {"optimizer": optimizer}

        


In [11]:
hm_dm = HMDataModule(
    csv_file="balanced_data.csv",
    img_dir=mypath
)

In [None]:
model = LitVisualEmbd()
model.datamodule = hm_dm

trainer = Trainer(
          progress_bar_refresh_rate=10,
          max_epochs=10,
      )
trainer.fit(model, hm_dm)

Using cache found in C:\Users\ojubh/.cache\torch\hub\pytorch_vision_v0.10.0
  rank_zero_deprecation(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | VisualEmbd | 11.9 M
-------------------------------------
11.3 M    Trainable params
683 K     Non-trainable params
11.9 M    Total params
47.783    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]