In [None]:
import os
import numpy as np
import tensorflow as tf
import keras
print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")

In [None]:
# First, let's analyze the U-Net implementation structure
def analyze_unet_structure():
    """Analyze the U-Net project structure"""
    
    print("📁 U-Net Project Structure:")
    for root, dirs, files in os.walk("."):
        level = root.replace('.', '').count(os.sep)
        if level > 2:  # Limit depth for readability
            continue
        indent = ' ' * 2 * level
        print(f'{indent}{os.path.basename(root)}/')
        subindent = ' ' * 2 * (level + 1)
        for file in files:
            if file.endswith(('.py', '.md', '.txt', '.ipynb')):
                print(f'{subindent}{file}')

analyze_unet_structure()

In [None]:
# Let's examine the key U-Net implementation files
def examine_unet_implementation():
    """Examine U-Net implementation details"""
    
    print("🔍 U-Net Implementation Analysis")
    print("=" * 60)
    
    # Read the main model file
    try:
        with open('model.py', 'r') as f:
            content = f.read()
            print("📄 model.py key components:")
            
            # Look for key U-Net components
            components = {
                'conv2d': 'Convolutional layers',
                'maxpooling2d': 'Downsampling', 
                'upconv2d': 'Upsampling',
                'concatenate': 'Skip connections',
                'unet': 'Main model definition'
            }
            
            for key, description in components.items():
                if key in content.lower():
                    print(f"   {description} - Found")
                else:
                    print(f"   {description} - Not found")
                    
    except Exception as e:
        print(f"Error reading model.py: {e}")
    
    # Check data preparation
    try:
        with open('data.py', 'r') as f:
            data_content = f.read()
            if 'load' in data_content or 'generator' in data_content:
                print("  Data loading utilities - Found")
    except:
        print("  ℹ️  data.py not found or couldn't be read")

examine_unet_implementation()

