In [None]:
!curl https://raw.githubusercontent.com/Ojasvi-97/DL_Final_Project_Spring_22/main/balanced_data.csv --output data.csv
# get images from google drive

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 8829k  100 8829k    0     0  17.1M      0 --:--:-- --:--:-- --:--:-- 17.1M


In [None]:
# !pip install --quiet --user "torch>=1.6, <1.9" "lightning-bolts" "pytorch-lightning>=1.3" "torchmetrics>=0.3" "torchvision" "torchtext"

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
!unzip gdrive/My\ Drive/Colab\ Notebooks/balanced_images_crop.zip > /dev/null

In [4]:
import os

for n in os.listdir("balanced_images_crop"):
  print(n)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
0892455003.jpg
0631473001.jpg
0796483006.jpg
0615797002.jpg
0579010042.jpg
0633529001.jpg
0674711006.jpg
0623374008.jpg
0733782002.jpg
0896557002.jpg
0568732008.jpg
0584284001.jpg
0662704001.jpg
0126589011.jpg
0378447033.jpg
0474915002.jpg
0820097001.jpg
0614316002.jpg
0879529002.jpg
0715828029.jpg
0617350002.jpg
0789629001.jpg
0720813002.jpg
0599502017.jpg
0693754002.jpg
0724392001.jpg
0722923002.jpg
0741433001.jpg
0554450019.jpg
0408875030.jpg
0610962001.jpg
0680214001.jpg
0682314002.jpg
0542138001.jpg
0918924001.jpg
0834058001.jpg
0588966001.jpg
0676090002.jpg
0756379003.jpg
0743079001.jpg
0920528002.jpg
0533404017.jpg
0646756001.jpg
0633136001.jpg
0815549002.jpg
0868283004.jpg
0846772002.jpg
0824764012.jpg
0876525001.jpg
0842066003.jpg
0665188002.jpg
0824358001.jpg
0607372005.jpg
0777728001.jpg
0839083002.jpg
0453358003.jpg
0735607003.jpg
0711761001.jpg
0747563005.jpg
0903518001.jpg
0886557004.jpg
0846270001.jpg
07542

KeyboardInterrupt: ignored

In [None]:
import os
import pandas as pd
from PIL import Image
import random

from torchvision.io import read_image
from torchvision import transforms
from torchvision.transforms import ToTensor, Lambda

import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import torch.nn as nn
from torch.nn.functional import cosine_similarity

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from torch.optim.lr_scheduler import OneCycleLR

seed_everything(7)
NUM_WORKERS = int(os.cpu_count() / 2)
AVAIL_GPUS = min(1, torch.cuda.device_count())
BATCH_SIZE = 256 if AVAIL_GPUS else 64

In [None]:
def create_splits(csv_file):
    df = pd.read_csv(csv_file)
    
    all_data = df.values.tolist()
    validation_size = int(0.2*len(all_data))
   
    validation_split = random.sample(all_data, validation_size)
    try:
        for x in validation_split:
            all_data.remove(x)
    except:
        pass
    
    validation_df = pd.DataFrame(validation_split, columns = ["index","article_id", "index_group_name", "detail_desc", "location"]).reset_index()
    train_df = pd.DataFrame(all_data, columns = ["index","article_id", "index_group_name", "detail_desc", "location"]).reset_index()
    
    return train_df, validation_df

In [None]:
class HMDataset(Dataset):
    def __init__(self, data, image_dir):
        self.df = data
        self.image_dir = image_dir
        self.preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        # preprocess function for text


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        
        ref_data = self.df.iloc[idx]
