In [None]:
# --- 0. KAGGLE ENVIRONMENT SETUP ---
import os
import sys

print("=== KAGGLE ENVIRONMENT SETUP ===")
print(f"Python version: {sys.version.split()[0]}")
print(f"Working directory: {os.getcwd()}")

# Validate Kaggle environment
assert '/kaggle/' in os.getcwd(), "This notebook is designed specifically for Kaggle environment!"
print("‚úì Confirmed running in Kaggle environment")

# Check available datasets
input_dir = "/kaggle/input"
if os.path.exists(input_dir):
    datasets = os.listdir(input_dir)
    print(f"Available datasets: {datasets}")
    
    # Look for sarcasm detection dataset
    sarcasm_datasets = [d for d in datasets if 'sarcasm' in d.lower() or 'multimodal' in d.lower()]
    if sarcasm_datasets:
        print(f"‚úì Found sarcasm datasets: {sarcasm_datasets}")
    else:
        print("‚ö† Warning: No sarcasm detection dataset found. Please add the dataset to this notebook.")
else:
    print("‚ö† Warning: No input datasets found")

# Setup working directory
os.makedirs("/kaggle/working/processed_data", exist_ok=True)
os.makedirs("/kaggle/working/models", exist_ok=True)

print("‚úì Kaggle environment ready")
print("=" * 50)

In [None]:
# --- 1. INSTALASI DAN IMPORT LIBRARY ---
print("Menginstal library yang diperlukan...")

# Install dengan suppressed output untuk menghindari spam
import subprocess
import sys

def install_package(package):
    """Install package with error handling"""
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])
        return True
    except subprocess.CalledProcessError as e:
        print(f"Warning: Failed to install {package}: {e}")
        return False

# List packages yang diperlukan
required_packages = [
    "transformers", 
    "ftfy", 
    "regex", 
    "accelerate", 
    "openai"
]

# Install packages
for package in required_packages:
    print(f"Installing {package}...", end=" ")
    if install_package(package):
        print("‚úì")
    else:
        print("‚úó")

print("\nImporting libraries...")

# Import libraries with error handling
try:
    import os
    import json
    import torch
    import pandas as pd
    import re
    import time
    from torch.utils.data import Dataset, DataLoader
    from PIL import Image
    from transformers import CLIPProcessor, CLIPModel
    from torch.optim import AdamW
    from sklearn.model_selection import train_test_split
    from tqdm.notebook import tqdm
    import torch.nn as nn
    import torch.nn.functional as F
    from sklearn.metrics import accuracy_score, f1_score
    from openai import OpenAI
    
    print("‚úì All libraries imported successfully")
    
except ImportError as e:
    print(f"Error importing libraries: {e}")
    print("Please make sure all required packages are installed")
    raise

# Set random seeds untuk reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

In [None]:
# --- 2. KONFIGURASI UMUM DAN GPU ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
# --- 5. DATA LOADING FROM SEPARATE PREPROCESSED CSV FILES ---
print("Loading separate preprocessed CSV files from team preprocessing...")

# Configuration for different shot variations
SHOT_VARIATION = "16shot"  # Options: "16shot", "64shot", "128shot", "512shot", "1024shot", "all"

print(f"üéØ Selected variation: {SHOT_VARIATION}")

# Shot configuration mapping
shot_configs = {
    "16shot": 16,
    "128shot": 128,
    "512shot": 512,
    "1024shot": 1024,
    "all": None  # Use all available data
}

samples_per_class = shot_configs.get(SHOT_VARIATION)
if samples_per_class:
    print(f"üìä Using {samples_per_class} samples per class (total: {samples_per_class * 2})")
else:
    print(f"üìä Using all available training data")

# Expected preprocessed files from team (SEPARATE FILES)
preprocessed_files = {
    'train': [
        '/kaggle/input/preprocessed-sarcasm-text/train_processed_complete.csv',  # Person 1 output
    ],
    'validation': [
        '/kaggle/input/preprocessed-sarcasm-text/validation_processed_complete.csv',  # Person 2 output
    ],
    'test': [
        '/kaggle/input/preprocessed-sarcasm-text/test_processed_complete.csv',  # Person 3 output
    ]
}

