In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
from tqdm import tqdm
from PIL import Image
import pickle
import logging
import os

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Paths
df_path = 'C:/Users/Aniruddha shinde/DL Project/Local/preprocessed_fashion_data.pkl'
user_item_path = 'C:/Users/Aniruddha shinde/DL Project/Local/user_item_interactions.pkl'
new_image_dir = 'C:/Users/Aniruddha shinde/DL Project/Local/images'
save_dir = 'C:/Users/Aniruddha shinde/DL Project/Local/clip_fashion_model'

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load data
df = pd.read_pickle(df_path)
with open(user_item_path, 'rb') as f:
    user_item_data = pickle.load(f)

# Convert user_item_data to DataFrame and aggregate duplicates
rows = []
for user_id, interactions in user_item_data.items():
    for item_id, description, image_paths in interactions:
        rows.append({'user_id': user_id, 'item_id': item_id, 'description': description, 'image_paths': image_paths})
user_item_df = pd.DataFrame(rows)
user_item_df = user_item_df.groupby(['user_id', 'item_id']).agg({
    'description': 'first',
    'image_paths': lambda x: list(set(sum(x, [])))
}).reset_index()

# Map user IDs to indices
user_mapping = {user_id: idx for idx, user_id in enumerate(user_item_df['user_id'].unique())}
user_item_df['user_idx'] = user_item_df['user_id'].map(user_mapping)

# Map item IDs to df indices using 'asin'
item_mapping = dict(zip(df['asin'], df.index))
user_item_df['item_idx'] = user_item_df['item_id'].map(item_mapping)

# Pre-check for missing images
def precheck_images(df, image_dir):
    missing_images = 0
    for idx, row in df.iterrows():
        image_path = row['image_paths'][0] if row['image_paths'] else None
        if not image_path or not isinstance(image_path, str) or not os.path.exists(image_path):
            missing_images += 1
            logger.warning(f"Missing or invalid image path at index {idx}: {image_path}")
    logger.info(f"Pre-check: Total missing or invalid images: {missing_images}")

logger.info("Starting pre-check for missing images...")
precheck_images(df, new_image_dir)

# Dataset for ITC fine-tuning
class FashionDataset(Dataset):
    def __init__(self, dataframe, image_dir, processor):
        self.dataframe = dataframe
        self.image_dir = image_dir
        self.processor = processor
        self.missing_images = 0
        self.missing_texts = 0

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image_path = row['image_paths'][0] if row['image_paths'] else None
        if not image_path or not isinstance(image_path, str) or not os.path.exists(image_path):
            image_path = f"{self.image_dir}/default.jpg"
            self.missing_images += 1
            if self.missing_images <= 10:
                logger.warning(f"Missing or invalid image path at index {idx}, using default: {image_path}")
        try:
            image = Image.open(image_path).convert('RGB')
        except Exception as e:
            logger.error(f"Error loading image {image_path} at index {idx}: {e}")
            self.missing_images += 1
            image = Image.new('RGB', (224, 224), color='gray')
        
        text = row.get('concatenated_text', row.get('title', 'No description available'))
        if not isinstance(text, str) or len(text.strip()) == 0:
            text = row.get('description', 'No description available')
            self.missing_texts += 1
            if self.missing_texts <= 10:
                logger.warning(f"Missing or empty text at index {idx}, using description: {text}")
        
        inputs = self.processor(
            text=[text],
            images=image,
            return_tensors="pt",
            padding="max_length",
            max_length=77,
            truncation=True
        )
        
        return {
            'pixel_values': inputs['pixel_values'].squeeze(0),
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0)
        }

    def log_data_issues(self):
        logger.info(f"Total missing or invalid images: {self.missing_images}")
        logger.info(f"Total missing or empty texts: {self.missing_texts}")

