In [None]:
import os
import csv
import subprocess
import xml.etree.ElementTree as ET
import pandas as pd
import torch
import torch.nn as nn
import torchvision.models as models
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import AutoTokenizer
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import numpy as np

# Set seeds for reproducibility
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Kaggle-compatible paths
BASE_DIR = "/kaggle/working" if os.path.exists("/kaggle") else os.path.dirname(os.path.abspath(__file__))
INPUT_DIR = "/kaggle/input" if os.path.exists("/kaggle") else os.path.join(BASE_DIR, "input")

# Configuration
CONFIG = {
    'base_dir': BASE_DIR,
    'input_dir': INPUT_DIR,
    'data_dir': os.path.join(INPUT_DIR, r"C:\Users\rammo\OneDrive\Desktop\medical-captioning\data"),  # Adjust to your dataset name

    'models_dir': os.path.join(BASE_DIR, "models"),
    'cache_dir': os.path.join(BASE_DIR, "cache"),
    'embed_size': 256,
    'hidden_size': 512,
    'max_len': 20,
    'batch_size': 16,  
    'epochs': 10,       
    'lr': 1e-3,
    'image_size': (224, 224)
}

# Auto-derived paths
PATHS = {
    'images': os.path.join(CONFIG['data_dir'], "images"),
    'xml': os.path.join(CONFIG['data_dir'], "xml"),
    'annotations': os.path.join(CONFIG['base_dir'], "annotations.csv"),
    'clean_annotations': os.path.join(BASE_DIR, "annotations_clean.csv"),
    'encoder_model': os.path.join(CONFIG['models_dir'], "encoder.pth"),
    'decoder_model': os.path.join(CONFIG['models_dir'], "decoder.pth"),
    'tokenizer_cache': CONFIG['cache_dir'],
    'submission': os.path.join(BASE_DIR, "submission.csv")
}

# Create directories
for dir_path in [CONFIG['models_dir'], CONFIG['cache_dir']]:
    os.makedirs(dir_path, exist_ok=True)


class DataProcessor:
    """Handles all data preparation tasks."""
    
    @staticmethod
    def create_sample_data():
        """Create sample data when no source files exist."""
        sample_data = pd.DataFrame({
            'image_id': ['sample.png'],
            'caption': ['Sample medical image']
        })
        sample_data.to_csv(PATHS['clean_annotations'], index=False)
        print("‚úÖ Sample data created")

    @staticmethod
    def extract_from_xml():
        """Extract captions from XML files."""
        if not os.path.exists(PATHS['xml']):
            print("No XML directory found")
            return
            
        rows = []
        for xml_file in os.listdir(PATHS['xml']):
            if xml_file.endswith(".xml"):
                xml_path = os.path.join(PATHS['xml'], xml_file)
                try:
                    parser = ET.XMLParser()
                    tree = ET.parse(xml_path, parser)
                    root = tree.getroot()
                    
                    findings = []
                    for elem in root.findall(".//AbstractText"):
                        label = elem.attrib.get("Label", "").lower()
                        if label in ["findings", "impression"] and elem.text:
                            findings.append(elem.text.strip())
                    
                    caption_text = " ".join(findings) if findings else "No findings available"
                    
                    for parent_img in root.findall(".//parentImage"):
                        img_id = parent_img.attrib.get("id") + ".png"
                        rows.append([img_id, caption_text])
                except Exception as e:
                    print(f"Error processing {xml_file}: {e}")
                    continue
        
        if rows:
            with open(PATHS['annotations'], "w", newline="", encoding="utf-8") as f:
                writer = csv.writer(f)
                writer.writerow(["image_id", "caption"])
                writer.writerows(rows)
            print(f"‚úÖ Extracted {len(rows)} annotations from XML")

    @staticmethod
    def clean_annotations():
        """Clean and validate annotations."""
        try:
            df = pd.read_csv(PATHS['annotations'])
        except FileNotFoundError:
            DataProcessor.create_sample_data()
            return
        
        df['image_id'] = df['image_id'].str.strip().str.replace('.jpg', '.png', regex=False)
        df["image_path"] = df["image_id"].apply(lambda x: os.path.join(PATHS['images'], x))
        df = df[df["image_path"].apply(os.path.exists)]
        
        df.to_csv(PATHS['clean_annotations'], index=False)
        print(f"‚úÖ Cleaned annotations: {len(df)} valid rows")

    @staticmethod
    def prepare():
        """Main data preparation pipeline."""
        print("Starting data preparation...")
        print(f"Looking for data in: {CONFIG['data_dir']}")
        
        if os.path.exists(PATHS['xml']) and os.listdir(PATHS['xml']):
            DataProcessor.extract_from_xml()
        
        DataProcessor.clean_annotations()


