In [3]:
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from dotenv import load_dotenv
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import io
from typing import List, Dict, Callable

load_dotenv()


class ImageDataset(Dataset):
    def __init__(self, images: List[bytes], labels: List[int], transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(io.BytesIO(self.images[idx])).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, self.labels[idx]


class SimpleCNN(nn.Module):
    def __init__(self, num_classes: int):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(128 * 28 * 28, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


class ImageClassifier:
    def __init__(self):
        self.model = None
        self.classes = []
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def train(
        self,
        images: List[bytes],
        labels: List[str],
        epochs: int,
        on_epoch_end: Callable[[int, float, float], None]
    ) -> Dict[str, float]:
        self.classes = sorted(list(set(labels)))
        label_to_idx = {label: idx for idx, label in enumerate(self.classes)}
        numeric_labels = [label_to_idx[label] for label in labels]

        dataset = ImageDataset(images, numeric_labels, self.transform)
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

        train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)  # Reduced batch size for speed
        val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

        self.model = SimpleCNN(len(self.classes)).to(self.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=0.001)

        best_val_acc = 0.0

        for epoch in range(epochs):
            self.model.train()
            train_loss = 0.0
            train_correct = 0
            train_total = 0

            for images_batch, labels_batch in train_loader:
                images_batch = images_batch.to(self.device)
                labels_batch = labels_batch.to(self.device)

                optimizer.zero_grad()
                outputs = self.model(images_batch)
                loss = criterion(outputs, labels_batch)
                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels_batch.size(0)
                train_correct += (predicted == labels_batch).sum().item()

            train_accuracy = train_correct / train_total

            self.model.eval()
            val_loss = 0.0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for images_batch, labels_batch in val_loader:
                    images_batch = images_batch.to(self.device)
                    labels_batch = labels_batch.to(self.device)

                    outputs = self.model(images_batch)
                    loss = criterion(outputs, labels_batch)

                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels_batch.size(0)
                    val_correct += (predicted == labels_batch).sum().item()

            val_accuracy = val_correct / val_total if val_total > 0 else 0.0

            if val_accuracy > best_val_acc:
                best_val_acc = val_accuracy

            on_epoch_end(epoch + 1, train_loss / len(train_loader), train_accuracy)

        return {
            'final_accuracy': best_val_acc,
            'classes': self.classes
        }

    def predict(self, image_bytes: bytes) -> List[Dict[str, float]]:
        if self.model is None:
            raise ValueError("Model not trained yet")

        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
        image_tensor = self.transform(image).unsqueeze(0).to(self.device)

        self.model.eval()
        with torch.no_grad():
            outputs = self.model(image_tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            probs = probabilities[0].cpu().numpy()

        results = [
            {'class': class_name, 'confidence': float(prob)}
            for class_name, prob in zip(self.classes, probs)
        ]
        return sorted(results, key=lambda x: x['confidence'], reverse=True)

    def save_model(self, path: str):
        # In in-memory mode we avoid creating files on disk.
        # If you need persistence, use an explicit disk save utility.
        if self.model is None:
            raise ValueError("Model not trained yet")
        # no-op: intentionally do not write to disk
        return None

    def load_model(self, path: str):
        # Disk loading is disabled in in-memory mode.
        raise RuntimeError("Loading from disk is disabled in the in-memory runner")


app = Flask(__name__, static_folder='static')
CORS(app)

# Supabase removed — use a local file-backed model registry instead
# In-memory registry — keep everything in RAM and avoid creating files
MODELS_META = []  # list of metadata dicts, newest first
MODELS_IN_MEMORY = {}  # name -> {'state_dict': ..., 'classes': [...], ...}

def _load_models_meta():
    return MODELS_META


def _save_models_meta(records):
    # no-op for disk; keep meta in memory
    global MODELS_META
    MODELS_META = list(records)


def add_model_meta(record):
    MODELS_META.insert(0, record)
    # keep a copy in the in-memory registry for lookups
    name = record.get('name')
    if name:
        MODELS_IN_MEMORY[name] = {
            'state_dict': None,
            'classes': record.get('classes', []),
            'meta': record,
        }

classifier = ImageClassifier()
training_metrics = []


@app.route('/')
def index():
    return send_from_directory('static', 'index.html')


@app.route('/api/train', methods=['POST'])
def train_model():
    global training_metrics
    training_metrics = []

    try:
        files = request.files.getlist('images')
        labels = request.form.getlist('labels')
        epochs = int(request.form.get('epochs', 3))  # Reduced default epochs for speed
        model_name = request.form.get('model_name', 'image-classifier')

        if len(files) < 2:
            return jsonify({'error': 'At least 2 images required'}), 400

        if len(set(labels)) < 2:
            return jsonify({'error': 'At least 2 different labels required'}), 400

        image_bytes = [file.read() for file in files]

        def on_epoch_end(epoch: int, loss: float, accuracy: float):
            metric = {
                'epoch': epoch,
                'loss': loss,
                'accuracy': accuracy
            }
            training_metrics.append(metric)
            print(f"Epoch {epoch}: Loss={loss:.4f}, Accuracy={accuracy:.4f}")

        result = classifier.train(image_bytes, labels, epochs, on_epoch_end)

        # Keep model in memory instead of creating files on disk
        try:
            # persist the trained state in the in-memory registry
            saved_state = classifier.model.state_dict() if getattr(classifier, 'model', None) is not None else None
            MODELS_IN_MEMORY[model_name] = {
                'state_dict': saved_state,
                'classes': result['classes'],
                'accuracy': float(result['final_accuracy']),
                'epochs': epochs,
                'total_images': len(files),
                'created_at': datetime.utcnow().isoformat()
            }
            add_model_meta({
                'name': model_name,
                'classes': result['classes'],
                'accuracy': float(result['final_accuracy']),
                'epochs': epochs,
                'total_images': len(files),
                'created_at': datetime.utcnow().isoformat()
            })
        except Exception as e:
            print(f"Model registry write error: {e}")

        return jsonify({
            'success': True,
            'accuracy': result['final_accuracy'],
            'classes': result['classes'],
            'metrics': training_metrics
        })

    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/api/predict', methods=['POST'])
def predict():
    try:
        if 'image' not in request.files:
            return jsonify({'error': 'No image provided'}), 400

        image_file = request.files['image']
        image_bytes = image_file.read()

        predictions = classifier.predict(image_bytes)

        return jsonify({
            'success': True,
            'predictions': predictions
        })

    except ValueError as e:
        return jsonify({'error': str(e)}), 400
    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/api/metrics', methods=['GET'])
def get_metrics():
    return jsonify({'metrics': training_metrics})


@app.route('/api/models', methods=['GET'])
def get_models():
    try:
        records = _load_models_meta()
        return jsonify({'models': records})
    except Exception as e:
        return jsonify({'error': str(e)}), 500


@app.route('/api/load-model/<model_name>', methods=['POST'])
def load_model(model_name):
    try:
        # load model from the in-memory registry
        if model_name not in MODELS_IN_MEMORY:
            return jsonify({'error': 'Model not found in memory'}), 404

        entry = MODELS_IN_MEMORY[model_name]
        state = entry.get('state_dict')
        classes = entry.get('classes', [])

        if state is None:
            return jsonify({'error': 'Model weights are not available in memory'}), 404

        # instantiate a new model and load the state dict
        classifier.model = SimpleCNN(len(classes)).to(classifier.device)
        classifier.model.load_state_dict(state)
        classifier.model.eval()
        classifier.classes = classes

        return jsonify({
            'success': True,
            'classes': classifier.classes
        })

    except Exception as e:
        return jsonify({'error': str(e)}), 500


if __name__ == '__main__':
    # Running in-memory mode — do not create files or directories automatically
    # Disable Flask's auto-reloader while in debug mode to avoid double-starting
    # (the duplicate messages and duplicate restarts come from the reloader spawning
    # a child process on Windows). Setting `use_reloader=False` keeps the single
    # process behavior while preserving debug output.
    app.run(debug=True, use_reloader=False, port=5000)

 * Serving Flask app '__main__'
 * Debug mode: on
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit

 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