def load_preprocessed_dataset(dataset_name, file_paths):
    """Load a single preprocessed dataset with fallback paths"""
    for file_path in file_paths:
        if os.path.exists(file_path):
            print(f"‚úì Loading {dataset_name} data from: {file_path}")
            try:
                df = pd.read_csv(file_path)
                
                # Validate required columns
                required_cols = ['id', 'text', 'processed_text', 'sarcasm', 'image_path']
                missing_cols = [col for col in required_cols if col not in df.columns]
                if missing_cols:
                    print(f"‚ö†Ô∏è  {dataset_name} missing columns: {missing_cols}, trying next path...")
                    continue
                
                print(f"   üìã Loaded {len(df)} samples")
                print(f"   üìä Sarcastic: {len(df[df['sarcasm']==1])}, Non-sarcastic: {len(df[df['sarcasm']==0])}")
                
                # Check processed text quality
                missing_processed = df['processed_text'].isna().sum()
                if missing_processed > 0:
                    print(f"   ‚ö†Ô∏è  Missing processed_text: {missing_processed}")
                else:
                    print(f"   ‚úÖ All samples have processed text")
                
                return df
                
            except Exception as e:
                print(f"   ‚ùå Error loading {file_path}: {e}")
                continue
    
    # If no file found
    raise FileNotFoundError(f"‚ùå {dataset_name} data not found in any of the expected paths: {file_paths}")

# Load all datasets separately
print("\n=== LOADING SEPARATE PREPROCESSED DATASETS ===")

train_df = load_preprocessed_dataset("Training", preprocessed_files['train'])
val_df = load_preprocessed_dataset("Validation", preprocessed_files['validation'])  
test_df = load_preprocessed_dataset("Test", preprocessed_files['test'])

print(f"\n‚úÖ All datasets loaded successfully!")

# Apply shot variation to training data
print(f"\nüîß Applying {SHOT_VARIATION} configuration to training data...")

if samples_per_class and SHOT_VARIATION != "all":
    # Separate by class
    sarcastic_train = train_df[train_df['sarcasm'] == 1]
    non_sarcastic_train = train_df[train_df['sarcasm'] == 0]
    
    print(f"  Available - Sarcastic: {len(sarcastic_train)}, Non-sarcastic: {len(non_sarcastic_train)}")
    
    # Sample from each class
    if len(sarcastic_train) < samples_per_class:
        print(f"‚ö†Ô∏è  Warning: Only {len(sarcastic_train)} sarcastic samples available, using all")
        selected_sarcastic = sarcastic_train
    else:
        selected_sarcastic = sarcastic_train.sample(n=samples_per_class, random_state=42)
    
    if len(non_sarcastic_train) < samples_per_class:
        print(f"‚ö†Ô∏è  Warning: Only {len(non_sarcastic_train)} non-sarcastic samples available, using all")
        selected_non_sarcastic = non_sarcastic_train
    else:
        selected_non_sarcastic = non_sarcastic_train.sample(n=samples_per_class, random_state=42)
    
    # Combine selected samples
    train_df = pd.concat([selected_sarcastic, selected_non_sarcastic]).sample(frac=1, random_state=42).reset_index(drop=True)
    
    print(f"  Selected - Sarcastic: {len(selected_sarcastic)}, Non-sarcastic: {len(selected_non_sarcastic)}")
    print(f"  Final training set: {len(train_df)} samples")
else:
    print(f"  Using all {len(train_df)} training samples")

# Final summary
print(f"\n=== FINAL DATASET SUMMARY ===")
print(f"Configuration: {SHOT_VARIATION}")
print(f"Training samples: {len(train_df)}")
print(f"  - Sarcastic: {len(train_df[train_df['sarcasm']==1])}")
print(f"  - Non-sarcastic: {len(train_df[train_df['sarcasm']==0])}")
print(f"Validation samples: {len(val_df)}")
print(f"  - Sarcastic: {len(val_df[val_df['sarcasm']==1])}")
print(f"  - Non-sarcastic: {len(val_df[val_df['sarcasm']==0])}")
print(f"Test samples: {len(test_df)}")
print(f"  - Sarcastic: {len(test_df[test_df['sarcasm']==1])}")
print(f"  - Non-sarcastic: {len(test_df[test_df['sarcasm']==0])}")

# Verify processed text quality across all datasets
missing_processed_train = train_df['processed_text'].isna().sum()
missing_processed_val = val_df['processed_text'].isna().sum() 
missing_processed_test = test_df['processed_text'].isna().sum()

print(f"\nüîç Data Quality Check:")
print(f"Missing processed_text - Train: {missing_processed_train}, Val: {missing_processed_val}, Test: {missing_processed_test}")

total_missing = missing_processed_train + missing_processed_val + missing_processed_test
total_samples = len(train_df) + len(val_df) + len(test_df)

if total_missing == 0:
    print("‚úÖ All datasets have complete processed text - LLM preprocessing successful!")
    print("ü§ù Team preprocessing workflow completed successfully!")
else:
    print(f"‚ö†Ô∏è  {total_missing}/{total_samples} missing processed text detected")
    print("üìã This is normal for fallback cases where LLM processing failed")

print(f"\nüìà Ready for {SHOT_VARIATION} training with preprocessed data!")
print("=" * 60)

In [None]:
# --- 5.1 DATA VALIDATION AND CLEANING ---
print("=== DATA VALIDATION AND CLEANING ===")

