In [7]:
# Cell: Install and Import Dependencies
# !pip install scikit-image
import os
import csv
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from skimage import filters
from skimage.color import rgb2gray

# Configuration
CONFIG = {
    'MODEL_ID': "google/siglip-base-patch16-224",
    'OUTPUT_DIR': "./siglip-scin-full-3000", # ./siglip-scin-lora",
    'DATA_DIR': "./data/scin_cache",  # Local cache directory
    'BATCH_SIZE': 16,
    'LEARNING_RATE': 5e-5, # 1e-4 before
    'LORA_RANK': 32,
    'LORA_ALPHA': 32,
    'MAX_STEPS': 3000,
    'LOSS_TYPE': "sigmoid",  # or "contrastive"
    'N_VAL_SAMPLES': 1000,
    'N_TRAIN_SAMPLES': 5000,  # Set to None to use all available
}

In [8]:
# Cell: Define Custom Augmentations

class GaussianNoise:
    """Adds Gaussian noise to a PIL Image."""
    def __init__(self, mean=0., std=0.1):
        self.mean = mean
        self.std = std

    def __call__(self, img):
        # Convert PIL Image to float numpy array (0-1)
        np_img = np.array(img).astype(np.float32) / 255.0
        # Generate noise
        noise = np.random.normal(self.mean, self.std, np_img.shape)
        # Add noise and clip
        noisy_img = np.clip(np_img + noise, 0, 1)
        # Convert back to PIL Image
        return Image.fromarray((noisy_img * 255).astype(np.uint8))

class SobelFilter:
    """Applies a Sobel filter to a PIL Image."""
    def __call__(self, img):
        # Convert to numpy array
        np_img = np.array(img)
        
        # Convert to grayscale if it's RGB
        if len(np_img.shape) == 3:
            np_img = rgb2gray(np_img)
            
        # Apply Sobel filter
        sobel_img = filters.sobel(np_img)
        
        # Normalize to 0-1 range
        sobel_img = (sobel_img - np.min(sobel_img)) / (np.max(sobel_img) - np.min(sobel_img) + 1e-6)
        
        # Convert to 0-255 uint8
        sobel_img = (sobel_img * 255).astype(np.uint8)
        
        # Stack to 3 channels to mimic RGB for the model
        return Image.fromarray(np.stack([sobel_img]*3, axis=-1))

In [None]:
# Cell: Generation Script (Corrected)

# --- 1. Define Configuration ---
OUTPUT_SYNTHETIC_DIR = "./data/synthetic"
IMG_DIR = os.path.join(OUTPUT_SYNTHETIC_DIR, "images")
METADATA_FILE = os.path.join(OUTPUT_SYNTHETIC_DIR, "labels.csv")

# Create directories
os.makedirs(IMG_DIR, exist_ok=True)

# Try to get image size from CONFIG, default to 224
try:
    IMG_SIZE = 224 # Default
    if 'CONFIG' in locals() or 'CONFIG' in globals():
        # Assuming 224 from "google/siglip-base-patch16-224"
        pass 
except NameError:
    IMG_SIZE = 224

# --- 2. Define Augmentation Pipelines ---
# (This section is unchanged, assuming it's in the same cell)
aug_pipelines = {
    "crop_resize": transforms.Compose([
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.5, 1.0)),
    ]),
    "crop_resize_flip": transforms.Compose([
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.5, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
    ]),
    "color_distort": transforms.Compose([
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    ]),
    "sobel": transforms.Compose([
        SobelFilter(),
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
    ]),
    "blur": transforms.Compose([
        transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)),
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    ]),
    "noise": transforms.Compose([
        GaussianNoise(mean=0., std=0.05),
        transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    ])
}

# --- 3. Run Generation Loop ---
print(f"Generating synthetic data in: {OUTPUT_SYNTHETIC_DIR}")