class CaptionDataset(Dataset):
    """Dataset for image captioning."""
    
    def __init__(self, csv_file, img_folder, tokenizer, transform=None):
        self.data = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.transform = transform
        self.img_folder = img_folder
        
        self.image_paths = [os.path.join(img_folder, fname) for fname in self.data['image_id']]
        self.captions = self.data['caption'].tolist()
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        try:
            with Image.open(self.image_paths[idx]) as img:
                image = img.convert("RGB")
                if self.transform:
                    image = self.transform(image)
        except Exception as e:
            print(f"Error loading image {self.image_paths[idx]}: {e}")
            # Return a blank image if loading fails
            image = torch.zeros(3, *CONFIG['image_size'])
        
        encodings = self.tokenizer(
            self.captions[idx],
            padding="max_length",
            truncation=True,
            max_length=CONFIG['max_len'],
            return_tensors="pt"
        )
        return image, encodings["input_ids"].squeeze(0)


class EncoderCNN(nn.Module):
    """CNN encoder using ResNet50."""
    
    def __init__(self):
        super().__init__()
        resnet = models.resnet50(weights='DEFAULT')
        for param in resnet.parameters():
            param.requires_grad = False
        resnet.fc = nn.Linear(resnet.fc.in_features, CONFIG['embed_size'])
        self.model = resnet
    
    def forward(self, images):
        return self.model(images)


class DecoderRNN(nn.Module):
    """RNN decoder using LSTM."""
    
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, CONFIG['embed_size'])
        self.lstm = nn.LSTM(CONFIG['embed_size'], CONFIG['hidden_size'], batch_first=True)
        self.fc = nn.Linear(CONFIG['hidden_size'], vocab_size)
    
    def forward(self, features, captions):
        embeddings = self.embedding(captions[:, :-1])
        inputs = torch.cat((features.unsqueeze(1), embeddings), 1)
        outputs, _ = self.lstm(inputs)
        return self.fc(outputs)[:, 1:, :]