def validate_and_clean_text_data(df, dataset_name):
    """Validate and clean text data to prevent tokenizer errors"""
    print(f"\nüîç Validating {dataset_name} dataset...")
    
    original_count = len(df)
    print(f"   Original samples: {original_count}")
    
    # Check text column
    if 'text' in df.columns:
        null_text = df['text'].isna().sum()
        empty_text = (df['text'].str.strip() == '').sum()
        print(f"   Null text: {null_text}")
        print(f"   Empty text: {empty_text}")
    
    # Check processed_text column
    if 'processed_text' in df.columns:
        null_processed = df['processed_text'].isna().sum()
        empty_processed = (df['processed_text'].str.strip() == '').sum()
        print(f"   Null processed_text: {null_processed}")
        print(f"   Empty processed_text: {empty_processed}")
        
        # Fill missing processed_text with original text
        df['processed_text'] = df['processed_text'].fillna(df['text'])
        
        # Check after filling
        still_null = df['processed_text'].isna().sum()
        if still_null > 0:
            print(f"   ‚ö†Ô∏è  Still {still_null} null values after filling, using fallback")
            df['processed_text'] = df['processed_text'].fillna("No text available")
    else:
        # Create processed_text from text if not exists
        df['processed_text'] = df['text'].fillna("No text available")
        print(f"   Created processed_text from text column")
    
    # Ensure all text values are valid strings
    df['processed_text'] = df['processed_text'].astype(str)
    df['text'] = df['text'].astype(str)
    
    # Replace empty strings with fallback
    df.loc[df['processed_text'].str.strip() == '', 'processed_text'] = "No text available"
    df.loc[df['text'].str.strip() == '', 'text'] = "No text available"
    
    final_count = len(df)
    print(f"   Final samples: {final_count}")
    print(f"   ‚úÖ All text values validated and cleaned")
    
    return df

# Validate and clean all datasets
print("üßπ Cleaning datasets to prevent tokenizer errors...")

train_df = validate_and_clean_text_data(train_df, "Training")
val_df = validate_and_clean_text_data(val_df, "Validation") 
test_df = validate_and_clean_text_data(test_df, "Test")

print("\nüîç Final validation check:")
for name, df in [("Train", train_df), ("Val", val_df), ("Test", test_df)]:
    null_text = df['processed_text'].isna().sum()
    empty_text = (df['processed_text'].str.strip() == '').sum()
    print(f"   {name}: {null_text} null, {empty_text} empty processed_text values")

print("‚úÖ Data validation and cleaning completed!")
print("=" * 60)

