In [None]:
!pip install -q transformers==4.41.0 torch pillow requests tqdm pandas \
  scikit-learn scikit-image accelerate bitsandbytes peft datasets \
  opencv-python imageio imageio-ffmpeg einops timm
import torch
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')
print(f"‚úÖ PyTorch: {torch.__version__}")
print(f"‚úÖ GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

‚úÖ PyTorch: 2.8.0+cu126
‚úÖ GPU: Tesla T4


In [None]:
from google.colab import drive
import os
import pandas as pd
from sklearn.model_selection import train_test_split
drive.mount('/content/drive')
DATA_CSV = "/content/drive/MyDrive/adobe/train.csv"
OUTPUT_DIR = "/content/drive/MyDrive/adobe/output"

os.makedirs(OUTPUT_DIR, exist_ok=True)
df = pd.read_csv(DATA_CSV)
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
print(f"‚úÖ Train data: {len(train_df)} samples")
print(f"‚úÖ Test data: {len(test_df)} samples")
print(f"\nColumns: {train_df.columns.tolist()}")
print(f"\nSample media:\n{train_df['media'].iloc[0]}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úÖ Train data: 13864 samples
‚úÖ Test data: 3467 samples

Columns: ['id', 'date', 'likes', 'content', 'username', 'media', 'inferred company']

Sample media:
[Photo(previewUrl='https://pbs.twimg.com/media/D1LiAVeXQAAjpZw?format=png&name=small', fullUrl='https://pbs.twimg.com/media/D1LiAVeXQAAjpZw?format=png&name=large')]


In [None]:
!pip install flash_attn



In [None]:
from transformers import AutoProcessor, AutoModelForCausalLM
import torch

print("="*70)
print("üöÄ LOADING FLORENCE-2 (GPU-OPTIMIZED)")
print("="*70)

device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "microsoft/Florence-2-large"

print(f"\nLoading {model_id} on {device}...")
florence_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    trust_remote_code=True
).to(device)
florence_processor = AutoProcessor.from_pretrained(
    model_id,
    trust_remote_code=True
)

print("‚úÖ Florence-2 and processor loaded")
print(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB" if device == "cuda" else "")
def generate_caption_gpu_optimized(image, task="<MORE_DETAILED_CAPTION>"):
    """
    GPU-optimized caption generation
    Keeps everything on GPU for speed
    """
    try:
        inputs = florence_processor(
            text=task,
            images=image,
            return_tensors="pt"
        )
        if device == "cuda":
            for key in inputs:
                if torch.is_tensor(inputs[key]):
                    inputs[key] = inputs[key].to(device, torch.float16)
        with torch.no_grad():
            generated_ids = florence_model.generate(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                max_new_tokens=256,
                num_beams=3
            )
        result = florence_processor.batch_decode(
            generated_ids,
            skip_special_tokens=False
        )[0]
        parsed = florence_processor.post_process_generation(
            result,
            task=task,
            image_size=(image.width, image.height)
        )

        return parsed.get(task, "")

    except Exception as e:
        return ""

print("‚úÖ GPU-optimized caption function ready")
def check_gpu_setup():
    """Verify everything is on GPU"""
    print("\nüîç GPU Setup Check:")
    try:
        param = next(florence_model.parameters())
        model_device = param.device
        print(f"  Model device: {model_device}")
    except:
        print("  Model device: Unknown")
    if torch.cuda.is_available():
        print(f"  GPU allocated: {torch.cuda.memory_allocated()/1e9:.2f}GB")
        print(f"  GPU reserved: {torch.cuda.memory_reserved()/1e9:.2f}GB")
        print(f"  GPU total: {torch.cuda.get_device_properties(0).total_memory/1e9:.2f}GB")
    print("\n  Testing inference on GPU...")
    from PIL import Image
    import numpy as np

    test_img = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))

    import time
    start = time.time()
    caption = generate_caption_gpu_optimized(test_img)
    elapsed = time.time() - start

    print(f"  ‚úÖ Inference time: {elapsed:.2f}s")
    if caption:
        print(f"  ‚úÖ Generated: {caption[:50]}...")
    else:
        print(f"  ‚ö†Ô∏è No caption generated (might be test image)")

check_gpu_setup()


üöÄ LOADING FLORENCE-2 (GPU-OPTIMIZED)

Loading microsoft/Florence-2-large on cuda...
‚úÖ Florence-2 and processor loaded
GPU Memory: 1.55 GB
‚úÖ GPU-optimized caption function ready