# Dataset for VBPR training
class VBPRDataset(Dataset):
    def __init__(self, user_item_df, item_df, processor):
        self.user_item_df = user_item_df
        self.item_df = item_df
        self.processor = processor
        self.items = list(range(len(item_df)))

    def __len__(self):
        return len(self.user_item_df)

    def __getitem__(self, idx):
        row = self.user_item_df.iloc[idx]
        user_idx = row['user_idx']
        pos_item_idx = row['item_idx']
        neg_item_idx = np.random.choice([i for i in self.items if i != pos_item_idx])
        
        pos_item = self.item_df.iloc[pos_item_idx]
        neg_item = self.item_df.iloc[neg_item_idx]
        
        pos_image_path = pos_item['image_paths'][0] if pos_item['image_paths'] else f"{new_image_dir}/default.jpg"
        neg_image_path = neg_item['image_paths'][0] if neg_item['image_paths'] else f"{new_image_dir}/default.jpg"
        
        try:
            pos_image = Image.open(pos_image_path).convert('RGB')
        except Exception as e:
            logger.error(f"Error loading pos image {pos_image_path} at index {idx}: {e}")
            pos_image = Image.new('RGB', (224, 224), color='gray')
        
        try:
            neg_image = Image.open(neg_image_path).convert('RGB')
        except Exception as e:
            logger.error(f"Error loading neg image {neg_image_path} at index {idx}: {e}")
            neg_image = Image.new('RGB', (224, 224), color='gray')
        
        pos_inputs = self.processor(images=pos_image, return_tensors="pt", padding="max_length", max_length=77, truncation=True)['pixel_values'].squeeze(0)
        neg_inputs = self.processor(images=neg_image, return_tensors="pt", padding="max_length", max_length=77, truncation=True)['pixel_values'].squeeze(0)
        
        return {
            'user_idx': torch.tensor(user_idx, dtype=torch.long),
            'pos_item_idx': torch.tensor(pos_item_idx, dtype=torch.long),
            'neg_item_idx': torch.tensor(neg_item_idx, dtype=torch.long),
            'pos_pixel_values': pos_inputs,
            'neg_pixel_values': neg_inputs
        }

# ITC Loss with temperature
class ITCLoss(nn.Module):
    def __init__(self, temperature=0.05):
        super(ITCLoss, self).__init__()
        self.temperature = temperature

    def forward(self, image_features, text_features, labels):
        image_features = F.normalize(image_features, dim=-1)
        text_features = F.normalize(text_features, dim=-1)
        logits_per_image = (image_features @ text_features.T) / self.temperature
        logits_per_text = (text_features @ image_features.T) / self.temperature
        labels = torch.arange(logits_per_image.size(0), device=logits_per_image.device)
        loss_i = F.cross_entropy(logits_per_image, labels)
        loss_t = F.cross_entropy(logits_per_text, labels)
        return (loss_i + loss_t) / 2

# VBPR Model
class VBPR(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim=100, visual_dim=512):
        super(VBPR, self).__init__()
        self.user_embeddings = nn.Embedding(num_users, embedding_dim)
        self.item_embeddings = nn.Embedding(num_items, embedding_dim)
        self.visual_embeddings = nn.Linear(visual_dim, embedding_dim)
        self.bias = nn.Parameter(torch.zeros(num_items))
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.user_embeddings.weight)
        nn.init.xavier_uniform_(self.item_embeddings.weight)
        nn.init.xavier_uniform_(self.visual_embeddings.weight)
        nn.init.zeros_(self.visual_embeddings.bias)

    def forward(self, user_idx, item_idx, image_features):
        user_emb = self.user_embeddings(user_idx)
        item_emb = self.item_embeddings(item_idx)
        batch_size = user_idx.size(0)
        image_features = image_features.unsqueeze(0).expand(batch_size, -1, -1)
        visual_emb = self.visual_embeddings(image_features)
        scores = (user_emb.unsqueeze(1) * (item_emb + visual_emb)).sum(-1) + self.bias[item_idx]
        return scores

# Load CLIP model and processor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
model.train()

# Set differential learning rates and StepLR scheduler
optimizer_grouped_parameters = [
    {'params': model.text_model.parameters(), 'lr': 5e-4},
    {'params': model.vision_model.parameters(), 'lr': 1e-4}
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Prepare datasets
dataset = FashionDataset(df, new_image_dir, processor)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=0)

vbpr_dataset = VBPRDataset(user_item_df, df, processor)
vbpr_loader = DataLoader(vbpr_dataset, batch_size=32, shuffle=True, num_workers=0)

