# Ray Serve: deploy del modello FashionCNN

Prerequisiti: `pip install -U ray "ray[serve]" torch torchvision pillow requests` e avere `models/fashion_cnn.pth`.

In [None]:

from ray import serve
import ray, torch, torch.nn as nn
from torchvision import transforms
from PIL import Image
from io import BytesIO
import requests

class FashionCNN(nn.Module):
    def __init__(self, hidden_units=400):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1,32,3,padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32,64,3), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc1 = nn.Linear(64*6*6, hidden_units)
        self.fc2 = nn.Linear(hidden_units, 120)
        self.fc3 = nn.Linear(120, 10)
    def forward(self, x):
        x = self.layer1(x); x = self.layer2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x); x = self.fc2(x); x = self.fc3(x)
        return x

@serve.deployment
class ModelDeployment:
    def __init__(self):
        self.model = FashionCNN()
        self.model.load_state_dict(torch.load("models/fashion_cnn.pth", map_location="cpu"))
        self.model.eval()
        self.tfm = transforms.Compose([transforms.Grayscale(), transforms.Resize((28,28)), transforms.ToTensor(),
                                       transforms.Normalize((0.5,), (0.5,))])
    async def __call__(self, request):
        raw = await request.body()
        img = Image.open(BytesIO(raw)).convert("L")
        x = self.tfm(img).unsqueeze(0)
        with torch.no_grad():
            out = self.model(x)
            pred = out.argmax(dim=1).item()
        return {"predicted_class": int(pred)}

ray.shutdown(); ray.init()
serve.start()
handle = ModelDeployment.bind()
app = serve.run(handle)  # http://127.0.0.1:8000/
print("Serve in ascolto su http://127.0.0.1:8000/ (POST PNG grezzo)")


### Client di test (invia un PNG)

In [None]:

try:
    with open("mnist_image.png", "rb") as f:
        r = requests.post("http://127.0.0.1:8000/", data=f.read(),
                          headers={"Content-Type":"application/octet-stream"})
    print(r.json())
except FileNotFoundError:
    print("Salva prima un'immagine test come mnist_image.png (vedi scripts/11j_generate_mnist_image.py)")