#         ref_image = read_image(os.path.join(self.image_dir,*ref_data["location"].split("/")))
        ref_image = Image.open(os.path.join(self.image_dir,*ref_data["location"].split("/")))
        ref_image = self.preprocess(ref_image)
        ref_text = ref_data["detail_desc"]
        
        pos_data = pd.DataFrame(self.df[self.df["index_group_name"] == ref_data["index_group_name"]].sample(n=1))
        pos_data = pos_data.iloc[0]
        pos_image = Image.open(os.path.join(self.image_dir,*pos_data["location"].split("/")))
        pos_image = self.preprocess(pos_image)
        pos_text = pos_data["detail_desc"]
        
        neg_data = pd.DataFrame(self.df[self.df["index_group_name"] != ref_data["index_group_name"]].sample(n=1))
        neg_data = neg_data.iloc[0]
        neg_image = Image.open(os.path.join(self.image_dir,*neg_data["location"].split("/")))
        neg_image = self.preprocess(neg_image)
        neg_text = neg_data["detail_desc"]
        
        attributes = {
            "ref_image": ref_image,
            "ref_text": ref_text,
            "pos_image": pos_image,
            "pos_text": pos_text,
            "neg_image": neg_image,
            "neg_text": neg_text
        }
        
        
        
        return attributes

In [None]:
class HMDataModule(pl.LightningDataModule):
    def __init__(self, csv_file="", img_dir="", batch_size= 64):
        super().__init__()
        self.batch_size = batch_size
        self.img_dir = img_dir
        self.train_df, self.val_df = create_splits(csv_file)
        

    def setup(self, stage=None):
        self.train_data = HMDataset(data=self.train_df,image_dir=self.img_dir)
        self.val_data = HMDataset(data=self.val_df,image_dir=self.img_dir)
        

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, num_workers=NUM_WORKERS)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, num_workers=NUM_WORKERS)

47211

In [None]:
class VisualEmbd(nn.Module):
    # resnet last 2 layers
    # dense layers
    # similarity
    def __init__(self):
        super().__init__()
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
        self.af = nn.ReLU()
        self.last_layer = nn.Linear(1000, 256)

        count=0
        for child in self.model.children():
            count+=1
            if count < 7:
                for param in child.parameters():
                    param.requires_grad = False
    
    

    def forward(self, x):
        x = self.model(x)
        x = self.last_layer(self.af(x))
        return x

in
Unnamed: 0                                                          0
article_id                                                  145872001
index_group_name                                                Sport
detail_desc         Long-sleeved sports top in fast-drying, breath...
location                                           014/0145872001.jpg
Name: 0, dtype: object


In [None]:
class LitVisualEmbd(LightningModule):
    def __init__(self, lr=0.01, margin=1.0):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisualEmbd()
        self.margin = margin
        self.lr = lr
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def training_step(self, batch, batch_idx):
        return self.evaluate(batch, batch_idx, stage="train")
    
    def evaluate(self, batch, batch_idx, stage=None):
        ref_batch = batch["ref_image"]
        pos_batch = batch["pos_image"]
        neg_batch = batch["neg_image"]
        
        ref_embd = self(ref_batch)
        pos_embd = self(pos_batch)
        neg_embd = self(neg_batch)
        
        pos_sim = cosine_similarity(ref_embd, pos_embd, dim=1)
        neg_sim = cosine_similarity(ref_embd, neg_embd, dim=1)
        
        loss = torch.relu(neg_sim - pos_sum + self.margin)
        
        if stage:
            self.log(f"{stage}_loss", loss.mean(), prog_bar=True)
#             self.log(f"{stage}_acc", acc, prog_bar=True)

        return loss.mean()

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, batch_idx, "val")
        
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams.lr,
            momentum=0.9,
            weight_decay=5e-4,
        )
#         steps_per_epoch = 40000 // BATCH_SIZE
        return {"optimizer": optimizer}

        


In [None]:
hm_dm = HMDataModule(
    csv_file="balanced_data.csv",
    img_dir=mypath
)

In [None]:
model = LitVisualEmbd()
model.datamodule = hm_dm

trainer = Trainer(
          progress_bar_refresh_rate=10,
          gpus=AVAIL_GPUS,
          max_epochs=10,
      )
trainer.fit(model, hm_dm)