<a href="https://colab.research.google.com/github/FunkyDonkey065/Facade_aesthetic_evaluator/blob/main/ResNet_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!unzip data.zip

Archive:  data.zip
   creating: data/
   creating: data/51/
  inflating: data/51/frame_3.jpg     
  inflating: data/51/frame_2.jpg     
  inflating: data/51/frame_2_score.json  
  inflating: data/51/frame_1.jpg     
   creating: data/52/
  inflating: data/52/frame_3.jpg     
  inflating: data/52/frame_2.jpg     
  inflating: data/52/frame_2_score.json  
  inflating: data/52/frame_1.jpg     
   creating: data/24/
  inflating: data/24/frame_3.jpg     
  inflating: data/24/frame_2.jpg     
  inflating: data/24/frame_2_score.json  
  inflating: data/24/frame_1.jpg     
   creating: data/22/
  inflating: data/22/frame_3.jpg     
  inflating: data/22/frame_2.jpg     
  inflating: data/22/frame_2_score.json  
  inflating: data/22/frame_1.jpg     
   creating: data/41/
  inflating: data/41/frame_3.jpg     
  inflating: data/41/frame_2.jpg     
  inflating: data/41/frame_2_score.json  
  inflating: data/41/frame_1.jpg     
   creating: data/30/
  inflating: data/30/frame_3.jpg     
  inflating:

Import

In [4]:
import os
import json
from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
import torchvision.models as models



Config

In [14]:
BASE_DIR = "data"           # 如果在 /content/data 就改成 "/content/data"
JSON_NAME = "frame_2_score.json"
BATCH_SIZE = 8
NUM_EPOCHS = 4
LR = 1e-4
VAL_RATIO = 0.2
NUM_WORKERS = 2  # Colab 可以设为 2, 本地可调大

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# 我们要从 JSON 中抽取的字段（8 维）
LABEL_KEYS = [
    ("stimulus", "composition_and_proportion"),
    ("stimulus", "material_and_details"),
    ("stimulus", "color_harmony"),
    ("organism", "visual_comfort"),
    ("organism", "sense_of_order"),
    ("organism", "preference_score"),
    ("response", "visual_saliency"),
    ("response", "attention_attraction"),
]

Using device: cpu


In [6]:
class FacadeSORDataset(Dataset):
    """
    返回内容为：
    - concat_img: frame_1/frame_2/frame_3 横向拼接后的图片
    - frame2_img: 单独的 frame_2
    - y: JSON 中的 S-O-R 标签（8 维）
    """

    def __init__(self, base_dir, transform=None):
        self.base_dir = base_dir
        self.transform = transform

        self.folders = sorted(
            d for d in os.listdir(base_dir)
            if os.path.isdir(os.path.join(base_dir, d))
        )
        print(f"Found {len(self.folders)} folders.")

    def __len__(self):
        return len(self.folders)

    def _load_label(self, json_path):
        with open(json_path, "r", encoding="utf-8") as f:
            data = json.load(f)
        labels = [float(data[group][key]) for group, key in LABEL_KEYS]
        return torch.tensor(labels, dtype=torch.float32)

    def _load_concat(self, f1, f2, f3):
        img1 = Image.open(f1).convert("RGB")
        img2 = Image.open(f2).convert("RGB")
        img3 = Image.open(f3).convert("RGB")

        w, h = img1.size
        concat_img = Image.new("RGB", (w * 3, h))
        concat_img.paste(img1, (0, 0))
        concat_img.paste(img2, (w, 0))
        concat_img.paste(img3, (2 * w, 0))
        return concat_img

    def __getitem__(self, idx):
        folder_name = self.folders[idx]
        folder_path = os.path.join(self.base_dir, folder_name)

        f1 = os.path.join(folder_path, "frame_1.jpg")
        f2 = os.path.join(folder_path, "frame_2.jpg")
        f3 = os.path.join(folder_path, "frame_3.jpg")
        json_path = os.path.join(folder_path, JSON_NAME)

        concat_img = self._load_concat(f1, f2, f3)
        frame2_img = Image.open(f2).convert("RGB")

        if self.transform is not None:
            concat_img = self.transform(concat_img)
            frame2_img = self.transform(frame2_img)

        y = self._load_label(json_path)

        return concat_img, frame2_img, y