In [None]:
# Now let's prepare for the TGS Salt Identification Challenge
def setup_tgs_salt_dataset():
    """Setup for TGS Salt Identification Challenge"""
    
    print("PROBLEM 1: TGS Salt Identification Challenge Setup")
    print("=" * 60)
    
    # Create directory structure for TGS dataset
    os.makedirs('tgs_data/train/images', exist_ok=True)
    os.makedirs('tgs_data/train/masks', exist_ok=True) 
    os.makedirs('tgs_data/test/images', exist_ok=True)
    os.makedirs('tgs_data/test/masks', exist_ok=True)
    
    print("Created TGS dataset structure:")
    print("  - tgs_data/")
    print("    - train/")
    print("      - images/  # Training seismic images")
    print("      - masks/   # Salt segmentation masks") 
    print("    - test/")
    print("      - images/  # Test seismic images")
    print("      - masks/   # Test masks (if available)")
    
    # Create dataset preparation script
    dataset_script = '''
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

class TGSSaltDataset:
    """TGS Salt Identification Challenge Dataset Handler"""
    
    def __init__(self, data_path='tgs_data'):
        self.data_path = data_path
        self.train_csv = os.path.join(data_path, 'train.csv')
        
    def load_and_prepare_data(self):
        """Load and prepare TGS Salt dataset"""
        print("Loading TGS Salt dataset...")
        
        # In a real scenario, you would:
        # 1. Download from Kaggle
        # 2. Extract zip files
        # 3. Organize into train/test splits
        
        # For demonstration, we'll create a sample structure
        self.create_sample_data()
        
    def create_sample_data(self):
        """Create sample data structure for demonstration"""
        print("🔧 Creating sample dataset structure...")
        
        # Create sample images and masks
        for split in ['train', 'test']:
            for i in range(10):  # Create 10 sample images
                # Create sample seismic image (128x128 grayscale)
                seismic_img = np.random.rand(128, 128) * 255
                seismic_img = seismic_img.astype(np.uint8)
                
                # Create sample salt mask (binary)
                salt_mask = np.zeros((128, 128), dtype=np.uint8)
                # Add some random salt regions
                salt_mask[30:60, 40:80] = 255
                
                # Save images
                img_path = os.path.join(self.data_path, split, 'images', f'sample_{i}.png')
                mask_path = os.path.join(self.data_path, split, 'masks', f'sample_{i}.png')
                
                Image.fromarray(seismic_img).save(img_path)
                Image.fromarray(salt_mask).save(mask_path)
        
        print(f"Created sample dataset with 10 train and 10 test images")
        
    def visualize_sample(self, split='train', idx=0):
        """Visualize a sample image and mask"""
        img_path = os.path.join(self.data_path, split, 'images', f'sample_{idx}.png')
        mask_path = os.path.join(self.data_path, split, 'masks', f'sample_{idx}.png')
        
        if os.path.exists(img_path) and os.path.exists(mask_path):
            img = Image.open(img_path)
            mask = Image.open(mask_path)
            
            fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
            
            ax1.imshow(img, cmap='gray')
            ax1.set_title('Seismic Image')
            ax1.axis('off')
            
            ax2.imshow(mask, cmap='gray')
            ax2.set_title('Salt Mask')
            ax2.axis('off')
            
            # Overlay
            ax3.imshow(img, cmap='gray')
            ax3.imshow(mask, cmap='Reds', alpha=0.5)
            ax3.set_title('Overlay')
            ax3.axis('off')
            
            plt.tight_layout()
            plt.show()
        else:
            print("Sample files not found. Run load_and_prepare_data() first.")

# Usage
if __name__ == "__main__":
    dataset = TGSSaltDataset()
    dataset.load_and_prepare_data()
    dataset.visualize_sample()
'''
    
    with open('tgs_dataset.py', 'w') as f:
        f.write(dataset_script)
    
    print("📜 Created TGS dataset handler: tgs_dataset.py")
    
    # Create adaptation script for U-Net
    adaptation_script = '''
import os
import numpy as np
from PIL import Image
from model import unet
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, EarlyStopping

class TGSUnetTrainer:
    """U-Net trainer adapted for TGS Salt Identification"""
    
    def __init__(self, img_rows=128, img_cols=128):
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.model = None
        
    def load_tgs_data(self):
        """Load TGS dataset in U-Net compatible format"""
        print("Loading TGS data for U-Net...")
        
        # This would normally load actual TGS data
        # For demo, we create sample data
        train_images = []
        train_masks = []
        
        train_dir = 'tgs_data/train/images'
        mask_dir = 'tgs_data/train/masks'
        
        if os.path.exists(train_dir):
            for img_file in os.listdir(train_dir)[:10]:  # Limit for demo
                if img_file.endswith('.png'):
                    # Load image
                    img_path = os.path.join(train_dir, img_file)
                    img = Image.open(img_path).resize((self.img_rows, self.img_cols))
                    img_array = np.array(img) / 255.0
                    
                    # Load corresponding mask
                    mask_path = os.path.join(mask_dir, img_file)
                    mask = Image.open(mask_path).resize((self.img_rows, self.img_cols))
                    mask_array = np.array(mask) / 255.0
                    
                    train_images.append(img_array)
                    train_masks.append(mask_array)
            
            train_images = np.array(train_images).reshape(-1, self.img_rows, self.img_cols, 1)
            train_masks = np.array(train_masks).reshape(-1, self.img_rows, self.img_cols, 1)
            
            print(f" Loaded {len(train_images)} training samples")
            return train_images, train_masks
        else:
            print("Training data not found. Please run tgs_dataset.py first.")
            return None, None
    
    def create_model(self):
        """Create U-Net model for salt segmentation"""
        print(" Creating U-Net model...")
        
        # Use the provided U-Net implementation
        self.model = unet(input_size=(self.img_rows, self.img_cols, 1))
        
        # Compile with appropriate loss and metrics for binary segmentation
        self.model.compile(optimizer=Adam(lr=1e-4), 
                          loss='binary_crossentropy',
                          metrics=['accuracy'])
        
        print(" U-Net model created and compiled")
        return self.model
    
    def train(self, epochs=10, batch_size=8):
        """Train the U-Net model"""
        print(" Starting U-Net training...")
        
        # Load data
        train_images, train_masks = self.load_tgs_data()
        
        if train_images is None:
            return
        
        # Create model
        self.create_model()
        
        # Setup callbacks
        callbacks = [
            ModelCheckpoint('tgs_unet_weights.h5', save_best_only=True),
            EarlyStopping(patience=5, restore_best_weights=True)
        ]
        
        # Train model
        history = self.model.fit(
            train_images, train_masks,
            batch_size=batch_size,
            epochs=epochs,
            validation_split=0.2,
            callbacks=callbacks,
            verbose=1
        )
        
        print("Training completed!")
        return history
    
    def predict(self, image_path):
        """Make prediction on a single image"""
        if self.model is None:
            print("Model not trained. Please train first.")
            return None
        
        # Load and preprocess image
        img = Image.open(image_path).convert('L')
        img = img.resize((self.img_rows, self.img_cols))
        img_array = np.array(img) / 255.0
        img_array = img_array.reshape(1, self.img_rows, self.img_cols, 1)
        
        # Predict
        prediction = self.model.predict(img_array)[0]
        
        # Convert to binary mask
        binary_mask = (prediction > 0.5).astype(np.uint8) * 255
        
        return binary_mask

# Training example
if __name__ == "__main__":
    trainer = TGSUnetTrainer()
    trainer.train(epochs=5)  # Short training for demo
'''
    
    with open('train_tgs_unet.py', 'w') as f:
        f.write(adaptation_script)
    
    print("🔧 Created U-Net adaptation for TGS: train_tgs_unet.py")
    
    print("\n TGS Salt Identification setup completed!")
    print("\\nNext steps:")
    print("1. Run: python tgs_dataset.py (to create sample data)")
    print("2. Run: python train_tgs_unet.py (to train U-Net)")
    print("3. Download actual TGS dataset from Kaggle for real training")