class ModelTrainer:
    """Handles model training and inference."""
    
    def __init__(self, device):
        self.device = device
        self.tokenizer = self._setup_tokenizer()
        self.transform = self._setup_transforms()
        self.train_losses = []
        
    def _setup_tokenizer(self):

        try:
            tokenizer = AutoTokenizer.from_pretrained(
                "gpt2",
                cache_dir=PATHS.get('tokenizer_cache', None),
                legacy=True
            )
        except Exception as e:
            print(f"‚ö†Ô∏è Error loading tokenizer: {e}")
            print("üîÅ Falling back to local GPT-2 tokenizer...")
    
            try:
                local_dir = "/kaggle/working/gpt2_tokenizer"
                if not os.path.exists(local_dir):
                    os.makedirs(local_dir, exist_ok=True)
                    print("üì¶ Downloading GPT-2 tokenizer locally...")
                    subprocess.run([
                        "huggingface-cli", "download", "gpt2",
                        "--local-dir", local_dir,
                        "--repo-type", "model"
                    ], check=True)
    
                tokenizer = AutoTokenizer.from_pretrained(
                    local_dir,
                    legacy=True,
                    local_files_only=True
                )
            except Exception as e2:
                print(f"‚ùå Fallback tokenizer load failed: {e2}")
                raise RuntimeError("Tokenizer could not be loaded. Check internet connection or model availability.")
            
        if tokenizer.pad_token is None:
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        return tokenizer
    
    def _setup_transforms(self):
        return transforms.Compose([
            transforms.Resize(CONFIG['image_size']),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])   
            
        ])
    
    def _collate_fn(self, batch):
        images, captions = zip(*batch)
        return torch.stack(images, dim=0), torch.stack(captions, dim=0)
    
    def create_dataset(self):
        return CaptionDataset(
            PATHS['clean_annotations'],
            PATHS['images'],
            self.tokenizer,
            self.transform
        )
    
    def train(self):
        dataset = self.create_dataset()
        if len(dataset) == 0:
            print("No valid data found.")
            return None, None
        
        dataloader = DataLoader(
            dataset,
            batch_size=min(CONFIG['batch_size'], len(dataset)),
            shuffle=True,
            collate_fn=self._collate_fn,
            num_workers=0  # Kaggle compatibility
        )
        
        vocab_size = len(self.tokenizer)
        encoder = EncoderCNN().to(self.device)
        decoder = DecoderRNN(vocab_size).to(self.device)
        
        criterion = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
        params = list(decoder.parameters()) + list(encoder.model.fc.parameters())
        optimizer = torch.optim.Adam(params, lr=CONFIG['lr'])
        
        print(f"Training on {self.device} with vocab size: {vocab_size}")
        
        for epoch in range(CONFIG['epochs']):
            encoder.train()
            decoder.train()
            epoch_loss = 0
            
            pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{CONFIG['epochs']}")
            for images, captions in pbar:
                images, captions = images.to(self.device), captions.to(self.device)
                
                features = encoder(images)
                outputs = decoder(features, captions)
                
                targets = captions[:, 1:].contiguous().view(-1)
                outputs = outputs.contiguous().view(-1, vocab_size)
                
                loss = criterion(outputs, targets)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            avg_loss = epoch_loss / len(dataloader)
            self.train_losses.append(avg_loss)
            print(f"Epoch [{epoch+1}/{CONFIG['epochs']}] Avg Loss: {avg_loss:.4f}")
        
        return encoder, decoder
    
    def generate_caption(self, image_path, encoder, decoder):
        """Generate caption for a single image."""
        encoder.eval()
        decoder.eval()
        
        try:
            with Image.open(image_path) as img:
                image = img.convert('RGB')
                image = self.transform(image).unsqueeze(0).to(self.device)
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            return "Error loading image"
        
        with torch.no_grad():
            features = encoder(image)
            inputs = torch.zeros(1, CONFIG['max_len'], dtype=torch.long).to(self.device)
            inputs[0, 0] = self.tokenizer.bos_token_id if self.tokenizer.bos_token_id else 0
            
            generated_ids = []
            
            for i in range(1, CONFIG['max_len']):
                # Decoder needs at least 2 tokens (slices [:, :-1])
                outputs = decoder(features, inputs[:, :i+1])
                if outputs.size(1) > 0:
                    predicted_id = outputs[0, -1, :].argmax().item()
                    generated_ids.append(predicted_id)
                    inputs[0, i] = predicted_id
                    
                    if predicted_id == self.tokenizer.eos_token_id:
                        break
                else:
                    break
            
            return self.tokenizer.decode(generated_ids, skip_special_tokens=True) if generated_ids else "Unable to generate caption"
    
    def visualize_training(self):
        """Plot training loss curve."""
        if not self.train_losses:
            return
        
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(self.train_losses) + 1), self.train_losses, marker='o', linewidth=2)
        plt.title('Training Loss Over Epochs', fontsize=16, fontweight='bold')
        plt.xlabel('Epoch', fontsize=12)
        plt.ylabel('Average Loss', fontsize=12)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(BASE_DIR, 'training_loss.png'), dpi=300, bbox_inches='tight')
        plt.close()
        print("‚úÖ Training loss plot saved")
    
    def visualize_predictions(self, encoder, decoder, num_samples=6):
        """Visualize sample predictions."""
        dataset = self.create_dataset()
        if len(dataset) == 0:
            return
        
        indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
        
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        for idx, ax in zip(indices, axes):
            img_path = dataset.image_paths[idx]
            true_caption = dataset.captions[idx]
            pred_caption = self.generate_caption(img_path, encoder, decoder)
            
            img = Image.open(img_path).convert('RGB')
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(f"True: {true_caption[:40]}...\nPred: {pred_caption[:40]}...", 
                        fontsize=9, wrap=True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(BASE_DIR, 'predictions.png'), dpi=300, bbox_inches='tight')
        plt.close()
        print("‚úÖ Predictions visualization saved")
    
    def visualize_dataset_stats(self):
        """Visualize dataset statistics."""
        dataset = self.create_dataset()
        if len(dataset) == 0:
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        
        # Caption length distribution
        caption_lengths = [len(cap.split()) for cap in dataset.captions]
        axes[0, 0].hist(caption_lengths, bins=30, color='skyblue', edgecolor='black')
        axes[0, 0].set_title('Caption Length Distribution', fontweight='bold')
        axes[0, 0].set_xlabel('Number of Words')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Word frequency
        all_words = ' '.join(dataset.captions).lower().split()
        word_freq = Counter(all_words).most_common(20)
        words, counts = zip(*word_freq)
        axes[0, 1].barh(range(len(words)), counts, color='coral')
        axes[0, 1].set_yticks(range(len(words)))
        axes[0, 1].set_yticklabels(words)
        axes[0, 1].set_title('Top 20 Most Common Words', fontweight='bold')
        axes[0, 1].set_xlabel('Frequency')
        axes[0, 1].invert_yaxis()
        
        # Dataset size
        axes[1, 0].bar(['Total Images'], [len(dataset)], color='lightgreen', edgecolor='black')
        axes[1, 0].set_title('Dataset Size', fontweight='bold')
        axes[1, 0].set_ylabel('Count')
        axes[1, 0].grid(True, alpha=0.3, axis='y')
        
        # Caption character distribution
        char_lengths = [len(cap) for cap in dataset.captions]
        axes[1, 1].hist(char_lengths, bins=30, color='plum', edgecolor='black')
        axes[1, 1].set_title('Caption Character Length Distribution', fontweight='bold')
        axes[1, 1].set_xlabel('Number of Characters')
        axes[1, 1].set_ylabel('Frequency')
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(BASE_DIR, 'dataset_stats.png'), dpi=300, bbox_inches='tight')
        plt.close()
        print("‚úÖ Dataset statistics visualization saved")