In [None]:
# --- 6. PYTORCH DATASET CLASS (KAGGLE OPTIMIZED) ---
class SarcasmDataset(Dataset):
    def __init__(self, dataframe, processor):
        self.dataframe = dataframe.copy()
        self.processor = processor
        
        # Use processed_text if available, otherwise use original text
        if 'processed_text' in dataframe.columns:
            # Handle missing processed_text values properly
            processed_texts = dataframe['processed_text'].fillna(dataframe['text'])
            original_texts = dataframe['text'].fillna("No text available")
            
            # Combine processed and original, ensure no null values
            self.texts = []
            for proc_text, orig_text in zip(processed_texts, original_texts):
                if pd.isna(proc_text) or not isinstance(proc_text, str) or not proc_text.strip():
                    if pd.isna(orig_text) or not isinstance(orig_text, str) or not orig_text.strip():
                        self.texts.append("No text available")
                    else:
                        self.texts.append(str(orig_text).strip())
                else:
                    self.texts.append(str(proc_text).strip())
            
            print(f"‚úì Using processed_text for {len(self.texts)} samples")
        else:
            # Fallback to original text with null handling
            original_texts = dataframe['text'].fillna("No text available")
            self.texts = [str(text).strip() if pd.notna(text) and str(text).strip() else "No text available" 
                         for text in original_texts]
            print(f"‚ö† Using original text for {len(self.texts)} samples")
            
        self.image_paths = dataframe['image_path'].tolist()
        self.labels = dataframe['sarcasm'].tolist()

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        text = self.texts[idx]
        image_path = self.image_paths[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        
        # Handle null/NaN text values
        if pd.isna(text) or text is None or not isinstance(text, str):
            text = "No text available"  # Fallback text
        
        # Ensure text is a proper string
        text = str(text).strip()
        if not text:  # Empty string
            text = "No text available"
        
        # Load image with fallback
        try:
            image = Image.open(image_path).convert("RGB")
        except:
            # Create white dummy image if file missing
            image = Image.new('RGB', (224, 224), color='white')
        
        # Process with CLIP
        inputs = self.processor(
            text=[text], 
            images=image, 
            return_tensors="pt", 
            padding="max_length", 
            max_length=77, 
            truncation=True
        )
        
        return {
            'input_ids': inputs['input_ids'].squeeze(0), 
            'attention_mask': inputs['attention_mask'].squeeze(0), 
            'pixel_values': inputs['pixel_values'].squeeze(0), 
            'labels': label
        }

In [None]:
# --- 7. KELAS MODEL (VERSI FINAL YANG SUDAH DIPERBAIKI) ---
class CueLearningSarcasmModel(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-large-patch14"):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(clip_model_name)
        self.processor = CLIPProcessor.from_pretrained(clip_model_name)
        for param in self.clip.parameters():
            param.requires_grad = False
        text_prompt_length, image_prompt_length, sarcasm_prompt_length = 12, 20, 8
        d_model = self.clip.text_embed_dim
        self.text_prompts = nn.Parameter(torch.randn(text_prompt_length, d_model))
        self.image_prompts = nn.Parameter(torch.randn(image_prompt_length, self.clip.vision_embed_dim))
        sarcasm_texts, non_sarcasm_texts = ["a sarcastic tweet", "this is sarcasm"], ["a normal tweet", "this is not sarcasm"]
        sarcasm_tokens = self.processor(text=sarcasm_texts, return_tensors="pt", padding=True, truncation=True)
        non_sarcasm_tokens = self.processor(text=non_sarcasm_texts, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            sarcasm_word_embeds = self.clip.text_model.embeddings.token_embedding(sarcasm_tokens.input_ids).mean(dim=0)
            non_sarcasm_word_embeds = self.clip.text_model.embeddings.token_embedding(non_sarcasm_tokens.input_ids).mean(dim=0)
        self.sarcasm_prompt_embeds = nn.Parameter(torch.cat([torch.randn(sarcasm_prompt_length, d_model), sarcasm_word_embeds], dim=0))
        self.non_sarcasm_prompt_embeds = nn.Parameter(torch.cat([torch.randn(sarcasm_prompt_length, d_model), non_sarcasm_word_embeds], dim=0))

    def _prepare_4d_attention_mask(self, mask_2d, dtype, device):
        mask_4d = mask_2d.to(dtype).unsqueeze(1).unsqueeze(1)
        inverted_mask = 1.0 - mask_4d
        return inverted_mask * torch.finfo(dtype).min

    def _prepare_4d_causal_attention_mask(self, shape, dtype, device):
        bsz, seq_len = shape[0], shape[1]
        causal_mask = torch.empty((bsz, seq_len, seq_len), dtype=dtype, device=device)
        causal_mask.fill_(torch.finfo(dtype).min)
        causal_mask.triu_(1)
        return causal_mask.unsqueeze(1)

    def forward(self, input_ids, attention_mask, pixel_values):
        inputs_embeds = self.clip.text_model.embeddings.token_embedding(input_ids)
        vision_outputs = self.clip.vision_model(pixel_values=pixel_values)
        image_embeds = vision_outputs[1]
        prompted_text_embeds = torch.cat([self.text_prompts.unsqueeze(0).expand(inputs_embeds.shape[0], -1, -1), inputs_embeds], dim=1)
        prompt_attention_mask = torch.ones(prompted_text_embeds.shape[0], self.text_prompts.shape[0], dtype=attention_mask.dtype, device=input_ids.device)
        extended_attention_mask_2d = torch.cat([prompt_attention_mask, attention_mask], dim=1)
        
        padding_mask_4d = self._prepare_4d_attention_mask(extended_attention_mask_2d, prompted_text_embeds.dtype, input_ids.device)
        causal_mask_4d = self._prepare_4d_causal_attention_mask(prompted_text_embeds.shape, prompted_text_embeds.dtype, input_ids.device)
        final_attention_mask = padding_mask_4d + causal_mask_4d
        
        text_encoder_outputs = self.clip.text_model.encoder(inputs_embeds=prompted_text_embeds, attention_mask=final_attention_mask)
        last_hidden_state = self.clip.text_model.final_layer_norm(text_encoder_outputs[0])
        
        shifted_eos_pos = input_ids.argmax(dim=-1) + self.text_prompts.shape[0]
        text_features = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=input_ids.device), shifted_eos_pos]
        
        image_features = self.clip.visual_projection(image_embeds)
        text_features_proj = self.clip.text_projection(text_features)
        multi_modal_features = F.normalize((text_features_proj + image_features) / 2.0, p=2, dim=-1)

        def get_prompt_features(prompt_embeds):
            prompt_embeds_b1 = prompt_embeds.unsqueeze(0)
            causal_mask = self._prepare_4d_causal_attention_mask(prompt_embeds_b1.shape, prompt_embeds_b1.dtype, prompt_embeds_b1.device)
            encoder_out = self.clip.text_model.encoder(inputs_embeds=prompt_embeds_b1, attention_mask=causal_mask)
            features = self.clip.text_model.final_layer_norm(encoder_out[0])[:, -1, :]
            return self.clip.text_projection(features)

        sarcasm_prompt_features = F.normalize(get_prompt_features(self.sarcasm_prompt_embeds), p=2, dim=-1)
        non_sarcasm_prompt_features = F.normalize(get_prompt_features(self.non_sarcasm_prompt_embeds), p=2, dim=-1)
        
        sim_sarcasm = F.cosine_similarity(multi_modal_features, sarcasm_prompt_features.squeeze(0))
        sim_non_sarcasm = F.cosine_similarity(multi_modal_features, non_sarcasm_prompt_features.squeeze(0))
        logits = torch.stack([sim_non_sarcasm, sim_sarcasm], dim=1) * self.clip.logit_scale.exp()
        return logits

In [None]:
# --- 8. TRAINING & EVALUATION (KAGGLE OPTIMIZED) ---
print("=== TRAINING SETUP ===")

# Initialize model
model = CueLearningSarcasmModel().to(device)
processor = model.processor

# Create datasets
train_dataset = SarcasmDataset(train_df, processor)
val_dataset = SarcasmDataset(val_df, processor)

# Optimized data loaders for 16-shot
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Setup optimizer
learnable_params = [p for p in model.parameters() if p.requires_grad]
print(f"Learnable parameters: {sum(p.numel() for p in learnable_params):,}")

optimizer = AdamW(learnable_params, lr=2e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# Training parameters optimized for 16-shot
num_epochs = 1  # Reduced for faster Kaggle execution
best_val_acc = 0.0
model_save_path = "/kaggle/working/models/best_model.pth"

print(f"\n=== STARTING 1024-SHOT TRAINING ({num_epochs} epochs) ===")

# Training loop
for epoch in range(num_epochs):
    # Training phase
    model.train()
    total_loss = 0
    train_correct = 0
    train_total = 0
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device) 
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        outputs = model(input_ids, attention_mask, pixel_values)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(learnable_params, max_norm=1.0)
        optimizer.step()
        
        # Statistics
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

    # Calculate training metrics
    avg_loss = total_loss / len(train_loader)
    train_acc = 100 * train_correct / train_total

    # Validation phase
    model.eval()
    val_preds, val_labels = [], []
    
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            pixel_values = batch['pixel_values'].to(device) 
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask, pixel_values)
            preds = torch.argmax(outputs, dim=1)
            
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())

    # Calculate validation metrics
    val_acc = accuracy_score(val_labels, val_preds) * 100
    val_f1 = f1_score(val_labels, val_preds)
    
    # Print results
    print(f"Epoch {epoch+1:2d}: Loss={avg_loss:.4f} | Train Acc={train_acc:.1f}% | Val Acc={val_acc:.1f}% | F1={val_f1:.3f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Use safe saving format for PyTorch 2.6+ compatibility
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_acc': float(best_val_acc),  # Convert to native Python float
        }, model_save_path)
        print(f"  ‚úì New best model saved! Val Acc: {best_val_acc:.1f}%")