üîç GPU Setup Check:
  Model device: cuda:0
  GPU allocated: 1.55GB
  GPU reserved: 3.13GB
  GPU total: 15.83GB

  Testing inference on GPU...
  ‚úÖ Inference time: 0.04s
  ‚ö†Ô∏è No caption generated (might be test image)


In [None]:
import re
import cv2
import imageio
import requests
import tempfile
from io import BytesIO
from PIL import Image
from typing import Dict, List, Tuple
from scipy.stats import entropy as scipy_entropy
def parse_media_string(media_str):
    """Parse Twitter media object"""
    if not media_str or str(media_str) == 'nan':
        return []
    media_str = str(media_str)
    media_list = []
    try:
        pattern = r'(Photo|Video|Gif)\((.*?)\)'
        matches = re.findall(pattern, media_str, re.DOTALL)

        for media_type, content in matches:
            media_dict = {
                "type": media_type.lower(),
                "thumbnail_url": "",
                "video_url": "",
                "content_type": "",
            }

            thumb_match = re.search(r"previewUrl='([^']+)'", content) or \
                         re.search(r"thumbnailUrl='([^']+)'", content)
            if thumb_match:
                media_dict["thumbnail_url"] = thumb_match.group(1)

            video_match = re.search(r"url='([^']+\.mp4)'", content)
            if video_match:
                media_dict["video_url"] = video_match.group(1)
                media_dict["content_type"] = "video/mp4"

            type_match = re.search(r"contentType='([^']+)'", content)
            if type_match:
                media_dict["content_type"] = type_match.group(1)

            media_list.append(media_dict)
    except:
        pass

    return media_list

def get_best_media_url(media_list):
    """Get best URL from media objects"""
    for media in media_list:
        if media.get('video_url'):
            return (media['video_url'], media['type'], media.get('content_type', 'video/mp4'))
        if media.get('thumbnail_url'):
            return (media['thumbnail_url'], media['type'], 'image')
    return ("", "", "")

def detect_media_type(url, content_type, parsed_type):
    """Detect actual media type"""
    url_lower = str(url).lower()
    content_type_lower = str(content_type).lower()

    if 'mp4' in content_type_lower or 'video' in content_type_lower:
        return 'video'
    if 'gif' in content_type_lower:
        return 'gif'
    if any(ext in url_lower for ext in ['.mp4']):
        return 'video'
    if any(ext in url_lower for ext in ['.gif']):
        return 'gif'
    if any(ext in url_lower for ext in ['.jpg', '.jpeg', '.png', '.webp']):
        return 'image'
    if parsed_type == 'photo':
        return 'image'
    if parsed_type in ['video', 'gif']:
        return parsed_type
    return 'image'
def fetch_media(url, timeout=10):
    """Download media"""
    try:
        response = requests.get(url, timeout=timeout)
        return BytesIO(response.content)
    except:
        return None
def calculate_entropy(frame):
    """Calculate frame entropy"""
    try:
        if isinstance(frame, Image.Image):
            frame = np.array(frame)

        if len(frame.shape) == 3:
            gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        else:
            gray = frame

        hist, _ = np.histogram(gray.flatten(), bins=256, range=(0, 256))
        hist = hist / (hist.sum() + 1e-7)
        return scipy_entropy(hist)
    except:
        return 0

def extract_frames_entropy(url, media_type, num_frames=5):
    """Extract frames using entropy-based selection"""
    try:
        media_file = fetch_media(url)
        if media_file is None:
            return []
        if media_type == 'image':
            img = Image.open(media_file).convert('RGB')
            return [img]
        elif media_type in ['video', 'gif']:
            with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp:
                tmp.write(media_file.getvalue())
                tmp_path = tmp.name

            reader = imageio.get_reader(tmp_path)
            total_frames = len(reader)

            if total_frames == 0:
                return []
            sample_indices = np.linspace(0, total_frames - 1, min(15, total_frames), dtype=int)
            sampled_frames = []
            for idx in sample_indices:
                try:
                    frame = reader.get_data(idx)
                    img = Image.fromarray(frame).convert('RGB')
                    sampled_frames.append(img)
                except:
                    pass
            if len(sampled_frames) == 0:
                return []
            entropies = [(calculate_entropy(np.array(img)), img) for img in sampled_frames]
            entropies.sort(key=lambda x: x[0], reverse=True)
            return [img for _, img in entropies[:num_frames]]
        return []
    except:
        return []

