In [None]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas as pd

# Load splited data
train_path = "/content/drive/MyDrive/GitHub_Repos/CS610-Product-Image-Text-Consistency-Detection-System-for-E-commerce/amazon_meta_data/split_data/train_data.parquet"
val_path = '/content/drive/MyDrive/GitHub_Repos/CS610-Product-Image-Text-Consistency-Detection-System-for-E-commerce/amazon_meta_data/split_data/val_data.parquet'
test_path = '/content/drive/MyDrive/GitHub_Repos/CS610-Product-Image-Text-Consistency-Detection-System-for-E-commerce/amazon_meta_data/split_data/test_data.parquet'

train_df = pd.read_parquet(train_path)
val_df = pd.read_parquet(val_path)
test_df = pd.read_parquet(test_path)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import pandas as pd
import os
import json
from tqdm import tqdm

def load_sparse_attention_data(input_dir='sparse_attention_features', batch_size=None):
    """
    Load data from files generated by the preprocessing function and recreate the sparse attention data structure.

    Parameters:
    input_dir: Directory where the preprocessed data is saved.
    batch_size: The batch size for data loading. If None, the original saved batch size will be used.

    Returns:
    A dictionary containing DataLoader, Dataset, and other objects.
    """
    print(f"Loading sparse attention data from {input_dir}...")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # === 1. Load dataset information ===
    print("\n1. Loading dataset information")
    with open(os.path.join(input_dir, 'dataset_info.json'), 'r') as f:
        dataset_info = json.load(f)

    max_length = dataset_info['max_length']
    bert_model_name = dataset_info['bert_model']
    patch_size = dataset_info['patch_size']

    if batch_size is None:
        batch_size = 32  # Default batch size

    # === 2. Load valid indices ===
    print("\n2. Loading valid indices")
    train_valid_indices = np.load(os.path.join(input_dir, 'train_valid_indices.npy'))
    val_valid_indices = np.load(os.path.join(input_dir, 'val_valid_indices.npy'))
    test_valid_indices = np.load(os.path.join(input_dir, 'test_valid_indices.npy'))

    # === 3. Load valid DataFrame ===
    print("\n3. Loading valid DataFrame")
    train_df_valid = pd.read_csv(os.path.join(input_dir, 'train_df_valid.csv'))
    val_df_valid = pd.read_csv(os.path.join(input_dir, 'val_df_valid.csv'))
    test_df_valid = pd.read_csv(os.path.join(input_dir, 'test_df_valid.csv'))

    # === 4. Initialize BERT tokenizer ===
    print("\n4. Initializing BERT tokenizer")
    tokenizer = BertTokenizer.from_pretrained(bert_model_name)

    # Add special tokens
    special_tokens = ['[TITLE]', '[CAT]', '[FEAT]', '[DESC]', '[DET]']
    special_tokens_dict = {'additional_special_tokens': special_tokens}
    tokenizer.add_special_tokens(special_tokens_dict)

    # === 5. Define Vision Transformer image transformations ===
    print("\n5. Defining Vision Transformer image transformations")
    image_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # === 6. Create dataset class ===
    print("\n6. Creating dataset class")
    class SparseAttentionDataset(Dataset):
        def __init__(self, df, tokenizer, transform, valid_indices, max_length,
                     text_column='text_for_deep', image_column='processed_image_path'):
            self.df = df
            self.tokenizer = tokenizer
            self.transform = transform
            self.valid_indices = valid_indices
            self.max_length = max_length
            self.text_column = text_column
            self.image_column = image_column

            print(f"Dataset size: {len(self.valid_indices)}")

        def __len__(self):
            return len(self.valid_indices)

        def __getitem__(self, idx):
            orig_idx = self.valid_indices[idx]
            row = self.df.iloc[idx]  # Since df is already valid_df, we use idx directly
            text = row[self.text_column]
            image_path = row[self.image_column]
            label = row['is_match']

            # Process text - Structured tokenization
            field_boundaries = {}

            # Find the positions of the field markers
            field_markers = ['[TITLE]', '[CAT]', '[FEAT]', '[DESC]', '[DET]']
            field_positions = {}

            for marker in field_markers:
                pos = text.find(marker)
                if pos != -1:
                    field_positions[marker] = pos

            # Sort the field markers by position
            sorted_markers = sorted(field_positions.items(), key=lambda x: x[1])

            # Extract text for each field
            field_texts = {}
            for i, (marker, pos) in enumerate(sorted_markers):
                start = pos + len(marker)
                if i < len(sorted_markers) - 1:
                    end = sorted_markers[i+1][1]
                else:
                    end = len(text)
                field_texts[marker] = text[start:end].strip()

            # Tokenize each field separately for field-level attention
            field_encodings = {}
            for marker, field_text in field_texts.items():
                if field_text:  # Only process non-empty fields
                    field_encodings[marker] = self.tokenizer(
                        field_text,
                        max_length=max_length // len(field_texts),  # Distribute tokens evenly
                        padding='max_length',
                        truncation=True,
                        return_tensors='pt'
                    )

            # Merge the tokenized fields, each field has its own token_type_id
            merged_input_ids = []
            merged_attention_mask = []
            merged_token_type_ids = []

            field_token_type = 0
            field_boundaries = {}

            for marker in field_markers:
                if marker in field_encodings:
                    # Record the start position of the field
                    start_pos = len(merged_input_ids)

                    # Get the encoding for the field
                    field_input_ids = field_encodings[marker]['input_ids'].squeeze()
                    field_attention_mask = field_encodings[marker]['attention_mask'].squeeze()

                    # Only add non-padding tokens
                    valid_length = field_attention_mask.sum().item()
                    field_input_ids = field_input_ids[:valid_length]
                    field_attention_mask = field_attention_mask[:valid_length]

                    # Add to the merged list
                    merged_input_ids.extend(field_input_ids.tolist())
                    merged_attention_mask.extend(field_attention_mask.tolist())

                    # Assign token_type_id for this field
                    merged_token_type_ids.extend([field_token_type] * len(field_input_ids))

                    # Record the end position of the field
                    end_pos = len(merged_input_ids) - 1
                    field_boundaries[marker.replace('[', '').replace(']', '')] = (start_pos, end_pos)

                    # The next field gets a different token_type_id
                    field_token_type += 1

            # Ensure the total length doesn't exceed max_length
            if len(merged_input_ids) > max_length:
                merged_input_ids = merged_input_ids[:max_length]
                merged_attention_mask = merged_attention_mask[:max_length]
                merged_token_type_ids = merged_token_type_ids[:max_length]

            # If the total length is less than max_length, pad to max_length
            padding_length = max_length - len(merged_input_ids)
            if padding_length > 0:
                merged_input_ids.extend([self.tokenizer.pad_token_id] * padding_length)
                merged_attention_mask.extend([0] * padding_length)
                merged_token_type_ids.extend([0] * padding_length)

            # Convert to tensors
            input_ids = torch.tensor(merged_input_ids)
            attention_mask = torch.tensor(merged_attention_mask)
            token_type_ids = torch.tensor(merged_token_type_ids)

            # Process image (using patches)
            try:
                image = Image.open(image_path).convert('RGB')
                image_tensor = self.transform(image)
            except Exception as e:
                print(f"Error processing image {image_path}: {e}")
                # Return a zero tensor as a placeholder
                image_tensor = torch.zeros((3, 224, 224))

            return {
                'input_ids': input_ids,
                'attention_mask': attention_mask,
                'token_type_ids': token_type_ids,
                'field_boundaries': field_boundaries,
                'image': image_tensor,
                'label': torch.tensor(label, dtype=torch.long),
                'idx': torch.tensor(idx)
            }

    # === 7. Create datasets ===
    print("\n7. Creating datasets")
    train_dataset = SparseAttentionDataset(
        train_df_valid, tokenizer, image_transforms,
        train_valid_indices, max_length
    )

    val_dataset = SparseAttentionDataset(
        val_df_valid, tokenizer, image_transforms,
        val_valid_indices, max_length
    )

    test_dataset = SparseAttentionDataset(
        test_df_valid, tokenizer, image_transforms,
        test_valid_indices, max_length
    )

    # Custom collate function to handle field_boundaries
    def sparse_attention_collate_fn(batch):
        input_ids = torch.stack([item['input_ids'] for item in batch])
        attention_mask = torch.stack([item['attention_mask'] for item in batch])
        token_type_ids = torch.stack([item['token_type_ids'] for item in batch])
        images = torch.stack([item['image'] for item in batch])
        labels = torch.stack([item['label'] for item in batch])
        indices = torch.stack([item['idx'] for item in batch])

        # field_boundaries is a list of dictionaries, cannot be simply stacked
        field_boundaries = [item['field_boundaries'] for item in batch]

        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids,
            'field_boundaries': field_boundaries,
            'image': images,
            'label': labels,
            'idx': indices
        }

    # === 8. Create data loaders ===
    print("\n8. Creating data loaders")
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        collate_fn=sparse_attention_collate_fn,
        pin_memory=True if torch.cuda.is_available() else False
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        collate_fn=sparse_attention_collate_fn,
        pin_memory=True if torch.cuda.is_available() else False
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        collate_fn=sparse_attention_collate_fn,
        pin_memory=True if torch.cuda.is_available() else False
    )

    print("\nSparse attention data loading complete!")

    return {
        'train_loader': train_loader,
        'val_loader': val_loader,
        'test_loader': test_loader,
        'train_dataset': train_dataset,
        'val_dataset': val_dataset,
        'test_dataset': test_dataset,
        'tokenizer': tokenizer,
        'image_transforms': image_transforms,
        'train_df_valid': train_df_valid,
        'val_df_valid': val_df_valid,
        'test_df_valid': test_df_valid,
        'dataset_info': dataset_info
    }

# Call the preprocessing function
sparse_attention_data = preprocess_for_sparse_attention(
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    text_column='text_for_deep',
    image_column='processed_image_path',
    output_dir='/content/drive/MyDrive/GitHub_Repos/CS610-Product-Image-Text-Consistency-Detection-System-for-E-commerce/amazon_meta_data/sparse_attention_features',
    bert_model_name='bert-base-uncased',
    max_length=256,
    batch_size=16,
    patch_size=16
)

# Get the returned data loaders and datasets
train_loader = sparse_attention_data['train_loader']
val_loader = data_dict['val_loader']
test_loader = data_dict['test_loader']
tokenizer = sparse_attention_data['tokenizer']

# Print the dataset sizes
print(f"Training set size: {len(sparse_attention_data['train_dataset'])} valid samples")
print(f"Validation set size: {len(sparse_attention_data['val_dataset'])} valid samples")
print(f"Test set size: {len(sparse_attention_data['test_dataset'])} valid samples")