def main():
    """Main execution function."""
    try:
        print("üè• Medical Image Captioning - Kaggle Version")
        print(f"Base directory: {BASE_DIR}")
        print(f"Input directory: {INPUT_DIR}")
        
        DataProcessor.prepare()
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        trainer = ModelTrainer(device)
        
        # Visualize dataset before training
        print("\nüìä Generating dataset visualizations...")
        trainer.visualize_dataset_stats()
        
        encoder, decoder = trainer.train()
        if encoder and decoder:
            torch.save(encoder.state_dict(), PATHS['encoder_model'])
            torch.save(decoder.state_dict(), PATHS['decoder_model'])
            print(f"‚úÖ Models saved to: {CONFIG['models_dir']}")
            
            # Generate visualizations
            print("\nüìä Generating visualizations...")
            trainer.visualize_training()
            trainer.visualize_predictions(encoder, decoder)
            
            # Test caption generation
            dataset = trainer.create_dataset()
            if len(dataset) > 0:
                test_path = dataset.image_paths[0]
                if os.path.exists(test_path):
                    caption = trainer.generate_caption(test_path, encoder, decoder)
                    print(f"Generated caption: {caption}")
                    
                    import shutil
                    models_dir = CONFIG.get("models_dir")
                    
                    if models_dir and os.path.exists(models_dir):
                        # Remove trailing slash if present (prevents weird double paths)
                        models_dir = models_dir.rstrip("/")
                    
                        shutil.make_archive(BASE_DIR, "zip", models_dir)
                        print(f"üì¶ Successfully zipped model directory ‚Üí {models_dir}.zip")
                    else:
                        print(f"‚ö†Ô∏è Model directory not found or not set: {models_dir}")
    
    except Exception as e:
        print(f"‚ùå Error: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()