print(f"\n=== TRAINING COMPLETED ===")
print(f"Best validation accuracy: {best_val_acc:.1f}%")

# Final test evaluation
if len(test_df) > 0:
    print(f"\n=== FINAL TEST EVALUATION ===")
    
    # Load best model with PyTorch 2.6+ compatibility
    try:
        # Try safe loading first (recommended)
        checkpoint = torch.load(model_save_path, map_location=device, weights_only=True)
    except Exception as safe_load_error:
        print(f"‚ö†Ô∏è  Safe loading failed: {safe_load_error}")
        print("üîÑ Attempting legacy loading (weights_only=False)...")
        try:
            # Fallback to legacy loading if safe loading fails
            checkpoint = torch.load(model_save_path, map_location=device, weights_only=False)
            print("‚úÖ Legacy loading successful")
        except Exception as legacy_error:
            print(f"‚ùå Both loading methods failed: {legacy_error}")
            print("üí° Skipping test evaluation due to model loading issues")
            checkpoint = None
    
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model_state_dict'])
    if checkpoint is not None:
        model.load_state_dict(checkpoint['model_state_dict'])
        
        test_dataset = SarcasmDataset(test_df, processor)
        test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)
        
        model.eval()
        test_preds, test_labels = [], []
        
        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Testing"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                pixel_values = batch['pixel_values'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(input_ids, attention_mask, pixel_values)
                preds = torch.argmax(outputs, dim=1)
                
                test_preds.extend(preds.cpu().numpy())
                test_labels.extend(labels.cpu().numpy())
        
        test_acc = accuracy_score(test_labels, test_preds) * 100
        test_f1 = f1_score(test_labels, test_preds)
        
        print(f"FINAL RESULTS:")
        print(f"  Test Accuracy: {test_acc:.1f}%")
        print(f"  Test F1-Score: {test_f1:.3f}")
        print(f"  Best Val Accuracy: {best_val_acc:.1f}%")
    else:
        print("‚ö†Ô∏è  Test evaluation skipped due to model loading issues")
        print("üí° Training completed but unable to load saved model for testing")