Center-Aware Model

In [7]:
class CenterAwareConcatModel(nn.Module):

    def __init__(self, backbone_name="resnet18", num_outputs=8):
        super().__init__()

        # --- Backbone: ResNet18 ---
        backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        in_features = backbone.fc.in_features
        backbone.fc = nn.Identity()

        # 共享参数
        self.backbone = backbone
        self.backbone_center = backbone

        # MLP head
        self.reg_head = nn.Sequential(
            nn.Linear(in_features * 2, 256),
            nn.ReLU(),
            nn.Linear(256, num_outputs)
        )

    def forward(self, concat_img, frame2_img):

        f_concat = self.backbone(concat_img)
        f_center = self.backbone_center(frame2_img)

        fused = torch.cat([f_center, f_concat], dim=1)
        out = self.reg_head(fused)

        return out


Transform + DataLoader

In [8]:
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])

dataset = FacadeSORDataset(BASE_DIR, transform=transform)

val_size = int(len(dataset) * VAL_RATIO)
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS)

print(f"Train: {train_size}, Val: {val_size}")


Found 100 folders.
Train: 80, Val: 20


Train

In [15]:
model = CenterAwareConcatModel(num_outputs=len(LABEL_KEYS)).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

def run_epoch(loader, training=True):
    model.train() if training else model.eval()
    total_loss = 0
    n = 0

    for concat_img, frame2_img, targets in loader:
        concat_img = concat_img.to(device)
        frame2_img = frame2_img.to(device)
        targets = targets.to(device)

        if training:
            optimizer.zero_grad()

        with torch.set_grad_enabled(training):
            preds = model(concat_img, frame2_img)
            loss = criterion(preds, targets)
            if training:
                loss.backward()
                optimizer.step()

        total_loss += loss.item() * concat_img.size(0)
        n += concat_img.size(0)

    return total_loss / n

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss = run_epoch(train_loader, True)
    val_loss   = run_epoch(val_loader, False)
    print(f"Epoch {epoch} | Train {train_loss:.4f} | Val {val_loss:.4f}")


Epoch 1 | Train 41.5406 | Val 31.5390
Epoch 2 | Train 26.7473 | Val 18.0533
Epoch 3 | Train 13.5775 | Val 8.0477
Epoch 4 | Train 4.6518 | Val 3.0601


Save

In [16]:
os.makedirs("checkpoints", exist_ok=True)
torch.save(model.state_dict(), "checkpoints/sor_center_model.pth")
print("Model saved.")


Model saved.


Test

In [17]:
def predict_one(idx):
    folder = os.path.join(BASE_DIR, str(idx))
    f1 = os.path.join(folder, "frame_1.jpg")
    f2 = os.path.join(folder, "frame_2.jpg")
    f3 = os.path.join(folder, "frame_3.jpg")

    concat_img = dataset._load_concat(f1, f2, f3)
    frame2_img = Image.open(f2).convert("RGB")

    concat_img = transform(concat_img).unsqueeze(0).to(device)
    frame2_img = transform(frame2_img).unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        pred = model(concat_img, frame2_img)[0].cpu().numpy().tolist()

    result = {}
    for (group, key), val in zip(LABEL_KEYS, pred):
        result.setdefault(group, {})
        result[group][key] = float(val)

    return result

print(predict_one(0))


{'stimulus': {'composition_and_proportion': 5.849903583526611, 'material_and_details': 5.208682537078857, 'color_harmony': 4.777099609375}, 'organism': {'visual_comfort': 4.312859058380127, 'sense_of_order': 4.408603668212891, 'preference_score': 5.228611946105957}, 'response': {'visual_saliency': 8.269301414489746, 'attention_attraction': 8.527433395385742}}