print("‚úÖ Helper functions ready")

‚úÖ Helper functions ready


In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
import gc
import time
from concurrent.futures import ThreadPoolExecutor
import queue
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print("‚úÖ TF32 enabled (faster on A100/RTX 30/40 series)")
try:
    florence_model = torch.compile(florence_model, mode="reduce-overhead")
    print("‚úÖ Model compiled with torch.compile()")
except:
    print("‚ö†Ô∏è torch.compile not available")

class AsyncMediaDataset(Dataset):
    """Dataset with async frame extraction"""

    def __init__(self, dataframe, num_workers=8):
        self.df = dataframe
        self.num_workers = num_workers
        self.frame_cache = {}
        self.executor = ThreadPoolExecutor(max_workers=num_workers)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        media_str = str(row['media'])
        try:
            parsed = parse_media_string(media_str)
            url, parsed_type, content_type = get_best_media_url(parsed)
            actual_type = detect_media_type(url, content_type, parsed_type)
            frames = extract_frames_entropy(url, actual_type, num_frames=5)
        except:
            frames = []

        return {
            'frames': frames,
            'index': idx,
            'media_type': actual_type if frames else 'none',
            'id': row.get('id', idx)
        }
def custom_collate_fn(batch):
    """Efficient batching - flattens frames while tracking indices"""

    all_frames = []
    frame_to_sample = []
    media_types = []
    indices = []
    ids = []

    for sample_idx, sample in enumerate(batch):
        frames = sample['frames']
        all_frames.extend(frames)
        for _ in frames:
            frame_to_sample.append(sample_idx)

        media_types.append(sample['media_type'])
        indices.append(sample['index'])
        ids.append(sample['id'])

    return {
        'frames': all_frames,
        'frame_to_sample': frame_to_sample,
        'media_types': media_types,
        'indices': indices,
        'ids': ids,
        'batch_size': len(batch)
    }
print("\nüìä Creating optimized DataLoader...")
media_dataset = AsyncMediaDataset(train_df, num_workers=8)
data_loader = DataLoader(
    media_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    collate_fn=custom_collate_fn,
    prefetch_factor=4,
    persistent_workers=True
)

print(f"‚úÖ DataLoader created with:")
print(f"   - Batch size: 128")
print(f"   - Workers: 8 (async extraction)")
print(f"   - Prefetch: 4 batches")
print(f"   - Pin memory: Yes")
print("\n" + "="*70)
print("‚ö° ULTRA-OPTIMIZED INFERENCE (Batch 128 + Async + Greedy Decode)")
print("="*70)

all_results = [None] * len(train_df)

inf_batch_size = 32
total_frames_processed = 0
start_time = time.time()

for batch_idx, batch in enumerate(tqdm(data_loader, desc="Processing Batches")):

    flat_frames = batch['frames']
    frame_to_sample = batch['frame_to_sample']
    batch_size = batch['batch_size']

    if not flat_frames:
        continue
    all_captions = []

    for i in range(0, len(flat_frames), inf_batch_size):
        sub_batch_frames = flat_frames[i:i+inf_batch_size]

        try:
            inputs = florence_processor(
                text="<MORE_DETAILED_CAPTION>",
                images=sub_batch_frames,
                return_tensors="pt"
            )

            for key in inputs:
                if torch.is_tensor(inputs[key]):
                    inputs[key] = inputs[key].to(device, torch.float16)

            with torch.no_grad():
                generated_ids = florence_model.generate(
                    input_ids=inputs["input_ids"],
                    pixel_values=inputs["pixel_values"],
                    max_new_tokens=200,
                    num_beams=1,
                    do_sample=False,
                    use_cache=True
                )
            captions_batch = florence_processor.batch_decode(
                generated_ids,
                skip_special_tokens=False
            )
            for caption_raw in captions_batch:
                parsed = florence_processor.post_process_generation(
                    caption_raw,
                    task="<MORE_DETAILED_CAPTION>",
                    image_size=(224, 224)
                )
                caption = parsed.get("<MORE_DETAILED_CAPTION>", "")
                all_captions.append(caption)

        except Exception as e:
            all_captions.extend([""] * len(sub_batch_frames))
    sample_captions = [[] for _ in range(batch_size)]

    for frame_idx, caption in enumerate(all_captions):
        sample_idx = frame_to_sample[frame_idx]
        sample_captions[sample_idx].append(caption)

    for local_idx, sample_idx in enumerate(batch['indices']):
        combined_caption = " ".join([c for c in sample_captions[local_idx] if c])

        all_results[sample_idx] = {
            'row_id': batch['ids'][local_idx],
            'media_type': batch['media_types'][local_idx],
            'combined_caption': combined_caption,
            'num_frames': len(sample_captions[local_idx]),
            'has_media': len(sample_captions[local_idx]) > 0
        }

    total_frames_processed += len(flat_frames)
    if batch_idx % 5 == 0:
        torch.cuda.empty_cache()
        gc.collect()

        elapsed = time.time() - start_time
        fps = total_frames_processed / elapsed
        eta_mins = (len(train_df) * 5 - total_frames_processed) / (fps * 5) / 60

        print(f"\nüìä Progress: {batch_idx+1}/{len(data_loader)} batches")
        print(f"   Frames/sec: {fps:.1f}")
        print(f"   ETA: {eta_mins:.1f} mins")