# Fine-tune CLIP with ITC loss
itc_loss = ITCLoss(temperature=0.05)
num_epochs = 15  # Reduced to 15 epochs
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch in tqdm(dataloader, desc=f"ITC Fine-tuning Epoch {epoch+1}/{num_epochs}"):
        pixel_values = batch['pixel_values'].to(device)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask)
        image_features = outputs.image_embeds
        text_features = outputs.text_embeds

        labels = torch.arange(image_features.size(0), device=device)
        loss = itc_loss(image_features, text_features, labels)
        total_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    scheduler.step()
    print(f"ITC Fine-tuning Epoch {epoch+1}/{num_epochs} Completed, Average Loss: {avg_loss:.4f}")

# Log data issues
dataset.log_data_issues()

# Save fine-tuned CLIP model and processor
model.save_pretrained(save_dir)
processor.save_pretrained(save_dir)

# Generate and normalize embeddings for all items
all_image_features = []
model.eval()
with torch.no_grad():
    for batch in tqdm(dataloader, desc="Generating embeddings"):
        pixel_values = batch['pixel_values'].to(device)
        image_features = model.get_image_features(pixel_values=pixel_values)
        all_image_features.append(image_features.cpu())
all_image_features = torch.cat(all_image_features, dim=0)
all_image_features = F.normalize(all_image_features, dim=-1)  # Normalize features
torch.save(all_image_features, f"{save_dir}/all_image_features.pt")

# VBPR training
num_users = len(user_mapping)
num_items = len(df)
vbpr_model = VBPR(num_users, num_items).to(device)
criterion = nn.MarginRankingLoss(margin=1.0)
optimizer = torch.optim.Adam(vbpr_model.parameters(), lr=0.001)

num_epochs = 10  # Reduced to 10 epochs
vbpr_model.train()
for epoch in range(num_epochs):
    total_loss = 0.0
    for batch in tqdm(vbpr_loader, desc=f"VBPR Training Epoch {epoch+1}/{num_epochs}"):
        user_idx = batch['user_idx'].to(device)
        pos_item_idx = batch['pos_item_idx'].to(device)
        neg_item_idx = batch['neg_item_idx'].to(device)
        pos_pixel_values = batch['pos_pixel_values'].to(device)
        neg_pixel_values = batch['neg_pixel_values'].to(device)

        pos_image_features = model.get_image_features(pixel_values=pos_pixel_values)
        neg_image_features = model.get_image_features(pixel_values=neg_pixel_values)

        pos_scores = vbpr_model(user_idx, pos_item_idx, pos_image_features)
        neg_scores = vbpr_model(user_idx, neg_item_idx, neg_image_features)

        loss = criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
        total_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(vbpr_loader)
    print(f"VBPR Training Epoch {epoch+1}/{num_epochs} Completed, Average Loss: {avg_loss:.4f}")

# Save VBPR model
torch.save(vbpr_model.state_dict(), f"{save_dir}/vbpr_model.pt")

INFO:__main__:Starting pre-check for missing images...
INFO:__main__:Pre-check: Total missing or invalid images: 0
ITC Fine-tuning Epoch 1/15: 100%|██████████| 364/364 [02:06<00:00,  2.89it/s]


ITC Fine-tuning Epoch 1/15 Completed, Average Loss: 2.1127


ITC Fine-tuning Epoch 2/15: 100%|██████████| 364/364 [02:07<00:00,  2.86it/s]


ITC Fine-tuning Epoch 2/15 Completed, Average Loss: 1.5733


ITC Fine-tuning Epoch 3/15: 100%|██████████| 364/364 [02:39<00:00,  2.28it/s]


ITC Fine-tuning Epoch 3/15 Completed, Average Loss: 1.1874


ITC Fine-tuning Epoch 4/15: 100%|██████████| 364/364 [02:38<00:00,  2.30it/s]


ITC Fine-tuning Epoch 4/15 Completed, Average Loss: 1.0915


ITC Fine-tuning Epoch 5/15: 100%|██████████| 364/364 [02:35<00:00,  2.34it/s]


ITC Fine-tuning Epoch 5/15 Completed, Average Loss: 1.1076


ITC Fine-tuning Epoch 6/15: 100%|██████████| 364/364 [02:33<00:00,  2.37it/s]


