In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def load_dataset(dataset_path, max_samples_per_class=50):
    """Load dataset with error handling"""
    images = []
    labels = []
    class_names = []

    if not os.path.exists(dataset_path):
        return generate_synthetic_data(), [], []

    class_dirs = sorted([d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))])

    for class_idx, class_name in enumerate(class_dirs):
        class_path = os.path.join(dataset_path, class_name)
        class_names.append(class_name)
        
        image_files = [f for f in os.listdir(class_path) if f.endswith(('.jpg', '.png', '.jpeg'))]
        loaded_count = 0

        for image_file in image_files:
            if loaded_count >= max_samples_per_class:
                break
                
            image_path = os.path.join(class_path, image_file)
            
            try:
                if os.path.getsize(image_path) == 0:
                    continue
                    
                img = cv2.imread(image_path)
                if img is not None:
                    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img_resized = cv2.resize(img_rgb, (64, 64))
                    images.append(img_resized)
                    labels.append(class_idx)
                    loaded_count += 1
            except:
                continue

    if len(images) == 0:
        return generate_synthetic_data()
    
    return np.array(images), np.array(labels), class_names

def generate_synthetic_data():
    """Generate synthetic card data"""
    images = []
    labels = []
    
    for card_idx in range(13):
        for suit_idx in range(4):
            for variation in range(20):
                img = np.zeros((64, 64, 3), dtype=np.uint8)
                
                base_color = [200, 50, 50] if suit_idx < 2 else [50, 50, 50]
                
                if card_idx == 0:
                    img[25:40, 25:40] = base_color
                elif card_idx > 9:
                    img[15:50, 20:45] = base_color
                    img[20:30, 25:40] = [255, 255, 255]
                else:
                    for i in range(min(card_idx + 1, 8)):
                        y = 10 + (i % 4) * 12
                        x = 15 + (i // 4) * 25
                        img[y:y+8, x:x+8] = base_color
                
                noise = np.random.normal(0, 15, (64, 64, 3))
                img = np.clip(img.astype(float) + noise, 0, 255).astype(np.uint8)
                
                images.append(img)
                labels.append(card_idx * 4 + suit_idx)
    
    return np.array(images), np.array(labels)

# Load dataset
dataset_path = '../../data/train_dataset'
card_images, card_labels, card_class_names = load_dataset(dataset_path)

if isinstance(card_images, tuple):
    card_images, card_labels = card_images
    card_class_names = [f"synthetic_{i}" for i in range(len(np.unique(card_labels)))]

print(f"Dataset loaded: {card_images.shape[0]} images, {len(np.unique(card_labels))} classes")

# Display sample images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flatten()):
    if i < len(card_images):
        idx = np.random.randint(0, len(card_images))
        ax.imshow(card_images[idx])
        ax.set_title(f'Class {card_labels[idx]}')
        ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

class SOM_GPU:
    """Self-Organizing Map with GPU support"""
    def __init__(self, map_size=(10, 10), input_dim=12288, device=None):
        self.map_height, self.map_width = map_size
        self.input_dim = input_dim
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        
        self.prototypes = torch.rand(self.map_height, self.map_width, self.input_dim, device=self.device) * 255
        self.is_trained = False
        self.training_errors = []
        
        print(f"SOM initialized on {self.device}")
    
    def _find_bmu(self, input_vector):
        """Find Best Matching Unit"""
        diffs = self.prototypes - input_vector
        dists = torch.norm(diffs, dim=2)
        bmu_index = torch.argmin(dists)
        return divmod(bmu_index.item(), self.map_width)
    
    def _gaussian_neighborhood(self, bmu_pos, sigma):
        """Calculate Gaussian neighborhood"""
        bmu_i, bmu_j = bmu_pos
        i_coords = torch.arange(self.map_height, device=self.device).view(-1, 1)
        j_coords = torch.arange(self.map_width, device=self.device).view(1, -1)
        distance_sq = (i_coords - bmu_i)**2 + (j_coords - bmu_j)**2
        return torch.exp(-distance_sq / (2 * sigma**2))
    
    def train(self, training_data, epochs=100, initial_learning_rate=0.1, initial_sigma=3.0):
        """Train the SOM"""
        training_data = torch.tensor(training_data, dtype=torch.float32, device=self.device)
        n_samples = training_data.shape[0]
        
        for epoch in tqdm(range(epochs), desc="Training SOM"):
            progress = epoch / epochs
            lr = initial_learning_rate * torch.exp(torch.tensor(-5 * progress))
            sigma = initial_sigma * torch.exp(torch.tensor(-3 * progress))
            indices = torch.randperm(n_samples)
            
            total_error = 0.0
            for idx in indices:
                vector = training_data[idx]
                bmu_pos = self._find_bmu(vector)
                bmu_i, bmu_j = bmu_pos
                
                error = torch.sum((self.prototypes[bmu_i, bmu_j] - vector) ** 2)
                total_error += error.item()
                
                neighborhood = self._gaussian_neighborhood(bmu_pos, sigma)
                influence = lr * neighborhood.unsqueeze(2)
                self.prototypes += influence * (vector - self.prototypes)
            
            self.training_errors.append(total_error / n_samples)
        
        self.is_trained = True
        print("SOM training completed")
    
    def visualize_prototypes(self, image_shape=(64, 64, 3), title="SOM Prototypes"):
        """Visualize SOM prototypes"""
        if not self.is_trained:
            print("SOM not trained yet!")
            return
        
        fig, axes = plt.subplots(self.map_height, self.map_width, figsize=(15, 15))
        prototypes_cpu = self.prototypes.detach().cpu()
        
        for i in range(self.map_height):
            for j in range(self.map_width):
                img = prototypes_cpu[i, j].reshape(image_shape)
                img_normalized = torch.clamp(img, 0, 255).numpy().astype('uint8')
                axes[i, j].imshow(img_normalized)
                axes[i, j].axis('off')
        
        plt.suptitle(title, fontsize=16)
        plt.tight_layout()
        plt.show()
    
    def plot_training_progress(self):
        """Plot training error over epochs"""
        if not self.training_errors:
            print("No training data available")
            return
        
        plt.figure(figsize=(10, 6))
        plt.plot(self.training_errors, 'b-', linewidth=2)
        plt.title('SOM Training Error')
        plt.xlabel('Epoch')
        plt.ylabel('Quantization Error')
        plt.grid(True, alpha=0.3)
        plt.show()
    
    def compress(self, image):
        """Compress image to BMU coordinates"""
        if not self.is_trained:
            raise ValueError("SOM not trained!")
        
        if isinstance(image, np.ndarray) and len(image.shape) >= 2:
            vector = torch.tensor(image.flatten(), dtype=torch.float32, device=self.device)
        else:
            vector = torch.tensor(image, dtype=torch.float32, device=self.device)
        
        return self._find_bmu(vector)
    
    def decompress(self, compressed_coords):
        """Decompress BMU coordinates back to image"""
        if not self.is_trained:
            raise ValueError("SOM not trained!")
        
        bmu_i, bmu_j = compressed_coords
        return self.prototypes[bmu_i, bmu_j].detach().cpu().numpy().reshape(64, 64, 3)

In [None]:
# Train SOM and visualize prototypes
print("Training SOM on card dataset...")

# Prepare training data
training_images = [img.flatten() for img in card_images[:1000]]

# Create and train SOM
som = SOM_GPU(map_size=(10, 10), input_dim=12288)
som.train(training_images, epochs=100)

# Visualize results
som.visualize_prototypes((64, 64, 3), "SOM Prototypes - Card Dataset")
som.plot_training_progress()

print(f"SOM training completed. Final error: {som.training_errors[-1]:.2f}")

In [None]:
# Test SOM compression and decompression
from sklearn.metrics import mean_squared_error

def test_som_compression(som, test_images, test_labels):
    """Test SOM compression and decompression"""
    compressed_coords = []
    reconstructed_images = []
    
    for img in test_images:
        # Compress
        coords = som.compress(img)
        compressed_coords.append(coords)
        
        # Decompress
        reconstructed = som.decompress(coords)
        reconstructed_images.append(reconstructed)
    
    return compressed_coords, reconstructed_images

def visualize_compression_results(original_images, reconstructed_images, compressed_coords, labels):
    """Visualize compression results"""
    n_images = len(original_images)
    fig, axes = plt.subplots(3, n_images, figsize=(18, 9))
    
    for i in range(n_images):
        original = original_images[i]
        reconstructed = reconstructed_images[i]
        coords = compressed_coords[i]
        
        # Original
        axes[0, i].imshow(original.astype(np.uint8))
        axes[0, i].set_title(f'Original\nLabel: {labels[i]}')
        axes[0, i].axis('off')
        
        # Reconstructed
        axes[1, i].imshow(reconstructed.astype(np.uint8))
        axes[1, i].set_title(f'Reconstructed\nBMU: {coords}')
        axes[1, i].axis('off')
        
        # Error
        error = np.abs(original.astype(float) - reconstructed.astype(float))
        axes[2, i].imshow(error.astype(np.uint8))
        
        # Calculate MSE
        mse = mean_squared_error(original.flatten(), reconstructed.flatten())
        axes[2, i].set_title(f'Error\nMSE: {mse:.1f}')
        axes[2, i].axis('off')
    
    plt.suptitle('SOM Compression/Decompression Results', fontsize=16)
    plt.tight_layout()
    plt.show()

# Test compression
test_indices = [0, 100, 200, 300, 400, 500]
test_images = [card_images[i] for i in test_indices]
test_labels = [card_labels[i] for i in test_indices]

compressed_coords, reconstructed_images = test_som_compression(som, test_images, test_labels)

# Visualize results
visualize_compression_results(test_images, reconstructed_images, compressed_coords, test_labels)

# Calculate compression statistics
compression_ratio = 12288 / 2  # 64*64*3 / 2 coordinates
avg_mse = np.mean([mean_squared_error(orig.flatten(), recon.flatten()) 
                   for orig, recon in zip(test_images, reconstructed_images)])

print(f"\nCompression Statistics:")
print(f"  - Compression ratio: {compression_ratio:.1f}x")
print(f"  - Average MSE: {avg_mse:.2f}")
print(f"  - BMU coordinates used: {set(compressed_coords)}")

In [None]:
# SOM Generator for continuous interpolation
class SOM_Generator:
    """Generate images from SOM using interpolation"""
    def __init__(self, som):
        self.som = som
        if not som.is_trained:
            raise ValueError("SOM must be trained first!")
    
    def _bilinear_interpolation(self, float_i, float_j):
        """Bilinear interpolation between prototypes"""
        i_low, i_high = int(np.floor(float_i)), int(np.ceil(float_i))
        j_low, j_high = int(np.floor(float_j)), int(np.ceil(float_j))
        
        # Clamp to valid range
        i_low = max(0, min(i_low, self.som.map_height - 1))
        i_high = max(0, min(i_high, self.som.map_height - 1))
        j_low = max(0, min(j_low, self.som.map_width - 1))
        j_high = max(0, min(j_high, self.som.map_width - 1))
        
        # Interpolation weights
        w_i = float_i - i_low
        w_j = float_j - j_low
        
        # Get corner prototypes
        top_left = self.som.prototypes[i_low, j_low]
        top_right = self.som.prototypes[i_low, j_high]
        bottom_left = self.som.prototypes[i_high, j_low]
        bottom_right = self.som.prototypes[i_high, j_high]
        
        # Interpolate
        top = (1 - w_j) * top_left + w_j * top_right
        bottom = (1 - w_j) * bottom_left + w_j * bottom_right
        result = (1 - w_i) * top + w_i * bottom
        
        return result
    
    def generate_from_position(self, position):
        """Generate image from continuous position"""
        i, j = position
        
        # If integer coordinates, return prototype directly
        if i == int(i) and j == int(j):
            return self.som.prototypes[int(i), int(j)].reshape(64, 64, 3).detach().cpu().numpy()
        
        # Use interpolation
        synthetic_vector = self._bilinear_interpolation(i, j)
        return synthetic_vector.reshape(64, 64, 3).detach().cpu().numpy()
    
    def generate_from_latent(self, latent_vector, latent_range=(-1, 1)):
        """Generate image from latent space coordinates"""
        x, y = latent_vector
        latent_min, latent_max = latent_range
        
        # Map to SOM coordinates
        norm_x = (x - latent_min) / (latent_max - latent_min) * (self.som.map_height - 1)
        norm_y = (y - latent_min) / (latent_max - latent_min) * (self.som.map_width - 1)
        
        return self.generate_from_position((norm_x, norm_y))

In [None]:
# Test SOM generation
generator = SOM_Generator(som)

# Generate images from different latent coordinates
latent_vectors = [(-1, -1), (-0.5, 0), (0, 0), (0.5, 0.5), (1, 1)]

fig, axes = plt.subplots(1, len(latent_vectors), figsize=(20, 4))

for i, latent in enumerate(latent_vectors):
    image = generator.generate_from_latent(latent)
    image = np.clip(image, 0, 255).astype(np.uint8)
    
    axes[i].imshow(image)
    axes[i].set_title(f"Latent: {latent}")
    axes[i].axis('off')

plt.suptitle('SOM Generated Images from Latent Space', fontsize=16)
plt.tight_layout()
plt.show()

# Generate images from SOM coordinates
som_positions = [(0, 0), (2.5, 2.5), (5, 5), (7.5, 7.5), (9, 9)]

fig, axes = plt.subplots(1, len(som_positions), figsize=(20, 4))

for i, position in enumerate(som_positions):
    image = generator.generate_from_position(position)
    image = np.clip(image, 0, 255).astype(np.uint8)
    
    axes[i].imshow(image)
    axes[i].set_title(f"SOM: {position}")
    axes[i].axis('off')

plt.suptitle('SOM Generated Images from Map Coordinates', fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# End of notebook