else:
    print("No test data available for final evaluation")

# Clean up GPU memory
torch.cuda.empty_cache()

# --- GENERATE ADDITIONAL OUTPUT FILES ---
print("\n=== GENERATING OUTPUT FILES ===")

# 1. Save training metrics and results
results_summary = {
    'experiment_config': {
        'shot_variation': SHOT_VARIATION,
        'training_samples': len(train_df),
        'validation_samples': len(val_df),
        'test_samples': len(test_df),
        'device': str(device),
        'pytorch_version': torch.__version__
    },
    'training_results': {
        'best_validation_accuracy': float(best_val_acc) if 'best_val_acc' in locals() else None,
        'final_test_accuracy': float(test_acc) if 'test_acc' in locals() else None,
        'final_test_f1': float(test_f1) if 'test_f1' in locals() else None
    },
    'model_info': {
        'architecture': 'CLIP + Cue Learning',
        'base_model': 'openai/clip-vit-large-patch14',
        'text_preprocessing': 'LLM-enhanced (mistralai/mistral-nemo)'
    }
}

# Save results as JSON
import json
results_file = f"/kaggle/working/training_results_{SHOT_VARIATION}.json"
with open(results_file, 'w') as f:
    json.dump(results_summary, f, indent=2)
print(f"‚úÖ Results saved: {results_file}")

# 2. Save model configuration and hyperparameters
config_file = f"/kaggle/working/model_config_{SHOT_VARIATION}.json"
model_config = {
    'shot_variation': SHOT_VARIATION,
    'model_architecture': 'CueLearningSarcasmModel',
    'base_clip_model': 'openai/clip-vit-large-patch14',
    'hyperparameters': {
        'learning_rate': 2e-3,
        'weight_decay': 1e-4,
        'num_epochs': 30,
        'batch_size': 2,
        'val_batch_size': 8
    },
    'prompt_lengths': {
        'text_prompts': 12,
        'image_prompts': 20,
        'sarcasm_prompts': 8
    }
}

with open(config_file, 'w') as f:
    json.dump(model_config, f, indent=2)
print(f"‚úÖ Config saved: {config_file}")

# 3. Save predictions if test evaluation was performed
if 'test_preds' in locals() and 'test_labels' in locals():
    predictions_df = pd.DataFrame({
        'true_label': test_labels,
        'predicted_label': test_preds,
        'correct': [1 if true == pred else 0 for true, pred in zip(test_labels, test_preds)]
    })
    predictions_file = f"/kaggle/working/test_predictions_{SHOT_VARIATION}.csv"
    predictions_df.to_csv(predictions_file, index=False)
    print(f"‚úÖ Predictions saved: {predictions_file}")

# 4. Save sample processed data for verification
if len(train_df) > 0:
    sample_data = train_df.head(10)[['id', 'text', 'processed_text', 'sarcasm']].copy()
    sample_file = f"/kaggle/working/sample_data_{SHOT_VARIATION}.csv"
    sample_data.to_csv(sample_file, index=False)
    print(f"‚úÖ Sample data saved: {sample_file}")

# 5. Create experiment summary text file
summary_file = f"/kaggle/working/experiment_summary_{SHOT_VARIATION}.txt"
with open(summary_file, 'w') as f:
    f.write(f"SARCASM DETECTION EXPERIMENT SUMMARY\n")
    f.write(f"===================================\n\n")
    f.write(f"Configuration: {SHOT_VARIATION}\n")
    f.write(f"Training samples: {len(train_df)}\n")
    f.write(f"Validation samples: {len(val_df)}\n")
    f.write(f"Test samples: {len(test_df)}\n")
    f.write(f"Device: {device}\n")
    f.write(f"PyTorch version: {torch.__version__}\n\n")
    
    if 'best_val_acc' in locals():
        f.write(f"Best Validation Accuracy: {best_val_acc:.2f}%\n")
    if 'test_acc' in locals():
        f.write(f"Final Test Accuracy: {test_acc:.2f}%\n")
        f.write(f"Final Test F1-Score: {test_f1:.4f}\n")
    
    f.write(f"\nModel Architecture: CLIP + Cue Learning\n")
    f.write(f"Base Model: openai/clip-vit-large-patch14\n")
    f.write(f"Text Preprocessing: LLM-enhanced (mistralai/mistral-nemo)\n")

print(f"‚úÖ Summary saved: {summary_file}")

print("‚úì Training completed successfully!")