ITC Fine-tuning Epoch 6/15 Completed, Average Loss: 0.6040


ITC Fine-tuning Epoch 7/15: 100%|██████████| 364/364 [01:53<00:00,  3.21it/s]


ITC Fine-tuning Epoch 7/15 Completed, Average Loss: 0.4246


ITC Fine-tuning Epoch 8/15: 100%|██████████| 364/364 [02:24<00:00,  2.53it/s]


ITC Fine-tuning Epoch 8/15 Completed, Average Loss: 0.3707


ITC Fine-tuning Epoch 9/15: 100%|██████████| 364/364 [02:21<00:00,  2.58it/s]


ITC Fine-tuning Epoch 9/15 Completed, Average Loss: 0.3317


ITC Fine-tuning Epoch 10/15: 100%|██████████| 364/364 [02:27<00:00,  2.46it/s]


ITC Fine-tuning Epoch 10/15 Completed, Average Loss: 0.3183


ITC Fine-tuning Epoch 11/15: 100%|██████████| 364/364 [02:38<00:00,  2.30it/s]


ITC Fine-tuning Epoch 11/15 Completed, Average Loss: 0.2899


ITC Fine-tuning Epoch 12/15: 100%|██████████| 364/364 [02:34<00:00,  2.35it/s]


ITC Fine-tuning Epoch 12/15 Completed, Average Loss: 0.2818


ITC Fine-tuning Epoch 13/15: 100%|██████████| 364/364 [02:23<00:00,  2.54it/s]


ITC Fine-tuning Epoch 13/15 Completed, Average Loss: 0.2664


ITC Fine-tuning Epoch 14/15: 100%|██████████| 364/364 [02:12<00:00,  2.74it/s]


ITC Fine-tuning Epoch 14/15 Completed, Average Loss: 0.2606


ITC Fine-tuning Epoch 15/15: 100%|██████████| 364/364 [01:50<00:00,  3.29it/s]
INFO:__main__:Total missing or invalid images: 0
INFO:__main__:Total missing or empty texts: 0


ITC Fine-tuning Epoch 15/15 Completed, Average Loss: 0.2728


Generating embeddings: 100%|██████████| 364/364 [01:38<00:00,  3.71it/s]
VBPR Training Epoch 1/10: 100%|██████████| 76/76 [00:43<00:00,  1.77it/s]


VBPR Training Epoch 1/10 Completed, Average Loss: 0.9740


VBPR Training Epoch 2/10: 100%|██████████| 76/76 [00:44<00:00,  1.70it/s]


VBPR Training Epoch 2/10 Completed, Average Loss: 0.8838


VBPR Training Epoch 3/10: 100%|██████████| 76/76 [00:44<00:00,  1.69it/s]


VBPR Training Epoch 3/10 Completed, Average Loss: 0.7903


VBPR Training Epoch 4/10: 100%|██████████| 76/76 [00:48<00:00,  1.57it/s]


VBPR Training Epoch 4/10 Completed, Average Loss: 0.6713


VBPR Training Epoch 5/10: 100%|██████████| 76/76 [00:44<00:00,  1.69it/s]


VBPR Training Epoch 5/10 Completed, Average Loss: 0.5535


VBPR Training Epoch 6/10: 100%|██████████| 76/76 [00:49<00:00,  1.55it/s]


VBPR Training Epoch 6/10 Completed, Average Loss: 0.4105


VBPR Training Epoch 7/10: 100%|██████████| 76/76 [00:47<00:00,  1.61it/s]


VBPR Training Epoch 7/10 Completed, Average Loss: 0.2626


VBPR Training Epoch 8/10: 100%|██████████| 76/76 [00:49<00:00,  1.53it/s]


VBPR Training Epoch 8/10 Completed, Average Loss: 0.1521


VBPR Training Epoch 9/10: 100%|██████████| 76/76 [00:43<00:00,  1.74it/s]


VBPR Training Epoch 9/10 Completed, Average Loss: 0.0947


VBPR Training Epoch 10/10: 100%|██████████| 76/76 [00:41<00:00,  1.82it/s]


VBPR Training Epoch 10/10 Completed, Average Loss: 0.0861
