In [None]:
import os
from PIL import Image
from torchvision.utils import data_loader, transforms
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from fastapi import FastAPI
from pydantic import BaseModel



In [None]:
settings = {
    'data_path': './data/',          # Path to image data
    'model_size': '18',              # 18, 34, 50, 101, 152
    'image_size': 224,
    'transform': None,               # Image transform
    'num_classes': 2,                # Number of classes
    'epochs': 10,                    # Number of epochs
    'batch_size': 32,                # Batch size
    'learning_rate': 0.001,          # Learning rate
    'output_path': './output/',      # Output directory for predictions, model and logs
}


In [None]:
class ImageDataset(data_loader.Dataset):
    """
    create a dataset class for image data at a given path and transforms
    """
    def __init__(self, data_path, transform=None):
        
        self.data_path = data_path
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = os.listdir(data_path)
        for i, class_name in enumerate(self.class_names):
            class_path = os.path.join(data_path, class_name)
            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)
                self.images.append(image_path)
                self.labels.append(i)

    def __len__(self):

        return len(self.images)
    
    def __getitem__(self, idx):

        image_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        return image, label
    

def get_data_loader(data_path, transform, batch_size):
    """
    create a data loader for image data at a given path and transforms
    """
    dataset = ImageDataset(data_path, transform)
    data_loader = data_loader.DataLoader(dataset,
                                        batch_size=batch_size,
                                        shuffle=True)
    return data_loader

In [None]:
class ModelBuild:
    """
    build a custom or pretrined image classfication model
    """
    def __init__(self, model_size, num_classes,pretrained=True):
        """
        initialize the model with the model size and number of classes
        """
        self.model_size = model_size
        self.num_classes = num_classes
        self.model = None

        pass

    def _get_pretrained_model(self):
        """
        get a pretrained model
        """
        if self.model_size == '18':
            self.model = models.resnet18(pretrained=True)
        elif self.model_size == '34':
            self.model = models.resnet34(pretrained=True)
        elif self.model_size == '50':
            self.model = models.resnet50(pretrained=True)
        elif self.model_size == '101':
            self.model = models.resnet101(pretrained=True)
        elif self.model_size == '152':
            self.model = models.resnet152(pretrained=True)
        else:
            raise ValueError("Invalid model size. Choose from '18', '34', '50', '101', '152'.")

    def _add_layers(self):
        """
        add layers to the model
        """
        if self.model is None:
            self._get_pretrained_model()
        
        num_features = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, self.num_classes)
        )

    def compile(self, optimizer, loss, metrics):
        """
        compile the model with optimizer, loss and metrics
        """
        self.optimizer = optimizer
        self.loss = loss
        self.metrics = metrics

    def save(self, path):
        """
        save the model to the given path
        """
        torch.save(self.model.state_dict(), path)

In [None]:
def show_metrics(model):
    """
    return the training and validation metrics as plot
    """
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].plot(model.history['accuracy'])
    ax[0].plot(model.history['val_accuracy'])
    ax[0].set_title('Model Accuracy')
    ax[0].set_xlabel('Epoch')
    ax[0].set_ylabel('Accuracy')
    ax[0].legend(['Train', 'Validation'], loc='upper left')

    ax[1].plot(model.history['loss'])
    ax[1].plot(model.history['val_loss'])
    ax[1].set_title('Model Loss')
    ax[1].set_xlabel('Epoch')
    ax[1].set_ylabel('Loss')
    ax[1].legend(['Train', 'Validation'], loc='upper left')
    return fig

def show_predictions(model, images, labels):
    predictions = model.predict(images)
    predicted_classes = predictions.argmax(axis=-1)
    fig, ax = plt.subplots(1, 4, figsize=(20, 5))
    for i in range(4):
        ax[i].imshow(images[i])
        ax[i].set_title(f'Actual: {labels[i]}, Predicted: {predicted_classes[i]}')
        ax[i].axis('off')
    return fig

def Gcam(model, image, layer_name):
    """
    return the Grad-CAM heatmap for the given image and layer
    """
    pass


In [None]:
app = FastAPI()

# Model to receive config data from Flutter
class Config(BaseModel):
    data_path: str
    model_size: str
    epochs: int
    batch_size: int

@app.post("/configure")
def configure(config: Config):
    settings['data_path'] = config.data_path
    settings['model_size'] = config.model_size
    settings['epochs'] = config.epochs
    settings['batch_size'] = config.batch_size
    
    return {"status": "Configuration updated"}

@app.post("/train")
def train():
    images, labels = get_data_loader(settings['data_path'], settings['transform'], settings['batch_size'])
    model = ModelBuild(settings['model_size'], settings['num_classes'], pretrained=True)
    model._add_layers()
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.save("models/trained_model.h5")
    return {"status": "Training completed"}

@app.get("/status")
def status():
    # Idealy, return training progress or model status
    return {"status": "Idle"}

if __name__ == '__main__':
    import uvicorn
    uvicorn.run(app, host="127.0.0.1", port=8000)