setup_tgs_salt_dataset()

In [None]:
# Let's test the dataset creation
def test_dataset_creation():
    """Test the TGS dataset creation"""
    
    print("🧪 Testing Dataset Creation")
    print("=" * 60)
    
    # Run the dataset creation
    !python tgs_dataset.py
    
    # Visualize samples
    try:
        from tgs_dataset import TGSSaltDataset
        dataset = TGSSaltDataset()
        dataset.visualize_sample('train', 0)
        dataset.visualize_sample('train', 1)
    except Exception as e:
        print(f"Visualization error: {e}")

test_dataset_creation()

In [None]:
def unet_code_reading():
    """Code reading for U-Net paper and implementation"""
    
    print("🎯 PROBLEM 2: U-Net Code Reading")
    print("=" * 60)
    
    # Key components from U-Net paper
    paper_components = {
        "Contracting Path (Encoder)": "Feature extraction with downsampling",
        "Expanding Path (Decoder)": "Feature reconstruction with upsampling", 
        "Skip Connections": "Connecting encoder and decoder features",
        "U-Shaped Architecture": "Symmetric encoder-decoder structure",
        "Data Augmentation": "Elastic deformations for limited data",
        "Weighted Loss": "Handling class imbalance in biomedical images",
        "Overlap-tile Strategy": "Seamless segmentation of large images"
    }
    
    print("📋 Key Components from U-Net Paper:")
    for component, description in paper_components.items():
        print(f"  • {component}: {description}")
    
    print("\n🔍 Finding corresponding code implementations...")
    
    # Read and analyze the model implementation
    try:
        with open('model.py', 'r') as f:
            model_code = f.read()
            
        print("\n📄 U-Net Implementation Analysis:")
        
        # Check for key architectural components
        key_elements = {
            'def unet': 'Main U-Net model definition',
            'Conv2D': 'Convolutional layers', 
            'MaxPooling2D': 'Downsampling in encoder',
            'UpSampling2D': 'Upsampling in decoder',
            'concatenate': 'Skip connections',
            'Dropout': 'Regularization',
            'sigmoid': 'Binary segmentation output'
        }
        
        for element, description in key_elements.items():
            if element in model_code:
                print(f"  ✅ {description} - Implemented")
            else:
                print(f"  ❌ {description} - Not found")
                
        # Extract model architecture details
        print("\n🏗️  Model Architecture Details:")
        lines = model_code.split('\n')
        in_unet = False
        indent_level = 0
        
        for line in lines:
            if 'def unet' in line:
                in_unet = True
                print("  Main U-Net function found")
            elif in_unet and line.strip().startswith('def '):
                break
            elif in_unet and line.strip():
                current_indent = len(line) - len(line.lstrip())
                if 'conv' in line.lower() and 'conv2d' in line.lower():
                    print(f"    {'  ' * indent_level}📦 Convolutional Layer")
                elif 'maxpooling' in line.lower():
                    print(f"    {'  ' * indent_level}⬇️  Downsampling (MaxPool)")
                    indent_level += 1
                elif 'upsampling' in line.lower() or 'upconv' in line.lower():
                    print(f"    {'  ' * indent_level}⬆️  Upsampling")
                    indent_level -= 1
                elif 'concatenate' in line.lower():
                    print(f"    {'  ' * indent_level}🔗 Skip Connection")
                    
    except Exception as e:
        print(f"Error analyzing model.py: {e}")