In [None]:
# --- 9. KAGGLE EXECUTION SUMMARY ---
print("=" * 70)
print("üéØ MULTIMODAL SARCASM DETECTION - EXECUTION SUMMARY")
print("=" * 70)

print(f"üìä Dataset Information:")
print(f"   ‚Ä¢ Training samples: {len(train_df)} ({SHOT_VARIATION})")
print(f"   ‚Ä¢ Validation samples: {len(val_df) if 'val_df' in locals() else 'N/A'}")
print(f"   ‚Ä¢ Test samples: {len(test_df) if 'test_df' in locals() else 'N/A'}")
print(f"   ‚Ä¢ Data source: Separate team-preprocessed CSV files")

# Show shot configuration details
if SHOT_VARIATION != "all":
    samples_per_class = shot_configs.get(SHOT_VARIATION)
    total_samples = samples_per_class * 2 if samples_per_class else len(train_df)
    print(f"   ‚Ä¢ Shot configuration: {samples_per_class} per class ‚Üí {total_samples} total training")
else:
    print(f"   ‚Ä¢ Shot configuration: All available training data")

print(f"\nü§ñ Model Configuration:")
print(f"   ‚Ä¢ Architecture: CLIP + Cue Learning (Multimodal)")
print(f"   ‚Ä¢ Device: {device}")
print(f"   ‚Ä¢ Text preprocessing: LLM-enhanced (mistralai/mistral-nemo)")
print(f"   ‚Ä¢ Image processing: CLIP visual encoder")

print(f"\nüéØ Shot Learning Experiment:")
available_shots = ["16shot", "64shot", "128shot", "512shot", "1024shot", "all"]
print(f"   ‚Ä¢ Current configuration: {SHOT_VARIATION}")
print(f"   ‚Ä¢ Available configurations: {', '.join(available_shots)}")
print(f"   ‚Ä¢ Purpose: Compare performance across different data scales")
print(f"   ‚Ä¢ Training data: Subsampled from preprocessed train CSV")
print(f"   ‚Ä¢ Val/Test data: Full preprocessed datasets for fair evaluation")

print(f"\nüìà Training Results:")
if 'best_val_acc' in locals():
    print(f"   ‚Ä¢ Best Validation Accuracy: {best_val_acc:.1f}%")
if 'test_acc' in locals():
    print(f"   ‚Ä¢ Final Test Accuracy: {test_acc:.1f}%")
    print(f"   ‚Ä¢ Final Test F1-Score: {test_f1:.3f}")

# Data quality summary
if 'missing_processed_train' in locals():
    total_missing = missing_processed_train + missing_processed_val + missing_processed_test
    total_samples_all = len(train_df) + len(val_df) + len(test_df)
    quality_pct = ((total_samples_all - total_missing) / total_samples_all) * 100
    print(f"\nüîç Data Quality:")
    print(f"   ‚Ä¢ Preprocessed text quality: {quality_pct:.1f}% complete")
    print(f"   ‚Ä¢ Missing processed entries: {total_missing}/{total_samples_all}")
    print(f"   ‚Ä¢ Train processed: {len(train_df) - missing_processed_train}/{len(train_df)}")
    print(f"   ‚Ä¢ Val processed: {len(val_df) - missing_processed_val}/{len(val_df)}")
    print(f"   ‚Ä¢ Test processed: {len(test_df) - missing_processed_test}/{len(test_df)}")

print(f"\nüíæ Output Files:")
kaggle_working = "/kaggle/working"
output_files = []

# Enhanced file listing with categorization
for root, dirs, files in os.walk(kaggle_working):
    for file in files:
        if file.endswith(('.pth', '.csv', '.txt', '.json')):
            filepath = os.path.join(root, file)
            size_kb = os.path.getsize(filepath) / 1024
            rel_path = os.path.relpath(filepath, kaggle_working)
            
            # Categorize files
            if file.endswith('.pth'):
                file_type = "ü§ñ Model"
            elif file.endswith('.json'):
                file_type = "üìä Config/Results"
            elif file.endswith('.csv'):
                file_type = "üìã Data"
            elif file.endswith('.txt'):
                file_type = "üìù Summary"
            else:
                file_type = "üìÑ Other"
                
            output_files.append((file_type, rel_path, size_kb))

if output_files:
    # Sort by file type for better organization
    output_files.sort(key=lambda x: x[0])
    
    print("   Expected output files for download:")
    for file_type, rel_path, size_kb in output_files:
        print(f"   {file_type}: {rel_path} ({size_kb:.1f} KB)")
        
    print(f"\n   üì¶ Total files generated: {len(output_files)}")
    print(f"   üìÅ All files available in: /kaggle/working/")
else:
    print("   ‚Ä¢ No output files found")

