In [None]:
import json
from pathlib import Path
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models

class ScentDataset(Dataset):
    def __init__(self, image_folder: str, json_file: str, transform=None, labels=None):
        self.image_folder = Path(image_folder)
        with open(json_file, "r", encoding="utf-8") as f:
            self.data = json.load(f)
        self.image_names = list(self.data.keys())
        self.transform = transform
        self.labels = labels or [
            "muddy", "earthy", "woody", "grassy", "floral",
            "rotten", "bloody", "musty", "sulfuric", "burnt",
            "chemical", "metallic", "clean", "oily", "dusty",
            "damp", "smoky", "salty", "sweet", "arcane", "no_scent"
        ]

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        image_path = self.image_folder / image_name
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        scent_data = self.data[image_name].get("scent_distribution", {})
        scent_vector = torch.tensor([scent_data.get(label, 0.0) for label in self.labels], dtype=torch.float32)
        if scent_vector.sum() > 0:
            scent_vector /= scent_vector.sum()
        return image, scent_vector

class ScentPredictor(nn.Module):
    def __init__(self, num_labels=21):
        super(ScentPredictor, self).__init__()
        self.base_model = models.resnet50(pretrained=True)
        self.base_model.fc = nn.Sequential(
            nn.Linear(self.base_model.fc.in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_labels),
            nn.Softmax(dim=1)
        )
    def forward(self, x):
        return self.base_model(x)

def top_k_accuracy(output, target, k=1):
    topk = torch.topk(output, k=k, dim=1).indices
    target_top = torch.topk(target, k=1, dim=1).indices.squeeze(1)
    correct = (topk == target_top.unsqueeze(1)).any(dim=1)
    return correct.float().mean().item()

def cosine_similarity(output, target):
    sim = F.cosine_similarity(output, target, dim=1)
    return sim.mean().item()

def mean_absolute_error(output, target):
    return torch.abs(output - target).mean().item()

def binary_precision_recall(output, target, threshold=0.05):
    pred_binary = (output > threshold).float()
    target_binary = (target > threshold).float()
    true_positives = (pred_binary * target_binary).sum(dim=0)
    precision = true_positives / (pred_binary.sum(dim=0) + 1e-9)
    recall = true_positives / (target_binary.sum(dim=0) + 1e-9)
    return precision.mean().item(), recall.mean().item()
image_folder_path = r"C:\Users\crazycyt\Desktop\1470\filtered_dataset"
json_file_path = "scent_labels_with_captions.json"

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