# **** START FIX ****
# Manually load the train_data from the cache file
import pickle
train_cache_path = None  # Define variable before the try block
try:
    # This line *requires* CONFIG to be defined
    if 'CONFIG' not in locals() and 'CONFIG' not in globals():
        raise NameError("name 'CONFIG' is not defined. Please run Cell 1 first.")
        
    train_cache_path = os.path.join(CONFIG['DATA_DIR'], f"train_{CONFIG['N_TRAIN_SAMPLES']}.pkl")
    print(f"Loading source data from: {train_cache_path}")
    
    with open(train_cache_path, "rb") as f:
        source_data = pickle.load(f)
    
    if not source_data:
        raise FileNotFoundError("Source data file is empty")
        
except Exception as e:
    print(f"FATAL ERROR: Could not load source data.")
    if train_cache_path:
        print(f"  Attempted path: {train_cache_path}")
    print("  Please ensure you have run Cell 1 (to define CONFIG) and Cell 2 (to cache the dataset).")
    print(f"  Error details: {e}")
    # Stop execution if we can't load the data
    raise e
# **** END FIX ****

image_counter = 0
metadata_rows = []

for transform_name, transform_pipeline in aug_pipelines.items():
    print(f"Applying transform: {transform_name}")
    
    for item in tqdm(source_data, desc=f"Generating {transform_name}"):
        original_image = item['image']
        label = item['text']
        
        try:
            augmented_image = transform_pipeline(original_image)
            filename = f"aug_{transform_name}_{image_counter:07d}.png"
            save_path = os.path.join(IMG_DIR, filename)
            augmented_image.save(save_path)
            metadata_rows.append([filename, label])
            image_counter += 1
        except Exception as e:
            print(f"Warning: Skipping image due to error: {e}")

# --- 4. Save Metadata CSV ---
print(f"\nSaving metadata to {METADATA_FILE}...")
with open(METADATA_FILE, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["image_file", "label"]) # Write header
    writer.writerows(metadata_rows)

print(f"\nDone! Generated {image_counter} new images.")

Generating synthetic data in: ./data/synthetic
Loading source data from: ./data/scin_cache/train_5000.pkl


In [4]:
# Cell: New Dataset Class for loading synthetic data

from torch.utils.data import Dataset
import pandas as pd

class SyntheticDataset(Dataset):
    """
    A PyTorch Dataset class to load the synthetic data 
    we just saved to disk.
    """
    def __init__(self, metadata_file, img_dir):
        print(f"Loading metadata from: {metadata_file}")
        # Load the CSV file
        self.metadata = pd.read_csv(metadata_file)
        self.img_dir = img_dir
        print(f"Found {len(self.metadata)} images.")

    def __len__(self):
        # Return the total number of samples
        return len(self.metadata)

    def __getitem__(self, idx):
        # Get the filename and label for the given index
        row = self.metadata.iloc[idx]
        filename = row['image_file']
        label = row['label']
        
        # Construct the full image path
        img_path = os.path.join(self.img_dir, filename)
        
        try:
            # Open the image, convert to RGB
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading {img_path}, returning None. Error: {e}")
            return None
        
        # Return in the same format as your original dataset
        return {"image": image, "text": label}

In [None]:
# --- Example of how to combine datasets ---

# 1. Initialize your new dataset
synthetic_dataset = SyntheticDataset(
    metadata_file=METADATA_FILE, 
    img_dir=IMG_DIR
)

# 2. Get your original dataset (assuming 'train_dataset' exists)
original_dataset = Path(CONFIG['DATA_DIR']) / f"train_{n_train}.pkl"

# 3. Combine them
# from torch.utils.data import ConcatDataset
combined_train_dataset = ConcatDataset([original_dataset, synthetic_dataset])

# 4. Use this combined dataset in your Trainer
# print(f"Total training samples: {len(combined_train_dataset)}")
# trainer = CustomTrainer(
#     ...
#     train_dataset=combined_train_dataset, # <-- Use the combined data
#     ...
# )