# Additional file expectations
print(f"\nüéØ Expected Output Files for {SHOT_VARIATION}:")
expected_files = [
    f"ü§ñ best_model.pth - Trained model weights",
    f"üìä training_results_{SHOT_VARIATION}.json - Experiment results",
    f"üìä model_config_{SHOT_VARIATION}.json - Model configuration", 
    f"üìã test_predictions_{SHOT_VARIATION}.csv - Test predictions",
    f"üìã sample_data_{SHOT_VARIATION}.csv - Sample processed data",
    f"üìù experiment_summary_{SHOT_VARIATION}.txt - Human-readable summary"
]

for expected in expected_files:
    print(f"   ‚Ä¢ {expected}")

print(f"\nüí° Download Instructions:")
print(f"   1. Go to Output tab in Kaggle")
print(f"   2. Download all files from /kaggle/working/")
print(f"   3. Use best_model.pth for inference")
print(f"   4. Share results JSON with team for comparison")

print(f"\nüìÅ Input Files Used:")
print(f"   ‚Ä¢ Training: {len(train_df)} samples from team preprocessing")
print(f"   ‚Ä¢ Validation: {len(val_df)} samples from team preprocessing")  
print(f"   ‚Ä¢ Test: {len(test_df)} samples from team preprocessing")
print(f"   ‚Ä¢ File structure: 3 separate CSV files (no merging)")

print(f"\nüöÄ Next Steps:")
print(f"   ‚Ä¢ Try different SHOT_VARIATION values to compare performance")
print(f"   ‚Ä¢ Upload model results to team for analysis")
print(f"   ‚Ä¢ Consider ensemble methods for final submission")
print(f"   ‚Ä¢ Each team member can focus on their specific preprocessing task")

print(f"\nüéØ Workflow Benefits:")
print(f"   ‚Ä¢ No merge required: Direct use of separate CSV files")
print(f"   ‚Ä¢ Parallel preprocessing: Each person works independently")
print(f"   ‚Ä¢ Easy debugging: Issues isolated per dataset")
print(f"   ‚Ä¢ Flexible deployment: Can use any combination of processed files")

print(f"\n‚úÖ Experiment completed successfully!")
print(f"Shot: {SHOT_VARIATION} | Device: {device} | Separate file workflow")

# --- PYTORCH VERSION & COMPATIBILITY CHECK ---
print("=== PYTORCH VERSION & COMPATIBILITY CHECK ===")
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch version for compatibility
import torch
pytorch_version = torch.__version__
major_version = int(pytorch_version.split('.')[0])
minor_version = int(pytorch_version.split('.')[1])

if major_version >= 2 and minor_version >= 6:
    print(f"‚úÖ PyTorch {pytorch_version} detected (2.6+)")
    print("üîß Using safe model loading by default")
    print("üí° If loading fails, will fallback to legacy mode")
else:
    print(f"‚úÖ PyTorch {pytorch_version} detected (< 2.6)")
    print("üîß Using legacy model loading")

# Check if we have numpy compatibility issues
try:
    import numpy as np
    print(f"NumPy version: {np.__version__}")
    
    # Test numpy scalar compatibility
    test_scalar = np.float64(1.0)
    print(f"‚úÖ NumPy scalar test passed: {type(test_scalar)}")
except Exception as e:
    print(f"‚ö†Ô∏è  NumPy compatibility issue: {e}")

print("=" * 50)

# --- MODEL SAVING HELPER FUNCTIONS ---
def save_model_safe(model, optimizer, epoch, val_acc, filepath):
    """Save model with PyTorch 2.6+ compatibility"""
    try:
        # Use native Python types to avoid unpickling issues
        checkpoint = {
            'epoch': int(epoch),
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_acc': float(val_acc),  # Convert to native float
            'pytorch_version': torch.__version__
        }
        
        torch.save(checkpoint, filepath)
        print(f"‚úÖ Model saved successfully: {filepath}")
        return True
        
    except Exception as e:
        print(f"‚ùå Error saving model: {e}")
        return False

def load_model_safe(filepath, device='cpu'):
    """Load model with PyTorch 2.6+ compatibility"""
    try:
        # Try safe loading first (PyTorch 2.6+ default)
        checkpoint = torch.load(filepath, map_location=device, weights_only=True)
        print("‚úÖ Safe model loading successful")
        return checkpoint
        
    except Exception as safe_error:
        print(f"‚ö†Ô∏è  Safe loading failed: {safe_error}")
        
        try:
            # Fallback to legacy loading
            print("üîÑ Attempting legacy loading...")
            checkpoint = torch.load(filepath, map_location=device, weights_only=False)
            print("‚úÖ Legacy model loading successful")
            return checkpoint
            
        except Exception as legacy_error:
            print(f"‚ùå Legacy loading also failed: {legacy_error}")
            return None

print("‚úÖ Model saving/loading helper functions defined")