In [29]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
from fastapi import FastAPI
from pydantic import BaseModel

In [None]:
settings = {
    'data_path': R"C:\Users\Dell\Desktop\github\Federated-project\Hospitals_Dataset\Hospital_1",
    'model_size': '18',
    'image_size': 224,
    'transform': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'num_classes': 2,
    'epochs': 10,
    'batch_size': 32,
    'learning_rate': 0.001,
    'output_path': './output/',
}

In [31]:
class ImageDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = os.listdir(data_path)
        for i, class_name in enumerate(self.class_names):
            class_path = os.path.join(data_path, class_name)
            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)
                self.images.append(image_path)
                self.labels.append(i)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label
    

# DataLoader function
def get_data_loader(data_path, transform, batch_size):
    dataset = ImageDataset(data_path, transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return data_loader

In [None]:
class ModelBuild:
    def __init__(self, model_size, num_classes, pretrained=True):
        self.model_size = model_size
        self.num_classes = num_classes
        self.model = self._get_pretrained_model(pretrained)
        self._add_layers()

    def get_pretrained_model(self, pretrained):
        if self.model_size == '18':
            return models.resnet18(pretrained=pretrained)
        elif self.model_size == '34':
            return models.resnet34(pretrained=pretrained)
        elif self.model_size == '50':
            return models.resnet50(pretrained=pretrained)
        elif self.model_size == '101':
            return models.resnet101(pretrained=pretrained)
        elif self.model_size == '152':
            return models.resnet152(pretrained=pretrained)
        else:
            raise ValueError("Invalid model size. Choose from '18', '34', '50', '101', '152'.")

    def add_layers(self):
        num_features = self.model.fc.in_features
        self.model.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, self.num_classes)
        )

    def save(self, path):
        torch.save(self.model.state_dict(), path)

In [33]:
# Visualization functions
def show_metrics(history):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].plot(history['train_accuracy'])
    ax[0].plot(history['val_accuracy'])
    ax[0].set_title('Model Accuracy')
    ax[0].set_xlabel('Epoch')
    ax[0].set_ylabel('Accuracy')
    ax[0].legend(['Train', 'Validation'], loc='upper left')

    ax[1].plot(history['train_loss'])
    ax[1].plot(history['val_loss'])
    ax[1].set_title('Model Loss')
    ax[1].set_xlabel('Epoch')
    ax[1].set_ylabel('Loss')
    ax[1].legend(['Train', 'Validation'], loc='upper left')
    plt.show()


In [35]:
# FastAPI app
app = FastAPI()

class Config(BaseModel):
    data_path: str
    model_size: str
    epochs: int
    batch_size: int

@app.post("/configure")
def configure(config: Config):
    settings['data_path'] = config.data_path
    settings['model_size'] = config.model_size
    settings['epochs'] = config.epochs
    settings['batch_size'] = config.batch_size
    return {"status": "Configuration updated"}

@app.post("/train")
def train():
    data_loader = get_data_loader(settings['data_path'], settings['transform'], settings['batch_size'])
    model = ModelBuild(settings['model_size'], settings['num_classes'], pretrained=True).model
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=settings['learning_rate'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Training loop
    for epoch in range(settings['epochs']):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Epoch {epoch+1}/{settings['epochs']}, Loss: {running_loss:.4f}, Accuracy: {accuracy:.4f}")

    # Save model
    model_path = os.path.join(settings['output_path'], "trained_model.pth")
    os.makedirs(settings['output_path'], exist_ok=True)
    torch.save(model.state_dict(), model_path)
    return {"status": "Training completed", "model_path": model_path}

@app.get("/status")
def status():
    return {"status": "Idle"}

if __name__ == '__main__':
    import uvicorn
    uvicorn.run(app, host="127.0.0.1", port=8000)

INFO:     Started server process [2444]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)


INFO:     127.0.0.1:53295 - "GET /docs HTTP/1.1" 200 OK
INFO:     127.0.0.1:53295 - "GET /openapi.json HTTP/1.1" 200 OK




INFO:     127.0.0.1:53298 - "POST /train HTTP/1.1" 500 Internal Server Error


ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "c:\Users\Dell\AppData\Local\Programs\Python\Python311\Lib\site-packages\uvicorn\protocols\http\h11_impl.py", line 412, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Dell\AppData\Local\Programs\Python\Python311\Lib\site-packages\uvicorn\middleware\proxy_headers.py", line 84, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Dell\AppData\Local\Programs\Python\Python311\Lib\site-packages\fastapi\applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "c:\Users\Dell\AppData\Local\Programs\Python\Python311\Lib\site-packages\starlette\applications.py", line 123, in __call__
    await self.middleware_stack(scope, receive, send)
  File "c:\Users\Dell\AppData\Local\Programs\Python\Python311\Lib\sit

INFO:     127.0.0.1:53303 - "GET /train HTTP/1.1" 405 Method Not Allowed
INFO:     127.0.0.1:53303 - "GET /favicon.ico HTTP/1.1" 404 Not Found


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [2444]