dataset = ScentDataset(image_folder=image_folder_path, json_file=json_file_path, transform=transform)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ScentPredictor(num_labels=21).to(device)
criterion = nn.KLDivLoss(reduction="batchmean")
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for images, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        images, targets = images.to(device), targets.to(device)
        outputs = model(images)
        log_outputs = torch.log(outputs + 1e-9)
        loss = criterion(log_outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    print(f"[Epoch {epoch+1}] Train Loss: {train_loss / len(train_loader):.4f}")

    model.eval()
    val_loss, top1, top3, cos_sim, mae, precision, recall = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    with torch.no_grad():
        for images, targets in tqdm(val_loader, desc="Validation"):
            images, targets = images.to(device), targets.to(device)
            outputs = model(images)
            log_outputs = torch.log(outputs + 1e-9)
            loss = criterion(log_outputs, targets)
            val_loss += loss.item()
            top1 += top_k_accuracy(outputs, targets, k=1)
            top3 += top_k_accuracy(outputs, targets, k=3)
            cos_sim += cosine_similarity(outputs, targets)
            mae += mean_absolute_error(outputs, targets)
            p, r = binary_precision_recall(outputs, targets)
            precision += p
            recall += r

    batches = len(val_loader)
    print(f"[Epoch {epoch+1}] Val Loss: {val_loss / batches:.4f}")
    print(f"  Top-1 Acc: {top1 / batches:.4f} | Top-3 Acc: {top3 / batches:.4f} | Cos Sim: {cos_sim / batches:.4f}")
    print(f"  Mean Absolute Error: {mae / batches:.4f}")
    print(f"  Binary Precision: {precision / batches:.4f}")
    print(f"  Binary Recall: {recall / batches:.4f}")


Epoch 1/10 - Training: 100%|███████████████████████████████████████████████████████████| 67/67 [05:46<00:00,  5.18s/it]


[Epoch 1] Train Loss: 0.9076


Validation: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [01:15<00:00,  4.42s/it]


[Epoch 1] Val Loss: 0.6279
  Top-1 Acc: 0.6569 | Top-3 Acc: 0.8499 | Cos Sim: 0.8214
  Mean Absolute Error: 0.0324
  Binary Precision: 0.1945
  Binary Recall: 0.2198


Epoch 2/10 - Training: 100%|███████████████████████████████████████████████████████████| 67/67 [04:46<00:00,  4.27s/it]


[Epoch 2] Train Loss: 0.5193


Validation: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [01:04<00:00,  3.81s/it]


[Epoch 2] Val Loss: 0.6900
  Top-1 Acc: 0.6550 | Top-3 Acc: 0.8640 | Cos Sim: 0.8105
  Mean Absolute Error: 0.0309
  Binary Precision: 0.2014
  Binary Recall: 0.2086


Epoch 3/10 - Training: 100%|███████████████████████████████████████████████████████████| 67/67 [04:46<00:00,  4.28s/it]


[Epoch 3] Train Loss: 0.3847


Validation: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [01:12<00:00,  4.29s/it]


[Epoch 3] Val Loss: 0.6166
  Top-1 Acc: 0.6134 | Top-3 Acc: 0.8290 | Cos Sim: 0.8266
  Mean Absolute Error: 0.0306
  Binary Precision: 0.2020
  Binary Recall: 0.2428


Epoch 4/10 - Training: 100%|███████████████████████████████████████████████████████████| 67/67 [04:53<00:00,  4.38s/it]


[Epoch 4] Train Loss: 0.3358


Validation: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [01:17<00:00,  4.59s/it]


[Epoch 4] Val Loss: 0.6228
  Top-1 Acc: 0.6072 | Top-3 Acc: 0.8272 | Cos Sim: 0.8270
  Mean Absolute Error: 0.0318
  Binary Precision: 0.2313
  Binary Recall: 0.2716


Epoch 5/10 - Training: 100%|███████████████████████████████████████████████████████████| 67/67 [04:55<00:00,  4.41s/it]


[Epoch 5] Train Loss: 0.2757


Validation: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [01:12<00:00,  4.28s/it]


[Epoch 5] Val Loss: 0.6233
  Top-1 Acc: 0.6477 | Top-3 Acc: 0.8658 | Cos Sim: 0.8241
  Mean Absolute Error: 0.0303
  Binary Precision: 0.2475
  Binary Recall: 0.2544


Epoch 6/10 - Training: 100%|███████████████████████████████████████████████████████████| 67/67 [05:09<00:00,  4.63s/it]


[Epoch 6] Train Loss: 0.2433


Validation: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [01:14<00:00,  4.38s/it]


[Epoch 6] Val Loss: 0.5952
  Top-1 Acc: 0.6477 | Top-3 Acc: 0.8462 | Cos Sim: 0.8326
  Mean Absolute Error: 0.0285
  Binary Precision: 0.2442
  Binary Recall: 0.2453


Epoch 7/10 - Training: 100%|███████████████████████████████████████████████████████████| 67/67 [04:56<00:00,  4.43s/it]


[Epoch 7] Train Loss: 0.2091


Validation: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [00:55<00:00,  3.28s/it]


[Epoch 7] Val Loss: 0.5752
  Top-1 Acc: 0.6569 | Top-3 Acc: 0.8511 | Cos Sim: 0.8346
  Mean Absolute Error: 0.0284
  Binary Precision: 0.2508
  Binary Recall: 0.2668


Epoch 8/10 - Training: 100%|███████████████████████████████████████████████████████████| 67/67 [02:16<00:00,  2.04s/it]


[Epoch 8] Train Loss: 0.1979


Validation: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [00:35<00:00,  2.11s/it]


[Epoch 8] Val Loss: 0.6035
  Top-1 Acc: 0.6385 | Top-3 Acc: 0.8701 | Cos Sim: 0.8305
  Mean Absolute Error: 0.0288
  Binary Precision: 0.2453
  Binary Recall: 0.2730


Epoch 9/10 - Training: 100%|███████████████████████████████████████████████████████████| 67/67 [02:17<00:00,  2.06s/it]


[Epoch 9] Train Loss: 0.1808


Validation: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [00:35<00:00,  2.09s/it]


[Epoch 9] Val Loss: 0.6255
  Top-1 Acc: 0.6550 | Top-3 Acc: 0.8634 | Cos Sim: 0.8293
  Mean Absolute Error: 0.0283
  Binary Precision: 0.2402
  Binary Recall: 0.2457


Epoch 10/10 - Training: 100%|██████████████████████████████████████████████████████████| 67/67 [02:13<00:00,  1.99s/it]


[Epoch 10] Train Loss: 0.1632


Validation: 100%|██████████████████████████████████████████████████████████████████████| 17/17 [00:36<00:00,  2.17s/it]

[Epoch 10] Val Loss: 0.5937
  Top-1 Acc: 0.6415 | Top-3 Acc: 0.8511 | Cos Sim: 0.8290
  Mean Absolute Error: 0.0296
  Binary Precision: 0.2757
  Binary Recall: 0.2913