elapsed_total = time.time() - start_time

print(f"\n" + "="*70)
print(f"‚úÖ INFERENCE COMPLETE")
print(f"="*70)
print(f"Total time: {elapsed_total/60:.1f} minutes")
print(f"Frames/sec: {total_frames_processed/elapsed_total:.1f}")
analysis_df = pd.DataFrame([r for r in all_results if r is not None])
print(f"\nüìä Results:")
print(f"  Processed: {len(analysis_df)}")
print(f"  With captions: {(analysis_df['combined_caption'].str.len() > 0).sum()}")
print(f"  Total frames: {analysis_df['num_frames'].sum()}")


‚úÖ TF32 enabled (faster on A100/RTX 30/40 series)
‚úÖ Model compiled with torch.compile()

üìä Creating optimized DataLoader...
‚úÖ DataLoader created with:
   - Batch size: 128
   - Workers: 8 (async extraction)
   - Prefetch: 4 batches
   - Pin memory: Yes

‚ö° ULTRA-OPTIMIZED INFERENCE (Batch 128 + Async + Greedy Decode)


Processing Batches:   0%|          | 0/109 [00:00<?, ?it/s]

Exception ignored in: <function _releaseLock at 0x7d3eb4989620>Exception ignored in: <function _releaseLock at 0x7d3eb4989620>

Traceback (most recent call last):


KeyboardInterrupt: 

  File "/usr/lib/python3.12/logging/__init__.py", line 243, in _releaseLock
Traceback (most recent call last):
    KeyboardInterrupt

In [None]:
print("="*70)
print("üíæ SAVING DATASET")
print("="*70)

train_df_enhanced = train_df.copy()
train_df_enhanced['florence_caption'] = analysis_df['combined_caption'].values
train_df_enhanced['media_type'] = analysis_df['media_type'].values
train_df_enhanced['num_frames'] = analysis_df['num_frames'].values
train_df_enhanced['has_media'] = analysis_df['has_media'].values

enhanced_csv = f"{OUTPUT_DIR}/train_florence2_enhanced.csv"
train_df_enhanced.to_csv(enhanced_csv, index=False)

print(f"‚úÖ Saved: {enhanced_csv}")
print(f"Shape: {train_df_enhanced.shape}")

sample = train_df_enhanced[train_df_enhanced['has_media'] == True].iloc[0]
print(f"\nüìã Sample:")
print(f"  Company: {sample['inferred company']}")
print(f"  Media: {sample['media_type']}")
print(f"  Frames: {sample['num_frames']}")
print(f"  Caption: {sample['florence_caption'][:100]}...")


In [None]:
print("Creating instructions...")
instructions = []
responses = []
for _, row in tqdm(train_df_enhanced.iterrows(), total=len(train_df_enhanced)):
    company = str(row.get('inferred company', 'Unknown'))
    username = str(row.get('username', company))
    likes = int(row.get('likes', 0))
    caption = str(row.get('florence_caption', ''))
    content = str(row.get('content', ''))

    context = f"Media: {caption}" if caption else "No media"

    instruction = f"""Generate tweet for {company} (@{username}) targeting {likes} likes:
{context}

Tweet:"""

    instructions.append(instruction)
    responses.append(content)

dataset_df = pd.DataFrame({'instruction': instructions, 'response': responses})
instruction_csv = f"{OUTPUT_DIR}/instructions_florence2.csv"
dataset_df.to_csv(instruction_csv, index=False)

print(f"‚úÖ {len(dataset_df)} instructions created")
