In [None]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd

# Load splited data
train_path = "/content/drive/MyDrive/GitHub_Repos/CS610-Product-Image-Text-Consistency-Detection-System-for-E-commerce/amazon_meta_data/split_data/train_data.parquet"
val_path = '/content/drive/MyDrive/GitHub_Repos/CS610-Product-Image-Text-Consistency-Detection-System-for-E-commerce/amazon_meta_data/split_data/val_data.parquet'
test_path = '/content/drive/MyDrive/GitHub_Repos/CS610-Product-Image-Text-Consistency-Detection-System-for-E-commerce/amazon_meta_data/split_data/test_data.parquet'

train_df = pd.read_parquet(train_path)
val_df = pd.read_parquet(val_path)
test_df = pd.read_parquet(test_path)

In [None]:
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import pandas as pd
import os
import json
from tqdm import tqdm
import clip

def preprocess_for_clip(train_df, val_df, test_df,
                       text_column='text_for_deep',
                       image_column='processed_image_path',
                       output_dir='clip_features',
                       clip_model_name="ViT-B/16",
                       batch_size=16):
    """
    Prepare text and image data for CLIP model

    Parameters:
    train_df, val_df, test_df: Pre-split datasets
    text_column: Column name for text data
    image_column: Column name for image paths
    output_dir: Directory to save the features
    clip_model_name: Name of the CLIP model to use (ViT-B/32, ViT-B/16, RN50, RN101, etc.)
    batch_size: Batch size for processing
    """
    os.makedirs(output_dir, exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    print(f"Preparing data for CLIP ({clip_model_name}) model...")

    # === 1. Load CLIP model and preprocessors ===
    print("\n1. Loading CLIP model and preprocessor")
    model, preprocess = clip.load(clip_model_name, device=device)

    # Extract tokenizer for text processing
    tokenizer = lambda text: clip.tokenize(text, truncate=True)

    # === 2. Create dataset class ===
    class CLIPDataset(Dataset):
        def __init__(self, df, preprocess, tokenizer, text_column, image_column):
            self.df = df
            self.preprocess = preprocess
            self.tokenizer = tokenizer
            self.text_column = text_column
            self.image_column = image_column

            # Record valid indices
            self.valid_indices = []
            for idx, row in tqdm(df.iterrows(), total=len(df), desc="Validating image paths"):
                if pd.notna(row[image_column]) and os.path.exists(row[image_column]):
                    self.valid_indices.append(idx)

            print(f"Valid samples: {len(self.valid_indices)}/{len(df)}")

        def __len__(self):
            return len(self.valid_indices)

        def __getitem__(self, idx):
            orig_idx = self.valid_indices[idx]
            row = self.df.iloc[orig_idx]
            text = row[self.text_column]
            image_path = row[self.image_column]
            label = row['is_match']

            # Process text - use CLIP's tokenizer
            # Note: Actual tokenization will happen during batching; here, we return the raw text
            text_for_tokenization = text

            # Process image - use CLIP's preprocessor
            try:
                image = Image.open(image_path).convert('RGB')
                image_tensor = self.preprocess(image)
            except Exception as e:
                print(f"Error processing image {image_path}: {e}")
                # Return a zero tensor as a placeholder
                image_tensor = torch.zeros((3, 224, 224))

            return {
                'text': text_for_tokenization,
                'image': image_tensor,
                'label': torch.tensor(label, dtype=torch.long),
                'idx': torch.tensor(orig_idx)
            }

    # === 3. Create datasets ===
    print("\n2. Creating datasets")
    train_dataset = CLIPDataset(
        train_df, preprocess, tokenizer,
        text_column, image_column
    )

    val_dataset = CLIPDataset(
        val_df, preprocess, tokenizer,
        text_column, image_column
    )

    test_dataset = CLIPDataset(
        test_df, preprocess, tokenizer,
        text_column, image_column
    )

    # Custom collate function to handle CLIP's batching
    def clip_collate_fn(batch):
        texts = [item['text'] for item in batch]
        images = torch.stack([item['image'] for item in batch])
        labels = torch.stack([item['label'] for item in batch])
        indices = torch.stack([item['idx'] for item in batch])

        # Tokenize text in batches
        text_tokens = tokenizer(texts)

        return {
            'text': texts,          # Raw text for debugging and visualization
            'text_tokens': text_tokens,  # Tokenized text
            'image': images,
            'label': labels,
            'idx': indices
        }

    # === 4. Create data loaders ===
    print("\n3. Creating data loaders")
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        collate_fn=clip_collate_fn,
        pin_memory=True if torch.cuda.is_available() else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        collate_fn=clip_collate_fn,
        pin_memory=True if torch.cuda.is_available() else False
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        collate_fn=clip_collate_fn,
        pin_memory=True if torch.cuda.is_available() else False
    )

    # === 5. Save dataset processing information ===
    print("\n4. Saving dataset information")
    dataset_info = {
        'train_size': len(train_dataset),
        'val_size': len(val_dataset),
        'test_size': len(test_dataset),
        'clip_model': clip_model_name,
    }

    with open(os.path.join(output_dir, 'dataset_info.json'), 'w') as f:
        json.dump(dataset_info, f)

    # Save valid indices
    np.save(os.path.join(output_dir, 'train_valid_indices.npy'), np.array(train_dataset.valid_indices))
    np.save(os.path.join(output_dir, 'val_valid_indices.npy'), np.array(val_dataset.valid_indices))
    np.save(os.path.join(output_dir, 'test_valid_indices.npy'), np.array(test_dataset.valid_indices))

    # Save valid DataFrames
    train_df_valid = train_df.iloc[train_dataset.valid_indices].copy()
    val_df_valid = val_df.iloc[val_dataset.valid_indices].copy()
    test_df_valid = test_df.iloc[test_dataset.valid_indices].copy()

    train_df_valid.to_csv(os.path.join(output_dir, 'train_df_valid.csv'), index=False)
    val_df_valid.to_csv(os.path.join(output_dir, 'val_df_valid.csv'), index=False)
    test_df_valid.to_csv(os.path.join(output_dir, 'test_df_valid.csv'), index=False)

    print("\nCLIP model data preparation complete!")

    return {
        'train_loader': train_loader,
        'val_loader': val_loader,
        'test_loader': test_loader,
        'train_dataset': train_dataset,
        'val_dataset': val_dataset,
        'test_dataset': test_dataset,
        'model': model,
        'preprocess': preprocess,
        'tokenizer': tokenizer,
        'train_df_valid': train_df_valid,
        'val_df_valid': val_df_valid,
        'test_df_valid': test_df_valid
    }

In [None]:
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import clip

clip_data = preprocess_for_clip(
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    text_column='text_for_deep',
    image_column='processed_image_path',
    output_dir='/content/drive/MyDrive/GitHub_Repos/CS610-Product-Image-Text-Consistency-Detection-System-for-E-commerce/amazon_meta_data/clip_features',
    clip_model_name="ViT-B/16",  # Optional: "ViT-B/32", "RN50", "RN101"
    batch_size=16
)

# Print the dataset sizes
print(f"Training set size: {len(clip_data['train_dataset'])} valid samples")
print(f"Validation set size: {len(clip_data['val_dataset'])} valid samples")
print(f"Test set size: {len(clip_data['test_dataset'])} valid samples")