In [1]:
import torch
import json
from torchvision.models import resnet50
import torch.nn as nn
from PIL import Image
from torchvision import transforms
import subprocess
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load multitask model
ckpt = torch.load("saved_models/multitask_car_net.pth", map_location=device)

NUM_MODELS = 429
NUM_YEARS = len(ckpt["year_to_idx"])

# Backbone
backbone = resnet50(weights=None)
backbone.fc = nn.Identity()

class ModelHead(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.BatchNorm1d(2048),
            nn.Dropout(0.5),
            nn.Linear(2048, num_classes)
        )
    def forward(self, x):
        return self.net(x)

model_head = ModelHead(NUM_MODELS)
year_head = nn.Linear(2048, NUM_YEARS)

backbone.load_state_dict(ckpt["backbone"])
model_head.load_state_dict(ckpt["model_head"])
year_head.load_state_dict(ckpt["year_head"])

backbone = backbone.to(device).eval()
model_head = model_head.to(device).eval()
year_head = year_head.to(device).eval()

year_to_idx = ckpt["year_to_idx"]
idx_to_year = ckpt["idx_to_year"]


In [2]:
from torchvision.datasets import ImageFolder
class_names = ImageFolder("../dataset/train").classes

In [3]:
with open("engine_specs.json") as f:
    engine_db = json.load(f)


In [4]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

def predict(image_path):
    img = Image.open(image_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        feats = backbone(x)
        model_logits = model_head(feats)
        year_logits = year_head(feats)

        model_idx = model_logits.argmax(1).item()
        year_idx  = year_logits.argmax(1).item()

    car = class_names[model_idx]
    year = idx_to_year[year_idx]

    return car, year


In [5]:
def build_prompt(car, year):
    key = f"{car}_{year}"

    if key in engine_db:
        specs = engine_db[key]
        return f"""
You are an automotive engine database.

Car: {car.replace("_", " ")}
Year: {year}

Verified data:
displacement_l = {specs["displacement"]}
top_speed_kmh = {specs["max_speed"]}
doors = {specs["doors"]}
seats = {specs["seats"]}

Using ONLY the verified values above, infer the missing engine fields.

Return ONLY valid JSON with exactly these keys:
bhp
torque_nm
cylinders
aspiration
gearbox
fuel

Do not explain.
Do not add text.
Return JSON only.
"""
    else:
        return f"""
You are an automotive engine database.

Car: {car.replace("_", " ")}
Year: {year}

No verified data exists.
Estimate petrol engine specs.

Return ONLY valid JSON with exactly these keys:
displacement_l
bhp
torque_nm
cylinders
aspiration
gearbox
fuel

Do not explain.
Do not add text.
Return JSON only.
"""


In [6]:
import subprocess

def ask_gemma(prompt):
    proc = subprocess.run(
        ["ollama", "run", "gemma3:4b-it-q4_K_M"],
        input=prompt.encode("utf-8"),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE
    )

    output = proc.stdout.decode("utf-8", errors="ignore")
    return output.strip()


In [22]:
image_path = "trials.avif"
car, year = predict(image_path)
prompt = build_prompt(car, year)
answer = ask_gemma(prompt)

print("Predicted:", car, year)
print("\n--- Engine Report ---\n")
print(answer)


Predicted: Hyundai_Veloster 2014

--- Engine Report ---

```json
{
  "bhp": 177,
  "torque_nm": 265,
  "cylinders": 4,
  "aspiration": "Turbocharged",
  "gearbox": "6-speed DCT",
  "fuel": "Petrol"
}
```