unet_code_reading()

In [None]:
# Let's examine the U-Net architecture in more detail
def detailed_architecture_analysis():
    """Detailed analysis of U-Net architecture implementation"""
    
    print("🔬 Detailed U-Net Architecture Analysis")
    print("=" * 60)
    
    try:
        # Import and examine the model
        from model import unet
        
        # Create a small model to examine structure
        model = unet(input_size=(128, 128, 1))
        
        print("📊 Model Summary:")
        model.summary()
        
        print("\n🏗️  Layer-by-Layer Analysis:")
        for i, layer in enumerate(model.layers):
            layer_type = layer.__class__.__name__
            output_shape = layer.output_shape
            print(f"  {i:2d}. {layer_type:20} → {str(output_shape)}")
            
            # Highlight key U-Net components
            if 'conv2d' in layer.name and 'concat' not in layer.name:
                if 'up' in layer.name:
                    print(f"       ⬆️  Decoder Convolution")
                else:
                    print(f"       ⬇️  Encoder Convolution")
            elif 'max_pooling2d' in layer.name:
                print(f"       🔽 Downsampling")
            elif 'up_sampling2d' in layer.name:
                print(f"       🔼 Upsampling") 
            elif 'concatenate' in layer.name:
                print(f"       🔗 Skip Connection")
                
    except Exception as e:
        print(f"Error in detailed analysis: {e}")
        
        # Alternative analysis by reading code
        print("\n📝 Alternative Code Analysis:")
        try:
            with open('model.py', 'r') as f:
                lines = f.readlines()
                
            encoder_blocks = 0
            decoder_blocks = 0
            skip_connections = 0
            
            for line in lines:
                line_lower = line.lower()
                if 'conv2d' in line_lower and 'maxpooling' not in line_lower:
                    if 'up' in line_lower:
                        decoder_blocks += 1
                    else:
                        encoder_blocks += 1
                elif 'concatenate' in line_lower:
                    skip_connections += 1
                    
            print(f"  Encoder Blocks: {encoder_blocks}")
            print(f"  Decoder Blocks: {decoder_blocks}") 
            print(f"  Skip Connections: {skip_connections}")
            print(f"  U-Shape Verified: {encoder_blocks == decoder_blocks}")
            
        except Exception as e2:
            print(f"Could not perform code analysis: {e2}")

detailed_architecture_analysis()

In [None]:
# Create a comprehensive training and evaluation script
def create_complete_pipeline():
    """Create complete training and evaluation pipeline"""
    
    print("🚀 Creating Complete Training Pipeline")
    print("=" * 60)
    
    complete_script = '''
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from model import unet

class CompleteTGSPipeline:
    """Complete pipeline for TGS Salt Identification with U-Net"""
    
    def __init__(self, img_size=128):
        self.img_size = img_size
        self.model = None
        self.history = None
        
    def prepare_data(self, data_dir='tgs_data'):
        """Prepare training and validation data"""
        print("📊 Preparing data...")
        
        # Load images and masks
        images = []
        masks = []
        
        image_dir = os.path.join(data_dir, 'train', 'images')
        mask_dir = os.path.join(data_dir, 'train', 'masks')
        
        for img_file in os.listdir(image_dir):
            if img_file.endswith('.png'):
                # Load and resize image
                img_path = os.path.join(image_dir, img_file)
                img = plt.imread(img_path)
                if len(img.shape) == 3:
                    img = img[:, :, 0]  # Take first channel if RGB
                img = np.array(Image.fromarray(img).resize((self.img_size, self.img_size)))
                
                # Load and resize mask
                mask_path = os.path.join(mask_dir, img_file)
                mask = plt.imread(mask_path)
                if len(mask.shape) == 3:
                    mask = mask[:, :, 0]
                mask = np.array(Image.fromarray(mask).resize((self.img_size, self.img_size)))
                
                images.append(img)
                masks.append(mask)
        
        # Convert to numpy arrays and normalize
        images = np.array(images) / 255.0
        masks = np.array(masks) / 255.0
        
        # Add channel dimension
        images = images.reshape(-1, self.img_size, self.img_size, 1)
        masks = masks.reshape(-1, self.img_size, self.img_size, 1)
        
        # Split into train and validation
        X_train, X_val, y_train, y_val = train_test_split(
            images, masks, test_size=0.2, random_state=42
        )
        
        print(f"✅ Data prepared: {X_train.shape[0]} train, {X_val.shape[0]} validation")
        return X_train, X_val, y_train, y_val
    
    def build_model(self):
        """Build U-Net model"""
        print("🏗️  Building U-Net model...")
        
        self.model = unet(input_size=(self.img_size, self.img_size, 1))
        
        # Custom compilation for salt segmentation
        self.model.compile(
            optimizer=Adam(lr=1e-4),
            loss='binary_crossentropy',
            metrics=['accuracy', self.dice_coef, self.iou_score]
        )
        
        print("✅ Model built and compiled")
        return self.model
    
    def dice_coef(self, y_true, y_pred):
        """Dice coefficient metric"""
        smooth = 1.0
        y_true_f = tf.keras.backend.flatten(y_true)
        y_pred_f = tf.keras.backend.flatten(y_pred)
        intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
        return (2. * intersection + smooth) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + smooth)
    
    def iou_score(self, y_true, y_pred):
        """Intersection over Union metric"""
        smooth = 1.0
        intersection = tf.keras.backend.sum(y_true * y_pred)
        union = tf.keras.backend.sum(y_true) + tf.keras.backend.sum(y_pred) - intersection
        return (intersection + smooth) / (union + smooth)
    
    def train(self, epochs=50, batch_size=16):
        """Train the model"""
        print("🎯 Starting training...")
        
        # Prepare data
        X_train, X_val, y_train, y_val = self.prepare_data()
        
        # Build model
        self.build_model()
        
        # Callbacks
        callbacks = [
            ModelCheckpoint('best_tgs_unet.h5', monitor='val_loss', save_best_only=True),
            EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
            ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-7)
        ]
        
        # Train
        self.history = self.model.fit(
            X_train, y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=(X_val, y_val),
            callbacks=callbacks,
            verbose=1
        )
        
        print("✅ Training completed!")
        return self.history
    
    def evaluate(self, test_dir='tgs_data/test'):
        """Evaluate model on test set"""
        if self.model is None:
            print("❌ Model not trained")
            return
        
        print("📈 Evaluating model...")
        
        # Load test data
        test_images = []
        test_masks = []
        
        image_dir = os.path.join(test_dir, 'images')
        mask_dir = os.path.join(test_dir, 'masks')
        
        for img_file in os.listdir(image_dir):
            if img_file.endswith('.png'):
                img_path = os.path.join(image_dir, img_file)
                mask_path = os.path.join(mask_dir, img_file)
                
                img = plt.imread(img_path)
                mask = plt.imread(mask_path)
                
                if len(img.shape) == 3:
                    img = img[:, :, 0]
                if len(mask.shape) == 3:
                    mask = mask[:, :, 0]
                
                img = np.array(Image.fromarray(img).resize((self.img_size, self.img_size)))
                mask = np.array(Image.fromarray(mask).resize((self.img_size, self.img_size)))
                
                test_images.append(img)
                test_masks.append(mask)
        
        test_images = np.array(test_images) / 255.0
        test_masks = np.array(test_masks) / 255.0
        test_images = test_images.reshape(-1, self.img_size, self.img_size, 1)
        test_masks = test_masks.reshape(-1, self.img_size, self.img_size, 1)
        
        # Evaluate
        results = self.model.evaluate(test_images, test_masks, verbose=0)
        metrics = ['Loss', 'Accuracy', 'Dice Coefficient', 'IoU Score']
        
        print("\\n📊 Test Results:")
        for metric, value in zip(metrics, results):
            print(f"  {metric}: {value:.4f}")
        
        return results
    
    def plot_results(self):
        """Plot training history and sample predictions"""
        if self.history is None:
            print("❌ No training history available")
            return
        
        # Plot training history
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Loss
        axes[0, 0].plot(self.history.history['loss'], label='Training Loss')
        axes[0, 0].plot(self.history.history['val_loss'], label='Validation Loss')
        axes[0, 0].set_title('Model Loss')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].legend()
        
        # Accuracy
        axes[0, 1].plot(self.history.history['accuracy'], label='Training Accuracy')
        axes[0, 1].plot(self.history.history['val_accuracy'], label='Validation Accuracy')
        axes[0, 1].set_title('Model Accuracy')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].legend()
        
        # Dice Coefficient
        if 'dice_coef' in self.history.history:
            axes[1, 0].plot(self.history.history['dice_coef'], label='Training Dice')
            axes[1, 0].plot(self.history.history['val_dice_coef'], label='Validation Dice')
            axes[1, 0].set_title('Dice Coefficient')
            axes[1, 0].set_ylabel('Dice')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].legend()
        
        # IoU Score
        if 'iou_score' in self.history.history:
            axes[1, 1].plot(self.history.history['iou_score'], label='Training IoU')
            axes[1, 1].plot(self.history.history['val_iou_score'], label='Validation IoU')
            axes[1, 1].set_title('IoU Score')
            axes[1, 1].set_ylabel('IoU')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].legend()
        
        plt.tight_layout()
        plt.show()

# Run complete pipeline
if __name__ == "__main__":
    pipeline = CompleteTGSPipeline(img_size=128)
    
    # For demo purposes, we'll just show the structure
    # In real scenario: pipeline.train(epochs=50)
    print("🚀 Complete pipeline ready!")
    print("To train: pipeline.train(epochs=50)")
    print("To evaluate: pipeline.evaluate()")
    print("To plot results: pipeline.plot_results()")
'''
    
    with open('complete_pipeline.py', 'w') as f:
        f.write(complete_script)
    
    print("📜 Created complete pipeline: complete_pipeline.py")
    
    print("\n✅ Complete U-Net pipeline created!")
    print("\\nThis includes:")
    print("  • Data preparation and loading")
    print("  • U-Net model building") 
    print("  • Training with callbacks")
    print("  • Evaluation with segmentation metrics")
    print("  • Visualization of results")

create_complete_pipeline()